import torch.utils
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchdyn
import torchdiffeq
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.utils import *


class NodeWrapper(nn.Module):
    def __init__(
            self, 
            model, 
            context=None,
        ) -> None:
        super().__init__()
        self.model = model
        if context is not None:
            context = context[None, :]
        self.context = context

    def forward(self, t, x, *args, **kwargs):
        if len(x.shape) == 1:
            x = x[None, :]
        t = t.repeat(x.shape[0])
        out = self.model(t, x, self.context)
        return out.squeeze(0)


class OptimalTransportConditionalFlowMatching():
    """
    PyTorch implementation of the Optimal Transport Conditional Flow Matching algorithm.
    Learns a ODE flow from source to target trajectories, given context.
    Wrapper class for the torchcfm library.
    (Tong, Alexander, et al. 2023)

    The input model should be a neural network that takes input (t, y, x) and outputs v:
        t: time
        y: target trajectory
        x: context
        v: vector field
    
    In case of unconditional flow matching, the context is not used,
    and the model should take input (t, y) and output v.
    For such a model, set conditional=False in the fit method.
    """
    def __init__(self, model, device='cpu'):
        self.device = device
        self.model = model.to(self.device)

    def __call__(self, source, context=None, **kwargs):
        return self.compute_target(source, context=context, **kwargs)
    
    def fit(
            self, 
            dataset: torch.utils.data.Dataset, 
            num_epochs: int=100, 
            batch_size: int=64, 
            learning_rate: float=1e-3, 
            sigma: float=0.1,
            conditional: bool=False,
            loss_scale=1,
            collate_fn=None,
            save_path=None,
            save_losses_path=None,
            save_interval=10,
        ):
        dataloader = DataLoader(
            dataset, batch_size=batch_size, 
            shuffle=True, collate_fn=collate_fn
        )
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        flow = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)

        # training loop
        self.model.train()
        losses = []
        pbar = tqdm(
            range(num_epochs * len(dataloader)), 
            desc=f'Epoch 0 / {num_epochs}, Loss: ---'
        )
        for epoch in range(num_epochs):
            for i, (x, y1, y0) in enumerate(dataloader):
                
                optimizer.zero_grad()
                x = x.to(self.device)
                y1 = y1.to(self.device)
                y0 = y0.to(self.device)

                if conditional:
                    t, yt, ut, _, x1 = flow.guided_sample_location_and_conditional_flow(y0, y1, y1=x)
                    x1 = x1.to(self.device)
                else:
                    t, yt, ut = flow.sample_location_and_conditional_flow(y0, y1)
                    x1 = None

                t = t.to(self.device)
                yt = yt.to(self.device)
                ut = ut.to(self.device)

                vt = self.model(t, yt, x1)
            
                loss = torch.mean((vt - ut) ** 2 * loss_scale)
                losses.append(loss.item())
                loss.backward()
                optimizer.step()
                
                pbar.update(1)
                pbar.set_description(f'Epoch {epoch + 1} / {num_epochs}, Loss: {loss.item():.5f}, vt^2 (mean): {(vt ** 2).mean().item():.3f}, vt^2 (max): {(vt ** 2).max().item():.3f}, ut^2 (mean): {(ut ** 2).mean().item():.3f}, ut^2 (max): {(ut ** 2).max().item():.3f}')

                if i % save_interval == 0:
                    if save_path is not None:
                        torch.save(self.model, save_path)
                    if save_losses_path is not None:
                        np.save(save_losses_path, np.array(losses))
            
        self.model.eval()
        return self.model, losses

    def _compute_trajectory(
            self, 
            source, 
            context=None, 
            use_torchdiffeq=True
        ):
        source = source.to(self.device)
        if context is not None:
            context = context.to(self.device)

        self.model.eval()
        if use_torchdiffeq:
            with torch.no_grad():
                model = lambda t, y: self.model(t, y, context)
                traj = torchdiffeq.odeint(
                    model,
                    source,
                    torch.linspace(0, 1, 2).to(self.device),
                    atol=1e-4, 
                    rtol=1e-4,
                    method='dopri5',
                )
        else:
            with torch.no_grad():
                node = NeuralODE(
                    NodeWrapper(self.model, context=context),
                    solver='dopri5', sensitivity='adjoint',
                    atol=1e-4, rtol=1e-4
                )
                traj = node.trajectory(
                    source, 
                    t_span=torch.linspace(0, 1, 2, device=self.device),
                )
        return traj
    
    def compute_target(
            self, 
            source, 
            context=None,
            use_torchdiffeq=True
        ):
        traj = self._compute_trajectory(
            source, context=context,
            use_torchdiffeq=use_torchdiffeq,
        )
        return traj[-1]
    
    def plot_trajectories(
            self, 
            source, 
            context=None,
            use_torchdiffeq=True
        ):
        traj = self._compute_trajectory(
            source, context=context,
            use_torchdiffeq=use_torchdiffeq,
        ).cpu().numpy()

        n = 2000
        plt.figure(figsize=(6, 6))
        plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
        plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
        plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
        plt.legend(["Source p0", "Flow", "Target p1"])
        plt.xticks([])
        plt.yticks([])
        plt.show()