import numpy as np
import matplotlib.pyplot as plt

from numpy import linalg


class Offline:
    def __init__(self):
        pass

    def run_simulation(self, data, markov_chain):
        all_sine_squared_errors = []
        num_repetitions = len(data)
        for r in range(num_repetitions):
            sine_squared_errors = []
            cov_eigengap = 0
            lambda1 = 0
            lambda2 = 0
            indices = []
            sample_mean = np.zeros(markov_chain.num_dimensions)
            sample_sum = np.zeros(markov_chain.num_dimensions)
            sample_covariance_matrix = np.zeros((markov_chain.num_dimensions, markov_chain.num_dimensions))
            sample_covariance_matrix_sum = np.zeros((markov_chain.num_dimensions, markov_chain.num_dimensions))
            itr = 0
            for t in range(len(data[r])):
                sample_sum += data[r][t][2]
                # sample_mean = sample_sum / (t + 1)
                At = data[r][t][2]
                sample_covariance_matrix_sum += np.outer(At , At)
                sample_covariance_matrix = sample_covariance_matrix_sum / (t+1)
                if itr % 100 == 0:
                    eigenvalues, v = linalg.eig(sample_covariance_matrix)
                    sorted_indices = np.argsort(eigenvalues)
                    largest_eigenvector = np.real(v[:, sorted_indices[-1]])
                    largest_eigenvector /= linalg.norm(largest_eigenvector)
                    lambda1 = np.real(eigenvalues[sorted_indices][-1])
                    lambda2 = np.real(eigenvalues[sorted_indices][-2])
                    cov_eigengap = lambda1 - lambda2
                    sine_squared_error = (1 - (np.dot(largest_eigenvector, markov_chain.largest_eigenvector)) ** 2)
                    sine_squared_errors.append(sine_squared_error)
                    indices.append(itr)
                itr += 1
            print("Offline Algorithm Results")
            print("lambda1 : ", lambda1)
            print("lambda2 : ", lambda2)
            print("Eigengap : ", cov_eigengap)
            print("=============================================")
            all_sine_squared_errors.append(sine_squared_errors)
        all_sine_squared_errors = np.array(all_sine_squared_errors)
        mean_sine_squared_errors = np.mean(all_sine_squared_errors, axis=0)
        std_sine_squared_errors = np.std(all_sine_squared_errors, axis=0)
        assert (len(mean_sine_squared_errors) == len(indices))
        return mean_sine_squared_errors, std_sine_squared_errors, indices
