require 'nn'
require 'model3'
require 'groundTruth'
require 'cutorch'
require 'cunn'

require 'math'

local input_dim=100
local model=get_model3(input_dim)
local gt=get_GT(input_dim)

local cri=nn.MSECriterion():cuda()

local batch=200
local gt_l1=gt:get(1):get(1)
print(model)

local input

require 'distributions'
local mean=torch.Tensor(input_dim):zero()
local cov=torch.eye(input_dim)
--l1.weight:uniform(-1,1):div(torch.sqrt(input_dim)):add(torch.eye(input_dim):cuda())
gt_l1.weight=torch.load("ground_truth")
print('gt_l1 norm=',torch.norm(gt_l1.weight))

local train_size=100000
local test_size=10000
local train_book=torch.Tensor(train_size,input_dim):zero()
local test_book=torch.Tensor(test_size,input_dim):zero()
distributions.mvn.rnd(train_book,mean,cov)
distributions.mvn.rnd(test_book,mean,cov)
train_book:cuda()
test_book:cuda()

local input=torch.CudaTensor(batch,input_dim):zero()
local real_ans=torch.CudaTensor(batch,1):zero()

local grand=torch.load('rlt_3')
local rlt={}
grand[#grand+1]=rlt

local stepSize=0.001
for epoch=1,200 do
   local loc=torch.randperm(train_size)
   local iter=1
   local tot_train_loss=0
   rlt[epoch]={}
   for j=1,train_size/batch do
      for k=1,batch do
         input[k]:copy(train_book[loc[iter]])
         iter=iter+1
      end

      real_ans=gt:forward(input)

      cri:forward(model:forward(input),real_ans)
      model:zeroGradParameters()
      cri:backward(model.output,real_ans)
      model:backward(input,cri.gradInput)
      for k,v in pairs(model:findModules('nn.Linear')) do
         if (v.weight:size(1)==v.weight:size(2)) then
            v.weight:add(-0.001,v.gradWeight)
         end
      end
      for k,v in pairs(model:findModules('nn.BatchNormalization')) do
         v.weight:add(-stepSize,v.gradWeight)
      end
      tot_train_loss=tot_train_loss+cri.output
   end
   print("tot train loss=",tot_train_loss/train_size*batch)
   rlt[epoch].train_loss=tot_train_loss/train_size*batch
   iter=1
   local tot_loss=0
   for j=1,test_size/batch do
      for k=1,batch do
         input[k]:copy(test_book[iter])
         iter=iter+1
      end
      real_ans=gt:forward(input)
      cri:forward(model:forward(input),real_ans)
      tot_loss=tot_loss+cri.output
   end
   print("tot loss=\t\t\t\t\t\t\t\t\t\t",tot_loss/test_size*batch)
   rlt[epoch].test_loss=tot_loss/test_size*batch
   torch.save('rlt_3',grand)
end
