# This code is adapted from the repository: https://github.com/facebookresearch/three_bricks
# The original code is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License.

"""
python main_reed_wm.py \
    --model_name "Llama-2-7b-chat-hf" \
    --dataset_path2 "data/other_maryland.jsonl" \
    --method_detect maryland \
    --degree 0 \
    --prop 1 \
    --nsamples 1000 \
    --batch_size 16 \
    --output_dir output/
"""

import random
import argparse
import os
import time
import json

import tqdm
import pandas as pd
import numpy as np

import torch
from utils import HiddenPrints 
with HiddenPrints():
    from peft import PeftModel    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from sentence_transformers import SentenceTransformer

from wm import (WmGenerator, OpenaiGenerator, MarylandGenerator, StanfordGenerator,
                WmDetector,  OpenaiDetector, MarylandDetector, StanfordDetector, 
                MarylandDetectorZ, OpenaiDetectorZ)
from wm.paths import model_names, adapters_names
import utils


def get_args_parser():
    parser = argparse.ArgumentParser('Args', add_help=False)

    # model parameters
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--adapters_name', type=str)

    # prompts parameters
    parser.add_argument('--dataset_path', type=str, default="")
    parser.add_argument('--dataset_path2', type=str, default=None)
    parser.add_argument('--degree', type=float, default=1)
    parser.add_argument('--prop', type=float, default=1)
    parser.add_argument('--fake_seed', type=int, default=0)

    # generation parameters
    parser.add_argument('--seq_length', type=int, default=256)

    # watermark parameters
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--seeding', type=str, default='hash', help='seeding method for rng key generation as introduced in https://github.com/jwkirchenbauer/lm-watermarking')
    parser.add_argument('--hash_key', type=int, default=35317, help='hash key for rng key generation')
    parser.add_argument('--method_detect', type=str, default='openai', help='Statistical test to detect watermark. Choose between: same (same as method), openai, openaiz, maryland, marylandz')
    parser.add_argument('--ngram', type=int, default=2, help='n-gram size for rng key generation')
    parser.add_argument('--topk', type=int, default=1, help='top-logits to look at')
    parser.add_argument('--gamma', type=float, default=0.25, help='gamma for maryland: proportion of greenlist tokens')
    parser.add_argument('--delta', type=float, default=2.0, help='delta for maryland: bias to add to greenlist tokens')
    parser.add_argument('--v1', type=int, default=0, help='wheter to use deduplication of watermarked windows + current token instead of just watermarked window')
    parser.add_argument('--no_dedup', type=int, default=0, help='SETTING TO 1 will make results false!')


    # expe parameters
    parser.add_argument('--nsamples', type=int, default=100000)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--do_eval', type=utils.bool_inst, default=True)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--split', type=int, default=None)
    parser.add_argument('--nsplits', type=int, default=None)

    # distributed parameters
    parser.add_argument('--ngpus', type=int, default=None)

    return parser


import json
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
    def __init__(self, filename, tokenizer, seed, seq_length=256, indicies_lines=None):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.texts = []
        self.indicies_lines = indicies_lines
        if (self.indicies_lines is None):
            print("carreful: indicies_lines is None")
        with open(filename, 'r') as file:
            lines = [json.loads(line) for line in file]
            random.Random(seed).shuffle(lines)
        for i, line in enumerate(lines):
            if (i>=self.indicies_lines[0] and i<self.indicies_lines[1]):
                data = json.loads(line)
                if "output" in data.keys():
                    text = data['output']
                elif 'result' in data.keys():
                    text = data['result']
                else:
                    text = data['text']
                tokens = tokenizer.tokenize(text)
                self.texts.extend(tokens)
                
    
    def __len__(self):
        return len(self.texts) // self.seq_length
    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = (idx + 1) * self.seq_length
        tokens = self.texts[start:end]
        return self.tokenizer.convert_tokens_to_ids(tokens)

class TwoTextDatasets(Dataset):
    """
    Dataset for two text files
    Used to perform radioactivity detection on a dataset made of two text files, one of which was used at training time and one of which was not.
    It simulates the degree of supervision in the paper.
    """
    def __init__(self, filename1,filename2,degree,tokenizer, seed, seq_length=256, split=None, prop=1):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.texts = []
        self.prop = prop
        with open(filename1, 'r') as file:
            lines1 = [json.loads(line) for line in file]
        nb_train = len(lines1)
        print("nb_trains", nb_train)
        lines1_prop = [lines1[i] for i in range(int(prop*nb_train))]
        print("nb_trains_after_prop", len(lines1_prop))
        nb_train = len(lines1_prop)
        if degree == 0:
            with open(filename2, 'r') as file:
                lines2 = [json.loads(line) for line in file]
                lines = lines2
        else:
            with open(filename2, 'r') as file:
                lines2 = [json.loads(line) for line in file]
                import copy
                lines2_bis = copy.copy(lines2)
                while len(lines2_bis)<(1/degree-1)*nb_train:
                    print("be careful, not enough validation samples. we use deduplication")
                    lines2_bis += lines2 
                assert(len(lines2_bis)>(1/degree-1)*nb_train)
                lines2 = lines2_bis[:int((1/degree-1)*nb_train)]
            lines = lines1_prop+lines2
        random.Random(seed).shuffle(lines)
        nb_lines = len(lines)
        print("nb_lines_after", nb_lines)
        if split is not None:
            (num_split, nb_split) = split
            left = nb_lines * num_split // nb_split 
            right = nb_lines * (num_split + 1) // nb_split if (num_split != nb_split - 1) else nb_lines
            print(f"Creating prompts from {left} to {right}")
        else:
            (left, right) = (0, nb_lines)
            print(f"Creating prompts from {left} to {right} (all prompts)")
        for i, line in enumerate(lines):
            if (i>=left and i<right):
                data = line
                if "output" in data.keys():
                    text = data['output']
                elif 'result' in data.keys():
                    text = data['result']
                else:
                    text = data['text']
                tokens = tokenizer.tokenize(text)
                self.texts.extend(tokens)
        print("double dataset initiated. Nsamples: ", len(self.texts))
        
                
    
    def __len__(self):
        return len(self.texts) // self.seq_length
    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = (idx + 1) * self.seq_length
        tokens = self.texts[start:end]
        return self.tokenizer.convert_tokens_to_ids(tokens)


def main(args):

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # build model
    # Llama-2-7b-chat-hf 
    model_name = args.model_name.lower()
    model_name = model_names[model_name] if model_name in model_names else model_name
    adapters_name = adapters_names[model_name] if model_name in adapters_names else None
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    args.ngpus = torch.cuda.device_count() if args.ngpus is None else args.ngpus
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        max_memory={i: '32000MB' for i in range(args.ngpus)}, # automatically handles the number of gpus
        offload_folder="offload",
    ).cuda()
    if args.adapters_name is None:
        adapters_name = adapters_names[model_name] if model_name in adapters_names else None
    else:
        adapters_name = args.adapters_name
    if adapters_name is not None:
        print(f"Loading adapter {adapters_name}")
        model = PeftModel.from_pretrained(model, adapters_name)
    model = model.cuda()
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    print(f"Using {args.ngpus}/{torch.cuda.device_count()} GPUs - {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated per GPU")

    # load wm detector
    assert (args.method_detect in ["openai", "maryland"])
    if args.method_detect == "openai":
        detector = OpenaiDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    elif args.method_detect == "openaiz":
        detector = OpenaiDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    elif args.method_detect == "maryland":
        detector = MarylandDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
    elif args.method_detect == "marylandz":
        detector = MarylandDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
    elif args.method_detect == "stanford":
        detector = StanfordDetector(tokenizer, args.ngram, args.seed, args.stanford_nruns)


    # do splits
    if args.dataset_path == "":
        assert args.degree == 0
        args.dataset_path = args.dataset_path2
    with open(args.dataset_path, "r") as f:
        nprompts = sum(1 for l in f)
    if args.split is not None:
        left = nprompts * args.split // args.nsplits 
        right = nprompts * (args.split + 1) // args.nsplits if (args.split != args.nsplits - 1) else nprompts
        print(f"Creating prompts from {left} to {right}")
    else:
        (left, right) = (0, nprompts)
        print(f"Creating prompts from {left} to {right} (all prompts)")
    
    # (re)start experiment
    os.makedirs(args.output_dir, exist_ok=True)
    start_point = 0 # if resuming, start from the last line of the file
    if os.path.exists(os.path.join(args.output_dir, f"results_kgrams.jsonl")):
        with open(os.path.join(args.output_dir, f"results_kgrams.jsonl"), "r") as f:
            for _ in f:
                start_point += 1
    print(f"Starting from {start_point}")
    import zlib
    if args.dataset_path2 is None:
        dataset = TextDataset(args.dataset_path, tokenizer, args.fake_seed, seq_length=args.seq_length, indicies_lines=(left, right))
    else:
        dataset = TwoTextDatasets(args.dataset_path,args.dataset_path2, args.degree, tokenizer, args.fake_seed, seq_length=args.seq_length, split=(args.split, args.nsplits) if args.split is not None else None, prop=args.prop)
    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    # generate
    all_times = []
#     pad_token_id = tokenizer.add_tokens([0])#tokenizer.add_tokens([model.config.pad_token_id if model.config.pad_token_id is not None else -1])
# # Now set the padding token to the new token
#     tokenizer.pad_token = 0#model.config.pad_token_id if model.config.pad_token_id is not None else -1
    
    from torch.nn import CrossEntropyLoss
    tokenizer.pad_token_id = 0
    loss_fct = CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction="none")
    num_iter = args.nsamples//args.batch_size
    dic = {}
    dic_r = {}
    with open(os.path.join(args.output_dir, f"results_kgrams.jsonl"), "a") as f:
        for i, x in enumerate(dataloader):
            if i %10 == 0:
                print(f"iteraton: {i}")
            if i >num_iter:
                break
            # generate chunk
            current_batch = {bs:set() for bs in range(args.batch_size)}
            with torch.inference_mode():
                x = torch.stack(x, 1).cuda()
                decoded_sentences = tokenizer.batch_decode(x, skip_special_tokens=True)
                fwd = model(x)
                logits = fwd[0][:,:-1,:]
                target = x[:, 1:]
                # outputs = torch.argmax(logits ,-1)
                values, outputs = torch.topk(logits,args.topk, dim=-1)
                # Define the value of k
                k = args.ngram
                # Loop over the range from 0 to 255-k
                for j in range(target.size()[-1]-k): # deduplication is done at the batch level
                    # Slice a from i to i+k and b at i+k, then concatenate along the last dimension. creater watermark window + current token made of watermark window from the input and current token predicted by the suspected model
                    slices = [torch.cat((target[:, j:j+k], outputs[:, j+k, top].unsqueeze(-1)), dim=-1) for top in range(args.topk)] 
                    slices = [torch.stack([slices[a][b] for a in range(args.topk)]).tolist() for b in range(len(slices[0]))]
                    # slices = [slice.tolist() for slice in slices]
                    for num_slice, slice in enumerate(slices):
                        l = len(slice)
                        for t, tuple in enumerate(slice):
                            # This is where dedup happens: we check that the watermark window is not part of the batch
                            if not str(tuple[:-1]) in current_batch[num_slice] or args.no_dedup: # if deduplication is off, we add all the tuples
                                tup = tuple if args.v1==0 else tuple[:-1] # We also either add the watermark window + current token to the set of seen tuples (or just the watermark window)
                                if t==l-1:
                                    current_batch[num_slice].add(str(tuple[:-1]))
                                if not str(tup) in dic.keys():
                                    seed = detector.get_seed_rng(tuple[:-1])
                                    detector.rng.manual_seed(seed)
                                    if args.method_detect == "openai":
                                        rs = torch.rand(detector.vocab_size, generator=detector.rng) # n
                                        rt = -(1 - rs).log()[tuple[-1]]
                                        r = rt.item()#detector.score_tok(tuple[:-1], tuple[-1])
                                    elif args.method_detect == "maryland":
                                        vocab_permutation = torch.randperm(detector.vocab_size, generator=detector.rng)
                                        greenlist = vocab_permutation[:int(detector.gamma * detector.vocab_size)] # gamma * n are in greenlist
                                        rt = 1 if tuple[-1] in greenlist else 0
                                        r = rt
                                    dic[str(tup)] = r
                                    dic_r[str(tup)] = rs[tuple[-1]].item() if args.method_detect == "openai" else r 
                mean_r = np.mean(list(dic_r.values()))
                nb_disinct = len(dic.values())
                pvalues = detector.get_pvalues_by_t(dic.values())
                mean_ = np.mean([pvalue if pvalue > 0 else 0 for pvalue in pvalues])
                log_pvalues = [np.log10(pvalue) if pvalue > 0 else -0.43 for pvalue in pvalues]
            # log
            f.write(json.dumps({
                "iteration": i, 
                'p_value': pvalues[-1] if len(pvalues)>0 else 0,
                "mean_r":mean_r,
                'log10_pvalue': log_pvalues[-1],
                "nb_disinct":nb_disinct,
                "nb":i*args.batch_size*args.seq_length
                }) + "\n")
            f.flush()




if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)
