import os

import cv2
import numpy as np
import torch
from PIL import Image


def convertTOImage(fname,total=1000):
    vidcap = cv2.VideoCapture('../big-lowrank/raw/videos/'+fname+'.mp4')
    success,image = vidcap.read()
    count = 0
    path="../big-lowrank/raw/videos/"+fname
    if not os.path.exists(path):
        os.makedirs(path)
    while success:
        cv2.imwrite(path+"/frame%d.jpg" % count, image)     # save frame as JPEG file
        success,image = vidcap.read()
        count += 1
        if count>total:
            break
    print("Success!")

def computeImage(total=1000):
    convertTOImage('eagle',total)
    convertTOImage('friends',total)
    convertTOImage('mit',total)



def processRaw(fname,N,rawdir):
    perm=torch.randperm(N)
    A_train,A_test=[],[]
    for i in range(N):
        if i%100==0:
            print(i)
        image = Image.open(rawdir+"raw/videos/"+fname+"/frame"+str(perm[i].item()+1000)+".jpg")
        im=np.array(image)
        cur=torch.from_numpy(im).view(im.shape[0]*3,-1).float()
        U, S, V = cur.svd()
        # print(S[0].item())
        # print(cur.max().item(), cur.min().item(), cur.mean().item(), cur.abs().sum())
        div=abs(S[0].item())
        if div<1:
            div=1
            print("Catch!")
            continue
        div/=100
        if np.random.random()<0.8:
            A_train.append(cur/div)
        else:
            A_test.append(cur/div)
    torch.save([A_train,A_test],rawdir+"raw/videos/"+fname+"_"+str(N)+".dat")

def getVideos(videoName,raw, N,rawdir):
    if not videoName in ['mit','eagle','friends']:
        print("Wrong video name!")
        assert(False)
    if N<0:
        N=200

    if raw:
        processRaw(videoName,N,rawdir)
    A_train,A_test=torch.load(rawdir+"raw/videos/"+videoName+"_"+str(N)+".dat")
    n=A_train[0].size(0)
    d=A_train[0].size(1)

    return A_train,A_test,n,d
