import torch 
import wandb
import argparse

import numpy as np
import pandas as pd

from time import time
# from tqdm import tqdm
from config import settings
from sensitivity.mamba_clf import MambaClf
from sensitivity.other_utils import sample_bstr, flip_i, Vocab, sents_to_idx, allbins


parser = argparse.ArgumentParser(description='')
parser.add_argument('-gpu', type=int, default=0, help='Specify the gpu to use')
parser.add_argument('-sample_size', type=int, default=2000, help='Sample size for estimation')

parser.add_argument('-len', type=int, default=10, help='Length of inputs')
parser.add_argument('-init_type', type=str, default='uniform', help='uniform / gaussian / xavier_norm / xavier_uniform')
parser.add_argument('-w_init', type=float, default=10.0, help='Range of weight values [-B,+B]')
parser.add_argument('-std_init', type=float, default=10.0, help='weight init st deviation ')
parser.add_argument('-i_init', type=float, default=1.0, help='Range of input values [-B,+B]')
parser.add_argument('-d_channels', type=int, default=64, help='Mamba d_channels')
parser.add_argument('-d_state', type=int, default=64, help='Mamba d_state')
parser.add_argument('-num_layers', type=int, default=1, help='Layers')
parser.add_argument('-trials', type=int, default=1000, help='Number of trials or randomly initialized models')

parser.add_argument('-wandb', dest='wandb', action='store_true', help='Store wandb')
parser.set_defaults(wandb=False)

args = parser.parse_args()

init_type, num_layers, num_channels = args.init_type, args.num_layers, args.d_channels
proj_name = 'RandomSensi-' + init_type + '_' + str(num_layers) + '_' + str(num_channels)

out_file= 'out/sensi_' + init_type + 'exp_A.csv'
out_grouped= 'out/sensi_'+ init_type + 'group_A.csv'

run_name = 'h_{}layer_{}t_{}_len_{}'.format(args.d_channels, args.num_layers, args.trials, args.len)

if args.gpu >= 0:
    device =torch.device('cuda:{}'.format(args.gpu))
else:
    device = torch.device('cpu')

########## Wandb Init ###############

if args.wandb:
    wandb.login(key=settings.WANDB_API_KEY)
    metrics = dict(
        num_layers= args.num_layers,
        d_channels = args.d_channels,
        length= args.len,
        sample_size= args.sample_size,
        trials = args.trials,
        weight_range= args.w_init
    )

    wandb.init(
        project= proj_name,
        entity=settings.WANDB_TEAM,
        name= run_name,
        config= metrics
    )

########################################

def test_model(str_list, clf):
    clf.eval()
    str_list = [x + 's' for x in str_list]
    wlens = torch.tensor([len(x) for x in str_list])
    batch_ids = sents_to_idx(voc, str_list)
    str_ids = batch_ids[:,:-1].transpose(0,1)
    str_ids = str_ids.to(device)
    wlens= wlens.to(device)
    
    with torch.no_grad():
        output = clf(str_ids, wlens)
        preds = output.cpu().numpy()
        preds= preds.reshape(-1)
        preds = np.array(preds>=0.5, dtype=int)
    return preds


### Model Def #####
print('Loading Model')
inp_len = args.len
d_channels = args.d_channels
d_state = args.d_state
num_layers= args.num_layers
weight_init= args.w_init
std_init= args.std_init
inp_init = args.i_init

clf = MambaClf(n_tokens=4, d_channels=d_channels, d_state=d_state, layers=args.num_layers, d_conv=4, d_expand=2, noutputs=1)
clf = clf.to(device)
voc = Vocab()

### Sensitivity Experiment ####

sample_size = args.sample_size
check_idx = list(range(inp_len))


### Pandas Data setup

cols = ['run_name', 'sample_size', 'len', 'd_channels', 'num_layers', 'is_bias', 'weight_init', 'trial', 'sensi_count', 'avg_sensi']
df= pd.DataFrame(data=[], columns=cols)
fixed_data= [run_name, sample_size, inp_len, d_channels, num_layers, False, weight_init]

cols_group = ['run_name', 'sample_size', 'num_trials', 'len', 'd_channels', 'num_layers', 'is_bias', 'weight_init', 'avg_sensi', 'std_sensi']
big_df = pd.DataFrame(data=[], columns=cols_group)
group_data= [run_name, sample_size, args.trials, inp_len, d_channels, num_layers, False, weight_init]


print('Generating Samples')
if args.len > 12:
    sample_points  =sample_bstr(sample_size, length= inp_len)
else:
    sample_points = allbins(args.len)
    sample_size = len(sample_points)

avg_list= []

for trial in range(args.trials):
    print('Running for trial : {}'.format(trial))
    start_time = time()
    if init_type == 'uniform':
        clf.init_weights(weight_init=weight_init)
    elif init_type == 'xavier_norm':
        clf.init_xavnormal()
    elif init_type == 'xavier_uniform':
        clf.init_xavuni()
    elif init_type == 'gaussian':
        clf.init_gauss_weights(std_init=std_init, inp_init=inp_init, dec_init=1.0)
    else:
        print(" incorrect initialisation type")
        exit(0)


    sensi_count = 0
    avg_sensi = 0.0

    sample_preds = []
    for j in range(sample_size):
        bstr= sample_points[j]
        pred_bstr= test_model([bstr], clf)
        sample_preds.append(pred_bstr[0])

    assert sample_size == len(sample_preds)

    for i in check_idx:
        flag = 0
        mis_counter = 0

        for j in range(sample_size):
            bstr = sample_points[j]
            fstr = flip_i(bstr, i)
            test_str = [fstr]

            preds = test_model(test_str, clf)
            pred_fstr= preds[0]
            pred_bstr = sample_preds[j]

            if pred_fstr != pred_bstr:
                if flag==0:
                    sensi_count+=1
                    flag =1
                mis_counter+=1
        
        ratio = mis_counter/sample_size
        avg_sensi += ratio

    avg_sensi = avg_sensi/inp_len
    avg_list.append(avg_sensi)

    time_taken = time() - start_time
    time_mins = int(time_taken/60)
    time_secs= time_taken%60

    print('Trial {} completed with {} samples...\nTime Taken: {} mins and {} secs'.format(trial, sample_size, time_mins, time_secs))

    print('\n-------------Trial {} Done--------------'.format(trial))
    print(sensi_count)
    print(avg_sensi)

    if args.wandb:
        wandb.log({
            'Avg_sensi': avg_sensi,
            'Sensi_count': sensi_count,
        }, step= trial)
    
    new_data= fixed_data.copy() + [trial, sensi_count, avg_sensi]
    df.loc[trial] = new_data

avg_list = np.array(avg_list)
mean_sensi = avg_list.mean()
std_sensi = avg_list.std()
print('Mean Sensi: {}'.format(mean_sensi))


try:
    out_df = pd.read_csv(out_file)
    new_df = pd.concat([out_df, df], ignore_index=True)
    new_df.to_csv(out_file, index=False)

except:
    df.to_csv(out_file, index=False)

big_df.loc[0] = group_data + [mean_sensi, std_sensi]

try:
    out_df = pd.read_csv(out_grouped)
    new_df = pd.concat([out_df, big_df], ignore_index=True)
    new_df.to_csv(out_grouped, index=False)

except:
    big_df.to_csv(out_grouped, index=False)


if args.wandb:
    wandb.log({
        'mean_sensi': mean_sensi,
        'mean+std' : mean_sensi+std_sensi
    })

    data = [[s] for s in avg_list]
    table = wandb.Table(data=data, columns=["sensitivity"])
    wandb.log({'histogram': wandb.plot.histogram(table, "sensitivity",
                            title="Histogram")})
