import copy
from tqdm import tqdm as tqdm
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
from omegaconf import DictConfig, OmegaConf
import hydra
import logging
import os
import json
from pathlib import Path

logging.basicConfig(level = logging.INFO)
log = logging.getLogger(__name__)

class LinearMap(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.linear_1 = nn.Linear(input_size, output_size)

    def forward(self, inputs):
        return self.linear_1(inputs)

class Bottleneck2Feature(Dataset):
    def __init__(self, intermediates, labels):
        self.intermediates = intermediates
        self.labels = labels

    def __len__(self):
        return len(self.intermediates)

    def __getitem__(self, idx):
        return {"inputs": torch.FloatTensor(self.intermediates[idx]), 
                "targets": torch.FloatTensor(self.labels[idx]),}

def get_loss(preds, targets, reg_cfg):
    criterion = nn.MSELoss()

    loss = criterion(preds, targets)
    return loss

def get_eval_loss(model, val_loader, reg_cfg):
    criterion = nn.MSELoss()
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in tqdm(val_loader):
            inputs = batch["inputs"].to(reg_cfg.device)
            targets = batch["targets"].to(reg_cfg.device)
            n_batch = inputs.shape[0]
            preds = model(inputs) 
            loss = get_loss(preds, targets, reg_cfg)
            total_loss += loss.item()
    return total_loss/len(val_loader)


def train_linear_mapping(model, train_loader, val_loader, reg_cfg):
    if reg_cfg.optim == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=reg_cfg.lr, momentum=0.9, weight_decay=reg_cfg.weight_decay)
    elif reg_cfg.optim == "Adam":
        optimizer = optim.AdamW(model.parameters(), lr=reg_cfg.lr, weight_decay=reg_cfg.weight_decay)
    else:
        print("no optim")

    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=reg_cfg.n_epochs)
    min_eval_loss = float('inf')
    best_model = None
    for epoch in range(reg_cfg.n_epochs):
        with tqdm(total=len(train_loader)) as bar:
            bar.set_description(f"Epoch {epoch}")
            train_loss, train_vdvae_loss, train_text_loss = 0, 0, 0
            for batch in train_loader:
                inputs = batch["inputs"].to(reg_cfg.device)
                n_batch = inputs.shape[0]
                #targets = batch["targets"].to(reg_cfg.device) #TODO
                #targets = batch["targets"].cuda(1)
                targets = batch["targets"].to(reg_cfg.device)
                optimizer.zero_grad()
                preds = model(inputs)
                loss = get_loss(preds, targets, reg_cfg)
                loss.backward()
                #print(loss.item())
                optimizer.step()
                bar.set_postfix({"loss":float(loss)})
                bar.update()
                train_loss += float(loss)
            avg_loss = train_loss/len(train_loader)

            eval_loss = get_eval_loss(model, val_loader, reg_cfg)
            bar.set_postfix({"eval": eval_loss, "train_loss": avg_loss})
        if eval_loss < min_eval_loss:
            min_eval_loss = eval_loss
            best_model = copy.deepcopy(model)
        scheduler.step()
    return best_model

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run decoding on bottleneck features")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    train_rep_path = cfg.exp.train_rep_path
    with open(os.path.join(train_rep_path), "rb") as f:
        train_intermediates = np.load(f)

    train_features_path = cfg.exp.train_labels_path
    with open(os.path.join(train_features_path), "rb") as f:
        train_labels = np.load(f)

    test_rep_path = cfg.exp.test_rep_path
    with open(os.path.join(test_rep_path), "rb") as f:
        test_intermediates = np.load(f)

    test_features_path = cfg.exp.test_labels_path
    with open(os.path.join(test_features_path), "rb") as f:
        test_labels = np.load(f)

    all_train_dataset = Bottleneck2Feature(train_intermediates, train_labels)
    test_dataset = Bottleneck2Feature(test_intermediates, test_labels)
    #all_train_dataset = Bottleneck2Feature(train_labels, train_labels) #TODO
    #test_dataset = Bottleneck2Feature(test_labels, test_labels) #TODO


    val_split = 0.15 #TODO hardcode
    train_idx, val_idx = train_test_split(list(range(len(all_train_dataset))), test_size=val_split)

    train_data = Subset(all_train_dataset, train_idx)
    val_data = Subset(all_train_dataset, val_idx)

    reg_cfg = cfg.exp.reg_cfg
    train_loader = DataLoader(train_data, batch_size=reg_cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=reg_cfg.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=reg_cfg.batch_size, shuffle=False)

    model = LinearMap(train_intermediates.shape[-1], test_labels.shape[-1])
    #model = LinearMap(test_labels.shape[-1], test_labels.shape[-1]) #TODO
    model = model.to(reg_cfg.device)

    model = train_linear_mapping(model, train_loader, val_loader, reg_cfg)
    test_loss = get_eval_loss(model, val_loader, reg_cfg) 
    print(test_loss)

    Path(cfg.exp.output_path).mkdir(exist_ok=True, parents=True)
    output_path = os.path.join(cfg.exp.output_path, "results.json")
    results = {"loss": test_loss}
    with open(output_path, "w") as f:
        json.dump(results, f)

if __name__=="__main__":
    main()


