import torch
from torch.nn import functional as F
import numpy as np

from GaussianConv import GaussianConv
from WeightMatrix import WeightMatrix

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("dark")

class NeuralSheet:

    def __init__(self, shape, patch_size, n_lateral, mode="SONS", dtype=torch.float32, device='cuda:0'):
        # seed stuff
        torch.random.manual_seed(0)
        np.random.seed(0)
        self.device = device
        # parameters
        self.mode = mode
        self.shape = shape
        self.n_neurons = int(np.prod(self.shape))
        # data structures
        self.graph = WeightMatrix(self.n_neurons, dtype=dtype, device=device)
        # states of the neuron (membrane potential)
        self.neuron_states = torch.zeros(size=(self.n_neurons, ), dtype=torch.float32, device=device)
        # neuron spiked or not
        self.neuron_spiked = torch.zeros(size=(self.n_neurons, ), dtype=torch.bool, device=device)
        self.neuron_last_spike_time = torch.zeros(size=(self.n_neurons, ), dtype=torch.long, device=device)
        # mark input neurons
        self.is_input_neuron = torch.zeros(size=(self.n_neurons,), dtype=torch.bool, device=device)
        # lateral excitation / inhibition kernels
        self.lateral_conv_excitation = GaussianConv(kernel_size=11, sigma=3, dtype=dtype, device=device)
        self.lateral_conv_inhibition = GaussianConv(kernel_size=n_lateral * 2 + 1, sigma=n_lateral, dtype=dtype, device=device)
        self.lateral_traces = torch.zeros_like(self.neuron_states)
        # keep track of time
        self.t = 0
        # reset, init everything
        self.reset_states()

    def reset_states(self):
        self.neuron_states[:] = 0
        self.neuron_spiked[:] = False
        self.is_input_neuron[:] = 0
        self.lateral_traces[:] = 0

    def reset_tracking(self, x_raw):
        self.input_activation_mean = torch.zeros(size=(self.neuron_spiked.shape[0], *x_raw.shape), dtype=torch.float64,device=self.neuron_spiked.device)
        self.input_activation_count = torch.zeros(size=(self.neuron_spiked.shape[0],), dtype=torch.long, device=self.neuron_spiked.device)

    def unravel_index(self, index, shape):
        out = []
        for dim in reversed(shape):
            out.append(index % dim)
            index = index // dim
        return tuple(reversed(out))

    def forward(self, x, x_raw=None):
        # set input
        self._set_input(x)
        # update neuron state
        self._neuron_forward()
        # keep track of input activation
        if self.t == 0:
            self.reset_tracking(x_raw)
            self.graph.initialize(self.is_input_neuron)
        self.input_activation_mean[self.neuron_spiked] = self.input_activation_mean[self.neuron_spiked] + x_raw.view(1, *x_raw.shape)
        self.input_activation_count[self.neuron_spiked] += 1
        # update time step
        self.t+=1
        return

    def _set_input(self, x):
        self.input_shape = x.shape
        # flatten last dimension and pad the input on left and right, to match shape
        x_2d = x.view(x.shape[0], -1)
        self.padding = ((self.shape[1]-x_2d.shape[1])//2, (self.shape[1]-x_2d.shape[1])//2+1, 0, 0)
        x_2d_padded = F.pad(x_2d, self.padding, mode='constant')
        # plt.imshow(x_2d_padded.cpu().numpy())
        # plt.show()
        # remember input and its shape
        self.input = x_2d_padded
        # flatten to match internal structure
        x_ravel = x_2d_padded.view(-1)
        # mark input neurons
        self.is_input_neuron[:x_ravel.shape[0]] = F.pad(torch.ones_like(x).view(x.shape[0], -1), self.padding, mode='constant').view(-1)
        # set pixel values as if its 2d
        self.neuron_states[:x_ravel.shape[0]] += x_ravel
        return

    def _neuron_forward(self):
        # forward pass
        self.neuron_states += self.graph.forward(self.neuron_spiked)
        # apply spike model
        self._stochastic_IAF()
        # lateral interaction
        if self.mode != "SNN":
            self._lateral_forward()
            self.neuron_states += self.lateral_traces
        # set spike times
        self.neuron_last_spike_time[self.neuron_spiked] = self.t

    def _stochastic_IAF(self):
        base_noise = 1e-6
        self.neuron_spiked = self.neuron_states**5 + base_noise >= torch.rand(size=self.neuron_states.shape, device=self.device)
        # set output of those that spiked to be negative, hard refractory period
        self.neuron_states[self.neuron_spiked] = -1000

    def _lateral_forward(self):
        # only on none-input neurons
        grid_neuron_spiked = self.neuron_spiked * ~self.is_input_neuron
        if len(torch.where(grid_neuron_spiked)[0]) > 0:
            grid_neuron_spiked = grid_neuron_spiked.view(size=self.shape)
            # forward exc and inh kernels
            grid_neuron_blurred_exc = self.lateral_conv_excitation(grid_neuron_spiked)
            grid_neuron_blurred_inh = self.lateral_conv_inhibition(grid_neuron_spiked)
            # combine to DoG
            grid_neuron_blurred = grid_neuron_blurred_exc - (grid_neuron_blurred_inh/grid_neuron_blurred_inh.max())*grid_neuron_blurred_exc.max()*0.5
            # normalize output response to better match membrane potentials
            grid_neuron_blurred = grid_neuron_blurred / grid_neuron_blurred.max()
            # store flattened , later used to determine spike
            self.lateral_traces[:] = grid_neuron_blurred.view(-1)
        else:
            self.lateral_traces[:] = 0

        return self.lateral_traces

    def update(self):
        neuron_spiked_idx = torch.where(self.neuron_spiked)[0]
        if len(neuron_spiked_idx) > 0:
            # calculate time diff between all pairs of synapses
            synapse_dt = self.neuron_last_spike_time.view(-1, 1) - self.neuron_last_spike_time.view(1, -1)
            # only take for spiked neurons
            synapse_dt = synapse_dt[neuron_spiked_idx]
            if synapse_dt.abs().sum() > 0:
                # calc stdp on diff
                dw = self.stdp_func(synapse_dt)
                # derive learning rate
                lr = 0.01
                if self.mode == "SONS":
                    xy1 = torch.stack(self.unravel_index(torch.arange(0, self.n_neurons, device=self.device), self.shape), dim=-1)
                    xy2 = torch.stack(self.unravel_index(torch.arange(0, self.n_neurons, device=self.device), self.shape), dim=-1)
                    # broadcast calc, could be precomputed for fixed fully connected matrices, but not dynamic sparse ones
                    synapses_in_dists = torch.sqrt(torch.sum((xy1.view(-1, 1, 2) - xy2.view(1, -1, 2)) ** 2, dim=-1).float())
                    synapses_in_dists = synapses_in_dists[neuron_spiked_idx]
                    lr = 0.1 / (synapses_in_dists**0.5 + 1e-15)
                # apply lr
                dw *= lr
                # apply stdp
                self.graph.update_values(dw, index=neuron_spiked_idx)
                # renormalize values
                self.graph.initialize(self.is_input_neuron)
                return

    def stdp_func(self, dt):
        A_pos, A_neg = 1.0, 0.75
        tau_pos, tau_neg = 1.0, 1.0
        dw = torch.where(dt>0, A_pos*2**(-dt/tau_pos), A_neg*2**(dt/tau_neg)) * dt
        return dw


    def visualize_features(self, i, mode):
        # -- response based --
        plt.figure(1)
        plt.clf()

        neuron_idxs = torch.arange(0, self.neuron_states.shape[0], device=self.device)[~self.is_input_neuron]
        neuron_idxs_2d = self.unravel_index(neuron_idxs, shape=self.shape)
        feature_grid = torch.zeros(size=(self.shape + self.input_activation_mean.shape[1:]), device=self.device)

        for k in range(neuron_idxs.shape[0]):
            idx = neuron_idxs[k]
            feature_grid[(neuron_idxs_2d[0][k], neuron_idxs_2d[1][k])] = self.input_activation_mean[idx].reshape(self.input_activation_mean.shape[1:])/(self.input_activation_count[idx])
        feature_grid = feature_grid.cpu().numpy()
        feature_grid = np.hstack(np.column_stack(feature_grid))
        plt.imshow(feature_grid, vmin=0, vmax=1)
        plt.savefig("./results/{}/{}_topographic_map.png".format(mode, mode))
        plt.pause(0.0001)

    def visualize_spiking(self, i, t, mode):
        plt.figure(3)
        plt.clf()

        state_grid = torch.where(self.neuron_spiked, torch.ones_like(self.neuron_states), self.neuron_states)
        plt.imshow(state_grid.reshape(self.shape).cpu().numpy(), vmin=-1, vmax=1)
        plt.savefig("./results/{}/rollout/{}_spikes_{}.png".format(mode, mode, t))

        plt.pause(0.1)
