import torch
from tqdm import tqdm, trange

def make_batchwise_system_matrix(targets, eps=1e-2):
    """ Create a batch of system matrices for solving for latent in softmax
    Args:
        targets: (batch_size, num_classes) target probabilities
        eps: (float) small value to avoid division by zero
    """
    assert targets.ndim == 2
    
    N = targets.shape[1]
    # Create target vectors
    A = (targets / (1 - targets + eps)).unsqueeze(-1)
    # Expand to matrices
    A = A.repeat(1, 1, N)
    # Set diagonals to 0
    A -= torch.eye(N) * A

    return A

def solve_x_softmax(y, alpha=1., eps=1e-8):
    """ Solve for latent in softmax
    Args:
        y: (batch_size, num_classes) target softmax outputs
        alpha: (float) scaling factor for latent
        eps: (float) small value to avoid division by zero
    """
    assert y.ndim == 2

    # Create system matrix
    A = make_batchwise_system_matrix(y, eps=eps)
    U,S,V = torch.linalg.svd(A)
    A_hat = U @ torch.diag(S) @ V.mH
    # Solve for latent
    return alpha * torch.log(torch.linalg.lstsq(A, y).solution)
    # return alpha * torch.log(torch.inverse(A.transpose(1,2) @ A) @ A.transpose(1,2) @ y[:,:,None])

def solve_x_softmax_iteratively(y, prev_z, alpha=1., iters=100, batch=300, eps=1e-8):
    """ Solve for latent in softmax
    Args:
        y: (batch_size, num_classes) target softmax outputs
        alpha: (float) scaling factor for latent
        eps: (float) small value to avoid division by zero
    """
    assert y.ndim == 2
    assert prev_z.shape == y.shape
    z_dim = prev_z.shape[1]

    z = prev_z.clone()
    for i in trange(iters):
        i_min = i * batch
        i_max = min((i + 1) * batch, y.shape[0])
        for j in range(z_dim):
            # z[:,j] = alpha * torch.log((y[:,j] * torch.exp(torch.min(z[:, torch.where(torch.arange(z_dim) != j)[0]] / alpha, torch.tensor(10))).sum(dim=-1)) / (1 - y[:,j]) + eps)
            
            z[i_min:i_max,j] = alpha * (
                torch.log(torch.clip(y[i_min:i_max,j], eps)) + 
                torch.logsumexp(z[i_min:i_max, torch.where(torch.arange(z_dim) != j)[0]] / alpha, dim=-1) - 
                torch.log(torch.clip(1 - y[i_min:i_max,j], eps)))

        # y_hat = torch.nn.functional.softmax(z / alpha, dim=-1)
        # err = (y - y_hat).pow(2).sum(dim=-1).sqrt().sum()
        # tqdm.write(str(err.detach().cpu().numpy()))

    return z

if  __name__ == "__main__":
    ### Solve Sigmoid ###
    # x = torch.rand(1)
    # y = torch.sigmoid(x)

    # def solve_x_sigmoid(y):
    #     return torch.log(y) - torch.log(1 - y)

    # print(x, y)
    # x_hat = solve_x_sigmoid(y)
    # print(x_hat, y)

    # ### Solve Softmax ###
    # num = 10
    # z_og = torch.rand(num) * 5
    # z = z_og.clone()
    # out = torch.nn.functional.softmax(z_og, dim=0)
    # target = torch.rand(num)
    # target = target / target.sum()

    # def solve_x_softmax(y, z, i, eps=1e-8):
    #     return torch.log((y * torch.exp(z[torch.where(torch.arange(len(z)) != i)]).sum()) / (1 - y) + eps)

    # def make_system_matrix(target, eps=1e-8):
    #     A = torch.zeros((len(target), len(target)))

    #     for i in range(len(target)):
    #         A[i, torch.where(torch.arange(len(target)) != i)[0]] = (target[i] / (1 - target[i]) + eps)

    #     return A

    # def make_system_matrix2(target, eps=1e-8):
    #     # Create target vector
    #     A = (target / (1 - target + eps)).unsqueeze(-1)
    #     # Expand to matrix
    #     A = A.repeat(1, len(target))
    #     # Set diagonal to 0
    #     A -= torch.eye(len(target)) * A

    #     return A

    # # Iteratively
    # for j in range(10):
    #     A = make_system_matrix(target)
    #     z = torch.log(A @ torch.exp(z))

    #     out = torch.nn.functional.softmax(z, dim=0)
        
    # print(((target - out)**2).sqrt().sum())

    # # solve via least squares
    # A = make_system_matrix(target)
    # z = torch.log(torch.linalg.lstsq(A, target).solution)
    # out = torch.nn.functional.softmax(z, dim=0)
    # print(((target - out)**2).sqrt().sum())


    ## Batchwise least squares
    batch_size = 1000
    alpha = 0.001
    num = 11
    z_og = torch.rand(batch_size, num) * 5
    z = z_og.clone()
    z /= alpha
    out = torch.nn.functional.softmax(z, dim=0)
    target = torch.exp(torch.rand(batch_size, num) * 50)
    target = torch.softmax(target, dim=-1)
    target = target
    # target = target / target.sum(-1)[:, None] 
    print('Pre-solve', ((target - out)**2).sum(dim=-1).sqrt().sum())

    z = solve_x_softmax_iteratively(target, z_og, alpha=alpha, iters=1000)
    out = torch.nn.functional.softmax(z / alpha, dim=-1)

    # print(z, z_og)
    print('After-solve',((target - out)**2).sum(dim=-1).sqrt().sum())