import abc
from pathlib import Path
from typing import List, Optional, Callable, Mapping

import torch
import torchaudio
import tqdm
from math import sqrt, ceil

from audio_diffusion_pytorch.diffusion import Schedule
from torch.utils.data import DataLoader

from main.data import assert_is_audio, SeparationDataset
from main.module_base import Model

class Separator(torch.nn.Module, abc.ABC):
    def __init__(self):
        super().__init__()
        
    @abc.abstractmethod
    def separate(mixture, num_steps) -> Mapping[str, torch.Tensor]:
        ...
    
    
class MSDMSeparator_DDIM(Separator):
    def __init__(self, model: Model, stems: List[str], sigma_schedule: Schedule, **kwargs):
        super().__init__()
        self.model = model
        self.stems = stems
        self.sigma_schedule = sigma_schedule
        self.separation_kwargs = kwargs
    
    def separate(self, mixture: torch.Tensor, num_steps:int = 100):
        device = self.model.device
        mixture = mixture.to(device)
        batch_size, _, length_samples = mixture.shape
        
        y = separate_mixture_DDIM(
            mixture=mixture,
            denoise_fn=self.model.model.diffusion.denoise_fn,
            sigmas=self.sigma_schedule(num_steps, device),
            noises=torch.randn(batch_size, len(self.stems), length_samples).to(device),
            **self.separation_kwargs,
        )
        return {stem:y[:,i:i+1,:] for i,stem in enumerate(self.stems)}


def differential_with_dirac(x, sigma, denoise_fn, mixture, source_id=0):
    num_sources = x.shape[1]
    x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :])
    score = (x - denoise_fn(x, sigma=sigma)) / sigma
    scores = [score[:, si] for si in range(num_sources)]
    ds = [s - score[:, source_id] for s in scores]
    return torch.stack(ds, dim=1)

@torch.no_grad()
def separate_mixture_DDIM(
    mixture: torch.Tensor, 
    denoise_fn: Callable,
    sigmas: torch.Tensor,
    noises: Optional[torch.Tensor],
    differential_fn: Callable = differential_with_dirac,
    s_churn: float = 0.0, # > 0 to add randomness
    num_resamples: int = 1,
    use_tqdm: bool = False,
):      
    # Set initial noise
    x = sigmas[0] * noises # [batch_size, num-sources, sample-length]
    source_id = 0
    vis_wrapper  = tqdm.tqdm if use_tqdm else lambda x:x 
    # print(sigmas)
    x_0 = None
    momentum = None
    for i in vis_wrapper(range(len(sigmas) - 1)):
        # print(lamb * torch.eye(4).cuda() + torch.matmul(A.T, A))
        sigma, sigma_next = sigmas[i], sigmas[i+1]
        for r in range(2):
            gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1)
            sigma_hat = sigma * (gamma + 1)
            x = x + torch.randn_like(x) * (sigma_hat ** 2 - sigma ** 2) ** 0.5
            # DDIM updates
            if sigma > 0.9:
                num_sources = x.shape[1]
                x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :])
                score = (x - denoise_fn(x, sigma=sigma_hat)) / sigma_hat
                scores = [score[:, si] for si in range(num_sources)]
                ds = [s - score[:, source_id] for s in scores]
                ds = torch.stack(ds, dim=1)
                x = x + ds * (sigma_next - sigma_hat)
                if r < num_resamples - 1:
                    x = x + sqrt(sigma ** 2 - sigma_next ** 2) * torch.randn_like(x)
            else:
                # lamb_list = [0.1, 0.2, 0.3, 0.5, 0.8, 1.0, 1.5, 2, 2.5, 3]
                lr = 0.01
                for k in range(40):
                    # lamb = lamb_list[k]
                    A = torch.ones([1, 4]).cuda()
                    # B = torch.linalg.inv(torch.eye(4).cuda() + lamb * torch.matmul(A.T, A))
                    num_sources = x.shape[1]
                    if x_0 is None:
                    # x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :])
                        score = (x - denoise_fn(x, sigma=sigma_hat)) / sigma_hat # score_function
                        # scores = [score[:, si] for si in range(num_sources)]
                        # ds = [s - score[:, source_id] for s in scores]
                        # ds = torch.stack(ds, dim=1)
                        ds = score
                        x_0_pred = 1/(sqrt(1-sigma_hat**2)) * (x - sigma_hat * ds)
                        x_0 = x_0_pred # 初始化x_0
                    else:
                        x = sqrt(1-sigma_hat**2) * x_0 + sigma_hat * torch.randn_like(x_0)
                        score = (x - denoise_fn(x, sigma=sigma_hat)) / sigma_hat # score_function
                        ds = score
                        x_0_pred = 1/(sqrt(1-sigma_hat**2)) * (x - sigma_hat * ds)
                    lmbd=1
                    beta=0.9
                    diff = -(torch.sum(x_0, dim=1, keepdim=True) - mixture).repeat(1,4,1) + lmbd * (x_0_pred - x_0)
                    if momentum is None:
                        momentum=diff
                    else:
                        momentum = beta * momentum + (1-beta) * diff
                    x_0 += lr * momentum
                    loss_obs = torch.mean(torch.norm(torch.sum(x_0, dim=1, keepdim=True) - mixture, dim=2)).item()
                    loss_cons = torch.mean(torch.norm(x_0-x_0_pred, dim=[1,2])).item()
                    print('obs loss:{}, cons loss:{}, lr:{}'.format(loss_obs, loss_cons, lr))
                    # print(torch.matmul(A.T, mixture).shape)
                    # x_0_pred_star = torch.matmul(B, x_0_pred + lamb * torch.matmul(A.T, mixture))
                    # error = mixture - x_0_pred.sum(dim=1, keepdim=True)
                    # x_0_pred_star = x_0_pred + 0.25 * error.repeat(1,4,1)
                    # x_t_star = torch.sqrt(1-sigma_hat**2) * x_0_pred_star + sigma_hat * torch.randn_like(x_0_pred_star)
                    # # x_next = sqrt(1-sigma_next**2) * x_0_pred + sigma_next * ds
                    # # x_next = sqrt(1 - sigma_next**2) * x_0_pred_star + sigma_next * torch.randn_like(x_0_pred_star)
                    # score = (x_t_star - denoise_fn(x_t_star, sigma=sigma_hat)) / sigma_hat
                    # x_next = x_t_star + score * (sigma_next - sigma_hat)
                    # if k < 9:
                    #     x_next = sqrt(1 - sigma_hat**2) * x_0_pred_star + sigma_hat * torch.randn_like(x_0_pred_star)
                    # else:
                    #     x_next = sqrt(1 - sigma_next**2) * x_0_pred_star + sigma_next * torch.randn_like(x_0_pred_star)
                    # x = x_next
                # x = sqrt(1 - sigma_hat**2) * x_0 + sigma_hat * torch.randn_like(x_0_pred_star)
                # score = (x - denoise_fn(x, sigma=sigma_hat)) / sigma_hat
                # x = x + score * (sigma_next - sigma_hat)
            # Renoise if not last resample step
        # for r in range(num_resamples):
        #     # Inject randomness
        #     gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1)
        #     sigma_hat = sigma * (gamma + 1)
        #     x = x + torch.randn_like(x) * (sigma_hat ** 2 - sigma ** 2) ** 0.5

        #     # Compute conditioned derivative
        #     d = differential_fn(mixture=mixture, x=x, sigma=sigma_hat, denoise_fn=denoise_fn)

        #     # Update integral
        #     x = x + d * (sigma_next - sigma_hat)

        #     # Renoise if not last resample step
        #     if r < num_resamples - 1:
        #         x = x + sqrt(sigma ** 2 - sigma_next ** 2) * torch.randn_like(x)
    return x_0.cpu().detach()
    # return x.cpu().detach()