import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorly as tl
from tensorly.decomposition import matrix_product_state
from timm.models import register_model


class TTMLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, ranks, shape):
        super(TTMLP, self).__init__()

        self.w1 = nn.Parameter(torch.randn(input_size, hidden_size1) * 0.01)
        self.b1 = nn.Parameter(torch.zeros(hidden_size1))

        w2 = torch.randn(hidden_size1, hidden_size2) * 0.01
        self.orig_shape = (hidden_size1, hidden_size2)
        self.shape = shape
        _w2 = w2.reshape(self.shape)
        factors = matrix_product_state(_w2, rank=ranks)
        self.factors = [nn.Parameter(fac) for fac in factors]
        self.b2 = nn.Parameter(torch.zeros(hidden_size2))

        self.w3 = nn.Parameter(torch.randn(hidden_size2, output_size) * 0.01)
        self.b3 = nn.Parameter(torch.zeros(output_size))

    def forward(self, x):
        x_flat = x.view(x.shape[0], -1)
        z1 = torch.mm(x_flat, self.w1) + self.b1
        a1 = F.relu(z1)
        T = tl.tt_to_tensor(self.factors)
        _T = T.reshape(self.orig_shape)
        z2 = a1 @ _T + self.b2
        a2 = F.relu(z2)
        z3 = torch.mm(a2, self.w3) + self.b3
        return z3


@register_model
def tt_mlp(hidden_size, **_):
    tl.set_backend('pytorch')
    shape = (10, 2, 10, 2, 25)
    ranks = [1, 5, 5, 5, 10, 1]
    return TTMLP(784, hidden_size, hidden_size, 10, ranks, shape)
