import os
import functools

from evaluation.experiments import separate_slakh_msdm
from main.data import ChunkedSupervisedDataset
from main.module_base import Model
from audio_diffusion_pytorch import KarrasSchedule
from main.separation import separate_dataset, MSDMSeparator
# from try_sep_DDIM import MSDMSeparator_DDIM
# from try_sep_DDIM_correction import MSDMSeparator_DDIM
from try_sep_opt_dirac import MSDMSeparator_DDIM
# from main.separation import MSDMSeparator
import torch


def differential_with_dirac(x, sigma, denoise_fn, mixture, source_id=0):
    num_sources = x.shape[1]
    x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :])
    score = (x - denoise_fn(x, sigma=sigma)) / sigma
    scores = [score[:, si] for si in range(num_sources)]
    ds = [s - score[:, source_id] for s in scores]
    return torch.stack(ds, dim=1)

def main():
    dataset_path = 'data/slakh2100/test'
    model_path = 'ckpts/avid-darkness-164/epoch=419-valid_loss=0.014.ckpt'
    output_dir = 'output/separations/opt_50epochs'
    source_id = 0
    s_churn = 20.0
    num_resamples = 5
    sigma_min = 1e-4
    sigma_max = 1.0
    num_steps = 150
    batch_size = 32
    resume = True
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dataset = ChunkedSupervisedDataset(
        audio_dir=dataset_path,
        stems=["bass", "drums", "guitar", "piano"],
        sample_rate=22050,
        max_chunk_size=262144,
        min_chunk_size=262144,
    )
    
    diff_fn = functools.partial(differential_with_dirac, source_id=source_id)
    model = Model.load_from_checkpoint(model_path).cuda()
    separator = MSDMSeparator_DDIM(
        model=model,
        stems=["bass", "drums", "guitar", "piano"],
        sigma_schedule=KarrasSchedule(sigma_min=sigma_min, sigma_max=sigma_max, rho=7.0),
        differential_fn=diff_fn,
        s_churn=s_churn,
        num_resamples=num_resamples,
        use_tqdm=True,
    )
    # separator = MSDMSeparator(
    #     model=model,
    #     stems=["bass", "drums", "guitar", "piano"],
    #     sigma_schedule=KarrasSchedule(sigma_min=sigma_min, sigma_max=sigma_max, rho=7.0),
    #     differential_fn=diff_fn,
    #     s_churn=s_churn,
    #     num_resamples=num_resamples,
    #     use_tqdm=True,
    # )
    separate_dataset(
        dataset=dataset,
        separator=separator,
        save_path=output_dir,
        num_steps=num_steps,
        batch_size=batch_size,
        resume=resume
    )





if __name__ == '__main__':
    main()