/*
The code presented realise an instance of the matrix-tensor model and run gradient descent to minimize the minus log-likelihood.
The code is associated with the paper: Who is Afraid of Big Bad Minima? Analysis of Gradient-Flow in a Spiked Matrix-Tensor Model
submitted to NeurIPS 2019.

The program admits a parallel implementation, in order to use it the line after "// Uncomment to parallelise" must be uncommented.
*/


//  SIMULATION PARAMETERS TO PASS:

// 1 => N : size of the system ;
// 2 => p  : order of the tensor (fixed to 3 in the current implementation);
// 3 => delta2inv : inverse of the variance for the matrix noise ;
// 4 => deltap : variance for the tensor noise ;
// 5 => N_steps : number of steps for the gradient descent ;
// 6 => precision : precision required befor stop ;
// 7 => N_simulations : how many simulation run for a given realisation of the noise ;
// 8 => n_group : unique number to be associated to this realisation of the noise ;
// 9 => gradient rate : time interval, dt, used to approximate gradient .

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <float.h>
#include <random>
// Uncomment to parallelise
//#include "omp.h"

void print_matrix_int(int** A, int rows, int columns);
int generate_random_integer(int Nmax);
int factorial(int n);
void normalize_sqrtN(double* vettore, int size);
double dot_product(double *v, double *u, int size);
double generate_random_unif();
double generate_random_gaussian();
void gradient_descent(double* x_star, unsigned int N, unsigned int p, float* T, float* Y, unsigned short int** spinsNNN, unsigned short int** spinsNN, 
	unsigned int nn_NNN_interactions, unsigned int nn_NN_interactions, unsigned int N_steps, double rate, double precision,
	char directory[100], double deltap, double delta2inv, unsigned short int n_group, unsigned short int idx_simulation);

int main(int argc, char *argv[]){

	// Uncomment to parallelise
	//omp_set_num_threads( 10 );	// SET NUMBER OF WORKERS 

	srand(time(NULL));

	unsigned int N, p, nn_NNN_interactions, nn_NN_interactions;
	unsigned short int **spinsNNN, **spinsNN;
	unsigned int idx, idx_2;
	unsigned short int tmp_int;
	unsigned int N_steps, iteration;
	unsigned short int N_simulations, n_group, idx_simulation; 

	double delta2inv, delta2, deltap, rho_NNN, rho_NN, tmp_double, magnetization;
	float *T, *Y;
	double *x_star;
	double precision, rate;

	char simulation_out_name[100];
	char directory[100]="simulations/";
	// char directory[100]="";

	////////////////////////////////////// Read input parameters or set defaults //////////////////////////////////////

	if(argc==1){
		N = 1000; p = 3; 
		delta2inv = 2.7; deltap = 1.;

		N_steps = 50001; precision = 1e-8;
		N_simulations = 10; n_group = 1;

		rate = 1.;
	}else{
		N = atoi(argv[1]); p = atoi(argv[2]);
		delta2inv = atof(argv[3]); deltap = atof(argv[4]);

		N_steps = atoi(argv[5]); precision = atof(argv[6]);
		N_simulations = atoi(argv[7]); n_group = atoi(argv[8]);

		rate = atof(argv[9]);
	}

	nn_NNN_interactions = N*N; nn_NN_interactions = int(N*sqrt(N));
	delta2 = 1/delta2inv; precision = precision*rate;

	x_star = (double*)malloc((N)*sizeof(double));
	T = (float*)malloc((nn_NNN_interactions)*sizeof(float));
	Y = (float*)malloc((nn_NN_interactions)*sizeof(float));
	spinsNNN =(unsigned short int**)malloc((p)*sizeof(unsigned short int*));
	for (idx=0; idx<p; idx++){
		spinsNNN[idx] =(unsigned short int*)malloc((nn_NNN_interactions)*sizeof(unsigned short int));
	}
	spinsNN = (unsigned short int**)malloc((2)*sizeof(unsigned short int*));
	for (idx=0; idx<2; idx++){
		spinsNN[idx] =(unsigned short int*)malloc((nn_NN_interactions)*sizeof(unsigned short int));
	}

	////////////////////////////////////// Generate instance //////////////////////////////////////

	printf("Creating interacting spins..\n");

	for (idx=0;idx<nn_NNN_interactions;idx++){
		spinsNNN[0][idx] = generate_random_integer(N);
		tmp_int = generate_random_integer(N);
		while(spinsNNN[0][idx] == tmp_int){
			tmp_int = generate_random_integer(N);   
		}
		spinsNNN[1][idx] = tmp_int;

		tmp_int = generate_random_integer(N);
		while(spinsNNN[0][idx] == tmp_int || spinsNNN[1][idx] == tmp_int){
			tmp_int = generate_random_integer(N);   
		}
		spinsNNN[2][idx] = tmp_int;
	}

	for (idx=0;idx<nn_NN_interactions;idx++){
		spinsNN[0][idx] = generate_random_integer(N);
		tmp_int = generate_random_integer(N);
		while(spinsNN[0][idx] == tmp_int){
			tmp_int = generate_random_integer(N);   
		}
		spinsNN[1][idx] = tmp_int;
	}

	printf("Creating signal..\n");
	for (idx=0;idx<N;idx++){
		x_star[idx] = generate_random_gaussian();
	}
	normalize_sqrtN(x_star,N);

	printf("Creating interactions..\n");

	rho_NNN = pow(N,.5*double(p))/sqrt(double(factorial(p))*nn_NNN_interactions);
	for (idx=0;idx<nn_NNN_interactions;idx++){
		T[idx] = float(sqrt(deltap)*generate_random_gaussian());
		tmp_double = x_star[spinsNNN[0][idx]];
		for (idx_2 = 1; idx_2<p; idx_2++){
			tmp_double *= x_star[spinsNNN[idx_2][idx]];
		}
		T[idx] += float(rho_NNN*sqrt(2.)*tmp_double/double(N));
		T[idx] = float(rho_NNN*sqrt(2.)*T[idx]/(double(N)*deltap));	
	}

	rho_NN = double(N)/sqrt(2.*nn_NN_interactions);
	for (idx=0;idx<nn_NN_interactions;idx++){
		Y[idx] = float(sqrt(delta2)*generate_random_gaussian() + rho_NN*x_star[spinsNN[0][idx]]*x_star[spinsNN[1][idx]]/sqrt(N));
		Y[idx] = float(rho_NN*Y[idx]/(sqrt(N)*delta2));	
	}

	printf("Instance created!\n\n");

	////////////////////////////////////// GD //////////////////////////////////////

	// Uncomment to parallelise
	//#pragma omp parallel for
		for(idx_simulation=1;idx_simulation<=N_simulations;idx_simulation++){

			printf("Starting descent n. %d..\n\n",idx_simulation);

			gradient_descent(x_star, N, p, T, Y, spinsNNN, spinsNN, nn_NNN_interactions, nn_NN_interactions, N_steps, rate, precision, 
				directory, deltap, delta2inv, n_group, idx_simulation);
			printf("Finished descent n. %d..\n\n",idx_simulation);

		}

	printf("Done.\n");
}

/////////////////////////////////////////////////// GD function ///////////////////////////////////////////////////

void gradient_descent(double* x_star, unsigned int N, unsigned int p, float* T, float* Y, unsigned short int** spinsNNN, unsigned short int** spinsNN, 
	unsigned int nn_NNN_interactions, unsigned int nn_NN_interactions, unsigned int N_steps, double rate, double precision, 
	char directory[100], double deltap, double delta2inv, unsigned short int n_group, unsigned short int idx_simulation){

	double *x_estim, *neg_gradient_2, *neg_gradient_p;
	double tmp_double;
	long double energy_2, energy_p, energy_old, magnetization;
	unsigned int idx, iteration;

	char simulation_out_name[100];

	FILE *simulation_out;

	x_estim = (double*)malloc((N)*sizeof(double));
	neg_gradient_p = (double*)malloc((N)*sizeof(double));
	neg_gradient_2 = (double*)malloc((N)*sizeof(double));

	
	sprintf(simulation_out_name, "%s2+%d_C-code_GradientDescent_diluted_nok_deltap%.2f_delta2inv%.2f_dt%.2e_N%d-%d-%d.txt",directory,p,deltap,delta2inv,rate,N,idx_simulation,n_group);


	for (idx=0;idx<N;idx++){
		x_estim[idx] = generate_random_gaussian();
	}
	normalize_sqrtN(x_estim,N);

	energy_2 = 0.; energy_p = 0.; energy_old = 0.; magnetization = dot_product(x_estim,x_star,N);
	
	simulation_out=fopen(simulation_out_name,"w");
	fprintf(simulation_out,"0 \t %.16Lf \t %.16Lf \t %.16Lf \t %.16Lf\n", magnetization, energy_2+energy_p, energy_2, energy_p);
	fclose(simulation_out);

	for (iteration=1;iteration<=N_steps;iteration++){
		for (idx=0;idx<N;idx++){
			neg_gradient_p[idx] = 0.; neg_gradient_2[idx] = 0.;
		}

		for (idx=0;idx<nn_NNN_interactions;idx++){
			tmp_double = T[idx]*x_estim[spinsNNN[0][idx]]*x_estim[spinsNNN[1][idx]]*x_estim[spinsNNN[2][idx]];
			neg_gradient_p[spinsNNN[0][idx]] += tmp_double;
			neg_gradient_p[spinsNNN[1][idx]] += tmp_double;
			neg_gradient_p[spinsNNN[2][idx]] += tmp_double;
		}

		for (idx=0;idx<nn_NN_interactions;idx++){
			tmp_double = Y[idx]*x_estim[spinsNN[0][idx]]*x_estim[spinsNN[1][idx]];
			neg_gradient_2[spinsNN[0][idx]] += tmp_double;
			neg_gradient_2[spinsNN[1][idx]] += tmp_double;
		}

		for (idx=0;idx<N;idx++){
			energy_2 -= neg_gradient_2[idx];
			energy_p -= neg_gradient_p[idx];

			x_estim[idx] = x_estim[idx] + rate*(neg_gradient_2[idx]+neg_gradient_p[idx])/x_estim[idx];
		}
		energy_2 /= (2.*double(N)); energy_p /= (double(p)*double(N));
		normalize_sqrtN(x_estim,N); magnetization = dot_product(x_estim,x_star,N);

		simulation_out=fopen(simulation_out_name,"a");
		fprintf(simulation_out,"%d \t %.16Lf \t %.16Lf \t %.16Lf \t %.16Lf\n", iteration, magnetization, energy_2+energy_p, energy_2, energy_p);	
		fclose(simulation_out);

		if(fabs(energy_old-energy_2-energy_p)<precision){
			break;
		}else{
			energy_old = energy_2 + energy_p;
		}
	}

	tmp_double = -1;
	for(idx=0;idx<N;idx++){
		if(tmp_double<fabs(x_estim[idx])) tmp_double = fabs(x_estim[idx]);
	}
	printf("Maximum value of estimator of %s : %f\n", simulation_out_name, tmp_double);
}

///////////////////////////////////////////////// Auxiliary functions /////////////////////////////////////////////////

void print_matrix_int(int** A, int rows, int columns){
	for (int idx_r = 0; idx_r < rows; idx_r++){
		for (int idx_c = 0; idx_c < columns; idx_c++){
			printf("%d \t",A[idx_r][idx_c]);
		}
		printf("\n");
	}
}

int generate_random_integer(int Nmax){
	return int(generate_random_unif() * (Nmax));
}

int factorial(int n){
   int c;
   int result = 1;
 
   for (c = 1; c <= n; c++){
		 result = result*c;
	}
   return ( result );
}

void normalize_sqrtN(double* vettore, int size){
	double norm = 0.; 
	double sqrtN = sqrt(size);
	int idx;

	for (idx=0;idx<size;idx++){
		norm += vettore[idx]*vettore[idx];
	}
	norm = sqrt(norm);
	for (idx=0;idx<size;idx++){
		vettore[idx] = sqrtN*vettore[idx]/norm;
	}
}

double dot_product(double *v, double *u, int size){
	double result = 0.0;
	for (int idx = 0; idx < size; idx++)
		result += v[idx]*u[idx];
	return result/double(size);
}

double generate_random_unif(){
	return rand() / std::nextafter(double(RAND_MAX), DBL_MAX);
}

double generate_random_gaussian(){
	double v1,v2,rsq;
	while(true){
		v1=2.0*generate_random_unif()-1.0;
		v2=2.0*generate_random_unif()-1.0;
		rsq=v1*v1+v2*v2;
		if(rsq<1.0 && rsq>0.0) break;
	}
	double fac=sqrt(-2.*log(rsq)/rsq);
        // saveGaussian=v1*fac;
        // switchGaussian=true;
	return v2*fac;
}
