require 'nn'
require 'model'
require 'groundTruth'
require 'cutorch'
require 'cunn'

require 'math'

local input_dim=100
local model=get_model(input_dim)
local gt=get_GT(input_dim)

local cri=nn.MSECriterion():cuda()

local batch=200
local l1=model:get(1):get(1)
local gt_l1=gt:get(1):get(1)
print(model)
print('linear layer=',l1)

local input
local function get_2_norm(M)
   local tmpM=M:clone():float()
   local eigs=torch.eig(tmpM)
   return torch.norm(eigs[1])
end

require 'distributions'
local mean=torch.Tensor(input_dim):zero()
local cov=torch.eye(input_dim)
l1.weight:uniform(-1,1):div(10)
--l1.weight:zero()
gt_l1.weight=torch.load("ground_truth")
print('gt_l1 norm=',get_2_norm(gt_l1.weight))
print('l1 norm=',get_2_norm(l1.weight))

local train_size=100000
local test_size=10000
local train_book=torch.Tensor(train_size,input_dim):zero()
local test_book=torch.Tensor(test_size,input_dim):zero()
distributions.mvn.rnd(train_book,mean,cov)
distributions.mvn.rnd(test_book,mean,cov)
train_book:cuda()
test_book:cuda()

local input=torch.CudaTensor(batch,input_dim):zero()
local real_ans=torch.CudaTensor(batch,1):zero()

local function potential()
   local sum=0
   local function get_ei(w,i)
      local t=w:clone()
      t[i]=t[i]+1
      return torch.norm(t)
   end
   for i=1,input_dim do
      sum=sum+get_ei(gt_l1.weight[i],i)-get_ei(l1.weight[i],i)
   end
   return sum
end
local grand=torch.load('rlt_1')
local rlt={}
grand[#grand+1]=rlt


for epoch=1,200 do
   local loc=torch.randperm(train_size)
   local iter=1
   local tot_train_loss=0
   rlt[epoch]={}
   for j=1,train_size/batch do
      for k=1,batch do
         input[k]:copy(train_book[loc[iter]])
         iter=iter+1
      end

      real_ans=gt:forward(input)

      cri:forward(model:forward(input),real_ans)
      model:zeroGradParameters()
      cri:backward(model.output,real_ans)
      model:backward(input,cri.gradInput)

      if (torch.dot(l1.gradWeight,l1.weight-gt_l1.weight)<0) then
         print('---->',potential(), torch.dot(l1.gradWeight,l1.weight-gt_l1.weight))
      end
      if (j==1) then
         rlt[epoch].potential=potential()
         rlt[epoch].inner=torch.dot(l1.gradWeight,l1.weight-gt_l1.weight)
      end
      l1.weight:add(-0.001,l1.gradWeight)
      tot_train_loss=tot_train_loss+cri.output
   end
   print(epoch,"tot train loss=",tot_train_loss/train_size*batch,torch.norm(l1.weight-gt_l1.weight), torch.norm(l1.weight))
   rlt[epoch].train_loss=tot_train_loss/train_size*batch
   rlt[epoch].dist_2_gt=torch.norm(l1.weight-gt_l1.weight)
   rlt[epoch].norm=torch.norm(l1.weight)

   iter=1
   local tot_loss=0
   for j=1,test_size/batch do
      for k=1,batch do
         input[k]:copy(test_book[iter])
         iter=iter+1
      end
      real_ans=gt:forward(input)
      cri:forward(model:forward(input),real_ans)
      tot_loss=tot_loss+cri.output
   end
   print("tot loss=\t\t\t\t\t\t\t\t\t\t",tot_loss/test_size*batch)
   rlt[epoch].test_loss=tot_loss/test_size*batch
   torch.save('rlt_1',grand)
end

