import math
import numpy as np
import torch
import matplotlib.pyplot as plt

# for debug
import sys
def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)
project_path = '/mnt/neural_acoustic_field'
add_path(project_path)

from libs.utils.utils import sequence_generation, OctaveBandsFactory


# bounce intersection
def naive_dirs_sampling(mesh_points, points, height, n_rays=None, sample_ratio=1.0, h_range=1.5,
                           space_height=150):
    """Sampling ray directions in 2D plane under this case.
    Args:
        mesh_points: `(M, 3)`, all mesh points for the input points
                 under each sampled directions.
        points: `(N, 2)`, listener points.
        h_range: float (cm), the height range length for valid sampling plane.
        space_heigh: float (cm), the height of the room.
    Return:
        bounces: `(K, 2)`, valid bounce points
        dirs: `(N, K, 2)`, directions from listener to bounce points
    """
    range = h_range / (space_height / (height - mesh_points[..., 2].min()))
    bounces = mesh_points[(mesh_points[..., 2]>height-range/2) & (mesh_points[..., 2]<height+range/2)][..., :2]
    # remove redundant bounce points
    bounces = np.array(list(set([tuple([round(x, 0), round(y, 0)]) for x, y in bounces]))).reshape(-1, 2)
    if n_rays is not None:
        sample_ratio = min(1. * n_rays / len(bounces), 1.0)
    dirs = bounces[None, ...] - points[:, None, :]
    # naive sampling
    bounces = bounces[::int(1.0/sample_ratio), :]
    dirs = dirs[:, ::int(1.0/sample_ratio), :]
    return bounces, dirs



# time domain ir diagram  ---inverse filter for each freq band---> discrete time domain value for each freq band
def Hist_to_IR(band, energy_value, seq, fs):
    # seq: j-th band, i-th bin, hbss samples
    seq = torch.from_numpy(seq).to(energy_value.device)
    result = seq*torch.sqrt(energy_value)
    result *= torch.sqrt(torch.from_numpy(np.asarray(band/fs*2.0)))
    return result


def Hist_to_IR_test():
    # number of time bins
    n_bins = 100
    fs = 16000
    # bin size
    bin_size = 0.003
    # maximum time
    t_max = n_bins*bin_size
    # samples per bin
    hbss = int(math.floor(fs*bin_size))
    # number of samples
    N = int(math.ceil(t_max*fs/hbss)*hbss)
    # sound speed
    c = 340
    volume_room = 100 # cubic unit: GIVEN
    #fractional delay
    fdl = 81
    fdl2 = fdl//2
    # generate dirac sequence
    print(int(math.ceil(t_max*fs/hbss)*hbss), N/fs)
    seq = sequence_generation(volume_room, N/fs, c, fs)
    seq = seq[:N]
    # generate octave band
    octave_band = OctaveBandsFactory(fs = fs)
    bws = octave_band.get_bw()

    # Histogram: JUST FOR VERIFICATION
    hist = torch.ones(len(bws), n_bins)

    # generate dirac sequence for each band, pre-computed
    seq_all_band = []
    for j, band in enumerate(bws):
        seq_bp = octave_band.analysis(seq, band=j)
        seq_bp_rot = seq_bp.reshape((-1, hbss))  # n_bins x samples per bin
        normalization = np.linalg.norm(seq_bp_rot, axis=1)
        indices = normalization > 0.0
        seq_bp_rot[indices, :] /= normalization[indices, None]
        seq_all_band.append(seq_bp_rot)

    # compute IR
    ir = []
    for i in range(n_bins):
        # ir for each bin
        ir_single_bin = torch.zeros(hbss)
        for j, band in enumerate(bws):
            # IR for j-th band, i-th bin
            ir_ij = Hist_to_IR(band, hist[j, i], seq_all_band[j][i, :], fs)
            ir_single_bin += ir_ij
        ir.append(ir_single_bin)
    ir = torch.cat(ir)


    # verify: compute all bins together
    np_ir = np.zeros(N)
    for j, band in enumerate(bws):
        result = seq_all_band[j]*np.sqrt(hist.numpy()[j][:, None])
        result = result*np.sqrt(band/fs*2.0)
        print(result.shape)
        result = result.reshape(-1)
        np_ir[:N] += result
    return ir, np_ir


if __name__ == '__main__':
    # mesh_points = np.loadtxt('data/room_0/ori_mesh.xyz')
    # meta_points = np.loadtxt('data/room_0/points.txt')
    # height = meta_points[0, 3]
    # points = meta_points[..., 1:3]
    # bounces, dirs = naive_dirs_sampling(mesh_points, points, height, sample_ratio=0.9)
    # print(dirs.shape)

    ir, np_ir = Hist_to_IR_test()
    ir = ir.numpy()
    plt.plot(ir)
    plt.savefig('data/ir.jpg')
    plt.plot(np_ir)
    plt.savefig('data/np_ir.jpg')