import torch
from models.gnns import GCN
from models.base_models import NCModel
from models.GEO import FusionModel


def get_model_dict(args, euc_g):
    device = args.device
    euclidean_teacher = GCN(euc_g,
                            euc_g.ndata["feat"].shape[1],
                            args.n_classes,
                            args.et_hidden_num,
                            args.et_layer_num
                            )
    euclidean_teacher = euclidean_teacher.to(device)

    hyperbolic_teacher = NCModel(args)
    hyperbolic_teacher = hyperbolic_teacher.to(device)

    student_model = GCN(euc_g,
                        euc_g.ndata["feat"].shape[1],
                        args.n_classes,
                        args.s_hidden_num,
                        args.s_layer_num
                        )
    student_model = student_model.to(device)
    student_model_optimizer = torch.optim.Adam(student_model.parameters(),
                                               lr=args.lr,
                                               weight_decay=args.weight_decay
                                               )

    EFN_model = FusionModel()

    model_dict = {}
    model_dict['et_model'] = {'model': euclidean_teacher}
    model_dict['ht_model'] = {'model': hyperbolic_teacher}
    model_dict['s_model'] = {'model': student_model, 'optimizer': student_model_optimizer}
    model_dict['GEO'] = {'model': EFN_model}
    return model_dict
