import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import argparse
import numpy as np
import random
import os
import sys
import pickle
from glob import glob
import math

from ccc_model import OriginalCCC
from sklearn.metrics import accuracy_score, r2_score


def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_train_data(data_name, data_seed, batch_size, target_transform, data_dir, early_stopping=False):
    with open(f'{data_dir}/{data_name}/{data_seed}.pkl', 'rb') as f:
        data_dict = pickle.load(f)
    if not data_dict['dataset_config']['regression'] and target_transform:
        print('!! pass classification with target transform')
        sys.exit(0)
    else:
        tr_x = torch.tensor(data_dict['x_train'])
        tr_y = torch.tensor(data_dict['y_train' if not target_transform else 'y_train_transform'])
        if early_stopping:
            print('!!! split validation set !!!')
            tr_count = int(tr_x.shape[0]*0.8)
            va_x = tr_x[tr_count:]
            va_y = tr_y[tr_count:]
            tr_x = tr_x[:tr_count]
            tr_y = tr_y[:tr_count]
            va_dataset = TensorDataset(va_x, va_y)
            va_dataloader = DataLoader(va_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
        else:
            va_dataloader = None
        tr_dataset = TensorDataset(tr_x, tr_y)
        tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
        return tr_dataloader, va_dataloader, data_dict['col_cat_count'], data_dict['label_cat_count']

    
def load_valid_test_data(data_name, data_seed, batch_size, target_transform, data_dir, early_stopping=False):
    with open(f'{data_dir}/{data_name}/{data_seed}.pkl', 'rb') as f:
        data_dict = pickle.load(f)
        
    if early_stopping:
        tr_x = torch.tensor(data_dict['x_train'])
        tr_y = torch.tensor(data_dict['y_train' if not target_transform else 'y_train_transform'])
        tr_count = int(tr_x.shape[0]*0.8)
        heldout_va_x = tr_x[tr_count:]
        heldout_va_y = tr_y[tr_count:]
        heldout_va_dataset = TensorDataset(heldout_va_x, heldout_va_y)
        heldout_va_dataloader = DataLoader(heldout_va_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    else:
        heldout_va_dataloader = None

    valid_dataset = TensorDataset(
        torch.tensor(data_dict['x_val']),
        torch.tensor(data_dict['y_val' if not target_transform else 'y_val_transform'])
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
    )
    test_dataset = TensorDataset(
        torch.tensor(data_dict['x_test']),
        torch.tensor(data_dict['y_test' if not target_transform else 'y_test_transform'])
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
    )
    return heldout_va_dataloader, valid_dataloader, test_dataloader, data_dict['target_transformer'], data_dict['dataset_config']['regression']

    
def build_model(
    col_cat_count, 
    label_cat_count, 
    num_cond_per_column_scale=16,
    num_cond_per_subtree=4,
    num_subtree_per_condset=1,
    num_subtree_per_estimator=-1,
    dropout=0.0, 
    shuffle_condition=True,
    condition_shuffle_type='random',
    device=torch.device('cpu')
):    
    return OriginalCCC(
        col_cat_count, 
        label_cat_count, 
        num_cond_per_column=None,
        num_cond_per_column_scale=num_cond_per_column_scale, 
        num_cond_per_subtree=num_cond_per_subtree, 
        num_subtree_per_condset=num_subtree_per_condset, 
        num_subtree_per_estimator=num_subtree_per_estimator,
        train_num_estimator=100,
        test_num_estimator=100,
        subtree_hidden_dim=128,
        dropout=dropout, 
        shuffle_condition=shuffle_condition,
        condition_shuffle_type=condition_shuffle_type,
        device=device,
    ).to(device)


def load_model(model_path, device=torch.device('cpu')):
    model = torch.load(model_path, map_location=device)
    model.eval()
    return model


def train(model, optimizer, dataloader, max_epoch, save_per_epoch, save_dir):
    model.train()
    for epoch in range(max_epoch):
        for bidx, (x, y) in enumerate(dataloader):
            _, loss = model(x.to(model.device), y.to(model.device))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        if (epoch+1) % save_per_epoch == 0:
            print(f'!! save {epoch+1} ...')
            save_path = f'{save_dir}/e{epoch+1}.ckpt'
            torch.save(model, save_path)
            

def evaluation(model, dataloader, is_rgr, seed=0, transform_target=False, target_transformer=None):
    model.eval()
    preds = []
    ys = []
    with torch.no_grad():
        for x, y in dataloader:
            set_seed(seed)
            pred, loss = model(x.to(model.device), y.to(model.device))
            preds.append(pred.detach().cpu())
            ys.append(y)
    preds = torch.cat(preds, dim=0)
    ys = torch.cat(ys, dim=0)

    if is_rgr and transform_target:
        ys = target_transformer.inverse_transform(ys.unsqueeze(-1).numpy()).squeeze()
        preds = target_transformer.inverse_transform(preds.numpy()).squeeze()
    else:
        ys = ys.numpy().squeeze()
        preds = preds.numpy().squeeze()
        
    if is_rgr:
        perf = {'r2_score': r2_score(ys, preds)}
    else:
        perf = {'accuracy': accuracy_score(ys, preds.argmax(-1))}
    return perf


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--data_name', type=str, required=True)
    parser.add_argument('--data_seed', type=int, required=True)
    parser.add_argument('--target_transform', action='store_true')
    parser.add_argument('--torch_seed', type=int, default=0)
    
    parser.add_argument('--num_cond_per_column_scale', type=int, default=16)
    parser.add_argument('--num_cond_per_subtree', type=int, default=4)
    parser.add_argument('--num_subtree_per_condset', type=int, default=1)
    parser.add_argument('--num_subtree_per_estimator', type=int, default=-1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--no_shuffle_condition', action='store_true')
    parser.add_argument('--condition_shuffle_type', type=str, default='random')
    
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--save_per_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    
    parser.add_argument('--ckpt_dir', type=str, default='tabular_benchmark_ckpt')
    parser.add_argument('--data_dir', type=str, default='tabular_benchmark_data')
    parser.add_argument('--early_stopping', action='store_true')
    parser.add_argument('--cuda_device', type=int, default=0)
    
    opt = parser.parse_args()
    print(opt)
    
    save_dir = f'{opt.ckpt_dir}/{opt.model_name}/{opt.data_name}/{opt.data_seed}' 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    ### set seed ###
    np.random.seed(opt.data_seed)
    random.seed(opt.data_seed)
    set_seed(opt.torch_seed)
    device = torch.device(f'cuda:{opt.cuda_device}')
    
    ### load data ###
    tr_dataloader, va_dataloader, col_cat_count, label_cat_count = load_train_data(opt.data_name, opt.data_seed, opt.batch_size, opt.target_transform, opt.data_dir, early_stopping=opt.early_stopping)
    
    ### build model ###
    model = build_model(        
        col_cat_count, 
        label_cat_count, 
        num_cond_per_column_scale=opt.num_cond_per_column_scale,
        num_cond_per_subtree=opt.num_cond_per_subtree,
        num_subtree_per_condset=opt.num_subtree_per_condset,
        num_subtree_per_estimator=opt.num_subtree_per_estimator,
        dropout=opt.dropout, 
        shuffle_condition=not opt.no_shuffle_condition,
        condition_shuffle_type=opt.condition_shuffle_type,
        device=device
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    ### train model ###
    train(model, optimizer, tr_dataloader, opt.max_epoch, opt.save_per_epoch, save_dir)
        
    ### eval model ###
    _, valid_dataloader, test_dataloader, target_transformer, is_rgr = load_valid_test_data(opt.data_name, opt.data_seed, opt.batch_size, opt.target_transform, opt.data_dir, early_stopping=opt.early_stopping)
    model = load_model(f'{save_dir}/e{opt.max_epoch}.ckpt', device=device)
    
    val_perf = evaluation(model, valid_dataloader, is_rgr, seed=opt.torch_seed, transform_target=opt.target_transform, target_transformer=target_transformer)
    print(f'valid_score = {val_perf}')
    
    test_perf = evaluation(model, test_dataloader, is_rgr, seed=opt.torch_seed, transform_target=opt.target_transform, target_transformer=target_transformer)
    print(f'test_score = {test_perf}')
