import logging

from ...constants import BINARY, MULTICLASS, REGRESSION
from ...models.tabular_nn.tabular_nn_model import TabularNeuralNetModel
from ....metrics import mean_squared_error, log_loss
from ...models.rf.rf_model import RFModel
from .presets import get_preset_models_regression, get_preset_models_softclass

logger = logging.getLogger(__name__)


def get_preset_models_distillation(path, problem_type, objective_func, stopping_metric=None, num_classes=None,
                      hyperparameters={'GBM':{}, 'NN':{}, 'RF':{}}, hyperparameter_tune=False, distill_level=0, name_suffix='_d'):
    normalize_predprobs = False
    if problem_type == MULTICLASS:
        normalize_predprobs = True
        models = get_preset_models_softclass(path=path, num_classes=num_classes, hyperparameters=hyperparameters,
                                             hyperparameter_tune=hyperparameter_tune, name_suffix=name_suffix)
    elif problem_type == BINARY:  # convert to regression in distillation
        normalize_predprobs = True
        if 'NN' in hyperparameters:
            hyperparameters['NN'].update({'y_range': (0.0,1.0), 'y_range_extend': 0.0})

        objective_func = mean_squared_error
        stopping_metric = mean_squared_error  # or = None
        problem_type = REGRESSION

    if problem_type == REGRESSION or problem_type == BINARY:
        models = get_preset_models_regression(path=path, problem_type=problem_type, objective_func=objective_func, stopping_metric=stopping_metric, hyperparameters=hyperparameters, hyperparameter_tune=hyperparameter_tune, name_suffix=name_suffix)

    if normalize_predprobs:
        for model in models:
            model.normalize_predprobs = True

    print("Distilling with each of these student models: \n", [model.name for model in models])
    return models
