import random
from tqdm import tqdm
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
from torch.utils.data import DataLoader

from utils import Gaussian
from models import MLP


def pretrain_gaussian(num_components=8, num_epochs=1000, learning_rate=0.01, batch_size=10):
    model = MixtureOfGaussians2DReparam(num_components=num_components)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    for step in range(num_epochs):
        optimizer.zero_grad()
        samples = model(batch_size)
        loss = torch.sum(samples ** 2)
        loss.backward()
        optimizer.step()
        
    return model


class MixtureOfGaussians2DReparam(nn.Module):
    def __init__(self, num_components=8):
        super(MixtureOfGaussians2DReparam, self).__init__()
        self.num_components = num_components
        # Parameters for means and standard deviations
        self.means = nn.Parameter(torch.randn(num_components, 2))  # Means for each Gaussian
        self.log_stds = nn.Parameter(torch.ones(num_components, 2))  # Log of std deviation for stability
        self.logits = nn.Parameter(torch.ones(num_components))  # Logits for categorical distribution

    def forward(self, num_samples):
        stds = torch.exp(self.log_stds)
        categorical = dists.Categorical(logits=self.logits)
        indices = categorical.sample((num_samples,))  # Sample component indices
        
        # Sample from each component using reparameterization
        eps = torch.randn(num_samples, 2).to(self.means.device)  # Standard normal noise
        selected_means = self.means[indices]  # Select means based on component indices
        selected_stds = stds[indices]  # Select stds based on component indices
        samples = selected_means + eps * selected_stds  # Reparameterization trick
        return samples
    
    def log_prob(self, x):
        stds = torch.exp(self.log_stds)
        categorical = dists.Categorical(logits=self.logits)
        components = dists.Independent(dists.Normal(self.means, stds), 1)
        mixture = dists.MixtureSameFamily(categorical, components)
        log_probs = mixture.log_prob(x)
        return log_probs
    
    def sample(self, num_samples=1):
        return self.forward(num_samples)


class RewardModel():
    def __init__(
        self, input_dim=2, hidden_dim=64, output_dim=1, 
        num_layers=3, context_dim=None, device='cpu'
        ) -> None:
        self.device = device
        self.model = MLP(
            input_dim=input_dim, output_dim=output_dim,
            hidden_dim=hidden_dim, num_layers=num_layers,
            context_dim=context_dim,
        ).to(self.device)
        
    def __call__(self, x):
        return self.model(x)
        
    def fit(self, dataset, num_epochs=100, learning_rate=1e-3, batch_size=64):
        self.model.train()
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        
        pbar = tqdm(range(num_epochs))
        for epoch in pbar:
            for x, y1, y0 in dataloader:
                r1 = self.model(y1.to(self.device))
                r0 = self.model(y0.to(self.device))
                logits = r1 - r0
                labels = torch.ones_like(logits).to(self.device)
                loss = criterion(logits, labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            pbar.set_description(f"loss {loss.item():0.3f}")
            
        self.model.eval()
        return self.model


class RL():
    def __init__(self, actor, critic, device='cpu'):
        self.actor = deepcopy(actor)
        self.critic = deepcopy(critic)
        self.reference = deepcopy(actor)
        self.device = device
    
    def train(self, num_epochs=200, batch_size=64, learning_rate=1e-3, beta=0.):
        self.actor.to(self.device)
        self.critic.model.to(self.device)
        self.reference.to(self.device)
        self.actor.train()
        optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)

        losses = []
        pbar = tqdm(range(num_epochs))
        for episode in pbar:
            action = self.actor(batch_size)
            ref = self.reference(batch_size)
            
            action_log_prob = self.actor.log_prob(action)
            reference_log_prob = self.reference.log_prob(ref)

            reward = self.critic(action)
            loss = -torch.mean(reward)
            kl = (action_log_prob - reference_log_prob).mean()
            loss = loss + beta * kl
            losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_description(f"loss {loss.item():0.3f}")
        
        self.actor.eval()
        return self.actor, losses
            
    def sample(self, num_samples=1):
        self.actor.eval()
        with torch.no_grad():
            samples = self.actor.sample(num_samples=num_samples)
        return samples


class DPO():
    def __init__(self, policy, reference, device='cpu'):
        self.device = device
        self.policy = deepcopy(policy)
        self.reference = deepcopy(reference)
            
    def fit(self, dataset, num_epochs=100, batch_size=64, learning_rate=1e-3, beta=0.01, label_smoothing=0.):
        self.policy.to(self.device)
        self.reference.to(self.device)
        
        optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        losses = []
        pbar = tqdm(range(num_epochs * len(dataloader)))
        for epoch in range(num_epochs):
            for x, y1, y0 in dataloader:
                
                y1 = y1.to(self.device)
                y0 = y0.to(self.device)
                
                pi_logratios = self.policy.log_prob(y1) - self.policy.log_prob(y0)
                ref_log_ratios = self.reference.log_prob(y1) - self.reference.log_prob(y0)
                logits = pi_logratios - ref_log_ratios
                
                loss = (
                    -F.logsigmoid(beta * logits) * (1 - label_smoothing) 
                    - F.logsigmoid(-beta * logits) * label_smoothing
                ).mean()
                
                optimizer.zero_grad()
                loss.backward()
                losses.append(loss.cpu().item())
                optimizer.step()
                
                pbar.update(1)
                pbar.set_description(f"loss {loss.item():0.3f}")
            
        self.policy.eval()
        return self.policy, losses
            
    def sample(self, num_samples=1):
        self.policy.eval()
        with torch.no_grad():
            samples = self.policy.sample(num_samples=num_samples)
        return samples