import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from torch.utils.data import TensorDataset, DataLoader, Dataset
from collections import defaultdict
import gc
import tqdm


NUM_OF_GPUS = torch.cuda.device_count()
device = 'cuda' if torch.cuda.is_available() else "cpu"
print(device)

# the reference model
checkpoint = torch.load('/net/scratch/dpo/Step1_SFT/Step1_SFT_Antrophic_Pythia28/policy.pt')
ref_model = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-2.8b')
ref_model.load_state_dict(checkpoint['state'])
ref_model.eval()

# tokenizer for the reference model
tokenizer = AutoTokenizer.from_pretrained('/net/scratch/dpo/Step1_SFT/Step1_SFT_Antrophic_Pythia28',
                                          padding=True,padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

# the reward model
# reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
reward_name = '/net/scratch/dpo/reward-model'
rank_model, reward_tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name), AutoTokenizer.from_pretrained(reward_name, padding_side='right')

# Move all the models to gpus
if NUM_OF_GPUS > 1:
    print(f"Using {NUM_OF_GPUS} GPUs!")
    ref_model = nn.DataParallel(ref_model)
    rank_model = nn.DataParallel(rank_model)
ref_model.to(device)
rank_model.to(device)
print('Model loaded.')


# some arguments for model generation
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0,
              "do_sample": True, "pad_token_id": tokenizer.eos_token_id}



# build the data loader
class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, attention_masks):
        self.embeddings = embeddings
        self.attention_masks = attention_masks

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.attention_masks[idx]

def build_dataloader(queries, batch_size=4, sampler='sequential'):
    query_tensor = tokenizer(queries, return_tensors='pt', padding=True)
    query_embedding_tensor = query_tensor['input_ids']
    query_attention_tensor = query_tensor['attention_mask']
    data = EmbeddingDataset(query_embedding_tensor, query_attention_tensor)
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
    return dataloader


# best and worst sampler
def best_and_worst_of_n_sampler(query_list, batch_size, max_len, temperature=0.5, n_seq = [3,6,8], responses_best_and_worst=None):
    '''This is the function to sample the best and wrost sample'''
    # build the data loader
    dataloader = build_dataloader(query_list, batch_size=batch_size)
    print('Built the data loader!')
    # Initialize the response dictionary
    if not responses_best_and_worst:
        responses_best_and_worst = {}
        for k in n_seq:
            responses_best_and_worst[k] = defaultdict(lambda: defaultdict(list))
            
    # iterate within the dataset
    pbar = tqdm.tqdm(enumerate(dataloader), total=len(dataloader),desc='Generate data')
    for step, batch in pbar:
        emb_batch, mask_batch = batch
        curr_batch_size = len(emb_batch)

        # original queries in this batch
        query_batch_list = query_list[(step*batch_size):(step*batch_size+batch_size)]
        
        # build the dataset the same time for all n's and get the responses
        embs_batch = emb_batch.repeat((8, 1))
        masks_batch = mask_batch.repeat((8, 1))
        print(embs_batch.shape)
        
        queries_len = embs_batch.shape[1] # length of prompts

        # move tensors to gpu
        embs_batch = embs_batch.to(device)
        masks_batch = masks_batch.to(device)

        # generate the responses
        output = ref_model.module.generate(embs_batch, attention_mask=masks_batch,
                                           max_new_tokens=max_len, 
                                           temperature = temperature, **gen_kwargs).squeeze()[:,queries_len:]
        responses = tokenizer.batch_decode(output, skip_special_tokens=True)
        print('Responses done.')
        
        # get the reward and the best of n samples for all n in n_seq together
        # batch tokenize and get rewards
        inputs = reward_tokenizer(query_batch_list*8, responses, return_tensors='pt', padding=True)
        scores_all = rank_model(**inputs).logits.cpu().detach() 
        scores_all = scores_all.reshape((8,curr_batch_size))
        print('Rewards done.')
        
        # get response pairs for each n
        for k in n_seq:
            scores_sub = scores_all[:k,:]
            min_indicies = scores_sub.argmin(dim=0)
            max_indicies = scores_sub.argmax(dim=0)
            # the response pairs
            # the former is the best response and the latter is the worst one
            response_bw_pairs = [[responses[i+curr_batch_size*max_indicies[i]], responses[i+curr_batch_size*min_indicies[i]]] for i in range(curr_batch_size)]
            
            # build the data as a dictionary
            for i, p in enumerate(query_batch_list):
                n_responses = len(responses_best_and_worst[k][p]['responses'])
                responses_best_and_worst[k][p]['pairs'].append((n_responses, n_responses + 1))
                responses_best_and_worst[k][p]['responses'].extend(response_bw_pairs[i])
                responses_best_and_worst[k][p]['sft_target'] = response_bw_pairs[i][1]
            
        torch.cuda.empty_cache()
        gc.collect()  
    return responses_best_and_worst


if __name__ == '__main__':
    from argparse import ArgumentParser
    import pandas as pd
    import json
    import gzip

    parser = ArgumentParser()
    parser.add_argument('--maxlen', type=int, default=30, help='The max number of new tokens allowed')
    parser.add_argument('--start', type=int, default=0, help='start index of the large dataset')
    parser.add_argument('--end', type=int, default=0, help='end index of the large dataset')
    parser.add_argument('--temperature', type=float, default=0.5, help='The sampling temperature for best of n')
    parser.add_argument('--batch_size', type=int, default=4, help='The sampling batch size')

    parser.add_argument('--prompt_set', type=str, default=None, help='The path of the prompt set')
    parser.add_argument('--save_data_path', type=str, default=None, help='The path to save the best and worst data')

    args = parser.parse_args()
    
    # load the dataset
    df = pd.read_csv(args.prompt_set)

    # specify the start and end point of this dataset, and build a query list
    prompts = df['Prompt'].tolist()
    prompts.sort(key=len)
    if args.end != 0:
        prompts = prompts[args.start:args.end]

    # run the experiment  
    best_worst_ds = best_and_worst_of_n_sampler(prompts, batch_size=args.batch_size, max_len=args.maxlen+30, temperature=args.temperature, n_seq=[3,6,8])

    # save the data     
    for pn in [3,6,8]:
        if args.start == 0 and args.end == 0:
            filename = f'/best-of-{pn}/bon_len_{args.maxlen}_temperature_{args.temperature}'
        else:
            filename = f'/best-of-{pn}/bon_len_{args.maxlen}_temperature_{args.temperature}_from_{args.start}_to_{args.end-1}'
        filename = filename.replace('.', '_')+'.jsonl.gz'

        # best and worst
        dict_to_save = dict(best_worst_ds[pn])
        full_filename = args.save_data_path+filename
        with gzip.open(full_filename,'wt',encoding='utf-8') as f:
            for key, value in dict_to_save.items():
                json_record = json.dumps({key: value})
                f.write(json_record + '\n')


    quit()