import torch
import numpy as np
import torch.nn.functional as F
from tools import builder
from utils.logger import *
from utils import misc


def run_net(args, config, train_writer=None, val_writer=None):
    logger = get_logger(args.log_name)
    pts_path = args.pts_path

    # build model
    base_model = builder.model_builder(config.model)
    if args.use_gpu:
        base_model.to(args.local_rank)
    base_model.zero_grad()
    base_model.eval()

    mask_axis = 1
    mask_direct = True
    mask_pos = 0

    gt_points = torch.from_numpy(np.load(pts_path)).cuda()
    dims = len(gt_points.shape)
    assert dims == 2 or dims == 3
    if dims == 2:
        gt_points = gt_points.unsqueeze(0)

    B, N, C = gt_points.shape
    if N > config.model.npoints:
        gt_points = misc.fps(gt_points.float(), config.model.npoints)
    else:
        gt_points = F.interpolate(gt_points.transpose(1, 2), size=config.model.npoints, mode='linear').transpose(1, 2)

    with torch.no_grad():
        part_points, completion_points = base_model.part_generation(gt_points, mask_axis, mask_direct, mask_pos)

    completion_data = {
        'part_points': part_points.cpu().numpy(),
        'completion_points': completion_points.cpu().numpy(),
        'gt_points': gt_points.cpu().numpy()
    }
    np.save('completion_data.npy', completion_data)
    print('Successfully save completion data to completion_data.npy')
