import torch
import yaml
from lightning import seed_everything
from typing import *

from dataset.evidence import *
from script.counterfactual.config_dataset import EvidenceDatasetConfig
from script.counterfactual.config_evidence import EvidenceConfig
from script.counterfactual.config_model import ModelConfig
from script.counterfactual.config_scm import SCMConfig
from script.config_dataloader import DefaultDataloaderConfig
from script.config_trainer import DefaultTrainerConfig
from script.counterfactual.parallel_test import parallel_test


def test_from_config(scm_config: SCMConfig,
                     evidence_config: EvidenceConfig,
                     test_dataset_config: EvidenceDatasetConfig,
                     test_dataloader_config: DefaultDataloaderConfig,
                     model_config: ModelConfig,
                     trainer_config: DefaultTrainerConfig,
                     seed: int = 1,
                     use_cpu: bool = False,
                     callback: Callable = None,
                     ):
    seed_everything(seed)
    torch.set_float32_matmul_precision('medium')

    scm = scm_config.get_scm()
    test_dataset = test_dataset_config.get_dataset(scm, evidence_config)

    # Test
    if not use_cpu:
        # Loader, Model and Trainer
        test_dataloader = test_dataloader_config.get_datasetloader(
            test_dataset, collate_fn=evidence_collate_fn,
        )
        model = model_config.get_model(
            scm, evidence_config, test_dataset.mean, test_dataset.std
        )
        trainer = trainer_config.get_trainer()

        # Test config
        fit_kwargs = {
            'model': model,
            'train_dataloaders': test_dataloader,
            'val_dataloaders': None,
        }
        if model_config.checkpoint_path is not None:
            fit_kwargs['ckpt_path'] = model_config.checkpoint_path
        trainer.test(**fit_kwargs)
        test_buffer = trainer.test_buffer
    else:
        # Loader and Model
        test_dataloader_config.num_workers = 0
        test_dataloader = test_dataloader_config.get_datasetloader(
            test_dataset, collate_fn=evidence_collate_fn,
        )
        model = model_config.load_model(
            scm, evidence_config,
            test_dataset.mean, test_dataset.std,
            checkpoint_path=model_config.checkpoint_path,
            map_location='cpu',
        )
        test_buffer = parallel_test(model, test_dataloader)

    # Callback
    if callback is not None:
        callback(test_buffer)

    return {
        'scm': scm,
        'evidence_type': evidence_config.get_evidence_type(),
        'evidence_kwargs': evidence_config.get_evidence_kwargs(),
        'test_dataset': test_dataset,
        'model': model,
    }


def test_from_config_string(config_str: str,
                            callback: Callable = None,
                            ):
    config_dict = yaml.safe_load(config_str)

    # SCM config
    assert 'scm' in config_dict
    scm_config = SCMConfig.deserialize(config_dict['scm'])

    # evidence config
    assert 'evidence' in config_dict
    evidence_config: EvidenceConfig = EvidenceConfig.deserialize(
        config_dict['evidence']
    )

    # test dataset config
    assert 'test_dataset' in config_dict
    test_dataset_config = EvidenceDatasetConfig.deserialize(
        config_dict['test_dataset']
    )
    assert 'test_dataloader' in config_dict
    test_dataloader_config = DefaultDataloaderConfig.deserialize(
        config_dict['test_dataloader']
    )

    # model config
    assert 'model' in config_dict
    model_config = ModelConfig.deserialize(config_dict['model'])

    # trainer config
    if 'trainer' in config_dict:
        trainer_config = DefaultTrainerConfig.deserialize(
            config_dict['trainer'])
    else:
        trainer_config = DefaultTrainerConfig.deserialize({})
    # no logging or checkpointint
    trainer_config.checkpoint_enable = False
    trainer_config.logger_enable = False

    # seed
    if 'seed' in config_dict:
        seed = config_dict['seed']
    else:
        seed = 0

    # parallel on cpu
    use_cpu = config_dict['use_cpu'] == 1 if 'use_cpu' in config_dict else False

    # Call config train
    return test_from_config(
        scm_config=scm_config,
        evidence_config=evidence_config,
        test_dataset_config=test_dataset_config,
        test_dataloader_config=test_dataloader_config,
        model_config=model_config,
        trainer_config=trainer_config,
        seed=seed,
        use_cpu=use_cpu,
        callback=callback,
    )


def test_from_config_file(config_path: str,
                          callback: Callable = None,
                          ):
    with open(config_path, mode='r', encoding='utf-8') as f:
        yaml_str = f.read()
    return test_from_config_string(yaml_str, callback)


"""
scm:
  scm_type: synthetic
  scm_kwargs:
    name: chain_nlin_3

evidence:
  evidence_type: context_masked
  evidence_kwargs:
    context_mode: 
    - e
    - t
    - w_e
    - w_t
    mask_mode:
    - fc
    - fc
    - fc
    - fc
  max_len_joint: 1

test_dataset:
  sampler:
    sampler_type: mcar_bernoulli
    sampler_kwargs:
      joint_number_low: 1
      joint_number_high: 1
      prob_intervened: 0.2
      prob_observed: 0.75
      prob_feature_observed: 1
  size: 16384

model:
  density_estimator_type: maf
  density_estimator_kwargs:
    transforms: 5
    hidden_features:
      - 64
      - 64
    reduce: attn
  base_distribution_type: gaussian
  indicator_type: l1
  learning_rate: 0.001
  eval_sample_size: 1000
  eval_sample_every_n_epoch: 10

seed: 0
"""
