from functools import partial
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import nn, einsum
from random import random
from einops.layers.torch import Rearrange
import torch.nn.functional as F
from einops import rearrange, reduce, repeat


from algorithms.diffusion_forcing.models.unet import Unet
from algorithms.diffusion_forcing.models.gru import Resnet2dGRUCell, Conv2dGRUCell
from algorithms.diffusion_forcing.models.resnet import ResBlock2d


class DFUnet(Unet):
    def __init__(
        self,
        z_channel=16,
        x_channel=3,
        external_cond_dim=None,
        network_size=32,
        num_gru_layers=1,
        self_condition=False,
    ):
        super().__init__(
            network_size,
            channels=x_channel,
            out_dim=z_channel,
            external_cond_dim=external_cond_dim,
            z_cond_dim=z_channel,
            self_condition=self_condition,
        )
        self.z_channel = z_channel
        self.x_channel = x_channel
        self.num_gru_layers = num_gru_layers
        self.self_condition = self_condition
        self.gru = Conv2dGRUCell(z_channel, z_channel) if num_gru_layers else None
        self.init_h = nn.Parameter(torch.randn([self.z_channel, 64, 64]), requires_grad=True)
        if num_gru_layers > 1:
            raise NotImplementedError("num_gru_layers > 1 is not implemented yet for TransitionUnet.")

    def forward(self, x, t, x_self_cond=None):
        z_cond = self.init_h.expand(x.shape[0], -1, -1, -1)
        z_next = super().forward(x, t, z_cond, x_self_cond)
        if self.num_gru_layers:
            z_next = self.gru(z_next, z_cond)

        return z_next


class DFUnetWrapper(nn.Module):
    def __init__(
        self,
        z_channel=16,
        x_channel=3,
        external_cond_dim=None,
        network_size=32,
        num_gru_layers=1,
        self_condition=False,
    ):
        super().__init__()
        self.unet = DFUnet(z_channel, x_channel, external_cond_dim, network_size, num_gru_layers, self_condition)
        self.x_from_z = nn.Sequential(
            ResBlock2d(z_channel, x_channel),
            nn.Conv2d(x_channel, x_channel, 1, padding=0),
        )
        # self.x_from_z = nn.Identity()
        self.channels = self.unet.channels
        self.out_dim = self.unet.channels
        self.self_condition = self.unet.self_condition

    def forward(self, x, t, x_self_cond=None):
        z = self.unet(x, t, x_self_cond)
        x_next = self.x_from_z(z)

        return x_next
