import os
import torch
import numpy as np
import logging
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import geoopt.optim
from functools import partial
from multiprocessing import Pool

import geoopt

import numpy as np
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl


import logging
import sys
# Configure logger
logger = logging.getLogger(__name__)# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = logging.INFO
logger.setLevel(log_level)


# Define the model
class LinearMap(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearMap, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=True)

    def forward(self, x):
        return self.linear(x)


def train_model(task_name, seed_x):
    ROOT_PATH = "representations-large"
    num_epochs = 20
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    logger.info(f'TASK {task_name}')
    logger.info(f'SEED X: {seed_x}')
    losses_seed = []
    X = torch.load(os.path.join(ROOT_PATH, f"seed-{seed_x}-task-{task_name}-train"), map_location=DEVICE)
    X = X[-1]  # only consider final layer
    for seed_y in [s for s in range(25) if s != seed_x]:
        logger.info(f'SEED Y: {seed_y}')
        Y = torch.load(os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task_name}-train"), map_location=DEVICE)
        Y = Y[-1]
        losses = []
        for lr in [0.0001, 0.001, 0.01, 0.1]:
            model = LinearMap(X.shape[-1], Y.shape[-1])
            model.to(DEVICE)
            # Define optimizer
            optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=lr)

            # Define loss function
            criterion = nn.MSELoss(reduction='none')

            # Convert data into PyTorch tensors and create DataLoader
            dataset = TensorDataset(torch.Tensor(X), torch.Tensor(Y))
            dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
            # Training loop
            for epoch in range(num_epochs):
                model.train()
                epoch_loss = 0.0  # Initialize epoch loss
                epoch_max_loss = 0.0
                for batch_x, batch_y in dataloader:
                    optimizer.zero_grad()
                    out_x = model(batch_x)
                    loss_vec = criterion(out_x, batch_y)
                    loss = loss_vec.mean()
                    max_loss = loss_vec.mean(axis=1).max()
                    max_loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item() * len(batch_x)  # Aggregate loss over the epoch
                    epoch_max_loss = max(max_loss, epoch_max_loss)  # max aggregation
                epoch_loss /= len(dataset)  # Calculate the mean loss over the epoch

                logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Max-Loss: {epoch_max_loss}")
            # losses.append([epoch_loss.detach().numpy(), epoch_max_loss.detach().numpy()])
            # Convert epoch_loss and epoch_max_loss to PyTorch tensors before appending
            epoch_loss_tensor = torch.tensor(epoch_loss, dtype=torch.float32)
            epoch_max_loss_tensor = torch.tensor(epoch_max_loss, dtype=torch.float32)
            # Append accuracy, train mse, and max mse after detaching
            losses.append([
                epoch_loss_tensor.cpu().detach().numpy(), 
                epoch_max_loss_tensor.cpu().detach().numpy()
            ])
            model.eval()
            with torch.no_grad():
                train_preds = model(X)
                torch.save(train_preds, os.path.join(ROOT_PATH, f"v2-intrinsic-maxloss-predictions-{task_name}-small-seedx-{seed_x}-seedy-{seed_y}-lr{lr}"))
                torch.save(model.state_dict(), os.path.join(ROOT_PATH,  f"v2-intrinsic-maxloss-model-{task_name}-small-seedx-{seed_x}-seedy-{seed_y}-lr{lr}"))
        losses_seed.append(losses)
    with open(
        os.path.join(
                ROOT_PATH,
                f"v2-intrinsic-maxloss-accuracies-{task_name}-small-{seed_x}.npy",
            ),
            "wb",
        ) as f:
            np.save(f, np.array(losses_seed))

if __name__ == '__main__':
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    tasks = ["sst2", "mrpc"]
    # tasks = ["mnli"]
    seeds = range(25)
    with Pool(processes=16) as pool:
        pool.starmap(train_model, [(task, seed) for task in tasks for seed in seeds])
