import copy

import torch
from loguru import logger
from models.nn_1 import NeuralNetworkOne
from models.nn_2 import NeuralNetworkTwo
from models.nn_3 import NeuralNetworkThree
from models.nn_4 import NeuralNetworkFour
from models.sin_embedding_layer import VectorToContinuousSin
from models.transformer import TransformerEncoder
from models.vae import VariationalAutoencoder
from utils_folder.optim import select_optimizer


def init_model_and_optimizer(
    cfg,
    library_pgm,
    device,
    fabric,
    num_ip_features,
    num_pgm_feature,
    num_variables,
    run_type="train",
    **kwargs,
):
    # Access the number of layers
    num_layers = cfg.student_layers
    # Access the layers from the predefined list using the indices
    model = init_model_and_embeddings(
        cfg,
        library_pgm,
        device,
        num_ip_features,
        num_pgm_feature,
        num_variables,
        kwargs,
        num_layers,
        is_teacher=False,
    )

    # Define the optimizer
    if run_type == "train":
        optimizer_name = cfg.train_optimizer
        lr = cfg.train_lr
        weight_decay = cfg.train_weight_decay
    else:
        optimizer_name = cfg.test_optimizer
        lr = cfg.test_lr
        weight_decay = cfg.test_weight_decay
    logger.info(f"We are in {run_type} mode")
    logger.info(f"Optimizer: {optimizer_name}")
    logger.info(f"Learning rate: {lr}")
    logger.info(f"Weight decay: {weight_decay}")
    optimizer = select_optimizer(model, optimizer_name, lr, weight_decay)
    model, optimizer = fabric.setup(model, optimizer)

    if cfg.dual_network:
        # These sets of experiments are for the novel dual network approach
        # student is a copy of the model
        if cfg.copy_student_to_teacher_dn:
            assert cfg.teacher_layers == cfg.student_layers, "Teacher and student layers should be the same if we want to use student as initialization for teacher"
        num_layers = cfg.teacher_layers
        
        optimizer_name = cfg.teacher_optimizer
        lr = cfg.teacher_lr
        weight_decay = cfg.teacher_weight_decay
        teacher_model = init_model_and_embeddings(
            cfg,
            library_pgm,
            device,
            num_ip_features,
            num_pgm_feature,
            num_variables,
            kwargs,
            num_layers,
            is_teacher=True,
        )
        teacher_optimizer = select_optimizer(
            teacher_model, optimizer_name, lr, weight_decay=0
        )
        teacher_model, teacher_optimizer = fabric.setup(
            teacher_model, teacher_optimizer
        )
        return model, optimizer, teacher_model, teacher_optimizer
    return model, optimizer


def init_model_and_embeddings(
    cfg,
    library_pgm,
    device,
    num_ip_features,
    num_pgm_feature,
    num_variables,
    kwargs,
    num_layers,
    is_teacher=False,
):
    predefined_layers = [128 * (2**i) for i in range(num_layers)]
    if cfg.pgm == "mn" and cfg.only_test_train_on_test_set:
        predefined_layers = [128 * (2**i) for i in range(3, 3+num_layers)]
        
        # predefined_layers = list(reversed(predefined_layers))
    if cfg.model in ["nn", "vae"]:
        hidden_size = predefined_layers
        ip_embedding_size, embedding_layer = init_embeddings(
            cfg, device, num_ip_features
        )
        # this is zero when no pgm features are used ["data"]
        pgm_embedding_size = num_pgm_feature
        if cfg.model == "nn":
            model = init_nn_model(
                cfg,
                library_pgm,
                device,
                num_variables,
                kwargs,
                hidden_size,
                ip_embedding_size,
                pgm_embedding_size,
                is_teacher=is_teacher,
            )
        elif cfg.model == "vae":
            model = VariationalAutoencoder(
                cfg,
                input_size=ip_embedding_size + pgm_embedding_size,
                hidden_sizes=hidden_size,
                num_variables=num_variables,
                latent_dim=cfg.latent_dim,
                library_pgm=library_pgm,
            ).to(device)
        if cfg.embedding_type in ["continuousEmbed", "continuousSin"]:
            # Add the embedding layer to the model - this is used for continuous embedding - stored in the model state_dict
            model.embedding_layer = embedding_layer
            model.ip_embedding_size = ip_embedding_size
    # Initialize the model model for training
    elif cfg.model == "transformer":
        if cfg.embedding_type == "continuousEmbed":
            NUM_STATES = cfg.num_states
            EMBEDDING_SIZE = cfg.embedding_size
            embedding_layer = torch.nn.Embedding(NUM_STATES, EMBEDDING_SIZE)
            ip_embedding_size = EMBEDDING_SIZE * num_ip_features
        else:
            raise ValueError("Invalid embedding type for transformer")
        model = TransformerEncoder(
            cfg=cfg,
            feature_size=num_ip_features,
            d_model=EMBEDDING_SIZE,
            nhead=NUM_STATES,
            num_layers=num_layers,
            library_pgm=library_pgm,
            dropout=cfg.dropout_rate,
        ).to(device)
        if cfg.embedding_type == "continuousEmbed":
            model.embedding_layer = embedding_layer
            model.ip_embedding_size = ip_embedding_size
    return model


def init_embeddings(cfg, device, num_ip_features):
    embedding_layer = None
    if cfg.embedding_type == "discrete":
        ip_embedding_size = num_ip_features * 2
    elif cfg.embedding_type == "continuousConst":
        ip_embedding_size = num_ip_features
    elif cfg.embedding_type == "continuousEmbed":
        # Initialize the Embedding layer
        NUM_STATES = cfg.num_states
        EMBEDDING_SIZE = cfg.embedding_size
        embedding_layer = torch.nn.Embedding(NUM_STATES, EMBEDDING_SIZE)
        ip_embedding_size = EMBEDDING_SIZE * num_ip_features
    elif cfg.embedding_type == "continuousSin":
        ip_embedding_size = cfg.sin_embedding_size
        embedding_layer = VectorToContinuousSin(
            num_ip_features * 2,
            cfg.sin_embedding_size,
            cfg.amplitude,
            cfg.phase,
            device,
        )
    else:
        raise ValueError("Invalid embedding type")
    return ip_embedding_size, embedding_layer


def init_nn_model(
    cfg,
    library_pgm,
    device,
    num_variables,
    kwargs,
    hidden_size,
    ip_embedding_size,
    pgm_embedding_size,
    is_teacher
):
    if cfg.model_type == "1":
        model = NeuralNetworkOne(
            cfg,
            ip_embedding_size + pgm_embedding_size,
            hidden_size,
            num_variables,
        ).to(device)
    elif cfg.model_type == "2":
        model = NeuralNetworkTwo(
            cfg,
            ip_embedding_size + pgm_embedding_size,
            hidden_size,
            num_variables,
            supervised_loss_lambda=cfg.supervised_loss_lambda,
            library_pgm=library_pgm,
            is_teacher=is_teacher,
        ).to(device)
    elif cfg.model_type == "3":
        model = NeuralNetworkThree(
            cfg,
            ip_embedding_size,
            pgm_embedding_size,
            hidden_size,
            num_variables,
            cfg.embedding_model_path,
            library_pgm=library_pgm,
        ).to(device)
    elif cfg.model_type == "4":
        model = NeuralNetworkFour(
            cfg,
            ip_embedding_size + pgm_embedding_size,
            hidden_size,
            num_variables=num_variables,
            library_pgm=library_pgm,
            **kwargs,
        ).to(device)
    else:
        raise ValueError("Invalid model type")
    return model
