import numpy as np
from p_stars import GaussMix
from math import cos,sin

def get_2drot_in_dimen(dimen, ang, nd1, nd2):
    ang = ang / 180 * np.pi
    retval = np.eye(dimen)
    retval[nd1,nd1] = cos(ang)
    retval[nd1,nd2] = -sin(ang)
    retval[nd2,nd1] = sin(ang)
    retval[nd2,nd2] = cos(ang)
    
    return retval
    
def get_std_gauss_4mix(dimen, crnr=1., rotMtx=None, shift=None):
    # empr set
    gmixprobs=np.ones((4,))* 0.25 #[.25,.5,.1,.1]
    gmixmeans=np.ones((4, dimen))
    gmixstdevs = np.ones((4,dimen))
    gmixmeans[0,:] *= -crnr
    gmixmeans[3,:] *=  crnr
    gmixmeans[1, range(0,dimen,2)] *= -crnr
    gmixmeans[1, range(1,dimen,2)] *=  crnr
    gmixmeans[2, range(1,dimen,2)] *=  crnr
    gmixmeans[2, range(0,dimen,2)] *= -crnr

    if rotMtx is not None:
        for j in range(4):
            gmixmeans[j,:] = rotMtx @ gmixmeans[j,:]
    
    if shift is not None:
        for j in range(4):
            gmixmeans[j,:] += shift[:]
            
    print ("Mean: \n", gmixmeans)
    gmix = GaussMix(gmixprobs, gmixmeans, gmixstdevs)

    return gmix


def get_theta_set(n_theta, dimen):
    while True:
        thetas = np.random.random((n_theta, dimen))
        eigs,v = np.linalg.eig(thetas @ thetas.transpose())
        smol=np.where(np.abs(eigs)<1e-2)
        if (np.size(smol) <= 0):
            return thetas


def sample_rotation_mtx(dimen):

    n_rots=dimen*2

    angs = np.random.uniform(-180.,180.,size=(n_rots,))
    from_ndx=np.random.randint(0,dimen,size=(n_rots,))
    to_delta=np.random.randint(1,dimen,size=(n_rots,))
    rotm = np.eye(dimen)
    for rot in range(n_rots):
        ang = angs[rot]
        fro = from_ndx[rot]
        to = (fro + to_delta[rot]) % (dimen)
        rotm = rotm @ get_2drot_in_dimen(dimen, ang, fro, to)
    
    return rotm


