from omegaconf import DictConfig, OmegaConf
import hydra
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm as tqdm
import argparse
import logging
import json
import sys

import models

logging.basicConfig(level = logging.INFO)

log = logging.getLogger(__name__)
USER = os.getenv('USER')
if USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/BrainBitsWIP/data/predicted_features/')
elif USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/projects/brainbits/BrainBitsWIP/data/predicted_features/')
else:
    raise ValueError(f"Unknown user {USER}")
BD_ROOT_PATH = Path('/storage/user1/brain-diffuser')


class fMRI2latent(Dataset):
    def __init__(self, fmri_data, latents_data):
        self.fmri_data = fmri_data
        self.latents_data = latents_data

    def __len__(self):
        return len(self.fmri_data)

    def __getitem__(self, idx):
        return {"inputs": torch.FloatTensor(self.fmri_data[idx]), 
                "targets": torch.FloatTensor(self.latents_data[idx])}

class NoBottleneck(nn.Module):
    def __init__(self, input_size, output_size, multi_gpu=False):
        super().__init__()
        self.fc1 = nn.Linear(input_size, output_size)#.cuda(0) #TODO

    def forward(self, x):
        x = self.fc1(x)
        return x

class BottleneckLinear(nn.Module):
    def __init__(self, input_size, bottleneck_size, output_size, multi_gpu=False):
        super().__init__()
        self.fc1 = nn.Linear(input_size, bottleneck_size)#.cuda(0) #TODO
        self.bn1 = nn.BatchNorm1d(bottleneck_size)#.cuda(1) #TODO
        #self.fc2 = nn.Linear(bottleneck_size, bottleneck_size)
        #self.fc3 = nn.Linear(bottleneck_size, bottleneck_size)
        #self.fc4 = nn.Linear(bottleneck_size, bottleneck_size)
        self.fc5 = nn.Linear(bottleneck_size, output_size)#.cuda(1) #TODO
        self.multi_gpu = multi_gpu
        #self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        #x = self.pool(F.relu(self.conv1(x)))
        #x = self.pool(F.relu(self.conv2(x)))
        #x = torch.flatten(x, 1) # flatten all dimensions except batch
        #x = x.cuda(0) #TODO
        if self.multi_gpu:
            x = x.cuda(0) #TODO
        x = self.bn1(self.fc1(x))
        #x = x.cuda(1) #TODO
        #x = F.relu(x)
        #x = self.fc2(x)
        #x = self.fc2(F.relu(x))
        #x = self.fc3(F.relu(x))
        #x = self.fc4(F.relu(x))
        #x = self.fc5(F.relu(x))
        x = self.fc5(x)
        return x

def train_linear_mapping(model, train_loader, reg_cfg):
    criterion = nn.MSELoss()
    if reg_cfg.optim == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=reg_cfg.lr, momentum=0.9, weight_decay=0.001)
    elif reg_cfg.optim == "Adam":
        optimizer = optim.AdamW(model.parameters(), lr=reg_cfg.lr, weight_decay=0.001)
    else:
        print("no optim")
        import pdb; pdb.set_trace()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20)
    for epoch in range(reg_cfg.n_epochs):
        with tqdm(total=len(train_loader)) as bar:
            bar.set_description(f"Epoch {epoch}")
            train_loss = 0
            for batch in train_loader:
                inputs = batch["inputs"].to(reg_cfg.device)
                #targets = batch["targets"].to(reg_cfg.device) #TODO
                #targets = batch["targets"].cuda(1)
                targets = batch["targets"].to(reg_cfg.device)
                optimizer.zero_grad()
                outputs = model(inputs)
                if isinstance(outputs, tuple):  # VQ-VAE
                    outputs, vq_loss, perplexity = outputs
                    loss = criterion(outputs, targets) + vq_loss
                else:
                    loss = criterion(outputs, targets)
                loss.backward()
                #print(loss.item())
                optimizer.step()
                bar.set_postfix(mse=float(loss))
                bar.update()
                train_loss += float(loss)
                #if epoch==18:
                #    import pdb; pdb.set_trace()
            
            avg_loss = train_loss/len(train_loader)
            bar.set_postfix(mse=avg_loss)
        scheduler.step(avg_loss)
    return model

def eval_model(model, test_loader, device): #, test_fmri):#TODO):
    model.eval()
    with torch.no_grad():
        all_preds = []
        for batch in tqdm(test_loader):
            inputs = batch["inputs"].to(device)
            predicted = model(inputs)
            if isinstance(predicted, tuple):  # VQ-VAE
                predicted = predicted[0]
            all_preds.append(predicted)
        all_preds = torch.cat(all_preds)
    pred_test_latent = all_preds.cpu().detach().numpy()
    return pred_test_latent

def scale_latents(pred_test_latent, train_latents):
    std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
    pred_latents = std_norm_test_latent * np.std(train_latents,axis=0) + np.mean(train_latents,axis=0)
    return pred_latents

def vdvae_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg):
    log.info("Getting VDVAE targets")

    #get latent targets
    nsd_path = 'data/extracted_features/subj{:02d}/nsd_vdvae_features_31l.npz'.format(sub)
    nsd_features = np.load(BD_ROOT_PATH / nsd_path)

    train_latents = nsd_features['train_latents']
    test_latents = nsd_features['test_latents']

    #train_latents = train_fmri*5 #+ 0.001*np.random.rand(256, 50) #TODO
    #test_latents = test_fmri*5 #+ 0.001*np.random.rand(256, 50) #TODO

    pred_latents, results = fmri2latent_regression(sub, train_fmri, test_fmri, train_latents, test_latents, bottleneck_size, reg_cfg)

    save_path_dir = SAVE_ROOT_PATH / f'subj_{sub}/bbits_{bottleneck_size}/'
    save_path_dir.mkdir(parents=True, exist_ok=True)
    np.save(save_path_dir / 'vdvae_preds.npy', pred_latents)

    with open(save_path_dir / "error.json", "w") as f:
        json.dump(results, f)

def fmri2latent_regression(sub, train_fmri, test_fmri, train_latents, test_latents, bottleneck_size, reg_cfg):
    train_data = fMRI2latent(train_fmri, train_latents)
    train_loader = DataLoader(train_data, batch_size=reg_cfg.batch_size, shuffle=True)

    blank_data = [0 for _ in range(len(test_fmri))]
    test_data = fMRI2latent(test_fmri, blank_data)
    test_loader = DataLoader(test_data, batch_size=reg_cfg.batch_size)
    first_batch = next(iter(test_loader))

    if reg_cfg.model == "linear":
        if bottleneck_size == -1:
            model = NoBottleneck(train_fmri.shape[-1], train_latents.shape[-1])
        else:
            model = BottleneckLinear(train_fmri.shape[-1], bottleneck_size, train_latents.shape[-1])
    elif reg_cfg.model == "vqvae":
        if bottleneck_size == -1:
            log.warning(f'VQVAE requires a bottleneck size, cannot be -1. Setting to {reg_cfg.vqvae.mlp_dim}')
            bottleneck_size = reg_cfg.vqvae.mlp_dim
        model = models.VQVAE(inp_dim=train_fmri.shape[-1], mlp_dim=reg_cfg.vqvae.mlp_dim, out_dim=train_latents.shape[-1], 
                      codebook_size=reg_cfg.vqvae.codebook_size, tokens=reg_cfg.vqvae.tokens, bottleneck_dim=bottleneck_size, 
                      commitment_cost=reg_cfg.vqvae.commitment_cost, decay=reg_cfg.vqvae.decay)

    model = model.to(reg_cfg.device)
    #if device=="cuda":
    #    model= nn.DataParallel(model)
    log.info("Training fMRI2latent mapping")
    model = train_linear_mapping(model, train_loader, reg_cfg)
    
    log.info("fMRI2latent test evaluation")
    pred_test_latent = eval_model(model, test_loader, reg_cfg.device)#, test_fmri)
    #pred_latents = pred_test_latent #TODO
    pred_latents = scale_latents(pred_test_latent, train_latents)
    #print(abs(pred_test_latent - test_latents).mean())
    #print(abs(pred_latents - test_latents).mean())
    #print(abs(pred_test_latent).mean(), abs(test_latents).mean())
    results = {"mse": ((pred_test_latent - test_latents)**2).mean().item(),
              "scaled_mse": ((pred_latents - test_latents)**2).mean().item()}
    log.info("fMRI2latent test results: {}".format(results))
    model.eval()
    with torch.no_grad():
        train_out = model(torch.FloatTensor(train_fmri).to(reg_cfg.device))
        test_out = model(torch.FloatTensor(test_fmri).to(reg_cfg.device))
    return pred_latents, results

def get_fmri_inputs(sub):
    #get fmri inputs
    log.info("Getting fMRI inputs")
    train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    train_fmri = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    test_fmri = np.load(BD_ROOT_PATH / test_path)

    train_fmri = train_fmri/300
    test_fmri = test_fmri/300

    norm_mean_train = np.mean(train_fmri, axis=0)
    norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
    train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
    test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
    #train_fmri = np.random.rand(256, 50) #TODO
    #test_fmri = np.random.rand(256, 50) #TODO
    return train_fmri, test_fmri

def clip_text_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg):
    log.info("Getting CLIP text targets")

    #get latent targets
    train_path = 'data/extracted_features/subj{:02d}/nsd_cliptext_train.npy'.format(sub)
    train_clip = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/extracted_features/subj{:02d}/nsd_cliptext_test.npy'.format(sub)
    test_clip = np.load(BD_ROOT_PATH / test_path)
    out_name = "clip_text_preds"

    #train_clip = np.random.rand(256,7, 50) #TODO
    #test_clip = np.random.rand(256, 7, 50) #TODO

    clip_regression(sub, train_fmri, test_fmri, train_clip, test_clip, out_name, bottleneck_size, reg_cfg)

def clip_vision_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg):
    log.info("Getting CLIP vision targets")

    #get latent targets
    train_path = 'data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub)
    train_clip = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub)
    test_clip = np.load(BD_ROOT_PATH / test_path)
    out_name = "clip_vision_preds"

    #train_clip = np.random.rand(256,7, 50) #TODO
    #test_clip = np.random.rand(256, 7, 50) #TODO

    clip_regression(sub, train_fmri, test_fmri, train_clip, test_clip, out_name, bottleneck_size, reg_cfg)

def clip_regression(sub, train_fmri, test_fmri, train_clip, test_clip, out_name, bottleneck_size, reg_cfg):
    num_samples,num_embed,num_dim = train_clip.shape
    pred_clip = []
    all_results = {}
    for i in range(num_embed):
        train_latents = train_clip[:,i]
        test_latents = test_clip[:,i]
        pred_latents, results = fmri2latent_regression(sub, train_fmri, test_fmri, train_latents, test_latents, bottleneck_size, reg_cfg)
        pred_clip.append(pred_latents)
        all_results[i] = results
    pred_clip = np.stack(pred_clip, axis=1)

    save_path_dir = SAVE_ROOT_PATH / f'subj_{sub}/bbits_{bottleneck_size}'
    save_path_dir.mkdir(exist_ok=True, parents=True)
    np.save(save_path_dir / f"{out_name}.npy", pred_clip)  #nsd_cliptext_predtest_nsdgeneral.npy

    with open(save_path_dir / "error.json", "w") as f:
        json.dump(all_results, f)

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run testing for all electrodes in all test_subjects")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    sub = cfg.exp["sub"]

    train_fmri, test_fmri = get_fmri_inputs(sub)
    
    bottleneck_sizes = cfg.exp["bottlenecks"]
    reg_cfg = cfg.exp.reg
    for bottleneck_size in bottleneck_sizes:
        vdvae_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg)
        clip_text_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg)
        clip_vision_regression(sub, train_fmri, test_fmri, bottleneck_size, reg_cfg)

if __name__=="__main__":
    # _debug = '''train.py +exp=latent_reg ++exp.bottlenecks=[5] ++exp.reg.batch_size=128 ++exp.reg.n_epochs=1 ++exp.reg.optim="SGD" ++exp.reg.device="cpu"'''
    # sys.argv = _debug.split(" ")
    main()
