from typing import Any
import csv

# import config as CFG
import glob, os, re
from transformers import AutoTokenizer, default_data_collator
from datasets import load_dataset, concatenate_datasets, DatasetDict, Dataset

from utility import interleave_map_style_datasets_batchwise
from collections import Counter
from train_tokenizer import train_tokenizer
from langs_keyword import langs_keywords


def _removing_repo_address(examples):
    clean = []
    for example in examples["text"]:
        clean.append(example.split("|", 1)[1].strip())
    return {"text": clean}


def _align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels


def _shifting_labels(examples):
    # entity_groups = [
    #         'O'
    #     'B-VAR',
    #     'I-VAR',
    #     'B-FUNC',
    #     'I-FUNC'
    #     ]
    new_labels = []
    for label in examples["tags"]:
        new_labels.append([3 if l == 2 else l for l in label])
    return {"tags": new_labels}


class CUDAizerDataset:

    def __init__(self, args) -> None:
        self.args = args

        # reading dataset files
        datasets = self._load_datasets_from_directory()

        self.args.logger.info("Loading Tokenizer")
        tokenizer_checkpoint = args.tokenizer_dir
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_checkpoint, model_max_length=self.args.chunk_size, trust_remote_code=True, add_prefix_space=True if self.args.train_mode=='aer' else False
        )

        if args.train_mode == "eval":
            self.args.logger.info("Tokenizing for DAE and BT")
            dataset = interleave_map_style_datasets_batchwise(
                [dataset for dataset in datasets], batch_size=self.args.batch_size
            )
            dataset["test"] = dataset["test"].map(
                self._tokenize_for_dae_bt_test_valid,
                batched=True,
                batch_size=self.args.tokenizer_batch_size,
                remove_columns=dataset["test"].column_names,
                num_proc=self.args.tokenizer_num_process,
            )
            dataset["valid"] = dataset["valid"].map(
                self._tokenize_for_dae_bt_test_valid,
                batched=True,
                batch_size=self.args.tokenizer_batch_size,
                remove_columns=dataset["valid"].column_names,
                num_proc=self.args.tokenizer_num_process,
            )
            self.dataset = dataset

    def __call__(self, split=None):
        if split is None:
            return self.dataset
        else:
            return self.dataset[split]

    def _load_datasets_from_directory(self):
        self.args.logger.info("Loading Dataset")
        valid_file_name = ".para.valid."
        test_file_name = ".para.test."
        dataset_file_format = 'tok'
        dataset_file_type = "text"

        # TODO: Update it to work with any file naming convection
        # Curretly works with the naming convection of lang.mono.split.dataset_format
        datasets = []
        for lang in self.args.langs:
            print(f"{lang}{valid_file_name}{dataset_file_format}")
            dataset = load_dataset(
                dataset_file_type,
                data_files={
                    "valid": os.path.join(
                        self.args.dataset_path,
                        f"{lang}{valid_file_name}{dataset_file_format}",
                    ),
                    "test": os.path.join(
                        self.args.dataset_path,
                        f"{lang}{test_file_name}{dataset_file_format}",
                    ),
                },
                keep_in_memory=self.args.keep_in_memory,
            )
            for split in dataset:
                dataset[split] = dataset[split].add_column(
                    "lang", [lang] * len(dataset[split])
                )
            datasets.append(dataset)

        return datasets

    def _group_text(self, examples):
        chunk_size = self.args.chunk_size
        # Concatinate all texts
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        # Compute length of concatenated texts
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # Dropping the last chunk if it's smaller thant the chunk_size
        if total_length >= chunk_size:
            total_length = (total_length // chunk_size) * chunk_size
        # Split by chunk of max_len
        result = {
            k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
            for k, t in concatenated_examples.items()
        }

        result["labels"] = result["input_ids"].copy()
        return result
    
    def _cuda_filter(self, example):
        # if 'text' in example.keys():
        #     for cuda_keyword in CUDAKW.cuda_keywords_strict:
        #         if cuda_keyword not in example['text']:
        #             return False
        # elif 'tokens' in example.keys():
        #     for cuda_keyword in CUDAKW.cuda_keywords_strict:
        #         if cuda_keyword not in example['tokens']:
        #             return False
        # return True
        for keyword in langs_keywords['cuda_keywords']:
            if 'text' in example.keys():
                if keyword in example['text']:
                    return True
            elif 'tokens' in example.keys():
                if keyword in example['tokens']:
                    return True
            return False
                
    def _stack_cuda_filter(self,example):
        # for cuda_keyword in CUDAKW.cuda_keywords_strict:
        #     if cuda_keyword not in example['text']:
        #         return False
        # return True
        for cuda_keyword in langs_keywords['cuda_keywords']:
            if cuda_keyword in example['text']:
                return True
        return False  

    def _tokenize_for_mlm(self, examples):
        result = self.tokenizer(examples["text"], truncation=True)
        if self.tokenizer.is_fast:
            result["word_ids"] = [
                result.word_ids(i) for i in range(len(result["input_ids"]))
            ]
        return result

    def _tokenize_and_align_labels(self, examples):
        tokenized_inputs = self.tokenizer(
            examples["tokens"],
            truncation=True,
            is_split_into_words=True,
            max_length=self.args.chunk_size,
        )
        all_labels = examples["tags"]
        new_labels = []
        for i, labels in enumerate(all_labels):
            word_ids = tokenized_inputs.word_ids(i)
            new_labels.append(_align_labels_with_tokens(labels, word_ids))

        tokenized_inputs["labels"] = new_labels
        return tokenized_inputs

    def _tokenize_for_dae_bt_train(self, examples):
        tokenized_inputs = self.tokenizer(
            examples["text"],
            truncation=True,
            max_length=self.args.chunk_size,
            padding=True,
        )
        tokenized_inputs["lang"] = self.tokenizer.convert_tokens_to_ids(
            examples["lang"]
        )
        return tokenized_inputs

    def _tokenize_for_dae_bt_test_valid(self, examples):
        # Source language
        tokenized_inputs = self.tokenizer(
            examples["text"],
            truncation=True,
            max_length=self.args.chunk_size,
            padding='max_length',
        )
        tokenized_inputs["lang"] = self.tokenizer.convert_tokens_to_ids(
            examples["lang"]
        )

        # Target language
        tokenized_inputs["labels"] = self.tokenizer(
            examples["labels"],
            truncation=True,
            max_length=self.args.chunk_size,
            padding='max_length',
        )["input_ids"]
        # tokenized_inputs['labels_lang'] = self.tokenizer.convert_tokens_to_ids(examples['labels_lang'])
        labels = tokenized_inputs["labels"]
        labels_with_ignore_index = []
        for labels_example in labels:
            labels_example = [
                label if label != self.tokenizer.pad_token_id else -100
                for label in labels_example
            ]
            labels_with_ignore_index.append(labels_example)
        tokenized_inputs["labels"] = labels_with_ignore_index
        return tokenized_inputs
    
    def comment_remover(self, batch):
        def replacer(match):
            s = match.group(0)
            if s.startswith('/'):
                return " " # note: a space and not an empty string
            else:
                return s
        pattern = re.compile(
            r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"|#[^\r\n]*(?:\\\r?\n[^\r\n]*)*',
            re.DOTALL | re.MULTILINE
        )
        result = [re.sub(pattern, replacer, code) for code in batch['content']]
        return {"text":result}
    

    def calculate_token_frequency(self):
        lang1_token_frequency = Counter()
        lang2_token_frequency = Counter()

        for example in self.dataset['train']:
            if example['lang'] == self.tokenizer.convert_tokens_to_ids(self.args.langs[0]):
                lang1_token_frequency.update(example['input_ids'])
            else:
                lang2_token_frequency.update(example['input_ids'])

        # We don't want to track the number of occurance of special tokens
        special_token_ids = list(self.tokenizer.added_tokens_decoder.keys())

        lang1_top_words = [self.tokenizer.decode(token_id, skip_special_tokens=True) for token_id, frequency in lang1_token_frequency.most_common(self.args.top_k_tokens) if token_id not in special_token_ids]
        lang2_top_words = [self.tokenizer.decode(token_id, skip_special_tokens=True) for token_id, frequency in lang2_token_frequency.most_common(self.args.top_k_tokens) if token_id not in special_token_ids]

        lang1_top_ids = [(token_id, frequency) for token_id, frequency in lang1_token_frequency.most_common(self.args.top_k_tokens) if token_id not in special_token_ids]
        lang2_top_ids = [(token_id, frequency) for token_id, frequency in lang2_token_frequency.most_common(self.args.top_k_tokens) if token_id not in special_token_ids]

        # Combine the lists into a list of tuples
        freq_word_data = list(zip(lang1_top_words, lang2_top_words))
        freq_token_id_data = list(zip(lang1_top_ids, lang2_top_ids))

        # Define the CSV file path
        csv_file_top_words = os.path.join(self.args.dataset_path, f'dataset_frequent_token_{self.args.langs[0]}_{self.args.langs[1]}.csv')
        csv_file_top_token_ids = os.path.join(self.args.dataset_path, f'dataset_frequent_token_ids_{self.args.langs[0]}_{self.args.langs[1]}.csv')

        # Write the data to the CSV file
        with open(csv_file_top_words, mode='w', newline='') as file:
            writer = csv.writer(file, delimiter='|')
            writer.writerow(self.args.langs)
            # Write data from the lists
            writer.writerows(freq_word_data)
        
         # Write the data to the CSV file
        with open(csv_file_top_token_ids, mode='w', newline='') as file:
            writer = csv.writer(file, delimiter='|')
            writer.writerow(self.args.langs)
            # Write data from the lists
            writer.writerows(freq_token_id_data)
        self.args.logger.info("Frequent tokens and token_ids are saved.")