import numpy as np

class CANN_v2:
    def __init__(self, N, k, J, tau, trans):
        self.N = N
        self.k = k
        self.J = J
        self.tau = tau
        self.cor = np.arange(-np.pi, np.pi, 2*np.pi/N)
        self.U = np.zeros(N)
        self.r = np.zeros(N)
        self.trans = trans

    def interact(self):
        if self.trans:
            rfft = np.fft.fft(self.r)
            jfft = np.fft.fft(self.J)
            U_tmp = np.fft.ifft(rfft*jfft)
        else:
            U_tmp = np.matmul(self.J,self.r)
        return np.real(U_tmp)

    def update(self,I_ext,dt):
        self.U = self.U + dt/self.tau*(self.interact() + I_ext - self.U)
        self.r = self.U**2/(1+self.k*(self.U**2).sum())
        lzero = self.U<0
        self.U[lzero] = 0

    def reset(self):
        self.U = np.zeros(self.N)
        self.r = np.zeros_like(self.U)