#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 11:44:29 2024

@author: anonymous
"""

import numpy as np
import torch 
import time
import tqdm
import pickle
import matplotlib.pyplot as plt 



###############################################################################



class CNN(torch.nn.Module):
    """ Convolutional neural network.

	Parameters
	----------
	dim : Int 
			takes images of size num_channels x dim x dim as input
    num_channels : Int
            number of input channels 
	num_classes : Int 
			number of classes
    	ks : Int 
    			kernel size in convolutional layers
	conv_channels : Int
            number of channels in conv layers
    lin1 : Int
            number of parameters of first dense layer
    lin2 : Int
            number of parameters of second dense layer            
    reg : Float
             parameter for dropout layers
    bn : Bool
             if True, batch normalization is used
    """ 
    def __init__(self, dim, num_channels, num_classes, ks, conv_channels, lin1, lin2, reg, bn):
        super(CNN, self).__init__()
        
        self.conv1=torch.nn.Conv2d(in_channels=num_channels, out_channels=conv_channels, kernel_size=(ks, ks),padding='same')
        self.maxpool1=torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2=torch.nn.Conv2d(in_channels=conv_channels, out_channels=conv_channels, kernel_size=(ks, ks), padding='same')
        self.maxpool2=torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv3=torch.nn.Conv2d(in_channels=conv_channels, out_channels=conv_channels, kernel_size=(ks, ks), padding='same')
        self.maxpool3=torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.fc1=torch.nn.Linear(in_features=conv_channels*int(int(int(dim/2)/2)/2)**2, out_features=lin1)
        self.fc2=torch.nn.Linear(in_features=lin1, out_features=lin2)
        self.fc3=torch.nn.Linear(in_features=lin2, out_features=num_classes)
        
        self.bn1=torch.nn.BatchNorm1d(lin1)
        self.bn2=torch.nn.BatchNorm1d(lin2)
        
        self.relu=torch.nn.ReLU() 
        self.reg=reg
        self.bn=bn
        
    def forward(self, x):
        x=self.relu(self.conv1(x))
        x=self.maxpool1(x)
        x=self.relu(self.conv2(x))
        x=self.maxpool2(x)
        x=self.relu(self.conv3(x))
        x=self.maxpool3(x)
        
        x=torch.flatten(x,1)
        
        x=torch.nn.functional.dropout(x, p=reg, training=self.training)
        if self.bn==True:
            x=self.relu(self.bn1(self.fc1(x)))
        else:
            x=self.relu(self.fc1(x))
        x=torch.nn.functional.dropout(x, p=reg, training=self.training)
        if self.bn==True:
            x=self.relu(self.bn2(self.fc2(x)))
        else:
            x=self.relu(self.fc2(x))
        x=torch.nn.functional.dropout(x, p=reg, training=self.training)
        x=self.fc3(x)
        
        return x
    
    
    
###############################################################################



def fit_verbose(model, train_loader, test_loader, optimizer, loss_function, Epochs):
    """ Trains model on training set and tests it on validation/test 
        set after every epoch. Plots the progression of training and 
        validation loss and accuracy at the end.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loader : Torch data loader 
        			training dataset
	test_loader : Torch data loader 
        			validation dataset
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """   
    l1=len(train_loader)
    l2=len(test_loader)
    Train_Acc=[]
    Test_Acc=[]
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda")  
    
    for epoch in range(Epochs):
        train_loss=0
        test_loss=0
        train_acc=0
        test_acc=0
        c1=0
        c2=0
        it=iter(train_loader)
        print('\n')
        print('Epoch', epoch+1,'/',Epochs,'\n')
        print('Training:','\n')
        
        model.train()
        
        for i in tqdm.tqdm(range(0,l1)):
            data=next(it)
            x=data[0].to(device)
            y=data[1].to(device)
            model.zero_grad()
            output=model(x)
            loss=loss_function(output, y) 
            loss.backward()
            optimizer.step()
            
        
        if (epoch+1)%1==0:
            model.eval()
            
            it=iter(train_loader)
            print('\n')
            print('Evaluating train data:','\n')

            for i in tqdm.tqdm(range(0,l1)):
                data=next(it)
                x=data[0].to(device)
                y=data[1].to(device)
                c1+=len(x)
                output=model(x) 
                loss=loss_function(output, y) 
                train_loss+=loss.item()
                pred=output.argmax(dim=1)
                train_acc+=int((pred==y).sum())
             
            it=iter(test_loader)
            print('\n')
            print('Evaluating test data:','\n')
                
            for i in tqdm.tqdm(range(0,l2)):
                data=next(it)
                x=data[0].to(device)
                y=data[1].to(device)
                c2+=len(x)
                output=model(x) 
                loss=loss_function(output, y) 
                test_loss+=loss.item()
                pred=output.argmax(dim=1)
                test_acc+=int((pred==y).sum())

            
            print('\n')
            print('Train_Loss: ',round(train_loss/l1,3),' Train_Acc: ',round(train_acc/c1,3),' Test_Loss: ',round(test_loss/l2,3),' Test_Acc: ',round(test_acc/c2,3))
            Train_Acc.append(round(train_acc/c1,3))
            Test_Acc.append(round(test_acc/c2,3))
            
    x_axis=np.arange(1,Epochs+1,1)
    fig, ax = plt.subplots(figsize=(5, 2.7), layout='constrained')
    ax.plot(x_axis,Train_Acc, label='Train Accuracy')  
    ax.plot(x_axis,Test_Acc, label='Test Accuracy')  
    plt.axhline(y=0.5, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.6, color='k', linewidth=0.5, label='_nolegend_')
    plt.axhline(y=0.7, color='k', linewidth=0.5, label='_nolegend_')
    ax.set_xlabel('Epochs') 
    ax.set_ylabel('Accuracy') 
    ax.legend()  
    print('\n')
    print('Mean test accuracy: ', round(np.mean(np.array(Test_Acc)[-10:]),3))
        
    
    
###############################################################################   
    


def fit(model, train_loader, optimizer, loss_function, Epochs):
    """ Trains model on training set.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	train_loader : Torch data loader 
        			training dataset
    	optimizer : Torch optimizer 
        			optimizer for training
	loss_function : Torch loss function 
        			loss function for training
    Epochs : Int
                number of training epochs              
    """    
    l=len(train_loader)
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    print('\n')
    print('Training model:','\n')
    
    model.train()
    
    for epoch in tqdm.tqdm(range(Epochs)):        
        it=iter(train_loader)
        
        for i in range(0,l):
            data=next(it)
            x=data[0].to(device)
            y=data[1].to(device)
            model.zero_grad()
            output=model(x)
            loss=loss_function(output, y) 
            loss.backward()
            optimizer.step()   

  

###############################################################################   



def evaluate(model, test_loader, loss_function):
    """ Evaluates model on test set.

	Parameters
	----------
	model : Torch neural network 
        			the model to train
	test_loader : Torch data loader 
        			test dataset
	loss_function : Torch loss function 
        			loss function for training             
    """    
    l=len(test_loader)
    c=0
    test_loss=0
    test_acc=0
    torch.backends.cudnn.benchmark=True
    device=torch.device("cuda") 

    it=iter(test_loader)
    print('\n')
    print('Evaluating test data','\n')
    
    model.eval()
                
    for i in tqdm.tqdm(range(0,l)):
        data=next(it)
        x=data[0].to(device)
        y=data[1].to(device)
        c+=len(x)
        output=model(x) 
        loss=loss_function(output, y) 
        test_loss+=loss.item()
        pred=output.argmax(dim=1)
        test_acc+=int((pred==y).sum())

    print('\n')
    print(' Test_Loss: ',round(test_loss/l,3),' Test_Acc: ',round(test_acc/c,3))
    return round(test_acc/c,3)



###############################################################################


# data_x=np.array(pickle.load(open('Data/persistence_images.txt','rb')))
data_x=np.array(pickle.load(open('Data/multipers_persistence_images.txt','rb')))[:,1,:,:].reshape(5000,1,100,100)
# data_x=np.array(pickle.load(open('Data/multipers_persistence_landscapes.txt','rb')))
# data_x=np.array(pickle.load(open('Data/multipers_signed_measure_convs.txt','rb')))
# data_x=np.array(pickle.load(open('Data/gril_landscapes.txt','rb')))
data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))


l=len(data_x) #Size of dataset
s=0.8  # Train/Test split
Batchsize=100 # Batch size

print('Data: ',data_x.shape,'\n')

torch.backends.cudnn.benchmark=True
device=torch.device("cuda")  


num_classes=5 #Number of classes
num_channels=1 #Number of input channels
dim=100 # Size of images
conv_channels=20 # Number of convolution channels
ks=5 # Kernel size
lin1=2000 #Number of neurons in dense layer 1
lin2=1000 #Number of neurons in dense layer 2
reg=0.5 #Parameter for dropout regularization
bn=True #If True uses batch normalization



###############################################################################

"""Train and evaluate model in a tensorflow style using the test 
   set as validation set and plot the progression of train and test accuracy"""
   
   
# perm=np.random.permutation(len(data_y))

# data_x=data_x[perm]
# data_y=data_y[perm]

# train_dataset=torch.utils.data.TensorDataset(torch.from_numpy(data_x[:int(l*s)]).float(),torch.from_numpy(data_y[:int(l*s)]))
# test_dataset=torch.utils.data.TensorDataset(torch.from_numpy(data_x[int(l*s):]).float(),torch.from_numpy(data_y[int(l*s):]))

# train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=Batchsize, shuffle=True)
# test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=Batchsize, shuffle=False)

# model=CNN(dim, num_channels, num_classes, ks, conv_channels, lin1, lin2, reg, bn)
# print(model,'\n')

# model.to(device)

# optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
# loss_function=torch.nn.CrossEntropyLoss()

# fit_verbose(model, train_loader, test_loader, optimizer, loss_function, 100)


###############################################################################

"""Train and evaluate model N times on training and test set respectively
   and compute average test set accuracy"""

accs=[]
N=20

for i in range(N):
    print('\n')
    print('Run '+str(i+1))
    
    perm=np.random.permutation(len(data_y))

    data_perm_x=data_x[perm]
    data_perm_y=data_y[perm]

    train_dataset=torch.utils.data.TensorDataset(torch.from_numpy(data_perm_x[:int(l*s)]).float(),torch.from_numpy(data_perm_y[:int(l*s)]))
    test_dataset=torch.utils.data.TensorDataset(torch.from_numpy(data_perm_x[int(l*s):]).float(),torch.from_numpy(data_perm_y[int(l*s):]))

    train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=Batchsize, shuffle=True)
    test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=Batchsize, shuffle=False)
    
    model=CNN(dim, num_channels, num_classes, ks, conv_channels, lin1, lin2, reg, bn)
    model.to(device)

    optimizer=torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function=torch.nn.CrossEntropyLoss()
        
    fit(model, train_loader, optimizer, loss_function, 100)
    acc=evaluate(model, test_loader, loss_function)
    accs.append(acc)

accs=np.array(accs)
print('\n')
print('Mean test accuracy over '+str(N)+' runs: ',round(np.mean(accs),3),'\n')
print('Standard deviation of test accuracy over '+str(N)+' runs: ',round(np.std(accs),3),'\n')



###############################################################################

"""Parameters: persistence images"""

# data_x=np.array(pickle.load(open('Data/persistence_images.txt','rb')))
# data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))

# num_classes=5
# num_channels=1
# dim=100
# conv_channels=20
# ks=5
# lin1=2000
# lin2=1000
# reg=0.5
# bn=True


"""Parameters: multiparameter persistence images"""

# data_x=np.array(pickle.load(open('Data/multipers_persistence_images.txt','rb')))[:,1,:,:].reshape(5000,1,100,100) #to get only H1 images
# data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))


# num_classes=5
# num_channels=1
# dim=100
# conv_channels=20
# ks=5
# lin1=2000
# lin2=1000
# reg=0.5
# bn=True


"""Parameters: multiparameter persistence landscapes"""

# data_x=np.array(pickle.load(open('Data/multipers_persistence_landscapes.txt','rb')))
# data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))

# num_classes=5
# num_channels=5
# dim=100
# conv_channels=20
# ks=5
# lin1=2000
# lin2=1000
# reg=0.5
# bn=True
 
 
"""Parameters: multiparameter signed measure convolutions"""

# data_x=np.array(pickle.load(open('Data/multipers_signed_measure_convs.txt','rb')))
# data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))

# num_classes=5
# num_channels=1
# dim=100
# conv_channels=20
# ks=5
# lin1=2000
# lin2=1000
# reg=0.5
# bn=True


"""Parameters: GRIL landscapes"""

# data_x=np.array(pickle.load(open('Data/gril_landscapes.txt','rb')))
# data_y=np.array(pickle.load(open('Data/pointcloud_labels.txt','rb')))


# num_classes=5
# num_channels=5
# dim=17
# conv_channels=20
# ks=4
# lin1=2000
# lin2=1000
# reg=0.3
# bn=False
