import torch
import PIL.Image as Image
import numpy as np
from .base import BaseFlow
import agentpy as ap


def normalize(v):
    """Normalize a vector to length 1."""
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return v / norm


class Boid(ap.Agent):
    """An agent with a position and velocity in a continuous space,
    who follows Craig Reynolds three rules of flocking behavior;
    plus a fourth rule to avoid the edges of the simulation space."""

    def setup(self):

        self.velocity = normalize(self.model.nprandom.random(self.p.ndim) - 0.5)

    def setup_pos(self, space):
        # self.time = 0
        self.space = space
        self.neighbors = space.neighbors
        self.pos = space.positions[self]

    def update_velocity(self, agent_id):

        pos = self.pos
        ndim = self.p.ndim

        # Rule 1 - Cohesion
        nbs = self.neighbors(self, distance=self.p.outer_radius)
        nbs_len = len(nbs)
        nbs_pos_array = np.array(nbs.pos)
        nbs_vec_array = np.array(nbs.velocity)
        if nbs_len > 0:
            center = np.sum(nbs_pos_array, 0) / nbs_len
            v1 = (center - pos) * self.p.cohesion_strength
        else:
            v1 = np.zeros(ndim)

        # Rule 2 - Seperation
        v2 = np.zeros(ndim)
        for nb in self.neighbors(self, distance=self.p.inner_radius):
            v2 -= nb.pos - pos
        v2 *= self.p.seperation_strength

        # Rule 3 - Alignment
        if nbs_len > 0:
            average_v = np.sum(nbs_vec_array, 0) / nbs_len
            v3 = (average_v - self.velocity) * self.p.alignment_strength
        else:
            v3 = np.zeros(ndim)

        # Rule 4 - Borders
        v4 = np.zeros(ndim)
        d = self.p.border_distance
        s = self.p.border_strength
        for i in range(ndim):
            if pos[i] < d:
                v4[i] += s
            elif pos[i] > self.space.shape[i] - d:
                v4[i] -= s

        # Update velocity
        self.velocity += v1 + v2 + v3 + v4
        self.velocity = normalize(self.velocity)
        # print(self.velocity.shape,self.pos.shape)
        # print(self.velocity[0],self.velocity[1])

    def update_position(self):

        self.space.move_by(self, self.velocity)
        # print(self.pos[0],self.pos[1])


class BoidsModel(ap.Model):
    """
    An agent-based model of animals' flocking behavior,
    based on Craig Reynolds' Boids Model [1]
    and Conrad Parkers' Boids Pseudocode [2].

    [1] http://www.red3d.com/cwr/boids/
    [2] http://www.vergenet.net/~conrad/boids/pseudocode.html
    """

    def setup(self):
        """Initializes the agents and network of the model."""

        self.space = ap.Space(self, shape=[self.p.size] * self.p.ndim)
        self.agents = ap.AgentList(self, self.p.population, Boid)

        self.space.add_agents(self.agents, random=True)

        self.agents.setup_pos(self.space)
        self.report(self.vars)

        self.my_positions_history = []
        self.my_flow_history = []

    def step(self):
        """Defines the models' events per simulation step."""

        self.agents.get

        old_pos = self.agents.pos + 0

        old_vel = self.agents.velocity + 0

        self.agents.update_velocity(self.agents.id)

        self.agents.update_position()

        new_pos = self.agents.pos

        my_positions = []
        my_flows = []

        for i, vel in enumerate(old_vel):

            opt_flow = np.zeros((self.p.size, self.p.size, 2))
            pos_t = np.zeros((self.p.size, self.p.size))

            f_list = list(old_pos[i])

            pos_t[round(self.p.size - 1 - f_list[1]), round(f_list[0] - 1)] = 1

            opt_flow[round(self.p.size - 1 - f_list[1]), round(f_list[0] - 1)][0] = (
                round(new_pos[i][0]) - round(old_pos[i][0])
            )
            opt_flow[round(self.p.size - 1 - f_list[1]), round(f_list[0] - 1)][1] = (
                round(new_pos[i][1]) - round(old_pos[i][1])
            )

            my_positions.append(pos_t)
            my_flows.append(opt_flow)

        self.my_positions_history.append(my_positions)
        self.my_flow_history.append(my_flows)

    def update(self):
        self.record(self.vars)
        self.report(self.vars)

    def end(self):
        self.report("sum_id", sum(self.agents.id))


class Birds(BaseFlow):
    def __init__(self, N) -> None:
        super().__init__(N=N)

        self.pos_prompt = "a small flock bird flying in the sky at the sunset"

        self.perform_agents_simulation()

    def get_spatial_eta(self, t):
        raise NotImplementedError

    def get_default_image(self) -> torch.Tensor:
        image = Image.open("base_images/birds_noGS.png")
        return image

    def get_default_framesteps(self) -> torch.Tensor:
        return torch.tensor(list(range(len(self.allag_pos_t))))

    def get_flow(self, t) -> torch.Tensor:
        raise NotImplementedError

    def perform_agents_simulation(self):
        SIMU_TIMESTEPS = 30  # 80
        NUM_AGENTS = 60

        parameters2D = {
            "size": 100,
            "seed": 123,
            "steps": SIMU_TIMESTEPS,
            "ndim": 2,
            "population": NUM_AGENTS,
            "inner_radius": 3,
            "outer_radius": 10,
            "border_distance": 10,
            "cohesion_strength": 0.0005,
            "seperation_strength": 0.1,
            "alignment_strength": 0.3,
            "border_strength": 0.5,
        }

        mymodel = BoidsModel(parameters2D)
        mymodel.run()

        allag_pos_t = mymodel.my_positions_history
        allag_flow_t = mymodel.my_flow_history

        # Select the agents to discard due to the simulation space
        sum = allag_pos_t[-1][0] * 0
        agent_to_get = []
        for t in range(NUM_AGENTS):
            posx = np.nonzero(allag_pos_t[0][t])[0]
            posy = np.nonzero(allag_pos_t[0][t])[1]
            if posx < 34 and posx > 1 and posy > 1:
                agent_to_get.append(t)
                sum += np.array(allag_pos_t[0][t])

        # Remove agents placed at position 4
        agent_to_get.pop(4)

        # Discard selected agents
        allag_pos_t = [
            [allag_pos_t[t][i] for i in agent_to_get] for t in range(SIMU_TIMESTEPS)
        ]
        allag_flow_t = [
            [allag_flow_t[t][i] for i in agent_to_get] for t in range(SIMU_TIMESTEPS)
        ]

        self.allag_pos_t = allag_pos_t
        self.allag_flow_t = allag_flow_t
