import math
import numpy as np
import torch
import torch.nn as nn
import random


def squared_error(ys_pred, ys):
    return (ys - ys_pred).square()


class Task:
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
        self.n_dims = n_dims
        self.b_size = batch_size
        self.pool_dict = pool_dict
        self.seeds = seeds
        assert pool_dict is None or seeds is None

    def evaluate(self, xs):
        raise NotImplementedError

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        raise NotImplementedError

    @staticmethod
    def get_training_metric():
        raise NotImplementedError


def get_task_sampler(
    task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs
):
    task_names_to_classes = {
        "relu_nn_regression": ReluNNRegression,
        "cot_skill_chain": CoTSkillChain,
        "relu_nn_regression_asymmetric": ReluNNRegressionAsymmetric
    }
    if task_name in task_names_to_classes:
        task_cls = task_names_to_classes[task_name]
        if num_tasks is not None:
            if pool_dict is not None:
                raise ValueError("Either pool_dict or num_tasks should be None.")
            pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs)
        return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs)
    else:
        print("Unknown task")
        raise NotImplementedError

# TODO
class CoTSkillChain(Task):
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None,
        n_skills=10, n_funcs=20, noise=0.1, seed=0, scale=1,  
        min_chain_length=1, max_chain_length=2, chain_length=5, ordered_chain=False, mode='linear'):
        super(CoTSkillChain, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.n_skills = n_skills
        self.n_funcs = n_funcs
        self.min_chain_length = 7#min_chain_length
        self.max_chain_length = max_chain_length+1
        # self.chain_length = chain_length
        self.chain_length = max_chain_length
        self.ordered_chain = ordered_chain
        self.mode = mode
        self.noise = noise

        rng = np.random.default_rng(seed)
        # self.func_set = torch.zeros(n_skills, n_funcs, n_dims, n_dims)
        # for i in range(n_skills):
        #     for j in range(n_funcs):
        #         self.func_set[i,j] = torch.linalg.svd(torch.randn(n_dims, n_dims, generator=generator))[0]
        self.func_set = np.linalg.svd(rng.standard_normal((n_skills, n_funcs, n_dims, n_dims)))[0]
        # print(self.func_set[0,1].dot( self.func_set[0,1].T))
        
        if mode == 'linear':
            def func(x):
                return x
            self.act_func = func
        else:
            raise ValueError('No vaild mode found')


    def __get_chain(self, xs, func_ids, skill_ids, eval):
        chain_length = skill_ids.shape[1] 

        bsize, n_points, n_dims = xs.shape
        skill_ids = np.concatenate([skill_ids, -np.ones((bsize,1), dtype=int)], axis=1)
        # if not eval:
        #     n_sample = list(np.random.randint(self.min_chain_length, self.max_chain_length+1, (bsize, n_points)))
        #     start_ids = list(map(lambda n:  np.random.randint(chain_length-n+1), n_sample))
        # else:
        #     n_sample = list(np.ones((bsize, n_points), dtype=int)*2)
        #     start_ids = list(np.arange(chain_length-1).reshape(1,-1).repeat(n_points,axis=0).reshape(1,-1)[:,:n_points].repeat(bsize,axis=0))
        # end_ids = list(map(lambda id, n: id + n, start_ids, n_sample))
        start_ids = list(np.zeros((bsize, n_points), dtype=int))
        n_sample = np.random.randint(self.min_chain_length, chain_length+1,size=bsize)
        end_ids = list(map(lambda id, n: id + n, start_ids, n_sample))

        ids = [list(map(lambda sid, eid: list(range(sid, eid)), sid, eid)) for sid, eid in zip(start_ids, end_ids)]
        ids_func = np.array([np.concatenate(list(map(lambda i, j: (chain_length+1)*j+np.array(i + [i[-1] + 1]), id, range(n_points))))[:n_points] for id in ids])
        random_noise = np.random.randn(chain_length,bsize,n_points,n_dims)
        ys = [xs]
        for i in range(self.chain_length):
            temp = self.func_set[skill_ids[:,i],func_ids[:,i]]
            xs = self.act_func(xs @ temp + self.noise * random_noise[i])
            ys += [xs]
        # ys += [np.zeros_like(xs)]
        ys += [xs]
        ys = np.stack(ys, axis=2)
        ys = ys.reshape(bsize, -1, n_dims)
        ys_ = np.array(list(map(lambda i, id: ys[i,id], range(bsize), ids_func)))
        ids_skill = [np.concatenate(list(map(lambda i: list(i) + [chain_length], id)))[:n_points] for id in ids]
        ids_skill = [oid[id] for oid,id in zip(skill_ids, ids_skill)]

        return ys_, ids_skill


    def evaluate(self, xs_b, eval=False):
        xs_b = xs_b.numpy()
        n_skills = self.n_skills
        n_funcs = self.n_funcs
        chain_length = self.chain_length
        bsize, n_points, dim = xs_b.shape

        # if self.ordered_chain:
        #     if n_skills != chain_length:
        #         raise ValueError(f"number of skills ({n_skills}) != chain length ({chain_length})")
        #     skill_ids = np.array([np.arange(n_skills+1) for _ in range(bsize)])
        # else:
        #     skill_ids = np.array([np.random.permutation(n_skills)[:chain_length] for _ in range(bsize)])
        #     skill_ids = np.concatenate([skill_ids, np.ones((bsize,1), dtype=int)*n_skills], axis=1)
        # func_ids = np.random.randint(n_funcs, size=(bsize, chain_length,))

        skill_ids = np.random.randint(n_skills, size=(bsize,self.max_chain_length))
        skill_ids = np.concatenate([skill_ids, np.ones((bsize,1), dtype=int)*n_skills], axis=1)
        func_ids = np.random.randint(n_funcs, size=(bsize, self.max_chain_length,))


        ys_b, ids_b = self.__get_chain(xs_b, func_ids, skill_ids, eval)   

        return torch.from_numpy(ys_b).float(), torch.from_numpy(np.array(ids_b)), skill_ids, func_ids

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegression(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=5,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_size, 1)

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b):
        W_init = self.W_init.to(xs_b.device)
        Ws = self.Ws.to(xs_b.device)
        v = self.v.to(xs_b.device)

        activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        layer_activations = [activ]
        for i in range(self.n_layers-2):
            activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
            layer_activations.append(activ)
        ys_b_nn = (activ @ v)[:, :, 0]        
        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegressionAsymmetric(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=4,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        hidden_layer_sizes = [hidden_layer_size for i in range(n_layers-1)]
        # make it asymmetric by making the last layer twice as wide
        hidden_layer_sizes[-1] = 2*hidden_layer_size
        self.hidden_layer_sizes = hidden_layer_sizes

        if len(hidden_layer_sizes) != n_layers-1:
            raise ValueError("hidden_layer_sizes={} not compatible with n_layers={}.".format(hidden_layer_sizes, n_layers))

        self.Ws = []
        for layer_idx in range(n_layers-1):
            if layer_idx == 0:
                self.Ws.append(torch.randn(self.b_size, self.n_dims, hidden_layer_sizes[layer_idx]))
            else:
                self.Ws.append(torch.randn(self.b_size, hidden_layer_sizes[layer_idx-1], hidden_layer_sizes[layer_idx]))

        # self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        # self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_sizes[-1], 1)

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b):
        for W in self.Ws:
            W.to(xs_b.device)
        v = self.v.to(xs_b.device)

        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # for i in range(self.n_layers-2):
        #     activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        #     layer_activations.append(activ)
        # ys_b_nn = (activ @ v)[:, :, 0]

        layer_activations = []
        for layer_idx in range(self.n_layers-1):
            if layer_idx == 0:
                activ = self.act_func(xs_b @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            else:
                activ = self.act_func(layer_activations[layer_idx-1] @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            layer_activations.append(activ)
        ys_b_nn = (layer_activations[-1] @ v)[:, :, 0]

        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


# TODO: @ks: need to check with @yingcong if we need this
class TaskFamily:
    def __init__(self):
        super(TaskFamily, self).__init__()
        ## TODO: we need define tasks that can input different dims (do not have to inpout 1-dim only)
        # Then we can include out_dim in task_kwargs
        self.task_mapping = {
            0: {
                'task_name' : 'linear_regression',
                'n_dims' : 10,
                'task_kwargs' : None
            }
        }

    # Given identifiers, input x, then output y
    # For an example: (x,1,2,3)->(f1(x),f2(f1(x)),f3(f2(f1(x))))
    def evaluate(self, xs_b, ids):
        bsize = xs_b.shape[0]
        points = xs_b.shape[1]
        dims = xs_b.shape[2]
        ys_b = []
        for id in ids:
            task_name = self.task_mapping[id]['task_name']
            n_dims = self.task_mapping[id]['n_dims']
            task_kwargs = self.task_mapping[id]['task_kwargs']
            # To adapt the output to its next input
            # Another option: fail when it does not fit
            if n_dims > dims:
                xs_b = torch.cat(
                    (
                        xs_b, torch.zeros(bsize, points, n_dims - dims, device=xs_b.device)
                    ),
                    axis=2,
                )
            elif n_dims < dims:
                xs_b = xs_b[:,:,:n_dims]
            task_sampler = get_task_sampler(task_name, n_dims, bsize, **task_kwargs)
            task = task_sampler()
            xs_b = task.evaluate(xs_b)
            ys_b.append(xs_b)

        return ys_b

    @staticmethod
    def get_metric():
        def squared_error(ys_pred, ys):
            return (ys - ys_pred).square().mean(-1)
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error

