import sys
import numpy as np
from math import sqrt
from timeit import default_timer as timer
from continualobservation import continual_observation_noise, binary_mechanism_noise
from scipy.linalg import toeplitz, matmul_toeplitz
from sys import stdout
import matplotlib.pyplot as plt
from tueplots import bundles

import pandas as pd

# Parameters

dimensions = 10000 # Number of noise values generated in each step
repetitions = 5 # Number of times the simulation is run
MIN_T = 100
MAX_T = 65000
# MAX_T = 6500
base = 1.1 # Values of T are set to integer powers of the base
DO_SAVE_DATA = True

# File name(s) in ./dumps/ with '[our|toeplitz|binary]_' trimmed away
# Empty string '' signifies that we will generate the data instead of loading
if len(sys.argv) > 1:
    LOAD_FILE = sys.argv[1]
else:
    LOAD_FILE = None
    # By default, running this file will load the following .csv files
    # Set LOAD_FILE = None to regenerate the data
    LOAD_FILE = 'performance_result_dim10000_reps5_minT100_max65000_2023-04-12 12:25:42.csv'

# Run tests on a given number of steps
# Simulate the time expended for each method

def timing(T):
    t_values.append(T)
    our_results.append([])
    binary_results.append([])
    toeplitz_results.append([])

    for _ in range(repetitions):
        print("*", end='')
        stdout.flush()

        # Timing our method
        start = timer()
        for x in continual_observation_noise(T, 1, dimensions):
            y = 1 # do nothing
        elapsed = timer() - start
        our_results[-1].append(elapsed)

        # Timing binary method
        start = timer()
        for x in binary_mechanism_noise(T, 1, dimensions):
            y = 1 # do nothing
        elapsed = timer() - start
        binary_results[-1].append(elapsed)

        # Setup for convolution-based method
        r = np.array([1] + T*[0]) # First row of L
        c = []
        f = 1.
        c.append(f)
        for k in range(1,T+1):
            f *= (1-0.5/k)
            c.append(f)
        c = np.array(c)  # First column of L

        # Timing convolution-based method
        start = timer()
        x = np.random.normal(0, 1, size = (T+1, dimensions))
        toeplitz_noise = matmul_toeplitz((c, r), x)
        elapsed = timer() - start
        toeplitz_results[-1].append(elapsed)


# Run performance test

if LOAD_FILE is None:
    i = 0
    T = 1
    while T < MIN_T:
        i += 1
        T = int(base**i)

    t_values = []
    our_results = []
    binary_results = []
    toeplitz_results = []

    while T < MAX_T:
        print(f"\nT={T}")
        timing(T)
        i += 1
        T = int(base**i)

    # Store everything in DataFrames to allow for dumping as .csv
    df_our_results = pd.DataFrame(np.array(our_results),
                                  index=t_values,
                                  columns=[i for i in range(repetitions)])
    df_binary_results = pd.DataFrame(np.array(binary_results),
                                  index=t_values,
                                  columns=[i for i in range(repetitions)])
    df_toeplitz_results = pd.DataFrame(np.array(toeplitz_results),
                                       index=t_values,
                                       columns=[i for i in range(repetitions)])
    if DO_SAVE_DATA:
        df_our_results.to_csv(f'dumps/our_performance_result_dim{dimensions}_reps{repetitions}_minT{MIN_T}_max{MAX_T}_{pd.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.csv')
        df_binary_results.to_csv(f'dumps/binary_performance_result_dim{dimensions}_reps{repetitions}_minT{MIN_T}_max{MAX_T}_{pd.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.csv')
        df_toeplitz_results.to_csv(f'dumps/toeplitz_performance_result_dim{dimensions}_reps{repetitions}_minT{MIN_T}_max{MAX_T}_{pd.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.csv')
else:
    df_our_results = pd.read_csv('dumps/our_' + LOAD_FILE, index_col=0, header=0)
    df_binary_results = pd.read_csv('dumps/binary_' + LOAD_FILE, index_col=0, header=0)
    df_toeplitz_results = pd.read_csv('dumps/toeplitz_' + LOAD_FILE, index_col=0, header=0)

# Plotting

def performance_test_plot(save_plot=True, show_plot=True):
    marker_size = 2

    with plt.rc_context(bundles.icml2022()):
        plt.figure()
        plt.xscale('log')

        t = np.array(df_our_results.index)

        # Hard-coded 'scatter plot'
        col1 = list(df_our_results.columns)[0]
        y = 1000 * df_toeplitz_results[col1] / t
        plt.plot(t, y, 'o', color='C0', markersize=marker_size, label='Henzinger et al. w/ matmul_toeplitz')
        y = 1000 * df_binary_results[col1] / t
        plt.plot(t, y, 'o', color='C1', markersize=marker_size, label='Binary Mechanism')
        y = 1000 * df_our_results[col1] / t
        plt.plot(t, y, 'o', color='C2', markersize=marker_size, label='Our Mechanism')
        for col in list(df_our_results.columns)[1:]:
            y = 1000 * df_toeplitz_results[col] / t
            plt.plot(t, y, 'o', color='C0', markersize=marker_size,)
            y = 1000 * df_binary_results[col] / t
            plt.plot(t, y, 'o', color='C1', markersize=marker_size,)
            y = 1000 * df_our_results[col] / t
            plt.plot(t, y, 'o', color='C2', markersize=marker_size)

        x1,x2,_,y2 = plt.axis()
        plt.axis((x1,x2,0,y2))

        plt.xlabel('number of d-dimensional vectors')
        plt.ylabel('ms/output')
        plt.legend()

        if save_plot:
            plt.savefig('../../figures/performance_test_v2.pdf',
                        bbox_inches='tight')
        if show_plot:
            plt.show()
        else:
            plt.clf()

performance_test_plot(save_plot=False)
