import os
import sys
from argparse import ArgumentParser
from pprint import pprint

import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from ae import Autoencoder
from dataloader import get_custom_dataloader
from loguru import logger
from tqdm import tqdm

sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path
from project_utils.data_utils import load_dataset
from project_utils.logging_utils import init_logger_and_wandb


def train(
    cfg,
    train_data,
    val_data,
    val_buckets,
    dataset_name,
    val_freq=None,
    schedular_name="cycle",
):
    if val_freq is None:
        val_freq = cfg.num_epochs // 10
        if val_freq == 0:
            val_freq = cfg.num_epochs // 2
    num_vars = train_data.shape[1]
    num_epochs = cfg.num_epochs
    train_loader = get_custom_dataloader(train_data, batch_size=cfg.batch_size)
    val_loader = get_custom_dataloader(
        val_data, buckets=val_buckets, batch_size=cfg.batch_size
    )
    input_size = 2 * num_vars
    model = Autoencoder(
        input_size,
        encoding_size=cfg.encoding_size,
        hidden_layers=cfg.num_hidden_layers,
    ).to(cfg.device)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)
    if schedular_name == "plateau":
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.9, patience=10, verbose=True
        )
    elif schedular_name == "step":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    elif schedular_name == "cycle":
        max_lr = 0.005  # Maximum learning rate to be achieved in the cycle
        epochs = num_epochs
        steps_per_epoch = len(train_loader)  # Number of batches in one epoch
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=max_lr,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            anneal_strategy="linear",  # Can be 'cos' for cosine annealing
            div_factor=25,  # Factor to divide max_lr to get the lower boundary of the learning rate
            final_div_factor=1e4,  # Factor to reduce the learning rate at the end of the cycle
        )
    else:
        raise ValueError("Invalid scheduler name")
    # Outer tqdm bar for epochs
    epoch_pbar = tqdm(
        range(num_epochs), desc=f"Training {dataset_name} Autoencoder", position=0
    )
    for epoch in epoch_pbar:
        model.train()
        train_loss = 0.0

        # Inner tqdm bar for iterations within an epoch
        iteration_pbar = tqdm(
            train_loader,
            desc=f"Epoch {epoch + 1}/{num_epochs}",
            leave=False,
            position=1,
        )
        for data in iteration_pbar:
            data = data.to(cfg.device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, data)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # Update iteration tqdm bar
            iteration_pbar.set_postfix({"Batch Loss": loss.item()})

        scheduler.step(train_loss)
        train_loss = train_loss / len(train_loader)

        # Update epoch tqdm bar
        epoch_pbar.set_postfix({"Epoch Loss": train_loss})
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}")
        wandb.log({f"train_loss_{dataset_name}": train_loss})
        if epoch % val_freq == 0:
            # Validate the model every val_freq epochs
            model.eval()  # Set the model to evaluation mode
            val_loss = 0.0
            with torch.no_grad():  # No gradients needed for validation
                for data in tqdm(val_loader):
                    inputs = data.to(cfg.device)
                    inputs = inputs.view(inputs.size(0), -1)
                    outputs = model(inputs)
                    loss = criterion(outputs, inputs)
                    val_loss += loss.item()
            # Calculate the average loss over all of the validation data
            val_loss = val_loss / len(val_loader)
            logger.info(f"Validation Loss: {val_loss:.4f}")
            wandb.log({f"val_loss_{dataset_name}": val_loss})
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}")

    logger.info(f"Finished training {dataset_name} Autoencoder")
    return model


def test_model(model, test_loader, test_buckets, dataset_name):
    test_loader = get_custom_dataloader(
        test_data, buckets=test_buckets, batch_size=32, shuffle=False
    )
    criterion = nn.MSELoss()
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    logger.info(f"Testing {dataset_name} Autoencoder")
    with torch.no_grad():  # No gradients needed for testing
        for data in tqdm(test_loader):
            inputs = data.to(cfg.device)
            inputs = inputs.view(inputs.size(0), -1)
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            test_loss += loss.item()

    test_loss = test_loss / len(test_loader)
    logger.info(f"Test Loss: {test_loss:.4f}")
    wandb.log({f"test_loss_{dataset_name}": test_loss})
    logger.info(f"Finished testing {dataset_name} Autoencoder")


# After the training loop


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument(
        "--dataset_directory",
        type=str,
    )
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--num-hidden-layers", type=int, default=4)
    parser.add_argument("--encoding-size", type=int, default=512)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--num-epochs", type=int, default=5)
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--no-mps",
        action="store_true",
        default=False,
        help="disables macOS GPU training",
    )
    cfg = parser.parse_cfg()
    cfg.dataset_name = os.path.basename(cfg.dataset_directory)
    # remove basename from dataset_directory
    cfg.dataset_directory = os.path.dirname(cfg.dataset_directory)
    # Extra arguments
    cfg.no_extra_data = True
    # These are used to provide enough data for the model to train on
    cfg.query_prob = 0.4
    cfg.evidence_prob = 0.3
    cfg.unobserved_prob = 0.3
    cfg.task = "mpe"
    project_name = f"ae_training"
    device, use_cuda, use_mps = init_logger_and_wandb(project_name, cfg)
    cfg.device = device
    pprint(vars(cfg))
    (
        train_data,
        test_data,
        val_data,
        extra_data,
        test_buckets,
        val_buckets,
    ) = load_dataset(
        cfg.dataset_name,
        cfg,
        load_test_buckets=True,
    )
    num_vars = train_data.shape[1]
    model = train(cfg, train_data, val_data, val_buckets, cfg.dataset_name)
    test_model(model, test_data, test_buckets, cfg.dataset_name)
    model_path = cfg.dataset_directory.replace("datasets", "autoencoder_models")
    model_path = os.path.join(
        model_path, cfg.dataset_name, f"model_num_hidden_{cfg.num_hidden_layers}.pt"
    )
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    torch.save(model.state_dict(), model_path)
