import torch
import json
import time
from tqdm import tqdm
from pprint import pprint
from torch.multiprocessing import Pool, Process, set_start_method, current_process, freeze_support
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
from csd import csd, csd_wide
from model import CSDraftingEncoderDecoderModel, CSDraftingMaGModel, get_mag_model_wide, \
CSDraftingDecoderModelKVCacheWide, CountedCSDraftingCachedEncoderDecoderModel, \
get_mag_model, DummyModel, CSDraftingDecoderModel, CountedCSDraftingDecoderModel, CountedCSDraftingCachedDecoderModel, CountedCSDraftingDecoderModelKVCache
from csd_datasets import get_test_set, format_vicuna_input



def start_from_config(config):
    draft_model_list = []
    for draft_name in config['draft_names']:
        draft_model_list.append(AutoModelForSeq2SeqLM.from_pretrained(draft_name))



def get_total_time_cost(draft_list, target_model, print_count=False):
    if print_count:
        print('target count: ')
        print('Target model call: {}'.format(target_model.forward_count))
        for i, draft_model in enumerate(draft_list):
            try:
                print('Draft model {} call: {}'.format(draft_model.name, draft_model.forward_count))
            except:
                pass
    total_cost = 0
    # for model in draft_list:
    #     if not isinstance(model, CSDraftingMaGModel):
    #         total_cost += model.calculate_time_cost()
    # total_cost += target_model.calculate_time_cost()
    return total_cost



def get_stats(draft_list, target_model, print_count=False):
    target_count = target_model.forward_count
    target_model.forward_count = 0
    draft_count = []
    for i, draft_model in enumerate(draft_list):
        try:
            draft_count.append(draft_model.forward_count)
            draft_model.forward_count = 0
        except:
            pass
    # for model in draft_list:
    #     if not isinstance(model, CSDraftingMaGModel):
    #         total_cost += model.calculate_time_cost()
    # total_cost += target_model.calculate_time_cost()
    return target_count, draft_count




def format_test(test, dataset_name):
    res = []
    for item in test:
        text_input = format_vicuna_input(item, dataset_name)
        initial_input = tokenizer(text_input, truncation=True, padding=False, return_tensors="pt")['input_ids'].to(target_model.device)
        res.append(initial_input)
    return res



def recursive_parameter_construction(all_params_set):
    if len(all_params_set) == 0:
        return [[]]
    param1 = all_params_set[0]
    res = []
    prev = recursive_parameter_construction(all_params_set[1:])
    for p in param1:
        for remain in prev:
            res.append([p] + remain)
    return res



# The parameter here is slightly different from the one on the paper as 
# k_{12} on the paper is actually k[1, 1] + k[1, 2] here
# eqhylxx/full-vicuna-160m    double7/vicuna-68m
config = {
    'draft_names': ['double7/vicuna-68m'],
    'target_name': 'llama_7b',
    'is_decoder_only': True,
    'use_mag': True,
    'k_matrix': [[2, 4], [0, 2]],
    'lenience': 3,
    'candidate_number': 15,
    'dataset': 'sampled_mmlu',
    'counter_version': 'model_parameters',
    'sample': False
}

draft_list = []
target_model = None
tokenizer = None
dataset_name = config['dataset']



def aggregate_result(generation_result_for_all_params):
    total_time_cost = 0
    for time_cost in generation_result_for_all_params:
        # print(time_cost)
        total_time_cost += time_cost
    return total_time_cost / len(generation_result_for_all_params)


def aggregate_result_and_print(generation_result_for_all_params):
    total_time_cost = {}
    for key, time_cost in generation_result_for_all_params:
        if key not in total_time_cost:
            total_time_cost[key] = []
        total_time_cost[key].append(time_cost)
    res = {}
    for key, time_cost_list in total_time_cost.items():
        res[key] = sum(time_cost_list) / len(time_cost_list)
    return res


pbar = tqdm(total=100000)
_CACHE_DIR = './cache/'


import re
from datasets import load_dataset
from nltk.tokenize import word_tokenize
import torch
import json
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader


# validate_set = data_set['validation']

LLAMA_PATH = '/scratch/your_dir/llama/llama/'


LLAMA_HF_PATH = LLAMA_PATH + 'hf_7b_chat'
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaModel


device = torch.device('cuda:2')


# export MODEL_REPO=double7/vicuna-68m
# ./scripts/prepare.sh $MODEL_REPO


# export MODEL_REPO=lmsys/vicuna-13b-v1.5
# ./scripts/prepare.sh $MODEL_REPO


# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_HF_PATH)
# hf_model = LlamaForCausalLM.from_pretrained(LLAMA_HF_PATH)



tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-13b-v1.5')
target_hf_model = LlamaForCausalLM.from_pretrained('lmsys/vicuna-7b-v1.5')
target_hf_model.half()
target_hf_model.cuda(device)


target_hf_model.half()


# hf_model = hf_model.to(torch.float)
target_model = CSDraftingDecoderModelKVCacheWide(target_hf_model, sample=config['sample'], name='llama', vocab_size=32000, is_final_target_model=True)


draft_list = []
for draft_name in config['draft_names']:
    hf_model = AutoModelForCausalLM.from_pretrained(draft_name)
    if draft_name == 'Felladrin/Llama-68M-Chat-v1':
        folder = '60m_vicuna_13b_kd_t1/'
        load_dir = '/scratch/your_dir/checkpoints/' + folder
        load_path = load_dir + 'checkpoint_best.pt'
        hf_model.load_state_dict(torch.load(load_path)['model_state_dict'])
        hf_model.cuda(device)
        hf_model.eval()
        print('loaded')
    hf_model.eval()
    # model = CountedCSDraftingDecoderModel(hf_model, sample=config['sample'], name=draft_name, vocab_size=32000, counter_version=config['counter_version'])
    model = CSDraftingDecoderModelKVCacheWide(hf_model, name=draft_name, vocab_size=32000, is_final_target_model=False)
    draft_list.append(model)

# target_model = CSDraftingDecoderModel(hf_model, name='llama', vocab_size=32000)
# tokenizer = AutoTokenizer.from_pretrained(draft_name)
dataset_name = config['dataset']

if config['use_mag']:
    # _BIGRAM_DIR =  '/home/your_dir/'
    topk_dir_llama = '/scratch/your_dir/wiki_bigram_naive_bayers_llama_topk.json'
    # if 't5' in config['target_name'].lower():
    #     bi_gram_path = _BIGRAM_DIR + 'wiki_bigram_naive_bayers_greedy_next_token.json'
    # else:
    #     bi_gram_path = _BIGRAM_DIR + 'wiki_bigram_naive_bayers_greedy_llama_next_token.json'
    mag_model = get_mag_model_wide(topk_dir_llama, config['is_decoder_only'])
    draft_list.append(mag_model)


for draft_model in draft_list:
    if draft_model.device != device:
        draft_model.cuda(device)

if target_model.device != device:
    target_model.cuda(device)



def true_work(package):
    global tokenizer, dataset_name, draft_list, target_model
    initial_input, k_matrix, candidate_num = package
    k_matrix = torch.tensor(k_matrix)
    # text_input = format_vicuna_input(item, dataset_name)
    # initial_input = tokenizer(text_input, truncation=True, padding=False, return_tensors="pt")['input_ids'].to(target_model.device)
    # key = str(item)
    # if isinstance(target_model, CountedCSDraftingCachedEncoderDecoderModel) \
    #     or isinstance(target_model, CountedCSDraftingEncoderDecoderModel):
    #     input_ids = torch.full((1, 1),
    #         target_model.first_decode_id,
    #         dtype=torch.long)
    # else:
    input_ids = initial_input
    # initial_input = input_ids
    input_ids = input_ids.to(target_model.device)
    result_ids = csd_wide(draft_list, target_model, input_ids, k_matrix, candidate_num)
    key = (k_matrix, candidate_num)
    time_cost = get_total_time_cost(draft_list, target_model, print_count = False)
    total_token_generated = result_ids.shape[1] - input_ids.shape[1]
    # print('Total token generated: {}'.format(total_token_generated))
    return None, result_ids, total_token_generated

test = dataset['test']
# test = test[:20]
test = format_test(test, dataset_name)



dataset_name = 'gsm8k'

full_gsm8k_dir = '/home/your_dir/gsm8k_dataset.json'
with open(full_gsm8k_dir) as f:
    dataset = json.load(f)

len(dataset['test'])


# test = get_test_set(dataset_name)

test = dataset['test']

sample_file = '/home/your_dir/sampled_mmlu.json'


res = []
with open(sample_file) as f:
    res = json.load(f)

test2 = res
dataset_name = 'sampled_mmlu'
# test = test[:20]
# test = format_test(test, dataset_name)

print(len(test))
print(len(test2))
candidate_number = config['candidate_number']
candidate_number = 5
k1 = 1
k2 = 1

dataset_names = ['gsm8k', 'sampled_mmlu']
result_in_str = []
result_simple = []
for candidate_number in range(2, 16, 4):
    # for k1 in range(1, 4):
    #     for k2 in range(1, 4):
            for k3 in range(1, 10):
                for to_go_test, dataset_name in zip([test, test2], dataset_names):
                    to_go_test = format_test(to_go_test, dataset_name)
                    to_go_test = to_go_test[:20]
                    k_matrix = config['k_matrix']
                    k_matrix = [[k1, k3], [0, k2]]
                    cur_params = [[k_matrix, candidate_number]]
                    # leniency = config['lenience']
                    to_go = recursive_parameter_construction([to_go_test, cur_params])
                    to_go = [[p[0]] + p[1] for p in to_go]
                    generation_result_for_all_params = []
                    start = time.time()
                    all_res = []
                    all_token_generated = 0
                    for package in tqdm(to_go):
                        # item = package[0]
                        # text_input = format_vicuna_input(item, dataset_name)
                        # initial_input = tokenizer(text_input, truncation=True, padding=False, return_tensors="pt")['input_ids'].to(target_model.device)
                        # result_ids = csd(draft_list, target_model, initial_input, input_ids, k_matrix, leniency=leniency)
                        text_input, result_ids, total_token_generated = true_work(package)
                        all_token_generated += total_token_generated
                        all_res.append((text_input, result_ids))
                    end = time.time()
                    if True:
                        print('Draft names:' + str(config['draft_names']))
                        print('Dataset: {}'.format(dataset_name))
                        # print('Mag: ' + str(_MAG_GENERATION))
                        print('K: {}'.format(str(k_matrix)))
                        print('Candidate number: {}'.format(candidate_number))
                        print('Time cost: {}'.format(end - start))
                        print('Total token generated: {}'.format(all_token_generated))
                        print('Token/s: {}'.format(all_token_generated / (end - start)))
                    result_simple.append((
                        all_token_generated / (end - start),
                        end - start,
                        k_matrix,
                        candidate_number,
                        dataset_name
                    ))
                    # Add the above print to the result_in_str
                    result_in_str.append('Draft names:' + str(config['draft_names']))
                    result_in_str.append('Dataset: {}'.format(dataset_name))
                    result_in_str.append('K: {}'.format(str(k_matrix)))
                    result_in_str.append('Candidate number: {}'.format(candidate_number))
                    result_in_str.append('Time cost: {}'.format(end - start))
                    result_in_str.append('Total token generated: {}'.format(all_token_generated))
                    result_in_str.append('Token/s: {}'.format(all_token_generated / (end - start)))




from pprint import pprint
pprint(result_in_str)

sorted_res = sorted(result_simple, key=lambda x: x[0], reverse=False)

pprint(sorted_res)

# Dataset: sampled_mmlu
# K: [[2, 4], [0, 2]]
# Candidate number: 2
# Time cost: 175.46371579170227
# Total token generated: 3445
# Token/s: 19.63368884818131





#  (59.5462779087672, 57.165621757507324, [[1, 4], [0, 2]], 5, 'sampled_mmlu'),
#  (59.74503375561485, 64.84220957756042, [[2, 4], [0, 1]], 13, 'gsm8k'),
#  (59.82762663798348, 64.7025499343872, [[2, 4], [0, 1]], 12, 'gsm8k'),
#  (59.853279067140804, 64.72494173049927, [[1, 4], [0, 3]], 15, 'gsm8k'),
#  (59.89597388200442, 64.66210913658142, [[2, 4], [0, 1]], 14, 'gsm8k'),
#  (60.09207592521725, 56.71296834945679, [[1, 4], [0, 1]], 5, 'sampled_mmlu'),
#  (60.16382931850191, 64.22463536262512, [[1, 4], [0, 2]], 13, 'gsm8k'),
#  (60.27580842006757, 64.10532021522522, [[1, 4], [0, 1]], 4, 'gsm8k'),
#  (60.293734843571244, 64.25211524963379, [[1, 4], [0, 3]], 14, 'gsm8k'),
#  (60.354235747433144, 64.05515623092651, [[1, 4], [0, 1]], 5, 'gsm8k'),
#  (60.408905215607, 63.8813099861145, [[1, 4], [0, 1]], 12, 'gsm8k'),
#  (60.49562470836264, 63.98809862136841, [[1, 4], [0, 3]], 12, 'gsm8k'),
#  (60.84034061225379, 63.51049256324768, [[1, 4], [0, 2]], 15, 'gsm8k'),
#  (60.86481729679431, 63.46852207183838, [[1, 4], [0, 2]], 12, 'gsm8k'),
#  (60.88788375317546, 63.37878346443176, [[1, 4], [0, 1]], 15, 'gsm8k'),
#  (60.92568321436202, 63.35587549209595, [[1, 4], [0, 2]], 5, 'gsm8k'),
#  (60.94829419328539, 63.31596398353577, [[1, 4], [0, 2]], 4, 'gsm8k'),
#  (61.161266565508654, 63.27534103393555, [[1, 4], [0, 3]], 10, 'gsm8k'),
#  (61.18152739624084, 63.10728359222412, [[1, 4], [0, 1]], 14, 'gsm8k'),
#  (61.3693441130961, 62.930442810058594, [[1, 4], [0, 2]], 14, 'gsm8k')]













#  (61.81937507796385, 62.43997120857239, [[1, 4], [0, 1]], 14, 'gsm8k'),
#  (62.17069123975242, 54.76857233047485, [[1, 2], [0, 1]], 6, 'sampled_mmlu'),
#  (62.56180242164191, 54.458149671554565, [[1, 2], [0, 1]], 10, 'sampled_mmlu'),
#  (62.71827329853568, 54.43389654159546, [[1, 3], [0, 1]], 10, 'sampled_mmlu'),
#  (63.10432248777297, 61.23193860054016, [[1, 3], [0, 1]], 6, 'gsm8k'),
#  (63.278978824704986, 60.652053356170654, [[1, 1], [0, 1]], 10, 'gsm8k'),
#  (63.41201154182254, 60.871748208999634, [[1, 4], [0, 1]], 10, 'gsm8k'),
#  (63.91781240006978, 59.99892449378967, [[1, 1], [0, 1]], 6, 'gsm8k'),
#  (64.47325670254494, 59.76121258735657, [[1, 2], [0, 1]], 6, 'gsm8k'),
#  (64.53083915900253, 59.83185791969299, [[1, 2], [0, 1]], 14, 'gsm8k'),
#  (65.14970553022161, 58.9565212726593, [[1, 1], [0, 1]], 14, 'gsm8k'),
#  (65.35293962016297, 59.079209327697754, [[1, 3], [0, 1]], 10, 'gsm8k'),
#  (65.71680835076067, 58.46297311782837, [[1, 2], [0, 1]], 10, 'gsm8k')]