import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import argparse
from torch.utils.data import DataLoader, TensorDataset
import math, random
from data_utils.ModelNetDataLoader40 import ModelNetDataLoader40
from data_utils.ModelNetDataLoader10 import ModelNetDataLoader10
from data_utils.ShapeNetDataLoader import PartNormalDataset
from data_utils.KITTIDataLoader import KITTIDataLoader
from data_utils.ScanObjectNNDataLoader import ScanObjectNNDataLoader
from utils.utils import set_seed,jitter_point_cloud,scale_point_cloud,rotate_point_cloud,class_wise_rot_sca, SRSDefense, SORDefense
from utils import *
from PointNN.models import Point_NN, Point_PN_mn40
 



def cls_acc(output, target, topk=1):
    pred = output.topk(topk, 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
    acc = 100 * acc / target.shape[0]
    return acc

def data_preprocess(data):
    """Preprocess the given data and label.
    """
    points, target = data
    points = points # [B, N, C]
    target = target[:, 0] # [B]
    points = points.cuda()
    target = target.cuda().long()
    return points, target


def load_txt(args, data_path):
    print('Start Loading Dataset...')
    file_names = os.listdir(data_path)
    print("Data_path=\"{}\"".format(data_path))
    assert len(file_names) > 0, 'No UD found! please generate UDs!'
    file_names = sorted(file_names)
    dataset, labels = [], []
    for fn in tqdm(file_names):
        if 'origin' not in fn:
            file_path = os.path.join(data_path, fn)
            pc = np.loadtxt(file_path).astype(np.float32)
            dataset.append(pc)
            labels.append(fn.split('.')[0].split('_')[-1])

    dataset = torch.from_numpy(np.array(dataset))
    labels = torch.from_numpy(np.array(labels).astype(np.float32)).unsqueeze(1)
    DATASET = TensorDataset(dataset, labels)
    dataloader = DataLoader(
        DATASET,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers
    )
    print(f'Finish Loading Dataset...{len(dataset)}')
    return dataloader


def load_clean_train_data(args, data_path):
    if args.dataset == 'ModelNet40':
        DATASET = ModelNetDataLoader40(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'ModelNet10':
        DATASET = ModelNetDataLoader10(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'ShapeNetPart':
        DATASET = PartNormalDataset(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'kitti':
        DATASET = KITTIDataLoader(
            root=data_path,
            npoints=256,
            split='train',
        )
    elif args.dataset == 'ScanObjectNN':
        DATASET = ScanObjectNNDataLoader(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
        )
    else:
        raise NotImplementedError

    T_DataLoader = torch.utils.data.DataLoader(
        DATASET,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers
    )
    return T_DataLoader
def get_arguments():
    
    parser = argparse.ArgumentParser()
    # parser.add_argument('--dataset', type=str, default='mn40')
    parser.add_argument('--dataset', type=str, default='ModelNet10')

    # parser.add_argument('--split', type=int, default=1)
    # parser.add_argument('--split', type=int, default=2)
    parser.add_argument('--split', type=int, default=3)

    parser.add_argument('--bz', type=int, default=16)  # Freeze as 16

    parser.add_argument('--points', type=int, default=1024)
    parser.add_argument('--stages', type=int, default=4)
    parser.add_argument('--dim', type=int, default=72)
    parser.add_argument('--k', type=int, default=90)
    parser.add_argument('--alpha', type=int, default=1000)
    parser.add_argument('--beta', type=int, default=100)


    parser.add_argument('--batch_size', type=int, default=16, metavar='N', help='input batch size for training (default: 1)')
    parser.add_argument('--input_point_nums', type=int, default=1024, help='Point nums of each point cloud')
    parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
    parser.add_argument('--num_workers', type=int, default=4,help='Worker nums of data loading.')
    parser.add_argument('--slight_range', type=int, default=15, help='x,y angle range [para 1]')
    parser.add_argument('--main_range', type=int, default=120, help='z angle range [para 2]')
    parser.add_argument('--sca_min', type=float, default=0.6, help='scale min bound [para 3]')
    parser.add_argument('--sca_max', type=float, default=0.8, help='scale max bound [para 4]')
    parser.add_argument('--clean_train', action='store_true')
    parser.add_argument('--aug', action='store_true', help='using data augmentations')
    parser.add_argument('--aug_type', type=str, default='rot')

    parser.add_argument('--seed', type=int, default=2023, metavar='S', help='random seed (default: 2023)')
    args = parser.parse_args()
    return args
    

@torch.no_grad()
def main():


    
    print('==> Loading args..')
    args = get_arguments()
    print(args)

    set_seed(args.seed)

    print('==> Preparing model..')
    point_nn = Point_NN(input_points=args.points, num_stages=args.stages,
                        embed_dim=args.dim, k_neighbors=args.k,
                        alpha=args.alpha, beta=args.beta).cuda()
    point_nn.eval()


    print('==> Preparing data..')

    if args.dataset == 'ModelNet40':
        args.NUM_CLASSES = 40
        args.data_path = "data/modelnet40_normal_resampled"
        test_dataset = ModelNetDataLoader40(root=args.data_path, npoint=args.input_point_nums, split='test', normal_channel=False)
    elif args.dataset == 'ModelNet10':
        args.NUM_CLASSES = 10
        args.data_path = "data/modelnet40_normal_resampled"
        test_dataset = ModelNetDataLoader10(root=args.data_path, npoint=args.input_point_nums, split='test', normal_channel=False)
    elif args.dataset == 'ShapeNetPart':
        args.NUM_CLASSES = 16
        args.data_path = "data/shapenetcore_partanno_segmentation_benchmark_v0_normal"
        test_dataset = PartNormalDataset(root=args.data_path, npoint=args.input_point_nums, split='test', normal_channel=False)
    elif args.dataset == 'kitti':
        args.NUM_CLASSES = 2
        args.data_path = '/HARD-DATA/LW/DATA/KITTI/training/object_cloud'
        test_dataset = KITTIDataLoader(root=args.data_path, npoints=256, split='test')
    elif args.dataset == 'ScanObjectNN':
        args.NUM_CLASSES = 15
        args.data_path ='data/h5_files'
        test_dataset = ScanObjectNNDataLoader(root=args.data_path, npoint=args.input_point_nums, split='test')

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers
    )

    data_path = os.path.join("./UDs", args.dataset, str(args.slight_range) + '_' + str(args.main_range) + '_' + str(args.sca_min) + '_' + str(args.sca_max)  , "example")
    print(data_path)
    if args.clean_train:
        train_loader = load_clean_train_data(args, args.data_path)
    else:
        train_loader = load_txt(args, data_path)

    print('==> Constructing Point-Memory Bank..')

    feature_memory, label_memory = [], []
    # with torch.no_grad():
    for data in tqdm(train_loader):
        if args.dataset == 'ShapeNetPart':
            data = data[:2] 
        data, labels = data_preprocess(data)

        if args.aug:
            # data = data.clone().detach().cpu().numpy()
            if args.aug_type == 'rot':
                data = data.clone().detach().cpu().numpy()
                data = rotate_point_cloud(data)
            elif args.aug_type == 'jit':
                data = data.clone().detach().cpu().numpy()
                data = jitter_point_cloud(data)
            elif args.aug_type == 'sca':
                data = data.clone().detach().cpu().numpy()
                data = scale_point_cloud(data)
            elif args.aug_type == 'sor':
                sor = SORDefense(k=2, alpha=1.1)
                data = sor(data)
            elif args.aug_type == 'srs':
                srs = SRSDefense(drop_num=500)
                data = srs(data)
            elif args.aug_type == 'rotsca':
                data = data.clone().detach().cpu().numpy()
                data = scale_point_cloud(rotate_point_cloud(data))
            elif args.aug_type == 'rotjit':
                data = data.clone().detach().cpu().numpy()
                data = jitter_point_cloud(rotate_point_cloud(data))
            elif args.aug_type == 'scajit':
                data = data.clone().detach().cpu().numpy()
                data = jitter_point_cloud(scale_point_cloud(data))
            else:
                print("Wrong data augmentation type!")
                exit(-1)
            data = torch.tensor(data).to(torch.float32)

        



        
        data = data.cuda().permute(0, 2, 1)
        # Pass through the Non-Parametric Encoder
        point_features = point_nn(data)
        feature_memory.append(point_features)

        labels = labels.cuda().long()
        label_memory.append(labels)      

    # Feature Memory
    feature_memory = torch.cat(feature_memory, dim=0)
    feature_memory /= feature_memory.norm(dim=-1, keepdim=True)
    feature_memory = feature_memory.permute(1, 0)
    # Label Memory
    label_memory = torch.cat(label_memory, dim=0)
    label_memory = F.one_hot(label_memory).squeeze().float()


    print('==> Saving Test Point Cloud Features..')
    
    test_features, test_labels = [], []
    # with torch.no_grad():
    for data in tqdm(test_loader):
        if args.dataset == 'ShapeNetPart':
            data = data[:2] 
        points, labels = data_preprocess(data)
        
        points = points.cuda().permute(0, 2, 1)
        # Pass through the Non-Parametric Encoder
        point_features = point_nn(points)
        test_features.append(point_features)

        labels = labels.cuda().long()
        test_labels.append(labels)

    test_features = torch.cat(test_features)
    test_features /= test_features.norm(dim=-1, keepdim=True)
    test_labels = torch.cat(test_labels)


    print('==> Starting Point-NN..')
    # Search the best hyperparameter gamma
    gamma_list = [i * 10000 / 5000 for i in range(5000)]
    best_acc, best_gamma = 0, 0
    for gamma in gamma_list:

        # Similarity Matching
        Sim = test_features @ feature_memory

        # Label Integrate
        logits = (-gamma * (1 - Sim)).exp() @ label_memory

        acc = cls_acc(logits, test_labels)

        if acc > best_acc:
            # print('New best, gamma: {:.2f}; Point-NN acc: {:.2f}'.format(gamma, acc))
            best_acc, best_gamma = acc, gamma

    print(f"Point-NN's classification accuracy: {best_acc:.2f}")


if __name__ == '__main__':
    main()
