import sys

import argparse
import os
import shutil
import pickle, random
import numpy as np
import torch
import torch.utils.tensorboard
from torch.utils.data import Subset
from sklearn.metrics import roc_auc_score
from torch.nn.utils import clip_grad_norm_
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose
from tqdm.auto import tqdm

import utils.misc as misc
import utils.train as utils_train
import utils.transforms as trans
from datasets import get_dataset
from datasets.merge_dataset import MergedProteinLigandData, FOLLOW_BATCH2
from datasets.pl_data import FOLLOW_BATCH
from models.molopt_score_model import ScorePosNet3D
from models.ipo import DPO3D
from graphbap.bapnet import BAPNet



def get_auroc(y_true, y_pred, feat_mode):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    avg_auroc = 0.
    possible_classes = set(y_true)
    for c in possible_classes:
        auroc = roc_auc_score(y_true == c, y_pred[:, c])
        avg_auroc += auroc * np.sum(y_true == c)
        mapping = {
            'basic': trans.MAP_INDEX_TO_ATOM_TYPE_ONLY,
            'add_aromatic': trans.MAP_INDEX_TO_ATOM_TYPE_AROMATIC,
            'full': trans.MAP_INDEX_TO_ATOM_TYPE_FULL
        }
        print(f'atom: {mapping[feat_mode][c]} \t auc roc: {auroc:.4f}')
    return avg_auroc / len(y_true)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('config', type=str)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--logdir', type=str, default='./logs_diffusion')
    parser.add_argument('--tag', type=str, default='')
    parser.add_argument('--train_report_iter', type=int, default=200)
    args = parser.parse_args()

    
    # Load configs
    config = misc.load_config(args.config)
    print(config)
    config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
    misc.seed_all(config.train.seed)

    # Logging
    log_dir = misc.get_new_log_dir(args.logdir, prefix=config_name, tag=args.tag)
    ckpt_dir = os.path.join(log_dir, 'checkpoints')
    os.makedirs(ckpt_dir, exist_ok=True)
    vis_dir = os.path.join(log_dir, 'vis')
    os.makedirs(vis_dir, exist_ok=True)
    logger = misc.get_logger('train', log_dir)
    writer = torch.utils.tensorboard.SummaryWriter(log_dir)
    logger.info(args)
    logger.info(config)
    shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
    shutil.copytree('./models', os.path.join(log_dir, 'models'))

    # Transforms
    protein_featurizer = trans.FeaturizeProteinAtom()
    ligand_featurizer = trans.FeaturizeLigandAtom(config.data.transform.ligand_atom_mode)
    transform_list = [
        protein_featurizer,
        ligand_featurizer,
        trans.FeaturizeLigandBond(),
    ]
    if config.data.transform.random_rot:
        transform_list.append(trans.RandomRotation())
    transform = Compose(transform_list)

    # Datasets and loaders
    logger.info('Loading dataset...')
    dataset, subsets = get_dataset(
        config=config.data,
        transform=transform
    )
    train_set, val_set = subsets['train'], subsets['test']
    

    def get_dpo_data(subset, split='train', fullset=dataset):
        logger.info('Getting DPO {} data...'.format(split))
        
        with open('dpo_idx.pkl', 'rb') as file:
            dpo_idx = pickle.load(file)
        
        train_id = subset.indices
        train_id_set  = set(train_id)
        cleaned_dpo_idx = {k: [item for item in v if item in train_id_set] for k, v in dpo_idx.items() if v}
        cleaned_dpo_idx = {k: v for k, v in cleaned_dpo_idx.items() if v}  # Remove any keys with empty lists

        new_pair = {}
        for idx in tqdm(subset.indices, 'Creating Preference Train Set'):
            if idx in cleaned_dpo_idx.keys():
                losing_id = cleaned_dpo_idx[idx]
                assert losing_id
                new_pair[idx] = losing_id[0]
      

        winning_idx = list(new_pair.keys())
        losing_idx = list(new_pair.values())
        subset_1 = Subset(fullset, winning_idx)
        subset_2 = Subset(fullset, losing_idx)

        dpo_subset = MergedProteinLigandData(subset_1, subset_2)

        return dpo_subset


    dpo_train_set = get_dpo_data(train_set, split='train',  fullset= dataset)
    dpo_val_set = get_dpo_data(val_set, split='val', fullset= dataset)
    del dataset

    logger.info(f'Training: {len(dpo_train_set)} Validation: {len(val_set)}')

    # follow_batch = ['protein_element', 'ligand_element']
    collate_exclude_keys2 = ['ligand_nbh_list', 'ligand_nbh_list2']
    collate_exclude_keys = ['ligand_nbh_list']
    train_iterator = utils_train.inf_iterator(DataLoader(
        dpo_train_set,
        batch_size=config.train.batch_size,
        shuffle=True,
        num_workers=config.train.num_workers,
        follow_batch=FOLLOW_BATCH2,
        exclude_keys=collate_exclude_keys2
    ))
    train_iterator = utils_train.inf_iterator(DataLoader(
        dpo_val_set,
        batch_size=config.train.batch_size,
        shuffle=False,
        num_workers=config.train.num_workers,
        follow_batch=FOLLOW_BATCH2,
        exclude_keys=collate_exclude_keys2
    ))
    # Model
    logger.info('Building model...')

    net_cond = BAPNet(ckpt_path=config.net_cond.ckpt_path, hidden_nf=config.net_cond.hidden_dim).to(args.device)
    net_cond.eval()
    for param in net_cond.parameters():
        param.requires_grad = False

    ckpt = torch.load(config.model.checkpoint, map_location=args.device)

    model = DPO3D(
        config.model,
        protein_atom_feature_dim=protein_featurizer.feature_dim,
        ligand_atom_feature_dim=ligand_featurizer.feature_dim).to(args.device)
    model.load_state_dict(ckpt['model'])

    ref_model = DPO3D(
        config.model,
        protein_atom_feature_dim=protein_featurizer.feature_dim,
        ligand_atom_feature_dim=ligand_featurizer.feature_dim).to(args.device)
    ref_model.load_state_dict(ckpt['model'])
    ref_model.eval()
    for param in ref_model.parameters():
        param.requires_grad = False


    print(f'protein feature dim: {protein_featurizer.feature_dim} ligand feature dim: {ligand_featurizer.feature_dim}')
    logger.info(f'# trainable parameters: {misc.count_parameters(model) / 1e6:.4f} M')

    # Optimizer and scheduler
    optimizer = utils_train.get_optimizer(config.train.optimizer, model)
    scheduler = utils_train.get_scheduler(config.train.scheduler, optimizer)


    def train(it):
        model.train()
        optimizer.zero_grad()
        for _ in range(config.train.n_acc_batch):
            batch = next(train_iterator).to(args.device)

            protein_noise = torch.randn_like(batch.protein_pos) * config.train.pos_noise_std
            protein_noise2 = torch.randn_like(batch.protein_pos2) * config.train.pos_noise_std

            gt_protein_pos = batch.protein_pos + protein_noise
            gt_protein_pos2 = batch.protein_pos2 + protein_noise2
            results = model.get_diffusion_loss(
                dpo_beta=config.train.dpo_beta,
                ref_model=ref_model,
                net_cond=net_cond,
                
                protein_pos=gt_protein_pos,
                protein_v=batch.protein_atom_feature.float(),
                batch_protein=batch.protein_element_batch,

                ligand_pos=batch.ligand_pos,
                ligand_v=batch.ligand_atom_feature_full,
                batch_ligand=batch.ligand_element_batch,

                protein_pos2=gt_protein_pos2,
                protein_v2=batch.protein_atom_feature2.float(),
                batch_protein2=batch.protein_element2_batch,

                ligand_pos2=batch.ligand_pos2,
                ligand_v2=batch.ligand_atom_feature_full2,
                batch_ligand2=batch.ligand_element2_batch,

                reward=-batch.affinity,
                reward2=-batch.affinity2
            )
            loss, loss_pos, loss_v = results['loss'], results['loss_pos'], results['loss_v']
            loss = loss / config.train.n_acc_batch
            loss.backward()
        orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
        optimizer.step()

        if it % args.train_report_iter == 0:
            logger.info(
                '[Train] Iter %d | Loss %.6f (pos %.6f | v %.6f) | Lr: %.6f | Grad Norm: %.6f' % (
                    it, loss, loss_pos, loss_v, optimizer.param_groups[0]['lr'], orig_grad_norm
                )
            )
            for k, v in results.items():
                if torch.is_tensor(v) and v.squeeze().ndim == 0:
                    writer.add_scalar(f'train/{k}', v, it)
            writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it)
            writer.add_scalar('train/grad', orig_grad_norm, it)
            writer.flush()


    def validate(it):
        # fix time steps
        sum_loss, sum_loss_pos, sum_loss_v, sum_n = 0, 0, 0, 0
        sum_loss_bond, sum_loss_non_bond = 0, 0
        all_pred_v, all_true_v = [], []
        all_pred_bond_type, all_gt_bond_type = [], []
        with torch.no_grad():
            model.eval()
            for batch in tqdm(val_loader, desc='Validate'):
                batch = batch.to(args.device)
                batch_size = batch.num_graphs
                t_loss, t_loss_pos, t_loss_v = [], [], []
                for t in np.linspace(0, model.num_timesteps - 1, 10).astype(int):
                    time_step = torch.tensor([t] * batch_size).to(args.device)
                    results = model.get_diffusion_loss(
                        dpo_beta=config.train.dpo_beta,
                        ref_model=ref_model,
                        net_cond=net_cond,
                        
                        protein_pos=gt_protein_pos,
                        protein_v=batch.protein_atom_feature.float(),
                        batch_protein=batch.protein_element_batch,

                        ligand_pos=batch.ligand_pos,
                        ligand_v=batch.ligand_atom_feature_full,
                        batch_ligand=batch.ligand_element_batch,

                        protein_pos2=gt_protein_pos2,
                        protein_v2=batch.protein_atom_feature2.float(),
                        batch_protein2=batch.protein_element2_batch,

                        ligand_pos2=batch.ligand_pos2,
                        ligand_v2=batch.ligand_atom_feature_full2,
                        batch_ligand2=batch.ligand_element2_batch,

                        reward=-batch.affinity,
                        reward2=-batch.affinity2
                    )
                    loss, loss_pos, loss_v = results['loss'], results['loss_pos'], results['loss_v']

                    sum_loss += float(loss) * batch_size
                    sum_loss_pos += float(loss_pos) * batch_size
                    sum_loss_v += float(loss_v) * batch_size
                    sum_n += batch_size
                    all_pred_v.append(results['ligand_v_recon'].detach().cpu().numpy())
                    all_true_v.append(batch.ligand_atom_feature_full.detach().cpu().numpy())

        avg_loss = sum_loss / sum_n
        avg_loss_pos = sum_loss_pos / sum_n
        avg_loss_v = sum_loss_v / sum_n
        atom_auroc = get_auroc(np.concatenate(all_true_v), np.concatenate(all_pred_v, axis=0),
                               feat_mode=config.data.transform.ligand_atom_mode)

        if config.train.scheduler.type == 'plateau':
            scheduler.step(avg_loss)
        elif config.train.scheduler.type == 'warmup_plateau':
            scheduler.step_ReduceLROnPlateau(avg_loss)
        else:
            scheduler.step()

        logger.info(
            '[Validate] Iter %05d | Loss %.6f | Loss pos %.6f | Loss v %.6f e-3 | Avg atom auroc %.6f' % (
                it, avg_loss, avg_loss_pos, avg_loss_v * 1000, atom_auroc
            )
        )
        writer.add_scalar('val/loss', avg_loss, it)
        writer.add_scalar('val/loss_pos', avg_loss_pos, it)
        writer.add_scalar('val/loss_v', avg_loss_v, it)
        writer.flush()
        return avg_loss


    try:
        best_loss, best_iter = None, None
        for it in tqdm(range(1, config.train.max_iters + 1), position=0, leave=True, desc='Train Iter'):
            # with torch.autograd.detect_anomaly():
            train(it)
            if it % config.train.val_freq == 0 or it == config.train.max_iters:
                val_loss = validate(it)
                if best_loss is None or val_loss < best_loss:
                    logger.info(f'[Validate] Best val loss achieved: {val_loss:.6f}')
                    best_loss, best_iter = val_loss, it
                    ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
                    torch.save({
                        'config': config,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'iteration': it,
                    }, ckpt_path)
                else:
                    logger.info(f'[Validate] Val loss is not improved. '
                                f'Best val loss: {best_loss:.6f} at iter {best_iter}')
    except KeyboardInterrupt:
        logger.info('Terminating...')
