import sklearn
import sklearn.cluster
import torch
import autograd

def cluster(data, k, temp, num_iter, init = None, cluster_temp=5):
    #normalize x so it lies on the unit sphere
    data = torch.diag(1./torch.norm(data, p=2, dim=1)) @ data
    #use kmeans++ initialization if nothing is provided
    if init is None:
        data_np = data.detach().numpy()
        norm = (data_np**2).sum(axis=1)
        init = sklearn.cluster.k_means_._k_init(data_np, k, norm, sklearn.utils.check_random_state(None))
        init = torch.tensor(init, requires_grad=True)
        if num_iter == 0: return init
    mu = init
    n = data.shape[0]
    d = data.shape[1]
#    data = torch.diag(1./torch.norm(data, dim=1, p=2))@data
    for t in range(num_iter):
        #get distances between all data points and cluster centers
#        dist = torch.pairwise_distance(data[:, None].expand(n, k, d).reshape((-1, d)), mu[None].expand(n, k, d).reshape((-1, d))).reshape((n, k))
#        dist = torch.cosine_similarity(data[:, None].expand(n, k, d).reshape((-1, d)), mu[None].expand(n, k, d).reshape((-1, d))).reshape((n, k))
        dist = data @ mu.t()
#        print(dist.min(), dist.max(), dist.median())
        #cluster responsibilities via softmax
        r = torch.softmax(cluster_temp*dist, 1)
        #total responsibility of each cluster
        cluster_r = r.sum(dim=0)
        #mean of points in each cluster weighted by responsibility
        cluster_mean = (r.t().unsqueeze(1) @ data.expand(k, *data.shape)).squeeze(1)
        #update cluster means
        new_mu = torch.diag(1/cluster_r) @ cluster_mean
#        print((mu - new_mu).abs().mean())
        mu = new_mu
#        mu = torch.diag(1./torch.norm(mu, p=2, dim=1)) @ mu
    dist = data @ mu.t()
    r = torch.softmax(cluster_temp*dist, 1)
    return mu, r, dist


def f(x, mu, cluster_temp):
    import autograd.numpy as np
    k = mu.shape[0]
    sim = x @ np.transpose(mu)
#    r = torch.softmax(cluster_temp*sim, 1)
    r = np.exp(cluster_temp*sim)
    r = np.diag(1./np.sum(r, 1)) @ r
    #total responsibility of each cluster
    cluster_r = np.sum(r, 0)
    #mean of points in each cluster weighted by responsibility
    cluster_mean = np.squeeze(np.expand_dims(np.transpose(r), 1) @ np.repeat(np.expand_dims(x, 0), k, 0), 1)
    #update cluster means
    actual_mu = np.diag(1/cluster_r) @ cluster_mean
    return mu - actual_mu


class SoftCluster(torch.autograd.Function):
    
    def __init__(self, k, temp, num_iter, cluster_temp):
        self.k = k
        self.temp = temp
        self.num_iter = num_iter
        self.cluster_temp = cluster_temp
        
    def forward(self, data):
        self.data = data
        mu, r, sim =  cluster(data, self.k, self.temp, self.num_iter, cluster_temp = self.cluster_temp) 
        self.mu = mu
        self.r = r
        self.sim = sim
        return mu, r, sim
#        return mu
    
    def backward(self, grad_mu, grad_r, grad_sim):
        r = self.r
        x = self.data
        mu = self.mu
        k = self.k
        p = x.shape[1]
        n = x.shape[0]
        
        jacobian_f_x = autograd.jacobian(f, 0)
        jacobian_f_mu = autograd.jacobian(f, 1)
        
        dfdx = torch.tensor(jacobian_f_x(x.detach().numpy(), mu.detach().numpy(), self.cluster_temp))
        dfdx = dfdx.view(k*p, n*p)
        dfdmu = torch.tensor(jacobian_f_mu(x.detach().numpy(), mu.detach().numpy(), self.cluster_temp))
        dfdmu = dfdmu.view(k*p, k*p)
        
        
        #TODO: figure out where the softmax temperature figures into this
        
        #compute df/dx
#        dsoftmax = -r.unsqueeze(2) @ r.unsqueeze(1)
#        for j in range(n):
#            dsoftmax[j] += torch.diag(r[j])
#        dsim = self.mu.repeat(n, 1, 1)
#        dr = dsoftmax @ dsim
#        x_expand = x.unsqueeze(1).repeat(1, k, 1)
#        R = r.sum(dim = 0)
#        R_expand = R.unsqueeze(1).repeat(n, 1, p)
#        C = (r.t().unsqueeze(1) @ x.expand(k, *x.shape)).squeeze(1)
#        C_expand = C.repeat(n, 1, 1)
#        i_term = (R_expand*x_expand - C_expand)/(R_expand**2)
#        dfdx = i_term.unsqueeze(3).flatten(0,1) @  dr.unsqueeze(2).flatten(0,1)
#        dfdx = dfdx.view(n, k, p, p)
#        diag_term = r/R.repeat(n, 1)
#        for l in range(x.shape[1]):
#            dfdx[:, :, l, l] += diag_term
#        dfdx = -dfdx
#        dfdx = dfdx.permute(1, 2, 0, 3)
#        dfdx = dfdx.flatten(0, 1).flatten(1,2)
#        
#        #compute df/dmu
#        dsoftmax_p = dsoftmax.unsqueeze(3).repeat(1, 1, 1, p)
#        x_expand = x.unsqueeze(1).unsqueeze(1).repeat(1, k, k, 1)
#        dsim = x_expand
#        dr = dsoftmax_p * dsim
#        
#        first_term = x_expand.flatten(0, 2).unsqueeze(2) @ dr.flatten(0, 2).unsqueeze(1)
#        first_term = first_term.view(n, k, k, p, p)
#        first_term = first_term.sum(dim = 0)
#        R_expand = R.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, k, p, p)
#        first_term = R_expand * first_term
#        
#        dr_sum = dr.sum(dim=0)
#        C_expand = C.unsqueeze(1).repeat(1, k, 1)
#        second_term = C_expand.flatten(0, 1).unsqueeze(2) @ dr_sum.flatten(0, 1).unsqueeze(1)
#        second_term = second_term.view(k, k, p, p)
#        denom = R_expand**2
#        
##        print(first_term.shape, second_term.shape, denom.shape)
#        dfdmu = -(first_term - second_term)/denom
#        for i in range(k):
#            for l in range(p):
#                dfdmu[i, i, l, l] += 1
#        dfdmu = dfdmu.transpose(1,2)
#        dfdmu = dfdmu.flatten(0,1).flatten(1,2)
        
        #solve linear system
#        print('do inversion')
        dmudx = dfdmu.inverse() @ dfdx
        dmudx = dmudx.view(k, p, n, p)
        
        return (grad_mu.view(1, k*p) @ dmudx.view(k*p, n*p)).view(n, p)
##        print('done inversion')
#        #use dmudx to compute the backward pass vector
#        dLdx = torch.zeros_like(x)
#        #path that goes through r (sum so that we don't allocate n^2 memory)
#        batch = np.random.choice(range(n), 100, replace=False)
#        for j in batch:
##            print(j)
##            dsimdx = torch.zeros(k, n, p)
##            for i in range(k):
##                dsimdx[i] = (x[j].unsqueeze(0) @ dmudx[i].view(p, n*p)).view(n, p)
#            dsimdx = (x[j].repeat(k, 1).unsqueeze(1) @ dmudx.view(k, p, n*p)).view(k, n, p)
#            dsimdx[:, j, :] += mu
#            drdx = dsoftmax[j] @ dsimdx.flatten(1,2)
#            dLdx += (grad_r[j].unsqueeze(0) @ drdx).view(n, p)
#        #path that goes through mu
#        dLdx += (grad_mu.view(1, k*p) @ dmudx.view(k*p, n*p)).view(n, p)
        
#        return dLdx
  
if __name__ == '__main__':
    data = torch.rand(100, 2, requires_grad = True)
    stuff, y = sklearn.datasets.make_blobs(n_samples=100, n_features=2, centers=5)
    data.data = torch.tensor(stuff, requires_grad=True).float()
    k = 5
    targets = torch.rand(k, 2)*5
    optimizer = torch.optim.Adam([data])
    mu = None
    
    xmin, xmax = data[:, 0].min().item(), data[:, 1].max().item()
    ymin, ymax = data[:, 1].min().item(), data[:, 1].max().item()
    xmin = min((xmin, targets[:, 0].min().item()))
    ymin = min((ymin, targets[:, 1].min().item()))
    xmax = max((xmax, targets[:, 0].max().item()))
    ymax = max((ymax, targets[:, 1].max().item()))
    
    for t in range(10000):
        mu, r = cluster(data, k, None, 10, mu.detach() if mu is not None else None)
        loss = torch.nn.MSELoss()(mu, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #    print(mu)
        if t % 500 == 0:
    #        print(data.grad.abs().sum())
            print(loss.item())
            X = data.detach().numpy()
            Y = targets.detach().numpy()
            Z = mu.detach().numpy()
            plt.figure()
            plt.scatter(X[:, 0], X[:, 1]); plt.scatter(Y[:, 0], Y[:, 1], s=200); plt.scatter(Z[:, 0], Z[:, 1], s=200)
            plt.xlim(xmin, xmax)
            plt.ylim(ymin, ymax)
        
        