import torch
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

torch.manual_seed(0)
model_checkpoint = "mistralai/Mistral-7B-v0.1" 
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

peft_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"],
        spectral_top=True
    )

model = AutoModelForCausalLM.from_pretrained(
        model_checkpoint, 
        device_map = 'auto',
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
        )



model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
for n,p in model.named_parameters():
    print(n,p.shape)

data = load_dataset("gsm8k", "main", split="train").to_pandas()
data["text"] = data[["question", "answer"]].apply(lambda x: "question: " + x["question"] + " answer: " + x["answer"], axis=1)
data = Dataset.from_pandas(data)

def tokenize(sample):
    model_inps =  tokenizer(sample["text"], padding=True, truncation=True, max_length=512)
    return model_inps

tokenized_data = data.map(tokenize, batched=True, desc="Tokenizing data", remove_columns=data.column_names)

batch_size = 4
training_arguments = TrainingArguments(
    output_dir="mistralv1_spectral_r8_35e5_e05",
    save_strategy = "epoch",
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=1,
    learning_rate=3.5e-5,
    weight_decay=0.01,
    logging_steps=10,
    num_train_epochs=5,
    push_to_hub=False,
    seed = 0
)


trainer = Trainer(
    model=model,
    train_dataset=tokenized_data,
    args=training_arguments,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
trainer.train()

