import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
# import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../')
from AGNES import AGNES

# Neural network class

class ffn(nn.Module):
  def __init__(self, dim_in, depth, width, dim_out, lrelu_slope = 0.0, gain = np.sqrt(2.0)):
    super(ffn, self).__init__()
    self.depth = depth
    self.layers = nn.ModuleList([nn.Linear(dim_in,width)])
    self.layers.extend([nn.Linear(width,width) for i in range(1,depth-1)])
    self.layers.extend([nn.Linear(width,dim_out)])
    # self.activation = torch.nn.ReLU()
    # self.activation = torch.nn.Tanh()
    self.activation = torch.nn.LeakyReLU(negative_slope = lrelu_slope)

    for l in range(depth):
      #torch.nn.init.xavier_uniform_(self.layers[l].weight, gain =gain)
      torch.nn.init.kaiming_uniform_(self.layers[l].weight)
        

  def forward(self, x):
    for i in range(self.depth-1):
      x= self.activation(self.layers[i](x))
      # x = self.normalize[i](x)
    x = self.layers[self.depth-1](x)
    return x

def generate_data(depth, width, dim_out, n=100000, lrelu_slope = .5):
    dim=12 #input dimension
    # Generating dataset
    
    # Last 4 dimensions distributed normally 
    x_data = torch.normal(mean = 2.0*torch.zeros(dim* n), std = 3* torch.ones(dim*n))
    x_data = torch.reshape(x_data, [n, dim])
    # First 4 dimensions random integers
    x_data[:, range(4)] = torch.randint(high = 10, size = [n,4], dtype =  torch.float32)
    # Middle 4 dimensions distributed uniformly
    x_data[:,range(4,8)] = torch.rand(size = [n,4], dtype =  torch.float32)
    
    # Small teacher network for labels
    teacher_network = ffn(dim, depth, width, dim_out, lrelu_slope)
    y_data = teacher_network(x_data).detach()
    
    Data = torch.utils.data.TensorDataset(x_data, y_data)
    
    train_set_size = int(.9*n)
    test_set_size = n - train_set_size
    train_data, test_data = torch.utils.data.random_split(Data, [train_set_size, test_set_size])
    os.makedirs(f'teacher_d{depth}w{width}o{dim_out}', exist_ok = True)
    torch.save((train_data,test_data), f'teacher_d{depth}w{width}o{dim_out}/data_d{depth}w{width}o{dim_out}')
    #return (train_data,test_data)

class trainer:

    def __init__(self, model, opt_name, train_loader, test_loader):
        # opt_name : str
        #    The name of the optimizer, can be one 'AGNES', 'ADAM', 'SGD0.99M', 'SGD', or 'SGD0.9M'
        #    If loading a pre-trained model, opt_name will be replaced by the corresponding opt_name in the saved file

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.use_cuda = torch.cuda.is_available()
        self.net = model.cuda() if self.use_cuda else model
        self.opt_name = opt_name
        self.train_accuracies = []
        self.train_losses =[]
        self.test_accuracies = []
        self.test_losses = []
        self.start_epoch = 0
        self.train_loader = train_loader
        self.test_loader = test_loader
        exec('self.optimizer ='+opt_name)
        
    def r2_score(self, y_true, y_pred):
        """
        Calculate the R-squared score.

        Parameters:
        - y_true: torch.Tensor, actual values.
        - y_pred: torch.Tensor, predicted values.

        Returns:
        - r2: torch.Tensor, the R-squared score.
        """
        ss_res = torch.sum((y_true - y_pred) ** 2)
        ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
        r2 = 1 - ss_res / ss_tot
        return r2

    def train(self, save_dir, num_epochs=100, schedule_lr_epochs=0, lr_factor=1, test_each_epoch=True, verbose=False, seed=False):
        """Trains the network.

        Parameters
        ----------
        save_dir : str
            The directory in which the parameters will be saved
        opt_name : str
            The name of the optimizer, can be one 'AGNES', 'ADAM', 'SGD0.99M', 'SGD', or 'SGD0.9M'
        num_epochs : int
            The number of epochs
        batch_size : int
            The batch size
        learning_rate : float
            The learning rate
        test_each_epoch : boolean
            True: Test the network after every training epoch, False: no testing
        verbose : boolean
            True: Print training progress to console, False: silent mode
        schedule_lr : int
            Number of epochs after which the learning rate (and correction step size for AGNES) is multiplied by a factor of lr_factor
            If schedule_lr==0, then a constant learning rate is used
        lr_factor : float

        """

        # if self.opt_name == 'AGNES':
        #     self.optimizer = AGNES(self.net.parameters(), weight_decay=1e-5)
        # elif self.opt_name == 'ADAM':
        #     self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3, weight_decay=1e-5)
        # elif self.opt_name == 'SGD': 
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, weight_decay=1e-5)
        # elif self.opt_name == 'SGD0.9M':
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-5)
        # elif self.opt_name == 'SGD0.99M':
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99, weight_decay=1e-5)

        if seed:
            torch.manual_seed(0)
            

        loss_function = torch.nn.MSELoss().cuda() if self.use_cuda else torch.nn.MSELoss()

        if self.start_epoch==0: #computing test loss and accuracy before the training starts
            self.test(loss_function)
#             test_loss, test_accuracy = self.test(test_loader, loss_function)
#             self.test_losses.append(test_loss)
#             self.test_accuracies.append(test_accuracy)
        print(self.opt_name)
        for epoch in range(self.start_epoch + 1, num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, num_epochs))
            

            for input, targets in self.train_loader:

                input = input.to(self.device, dtype=self.net.layers[0].weight.dtype)
                targets = targets.to(self.device, dtype=self.net.layers[0].weight.dtype)

                self.optimizer.zero_grad()
                #nn.utils.clip_grad_norm_(model.parameters(), 1, error_if_nonfinite=True)

                outputs = self.net(input)
                #outputs = torch.squeeze(outputs)

                loss = loss_function(outputs, targets)

                loss.backward()
                self.optimizer.step()

                #_,prediction = torch.max(outputs, axis=1) #discard values, only keep indices
                #correct += (prediction == targets).sum().item()

                # for param in model.parameters():
                #   print(param.grad)
                # print(loss.item())
                #average_loss += loss.item()
                self.train_losses.append(loss.item())
                self.train_accuracies.append(self.r2_score(targets, outputs).item())
                if torch.isnan(loss):
                    return

#                 if verbose:
#                     # Update progress bar in console
#                     info_str = 'Last batch accuracy: {:.4f} - Running epoch accuracy {:.4f}'.\
#                                 format(batch_correct / batch_total)
#                     progress_bar.update(max_value=len(train_loader), current_value=i, info=info_str)

#             #self.train_accuracies.append(epoch_correct / epoch_total)
#             if verbose:
#                 progress_bar.new_line()

            if test_each_epoch:
                self.test(loss_function)
#                 test_loss, test_accuracy = self.test(test_loader, loss_function)
#                 self.test_losses.append(test_loss)
#                 self.test_accuracies.append(test_accuracy)
#                 if verbose:
#                     print('Test R²: {}'.format(test_accuracy))

            if schedule_lr_epochs:
                if epoch%schedule_lr_epochs == 0:
                #update the learning rate every schedule_lr_epochs epochs
                    for g in self.optimizer.param_groups:
                        g['lr'] *= lr_factor #updating the learning rate
                        if 'correction' in g.keys():
                            g['correction'] *= lr_factor #updating the correction step for AGNES

            # Save parameters after every 10 epochs
            if epoch%10==0 or epoch==num_epochs:
                self.save_parameters(epoch, directory=save_dir)
                

    def test(self, loss_function):
        """Tests the network.
        Currently set up to work only for the case when there is only one batch
        """
        self.net.eval()
        with torch.no_grad():
            for input, targets in self.test_loader:
                input = input.to(self.device, dtype=self.net.layers[0].weight.dtype)
                targets = targets.to(self.device, dtype=self.net.layers[0].weight.dtype)
            
                outputs = self.net(input)
                #outputs = torch.squeeze(outputs)
                loss = loss_function(outputs, targets)
                self.test_losses.append(loss.item())
                self.r2_score(targets, outputs)
                self.test_accuracies.append(self.r2_score(targets, outputs).item())
        
        self.net.train()
        #return (loss.item(), self.r2_score(targets, outputs).item())

    def save_parameters(self, epoch, directory):
        """Saves the parameters of the network to the specified directory.

        Parameters
        ----------
        epoch : int
            The current epoch
        directory : str
            The directory to which the parameters will be saved
        """
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'opt_name': self.opt_name,
            'epoch': epoch,
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            #'scheduler_state_dict': self.scheduler.state_dict(),
            'train_accuracies': self.train_accuracies,
            'train_losses': self.train_losses,
            'test_accuracies': self.test_accuracies,
            'test_losses': self.test_losses,
        }, directory +'_'+ str(epoch) + '.pth')
        
        
    def load_parameters(self, path):
        """Loads the given set of parameters.

        Parameters
        ----------
        path : str
            The file path pointing to the file containing the parameters
        """
        checkpoint = torch.load(path, map_location=self.device)

        # self.opt_name = checkpoint['opt_name']

        self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_accuracies = checkpoint['train_accuracies']
        self.train_losses = checkpoint['train_losses']
        self.test_accuracies = checkpoint['test_accuracies']
        self.test_losses = checkpoint['test_losses']
        self.start_epoch = checkpoint['epoch']



def plot_results(batch, epochs, dir_name, train_size=9e4, decay=.99, title="Batch", runs=1):
    data = {}
    epoch_step = train_size/batch
    total_steps = epoch_step*epochs
    #dir_name = "batch10_d15w15"    
    
    metrics = ['Test R²', 'Training R²', 'Test Loss', 'Training Loss']
    
    names = []
    max_acc = {}
    for filename in os.listdir(dir_name):
        for i in range(runs):
            if filename.startswith(str(i)) and filename.endswith(f"_{epochs}.pth"):
                name = filename[1:-len(f"_{epochs}.pth")]
                if name not in names:
                    names.append(name)
                    data[name] = {metric:[] for metric in metrics}
                with open(os.path.join(dir_name,filename), 'rb') as file:
                    temp = torch.load(file, map_location=torch.device('cpu'))
                    data[name]['Test Loss'].append(np.array(temp['test_losses']))
                    data[name]['Test R²'].append(np.array(temp['test_accuracies']))
                    running_averages = [temp['train_losses'][0]]
                    for num in temp['train_losses']:
                        running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                    data[name]['Training Loss'].append(np.array(running_averages))
                    running_averages = [temp['train_accuracies'][0]]
                    for num in temp['train_accuracies']:
                        running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                    data[name]['Training R²'].append(np.array(running_averages))
    #                 data[name]['Max Accuracy'].append(np.maximum.accumulate(data[name]['Test R²'][i]))
    #                 max_acc[f'{name}_{i}']=data[name]['Max Accuracy'][i][-1]

    
    
    # for metric in metrics[:2]:
    #     data['AGNES, eta=5e-3'][metric][0].pop(70)
    # del data['AGNES, eta=5e-3'][metrics[2]][0][70*200:71*200]
    
    # with open(os.path.join(dir_name,f'max_acc_{title}.txt'), 'w') as file:
    #     json.dump(max_acc, file, indent = 3)  
    
    metric='Test R²'
    plt.figure()
    for name in names:
        mean = np.mean(data[name][metric], axis = 0)
        std = np.std(data[name][metric] , axis = 0)
    
        plt.plot(np.arange(0,total_steps+1,epoch_step), mean, label = name)#, color = colors[name])
        plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])
    
    plt.title(title+metric)
    plt.legend()
    plt.ylim([.9,1])
    #plt.show()
    plt.savefig(os.path.join(dir_name,title+metric))
    #     plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))
    
    metric='Training R²'
    plt.figure()
    for name in names:
        mean = np.mean(data[name][metric], axis = 0)
        std = np.std(data[name][metric] , axis = 0)
    
        plt.plot(mean, label = name)#, color = colors[name])
        plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])
    
    plt.title(title+metric)
    plt.legend()
    plt.ylim([.9,1])
    #plt.show()
    plt.savefig(os.path.join(dir_name,title+metric))
    
    
    metric='Test Loss'
    plt.figure()
    for name in names:
        mean = np.mean(data[name][metric], axis = 0)
        std = np.std(data[name][metric] , axis = 0)
    
        plt.semilogy(np.arange(0,total_steps+1,epoch_step), mean, label = name)#, color = colors[name])
        plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])
    
    plt.title(title+metric)
    plt.legend()
    #plt.show()
    plt.savefig(os.path.join(dir_name,title+metric))
    
    metric='Training Loss'
    plt.figure()
    for name in names:
        mean = np.mean(data[name][metric], axis = 0)
        std = np.std(data[name][metric] , axis = 0)
    
        plt.semilogy(mean, label = name)#, color = colors[name])
        plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])
    
    plt.title(title+metric)
    plt.legend()
    #plt.show()
    plt.savefig(os.path.join(dir_name,title+metric))
