from __future__ import annotations
import numpy as np
import pandas as pd
import statsmodels.api as sm
# import trmf

from abc import ABC, abstractmethod
from numpy.typing import NDArray
from typing import Optional, Union, List
from tqdm import tqdm
from prophet import Prophet as ProphetBase

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import LSTM, TimeDistributed, RepeatVector

from sklearn.preprocessing import StandardScaler

from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.dataset.common import ListDataset
from gluonts.evaluation import make_evaluation_predictions
from prophet import Prophet as ProphetBase
from .util import (
    learnAR,
    leastSquares,
    truncatedSVD,
    lowestMultiple,
    donohoRank,
    energyRank,
)


class TimeSeriesModel(ABC):
    @staticmethod
    @abstractmethod
    def updatable() -> bool:
        pass

    @staticmethod
    @abstractmethod
    def oneShot() -> bool:
        pass

    @abstractmethod
    def fit(self, series: NDArray) -> TimeSeriesModel:
        pass

    @abstractmethod
    def update(self, series: NDArray) -> TimeSeriesModel:
        pass

    @abstractmethod
    def predict(self, numSteps: int) -> NDArray:
        pass


                
class LSTM_(TimeSeriesModel):
    def __init__(
        self, predictionLength: int=1, lags: int=30, epochs: int=100, batch_size: int =128, num_layers: int = 2,  **kwargs
    ) -> None:
        super().__init__()
        self.fitted = False
        self.freq = "D"
        self.start = pd.Period("01-01-2019", freq=self.freq)
        self.predictionLength = max(
            2, predictionLength
        )  # errors occur when predictionLength = 1
        self.h = self.predictionLength
        self.lags = lags
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_layers = num_layers
    
    def to_supervised(self, data):
            # flatten data
            T,N = data.shape
            feature_mat = np.zeros([T-self.h, self.lags,N])
            output = np.zeros([T-self.h,self.h,N])

            for t in range(self.lags,len(data)-self.h):
                feature_mat[t-self.lags,:,:] = data[t-self.lags:t,:]
                output[t-self.lags,:,:] = data[t:t+self.h,:]
            return feature_mat, output
        
    def test_prep(self, alldata):
            T,N = alldata.shape
            X = alldata[-self.lags:,:]
            return X



    @staticmethod
    def updatable() -> bool:
        return True

    @staticmethod
    def oneShot() -> bool:
        return False

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert series.ndim == 2, "Expected T x N matrix!"
        assert np.isnan(series).sum() == 0    
        self.scaler = StandardScaler()
        data = self.scaler.fit_transform(series)
        X, y = self.to_supervised(data)
        verbose, epochs, batch_size = True, self.epochs, self.batch_size
        n_timesteps, n_features,n_steps_out,  n_outputs = X.shape[1], X.shape[2], y.shape[1],  y.shape[2]
        # define model
        self.model = Sequential()
        self.model.add(LSTM(45, input_shape=(n_timesteps,n_features)))
        self.model.add(RepeatVector(n_steps_out))
        for _ in range(self.num_layers):
                self.model.add(LSTM(45, return_sequences=True))
            
        self.model.add(TimeDistributed(Dense(units=n_outputs)))
        self.model.compile(loss='mse', optimizer='adam')
        # fit network
        self.model.fit(X, y, epochs=epochs, batch_size=batch_size, verbose=verbose)
        self.len_train = len(data)
        self.train_data = data
        return self
    
            

    def update(self, series: NDArray) -> TimeSeriesModel:
        series_norm = self.scaler.transform(series)
        self.train_data = np.concatenate([self.train_data, series_norm], axis = 0)
        
    def predict(self,  numSteps):
            #initialise prediction array
            x_test = self.test_prep(self.train_data)[None, ...]
            yhat = self.model.predict(x_test, verbose=1)
            yhat  = np.array(yhat[0,:,:])
            predictions = self.scaler.inverse_transform(yhat)
            return predictions[:numSteps,:]



class DeepAR(TimeSeriesModel):
    def __init__(
        self, predictionLength: int, numEpochs: int, numSamples: int, **kwargs
    ) -> None:
        super().__init__()
        self.fitted = False
        self.freq = "D"
        self.start = pd.Period("01-01-2019", freq=self.freq)
        self.predictionLength = max(
            2, predictionLength
        )  # errors occur when predictionLength = 1
        self.numSamples = numSamples
        self.modelSpec = DeepAREstimator(
            self.freq,
            self.predictionLength,
            trainer_kwargs=dict(max_epochs=numEpochs),
            **kwargs,
        )

    @staticmethod
    def updatable() -> bool:
        return True

    @staticmethod
    def oneShot() -> bool:
        return False

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert series.ndim == 2, "Expected T x N matrix!"
        train_ds = ListDataset(
            [{"target": x, "start": self.start} for x in series.T],
            freq=self.freq,
        )
        self.trainData = series
        self.model = self.modelSpec.train(train_ds)
        self.fitted = True
        return self

    def update(self, series: NDArray) -> TimeSeriesModel:
        self.trainData = np.concatenate([self.trainData, series], axis = 0)
        # raise NotImplementedError("DeepAR does not support online updates!")

    def predict(self, numSteps: int) -> NDArray:
        assert (
            numSteps <= self.predictionLength
        ), "Cannot predict more than `self.predictionLength` steps!"
        test_ds = ListDataset(
            [
                {
                    "target": np.concatenate([x, np.zeros(self.predictionLength)]),
                    "start": self.start,
                }
                for x in self.trainData.T
            ],
            freq=self.freq,
        )
        forecast_it, _ = make_evaluation_predictions(
            dataset=test_ds,
            predictor=self.model,
            num_samples=self.numSamples,
        )
        return np.column_stack([forecast.mean[:numSteps] for forecast in forecast_it])


class MSSA(TimeSeriesModel):
    def __init__(
        self,
        numSeries: int,
        numCoefs: int,
        rank: Optional[int] = None,
        rankEst: str = "donoho",
        arOrder: Optional[Union[int, List[int]]] = None,
    ) -> None:
        super().__init__()
        if rank is None:
            assert rankEst in ("donoho", "energy", "fixed")
        self.numSeries = numSeries
        self.numCoefs = numCoefs
        self.numPageRows = numCoefs + 1
        self.rank = rank
        self.rankEst = rankEst
        if arOrder is not None:
            self.arOrder = (
                arOrder if isinstance(arOrder, list) else [arOrder] * numSeries
            )
        else:
            self.arOrder = None
        self.maxOrder = max(1, max(self.arOrder)) if self.arOrder is not None else None
        self.fitted = False

    @staticmethod
    def updatable() -> bool:
        return True

    @staticmethod
    def oneShot() -> bool:
        return False

    def _check_dims(self, series: NDArray) -> None:
        assert series.ndim == 2, "Expected T x N matrix!"
        assert (
            series.shape[1] == self.numSeries
        ), f"Expected {self.numSeries} time series!"

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert not self.fitted, "Model already fitted!"
        self._check_dims(series)

        # Truncate time series to multiple of L, then form page matrix
        numSteps = series.shape[0]
        truncatedSteps = lowestMultiple(numSteps, self.numPageRows)
        self.page = series[:truncatedSteps].reshape(self.numPageRows, -1, order="F")

        # Denoise the page matrix, then fit the betas
        if self.rank is None:
            if self.rankEst == "donoho":
                self.rank = donohoRank(self.page)
            elif self.rankEst == "energy":
                self.rank = energyRank(self.page)
            else:
                self.rank = 5
        self.denoisedPage = truncatedSVD(self.page, self.rank)
        self.coefs = leastSquares(self.denoisedPage[:-1].T, self.denoisedPage[-1])

        if self.arOrder is not None:
            assert self.maxOrder is not None
            # Recover the stationary process, then fit AR coefficients for each series
            self.arCoefs = np.zeros((self.maxOrder, self.numSeries))
            extractedNoise = series[:truncatedSteps] - self.denoisedPage.reshape(
                truncatedSteps, -1, order="F"
            )
            for i in range(self.numSeries):
                if self.arOrder[i] > 0:
                    self.arCoefs[-self.arOrder[i] :, i] = learnAR(
                        extractedNoise[:, i], self.arOrder[i]
                    )
                else:
                    self.arCoefs[:, i] = 0

        # Store dataset, with extra space for future time steps
        self.history = np.empty((2 * numSteps, self.numSeries))
        self.history[:numSteps] = series
        self.historyLength = numSteps

        self.fitted = True

        return self

    def update(self, series: NDArray) -> TimeSeriesModel:
        assert self.fitted, "Model not yet fitted!"
        self._check_dims(series)
        numSteps = series.shape[0]
        if numSteps + self.historyLength <= self.history.shape[0]:
            # If there is enough space, just store the new data
            self.history[self.historyLength : self.historyLength + numSteps] = series
        else:
            # Allocate more space to store time series
            expandedLength = 2 * (self.historyLength + numSteps)
            oldHistory, self.history = self.history, np.empty(
                (expandedLength, self.numSeries)
            )
            self.history[: self.historyLength] = oldHistory
            self.history[self.historyLength : self.historyLength + numSteps] = series
        self.historyLength += numSteps
        return self

    def predict(self, numSteps: int) -> NDArray:
        # Retrieve the most recent values of the time series to do autoregressive prediction
        contextLength = (
            self.numCoefs + self.maxOrder
            if self.maxOrder is not None
            else self.numCoefs
        )

        fForecastWithContext = np.empty((contextLength + numSteps, self.numSeries))
        fForecastWithContext[:contextLength] = self.history[
            self.historyLength - contextLength : self.historyLength
        ]
        for i in range(contextLength, contextLength + numSteps):
            fForecastWithContext[i] = (
                fForecastWithContext[i - self.numCoefs : i].T @ self.coefs
            )
        fForecast = fForecastWithContext[contextLength:]

        if self.arOrder is not None:
            assert self.maxOrder is not None
            yContext = self.history[
                self.historyLength - contextLength : self.historyLength
            ]
            fImputed = np.empty((contextLength, self.numSeries))
            for idx in range(self.numCoefs, contextLength):
                fImputed[idx] = yContext[idx - self.numCoefs : idx].T @ self.coefs
            xForecastWithContext = np.empty((contextLength + numSteps, self.numSeries))
            xForecastWithContext[:contextLength] = yContext - fImputed
            for idx in range(contextLength, contextLength + numSteps):
                xForecastWithContext[idx] = (
                    xForecastWithContext[idx - self.maxOrder : idx] * self.arCoefs
                ).sum(axis=0)
            xForecast = xForecastWithContext[contextLength:]
            fForecast += xForecast

        return fForecast


class UnivariateARIMA(TimeSeriesModel):
    def __init__(self, arOrder: int, diffOrder: int, maOrder: int) -> None:
        super().__init__()
        self.arOrder = arOrder
        self.diffOrder = diffOrder
        self.maOrder = maOrder
        self.fitted = False

    @staticmethod
    def updatable() -> bool:
        return True

    @staticmethod
    def oneShot() -> bool:
        return False

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert not self.fitted, "Model already fitted!"
        assert series.ndim == 1, "Expected 1D time series!"
        try:
            self.model = sm.tsa.ARIMA(
                series,
                order=(self.arOrder, self.diffOrder, self.maOrder),
            ).fit()
        except:
            print("Error fitting while enforcing stationarity, trying without...")
            self.model = sm.tsa.ARIMA(
                series,
                order=(self.arOrder, self.diffOrder, self.maOrder),
                enforce_stationarity=False,
            ).fit()
        self.fitted = True
        return self

    def update(self, series: NDArray) -> TimeSeriesModel:
        assert self.fitted, "Model not yet fitted!"
        assert series.ndim == 1, "Expected 1D time series!"
        self.model = self.model.append(series, refit=False)
        return self

    def predict(self, numSteps: int) -> NDArray:
        assert self.fitted, "Model not yet fitted!"
        return self.model.forecast(numSteps)


class MultivariateARIMA(TimeSeriesModel):
    def __init__(
        self,
        numSeries: int,
        arOrder: Union[int, List[int]],
        diffOrder: Union[int, List[int]],
        maOrder: Union[int, List[int]],
    ) -> None:
        super().__init__()
        self.numSeries = numSeries
        self.arOrder = arOrder if isinstance(arOrder, list) else [arOrder] * numSeries
        self.diffOrder = (
            diffOrder if isinstance(diffOrder, list) else [diffOrder] * numSeries
        )
        self.maOrder = maOrder if isinstance(maOrder, list) else [maOrder] * numSeries
        self.fitted = False

    @staticmethod
    def updatable() -> bool:
        return True

    @staticmethod
    def oneShot() -> bool:
        return False

    def _check_dims(self, series: NDArray) -> None:
        assert series.ndim == 2, "Expected T x N matrix!"
        assert (
            series.shape[1] == self.numSeries
        ), f"Expected {self.numSeries} time series!"

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert not self.fitted, "Model already fitted!"
        self._check_dims(series)
        print("Fitting individual ARIMA models...")
        self.models = [
            UnivariateARIMA(self.arOrder[i], self.diffOrder[i], self.maOrder[i]).fit(
                series[:, i]
            )
            for i in tqdm(range(self.numSeries))
        ]
        self.fitted = True
        return self

    def update(self, series: NDArray) -> TimeSeriesModel:
        assert self.fitted, "Model not yet fitted!"
        self._check_dims(series)
        for i in range(self.numSeries):
            self.models[i].update(series[:, i])
        return self

    def predict(self, numSteps: int) -> NDArray:
        assert self.fitted, "Model not yet fitted!"
        return np.column_stack(
            [self.models[i].predict(numSteps) for i in range(self.numSeries)]
        )


class Prophet(TimeSeriesModel):
    def __init__(
        self,
        numSeries: int,
        changepointPriorScale: Union[float, List[float]] = 1e-3,
        seasonalityPriorScale: Union[float, List[float]] = 1e-2,
        seasonalityMode: Union[str, List[str]] = "additive",
        freq: str = "H",
    ) -> None:
        super().__init__()
        self.freq = freq
        self.numSeries = numSeries
        self.changepointPriorScale = (
            changepointPriorScale
            if isinstance(changepointPriorScale, list)
            else [changepointPriorScale] * numSeries
        )
        self.seasonalityPriorScale = (
            seasonalityPriorScale
            if isinstance(seasonalityPriorScale, list)
            else [seasonalityPriorScale] * numSeries
        )
        self.seasonalityMode = (
            seasonalityMode
            if isinstance(seasonalityMode, list)
            else [seasonalityMode] * numSeries
        )
        self.fitted = False

    @staticmethod
    def updatable() -> bool:
        return False

    @staticmethod
    def oneShot() -> bool:
        return True

    def fit(self, series: NDArray) -> TimeSeriesModel:
        assert not self.fitted, "Model already fitted!"
        assert series.ndim == 2, "Expected T x N matrix!"
        T, N = series.shape
        dates = pd.date_range(start="01-01-2019", periods=T, freq=self.freq)
        dfs = [pd.DataFrame({"ds": dates, "y": series[:, i]}) for i in range(N)]
        print(dfs[0])
        print("Fitting Prophet models...")
        self.models = [
            ProphetBase(
                changepoint_prior_scale=self.changepointPriorScale[i],
                seasonality_prior_scale=self.seasonalityPriorScale[i],
                seasonality_mode=self.seasonalityMode[i],
            ).fit(dfs[i])
            for i in tqdm(range(len(dfs)))
        ]
        self.fitted = True
        return self

    def update(self, series: NDArray) -> TimeSeriesModel:
        raise NotImplementedError("Prophet does not support online updates!")

    def predict(self, numSteps: int) -> NDArray:
        assert self.fitted, "Model not yet fitted!"
        return np.column_stack(
            [
                model.predict(
                    model.make_future_dataframe(
                        numSteps, include_history=False, freq=self.freq
                    )
                ).yhat.values
                for model in self.models
            ]
        )
