import argparse
import os

import torch

from data.hyperspectra import getHyper
from data.tech import getTech
from data.videos import getVideos
from evaluate import evaluate,evaluate_both,getbest,evaluate_dense,evaluate_extra
from pathlib import Path
import sys

def get_hostname():
    with open("/etc/hostname") as f:
        hostname=f.read()
    hostname=hostname.split('\n')[0]
    return hostname

def mysvd(init_A,k):
    if k>min(init_A.size(0),init_A.size(1)):
        k=min(init_A.size(0),init_A.size(1))
    d=init_A.size(1)
    x=[torch.Tensor(d).uniform_() for i in range(k)]
    for i in range(k):
        x[i]=x[i].cuda()
        x[i].requires_grad=False
    def perStep(x,A):
        x2=A.t().mv(A.mv(x))
        x3=x2.div(torch.norm(x2))
        return x3
    U=[]
    S=[]
    V=[]
    Alist=[init_A]
    for kstep in range(k): #pick top k eigenvalues
        cur_list=[x[kstep]]   #current history
        for j in range(300):  #steps
            cur_list.append(perStep(cur_list[-1],Alist[-1]))  #works on cur_list
        V.append((cur_list[-1]/torch.norm(cur_list[-1])).view(1,cur_list[-1].size(0)))
        S.append((torch.norm(Alist[-1].mv(V[-1].view(-1)))).view(1))
        U.append((Alist[-1].mv(V[-1].view(-1))/S[-1]).view(1,Alist[-1].size(0)))
        Alist.append(Alist[-1]-torch.ger(Alist[-1].mv(cur_list[-1]), cur_list[-1]))
    return torch.cat(U,0).t(),torch.cat(S,0),torch.cat(V,0).t()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)


    aa("--data", type=str, default="tech", help="tech|video|hyper")
    aa("--dataname", type=str, default="mit", help="transformer|mit|friends")
    aa("--m", type=int, default=10, help="m for S")
    aa("--k", type=int, default=10, help="target: rank k approximation")
    aa("--iter", type=int, default=5000, help="total iterations")
    aa("--size", type=int, default= -1, help="dataset size")

    args = parser.parse_args()
    rawdir="/git/big-lowrank/" if get_hostname()=="Dragon" else "/big-lowrank/"
    rltdir="/git/big-lowrank/" if get_hostname()=="Dragon" else "/big-lowrank/"

    print(args)
    m=args.m//2
    k=args.k



    if args.data=='tech':
        save_dir=rltdir+'rlt/tech/'
    elif args.data=='hyper':
        save_dir=rltdir+'rlt/hyper/'
    elif args.data=='video':
        save_dir=rltdir+'rlt/video/'+args.dataname+'/'
    else:
        print("Wrong data option!")
        sys.exit()

    if os.path.isfile(save_dir+'m='+str(m)+'_k='+str(k)+'_iter='+str(args.iter)+'_N='+str(args.size)+'_sep'):
        print("This one is already done! Now exiting.")
        sys.exit()

    if args.data=='tech':
        A_train,A_test,n,d=getTech(False,rawdir)
    elif args.data=='hyper':
        A_train,A_test,n,d=getHyper(False,args.size,rawdir)
    else:
        A_train,A_test,n,d=getVideos(args.dataname,False,args.size,rawdir)

    print("Running sep eval!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

    print("Working on data ", args.data)

    Path(save_dir).mkdir(parents=True, exist_ok=True)

    N_train=len(A_train)
    N_test=len(A_test)
    print("Dim= ", n,d)
    print("N train=", N_train, "N test=", N_test)


    best_file=save_dir+"N="+str(args.size)+"_k="+str(k)+'_best'
    best_train,best_test=torch.load(best_file)

    rlt_dic={}

    sparse=(args.data=='tech')

    print_freq=50



    sketch_vector2 = torch.randint(m, [n]).int()  # m*n
    sketch_vector2.requires_grad = False
    sketch_value2 = ((torch.randint(2, [n]).float() - 0.5) * 2).cuda()
    sketch_value2.requires_grad = False
    if sparse:
        Ad= -1
        An= -1
    else:
        Ad=d
        An=n

    for bigstep in range(args.iter+1):
        if bigstep>200:
            print_freq=200

        if bigstep % print_freq == 0:
            print(bigstep, '.')
            f_name ='m='+str(m)+'_k='+str(k)+'_iter=' + str(bigstep)+'_N='+str(args.size)
            if (not os.path.isfile(save_dir+f_name)):
                continue

            sketch_vector,sketch_value,_,_,_=torch.load(save_dir+f_name)

            f_name ='m='+str(m)+'_k='+str(k)+'_iter=' + str(bigstep)+'_N='+str(args.size)+"_sep"
            rlt_dic[f_name] = (evaluate_extra(sparse,A_train,sketch_vector,sketch_value,sketch_vector2, sketch_value2,m,m,k,An,Ad),
                               evaluate_extra(sparse,A_test,sketch_vector,sketch_value,sketch_vector2, sketch_value2,m,m,k,An,Ad))
            torch.save([sketch_vector2, sketch_value2, rlt_dic[f_name], N_train, N_test], save_dir+f_name)
            print(f_name, rlt_dic[f_name][0]/N_train-best_train, rlt_dic[f_name][1]/N_test-best_test)


