"""Main script to test a pretrained model"""
import argparse
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

from svrss.utils.paths import Paths
from svrss.utils.functions import count_params
from svrss.learners.tester import Tester
from svrss.models import FCN8s, UNet
# from kurad.loaders.dataset import PDRdata
from svrss.loaders.dataloaders import SequenceDataset
from svrss.utils.distributed_utils import init_distributed_mode

class PDRdata:
    """Class to load PDRdata dataset"""

    def __init__(self):
        self.paths = Paths().get()
        # self.warehouse = self.paths['warehouse']
        self.pdrdata = self.paths['PDRdata']
        self.data_seq_ref = self._load_data_seq_ref()
        self.annotations = self._load_dataset_ids()
        self.train = dict()
        self.validation = dict()
        self.test = dict()
        self.ratio_spilt = [0.8, 0.1, 0.1]
        self._split()

    def _load_data_seq_ref(self):
        path = self.pdrdata / 'data_seq_ref.json'
        with open(path, 'r', encoding='UTF-8') as fp:
            data_seq_ref = json.load(fp)
        return data_seq_ref

    def _load_dataset_ids(self):
        path = self.pdrdata / 'light_dataset_frame_oriented.json'
        with open(path, 'r', encoding='GBK') as fp:
            annotations = json.load(fp)
        return annotations

    # def _split(self):
    #     cum_ratio = np.cumsum(self.ratio_spilt)
    #     for sequence in self.annotations.keys():
    #         num_frames = len(self.annotations[sequence])
    #         # self.annotations[sequence] = np.random.permutation(num_frames)
    #         self.train[sequence] = self.annotations[sequence][0:int(num_frames*cum_ratio[0])]
    #         self.validation[sequence] = self.annotations[sequence]\
    #                 [int(num_frames*cum_ratio[0])::2]
    #         self.test[sequence] = self.annotations[sequence][(int(num_frames*cum_ratio[0])+1)::2]
    
    def _split(self):
        cum_ratio = np.cumsum(self.ratio_spilt)
        for sequence in self.annotations.keys():
            num_frames = len(self.annotations[sequence])
            # self.annotations[sequence] = np.random.permutation(num_frames)
            self.train[sequence] = self.annotations[sequence][0:int(num_frames*cum_ratio[0])]
            self.validation[sequence] = self.annotations[sequence]\
                    [int(num_frames*cum_ratio[0]):int(num_frames*cum_ratio[1])]
            self.test[sequence] = self.annotations[sequence][int(num_frames*cum_ratio[1]):]
    
    def get(self, split):
        """Method to get the corresponding split of the dataset"""
        if split == 'Train':
            return self.train
        if split == 'Validation':
            return self.validation
        if split == 'Test':
            return self.test
        raise TypeError('Type {} is not supported for splits.'.format(split))

def test_model():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', help='Path to config file of the model to test.',
                        default='config.json')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action='store_true')
    args = parser.parse_args()

    init_distributed_mode(args)
    cfg_path = args.cfg
    with open(cfg_path, 'r') as fp:
        cfg = json.load(fp)
    device = torch.device(cfg['device'])
    cfg['distributed'] = args.distributed

    paths = Paths().get()
    # exp_name = cfg['name_exp'] + '_' + str(cfg['version'])
    # path = paths['logs'] / cfg['dataset'] / cfg['model'] / exp_name
    model_path = './logs/PDRdata/unet/unet_e300_lr0.0001_s42_0_2024-03-10-15:55:40/results/test_doppler_model.pt'
    # test_results_path = path / 'results' / 'test_results.json'

    if cfg['model'] == 'fcn8s':
        # loading parallel saved model in local single gpu mode
        model = FCN8s(n_classes=cfg['nb_classes'], n_frames=cfg['nb_input_channels'])
    elif cfg['model'] == 'unet':
        model = UNet(n_classes=cfg['nb_classes'],
                   n_frames=cfg['nb_input_channels'])
    else:
        raise ValueError('model {} is not supported in test.py yet.'.format(cfg['model']))
    
    print('Number of trainable parameters in the model: %s' % str(count_params(model)))
    saved_model = torch.load(model_path, map_location=torch.device('cpu'))
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.load_state_dict(saved_model)
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    tester = Tester(cfg)
    print(tester)
    data = PDRdata()
    test = data.get('Train')
    testset = SequenceDataset(test)
    seq_testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0)
    tester.set_annot_type(cfg['annot_type'])
    if cfg['model'] in ['mvnet', 'fcn8s', 'unet', 'deeplabv3plus']:
        test_metrics = tester.predict(model, seq_testloader, get_quali=True, add_temp=False)
    else:
        # print(cfg['model'])
        test_metrics = tester.predict(model, seq_testloader, get_quali=True, add_temp=True)
    # if get_rank() == 0:
    print('Test Prec: '
          'RD={}'.format(test_metrics['range_doppler']['prec']))
    # @20220227
    print('Test mIoU: '
        'RD={}'.format(test_metrics['range_doppler']['miou']))
    print('Test mIoU by class: '
        'RD={}'.format(test_metrics['range_doppler']['miou_by_class'][0:2]))
    print(test_metrics['range_doppler']['miou_by_class'][2:])
    print('Test Dice: '
        'RD={}'.format(test_metrics['range_doppler']['dice']))
    print('Test Dice by class: '
        'RD={}'.format(test_metrics['range_doppler']['dice_by_class'][0:2]))
    print(test_metrics['range_doppler']['miou_by_class'][2:])
    print('Confusion matrix: RD=')
    for i in range(5):
        print('{}'.format(test_metrics['range_doppler']['confusion_matrix'][i]))

if __name__ == '__main__':
    test_model()
