import abc
from collections import OrderedDict

from typing import Iterable
from torch import nn as nn

from rlkit.core.batch_rl_algorithm import EvalAlgorithm, BatchRLAlgorithm, BatchRLAlgorithmA, BatchRLAlgorithm2,\
    BatchRLAlgorithm3, BatchRLAlgorithm4, BatchRLAlgorithm4b, BatchRLAlgorithm4_2, BatchRLAlgorithm4_3,\
    BatchRLAlgorithm4sbp, BatchRLAlgorithm4sbpt, BatchRLAlgorithm4sbptt,  BatchRLAlgorithm4bt, BatchRLAlgorithm4sbp2, BatchRLAlgorithm4sbp3,\
    BatchRLAlgorithm4sbpq, BatchRLAlgorithmv4, BatchRLAlgorithmv4b

from rlkit.core.online_rl_algorithm import OnlineRLAlgorithm
from rlkit.core.trainer import Trainer, CustomTrainer
from rlkit.torch.core import np_to_pytorch_batch


class TorchOnlineRLAlgorithm(OnlineRLAlgorithm):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchEvalAlgorithm(EvalAlgorithm):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchBatchRLAlgorithm(BatchRLAlgorithm):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm2(BatchRLAlgorithm2):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm3(BatchRLAlgorithm3):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm4(BatchRLAlgorithm4):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm4b(BatchRLAlgorithm4b):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm4bt(BatchRLAlgorithm4bt):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm4sbp(BatchRLAlgorithm4sbp):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm4sbpq(BatchRLAlgorithm4sbpq):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm4sbpt(BatchRLAlgorithm4sbpt):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm4sbptt(BatchRLAlgorithm4sbptt):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm4sbp2(BatchRLAlgorithm4sbp2):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithm4sbp3(BatchRLAlgorithm4sbp3):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithmv4(BatchRLAlgorithmv4):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)

class TorchBatchRLAlgorithmv4b(BatchRLAlgorithmv4b):
    def to(self, device):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.to(device)
            else:
                net.to(device)

    def training_mode(self, mode):
        for net in self.pi_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)
        for net in self.beta_trainer.networks:
            if isinstance(net, list):
                for n in net:
                    n.train(mode)
            else:
                net.train(mode)


class TorchBatchRLAlgorithm4_2(BatchRLAlgorithm4_2):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchBatchRLAlgorithm4_3(BatchRLAlgorithm4_3):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchBatchRLAlgorithmA(BatchRLAlgorithmA):
    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)

class TorchTrainer(Trainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        self.train_from_torch(batch)

    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    @abc.abstractmethod
    def train_from_torch(self, batch):
        pass

    @property
    @abc.abstractmethod
    def networks(self) -> Iterable[nn.Module]:
        pass

class BetaTorchTrainer(Trainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        self.policy_train(batch)
        self.q_train(batch)

    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    @abc.abstractmethod
    def policy_train(self, batch):
        pass

    @abc.abstractmethod
    def q_train(self, batch):
        pass

    @property
    @abc.abstractmethod
    def networks(self) -> Iterable[nn.Module]:
        pass

class RTorchTrainer(Trainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        out = self.train_from_torch(batch)

        if out is not None:
            return out


    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    @abc.abstractmethod
    def train_from_torch(self, batch):
        pass

    @property
    @abc.abstractmethod
    def networks(self) -> Iterable[nn.Module]:
        pass

class R2TorchTrainer(Trainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        out1, out2 = self.train_from_torch(batch)

        return out1, out2


    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    @abc.abstractmethod
    def train_from_torch(self, batch):
        pass

    @property
    @abc.abstractmethod
    def networks(self) -> Iterable[nn.Module]:
        pass

class CustomTorchTrainer(CustomTrainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch, replay_buffer):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        self.train_from_torch(batch, replay_buffer)

    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    @abc.abstractmethod
    def train_from_torch(self, batch, replay_buffer):
        pass

    @property
    @abc.abstractmethod
    def networks(self) -> Iterable[nn.Module]:
        pass
