import torch
import numpy as np
import torch.nn as nn
import trimesh
import os
import shutil
from scipy import signal


def generate_human_mesh(pose, save_path, vposer, smplx_model, gt=True):
    ori, trans, latent = torch.split(pose.clone(), [3, 3, 32], dim=-1)
    pose = dict()
    out_meshs = []
    body_pose = vposer.decode(latent, output_type='aa')
    body_pose = body_pose.view(body_pose.shape[0], -1)  # bs * 21 * 3 -> bs * 63

    pose['body_pose'] = body_pose
    pose['pose_embedding'] = latent.view(-1, 32)
    pose['global_orient'] = ori.view(-1, 3)
    pose['transl'] = trans.view(-1, 3)

    smplx_output = smplx_model(return_verts=True, **pose)
    body_verts_batch = smplx_output.vertices
    smplx_faces = smplx_model.faces

    for i in range(body_verts_batch.shape[0]):
        out_mesh = trimesh.Trimesh(body_verts_batch[i].cpu().numpy(), smplx_faces, process=False)
        out_meshs.append(out_mesh)
        if gt:
            save_dir = os.path.join(save_path, 'gt_{}.obj'.format(i))
        else:
            save_dir = os.path.join(save_path, 'pred_{}.obj'.format(i))
        out_mesh.export(save_dir)

    return out_meshs


def generate_human_mesh_from_joints(joints, gt_joints, save_path, pose_init, vposer, smplx_model, gt_factor=0.0):
    pose_init = pose_init.squeeze(0)
    fitter = SMPLX_fitter(vposer, smplx_model, bs=joints.shape[0] * joints.shape[1]).cuda()
    fitter.ori.data, fitter.trans.data, fitter.latent.data = torch.split(pose_init.clone(), [3, 3, 32], dim=-1)
    opt = torch.optim.SGD(fitter.parameters(), lr=20.0, momentum=0.9)
    torch.set_grad_enabled(True)
    fit_smplx = None

    for step in range(1000):
        fit_smplx = fitter()
        fit_joints = fit_smplx.joints[None, :, :23]
        target_joints = joints * (1 - gt_factor) + gt_joints * gt_factor
        loss = ((fit_joints - target_joints) ** 2).mean()
        loss_l2 = ((torch.cat([fitter.ori.data, fitter.trans.data, fitter.latent.data], dim=-1) - pose_init) ** 2).mean()
        if step % 100 == 99:
            print(loss.item(), loss_l2.item())
        opt.zero_grad()
        loss.backward()
        opt.step()
    torch.set_grad_enabled(False)

    body_verts_batch = fit_smplx.vertices
    smplx_faces = smplx_model.faces
    for i in range(16):
        out_mesh = trimesh.Trimesh(body_verts_batch[i].cpu().numpy(), smplx_faces, process=False)
        out_mesh.export(os.path.join(save_path, 'pred_{}.obj'.format(i)))


def lpf_1d(data, kernel_size=5):
    assert len(data.shape) == 1

    lp_kernel = gaussian_kernel(kernel_size, std_dev=kernel_size / 4 - 0.25)
    data = torch.cat([data[0].repeat(kernel_size // 2), data, data[-1].repeat(kernel_size // 2)])[None, None, :]
    data = torch.nn.functional.conv1d(data, lp_kernel)

    return data[0, 0].detach()


def lpf_1d_batch(data, kernel_size=5):
    assert len(data.shape) == 2

    lp_kernel = gaussian_kernel(kernel_size, std_dev=kernel_size / 4 - 0.25)

    front_pad = data[:, 0].repeat(1, kernel_size // 2)
    back_pad = data[:, -1].repeat(1, kernel_size // 2)
    padded_data = torch.cat([front_pad, data, back_pad], dim=1).unsqueeze(1)
    data = torch.nn.functional.conv1d(padded_data, lp_kernel)

    return data.squeeze(1).detach()


def gaussian_kernel(size, std_dev=1.0):
    """Create a 1D Gaussian kernel using the specified size and standard deviation."""
    # Create a tensor with `size` elements, centered at zero
    x = torch.linspace(-size // 2, size // 2, steps=size)
    # Calculate the Gaussian function
    gaussian = torch.exp(-x.pow(2) / (2 * std_dev**2))
    # Normalize the kernel to ensure the sum is 1
    gaussian /= gaussian.sum()
    return gaussian.view(1, 1, -1)


def latent_to_joints(pose_latent, vposer, smplx_model):
    bs, T, c = pose_latent.shape
    body_pose = vposer.decode(pose_latent[:, :, 6:], output_type='aa').view(-1, 63)
    smplx_output = smplx_model(return_verts=True, body_pose=body_pose,
                               global_orient=pose_latent[:, :, :3].view(-1, 3),
                               transl=pose_latent[:, :, 3:6].view(-1, 3),
                               pose_embedding=pose_latent[:, :, 6:].view(-1, 32))
    joints = smplx_output.joints
    return joints.reshape(bs, T, 127, 3)


def values_to_colors(values):
    # Define the color stops and corresponding RGB values
    min_val, max_val = np.min(values) - 0.001, np.max(values) + 0.001
    assert max_val - min_val > 9
    color_stops = np.array([min_val, max_val - 7, max_val - 4.5, max_val - 2, max_val])
    # color_stops = np.array([min_val, max_val - 6, max_val - 4, max_val - 2, max_val])
    # color_stops = np.array([min_val, min_val + (max_val - min_val) * 0.4, min_val + (max_val - min_val) * 0.6,
    #                         min_val + (max_val - min_val) * 0.8, max_val])
    colors = np.array([[0, 0, 255], [0, 0, 255], [0, 255, 0], [255, 165, 0], [255, 0, 0]])

    # Ensure values are within the specified range
    # values = np.clip(values, -10, 0)

    # Find the appropriate color ranges for interpolation
    lower_idx = np.digitize(values, color_stops) - 1
    upper_idx = np.minimum(lower_idx + 1, len(color_stops) - 1)

    # Perform linear interpolation
    t = (values - color_stops[lower_idx]) / (color_stops[upper_idx] - color_stops[lower_idx])
    interpolated_colors = (1 - t[:, np.newaxis]) * colors[lower_idx] + t[:, np.newaxis] * colors[upper_idx]

    return interpolated_colors.astype(int)


def write_ply(filename, points, attention=None, clip_high=10000):
    num_points = len(points)
    num_legal_points = (points[:, 1] < clip_high).int().sum().item()
    if attention is not None:
        colors = values_to_colors(attention.detach().cpu().numpy())
    with open(filename, 'w') as f:
        # Write PLY header
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {num_legal_points}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        if attention is not None:
            f.write("property uchar red\n")
            f.write("property uchar green\n")
            f.write("property uchar blue\n")
        f.write("end_header\n")

        # Write point data
        for i in range(num_points):
            x, y, z = points[i]
            if y < clip_high:
                if attention is not None:
                    r, g, b = colors[i]
                    # r, g, b = 0, 80, 175
                    f.write(f"{x} {y} {z} {r} {g} {b}\n")
                else:
                    f.write(f"{x} {y} {z}\n")


class SMPLX_fitter(nn.Module):
    def __init__(self, vposer, smplx, bs=1):
        super().__init__()
        self.vposer = vposer
        self.smplx = smplx

        self.latent = nn.Parameter(torch.zeros(bs, 32))
        self.trans = nn.Parameter(torch.zeros(bs, 3))
        self.ori = nn.Parameter(torch.zeros(bs, 3))

    def forward(self):
        pose = dict()
        body_pose = self.vposer.decode(self.latent, output_type='aa')
        body_pose = body_pose.view(body_pose.shape[0], -1)  # bs * 21 * 3 -> bs * 63

        pose['body_pose'] = body_pose
        pose['pose_embedding'] = self.latent.view(-1, 32)
        pose['global_orient'] = self.ori.view(-1, 3)
        pose['transl'] = self.trans.view(-1, 3)

        smplx_output = self.smplx(return_verts=True, **pose)
        return smplx_output
