import platform
# print('python_version ==', platform.python_version())
import torch
# print('torch.__version__ ==', torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import time
import argparse
import numpy as np
from renhd import *
# from sgnht import *
from evaluation import *
import os
import random
from model_zoo_ext import *
import torch.multiprocessing as torch_mp
# import multiprocessing as mp
import time
import resnet
# import psutil # after suspending a process, it's blocked and cannot remuse
import statistics

#############################################################################
model_names = sorted(name for name in resnet.__dict__
                     if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

'''set up hyperparameters of the experiments'''
parser = argparse.ArgumentParser(description='RENHD on RESNET tested on CIFAR10 appending noise')
parser.add_argument('--train-batch-size', type=int, default=256)  # 64
parser.add_argument('--test-batch-size', type=int, default=10000)
parser.add_argument('--num-burn-in', type=int, default=10000)
parser.add_argument('--num-epochs', type=int, default=1000)
parser.add_argument('--evaluation-interval', type=int, default=50)  # 1000
parser.add_argument('--exchange-per-evaluation', type=int, default=50)  # 100
parser.add_argument('--eta-theta', type=float, default=1e-7)  # 1.7e-8 step size
parser.add_argument('--c-theta', type=float, default=0.1)  # 0.1，0.001
parser.add_argument('--mu', type=float, default=1)  # 0.1，0.001
# parser.add_argument('--gamma-theta', type=float, default=1)
parser.add_argument('--xi-base', type=float, default=1.2)  # 2
parser.add_argument('--prior-precision', type=float, default=1e-3)  # 1e-3 equally the weight of l2 norm
parser.add_argument('--permutation', type=float, default=0.3)
parser.add_argument('--device-num', type=int, default=3)
parser.add_argument('--cuda-id', type=int, default=1)
parser.add_argument('--num-process-per-processor', type=int, default=5)  # 2
parser.add_argument('--renhd-model', type=str, default='resnet_renhd')
parser.add_argument('--check-point-path', type=str, default='check-point')
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                         ' (default: resnet20)')
args = parser.parse_args()

cuda_availability = torch.cuda.is_available()

torch.manual_seed(args.seed)  # set up random seed for CPU
if cuda_availability:
    # torch.cuda.manual_seed(args.seed)#set up random seed for GPU
    torch.cuda.manual_seed_all(args.seed)#set up random seed for all GPU
# print('set up random seed for GPU')

#############################################################################
if cuda_availability:
    # torch.cuda.set_device(args.device_num)
    num_processor = args.device_num  # torch.cuda.device_count()
    curr_device = 'cuda:' + str(args.cuda_id)
else:
    num_processor = 1
    curr_device = 'cpu'

num_process = num_processor * args.num_process_per_processor

if not os.path.exists(args.check_point_path):
    os.makedirs(os.path.abspath(args.check_point_path), exist_ok=True)

#############################################################################
'''load dataset'''
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./cifar10-dataset', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ])),
    batch_size=args.train_batch_size, shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./cifar10-dataset', train=False,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.1307,), (0.3081,))
                     ])),
    batch_size=args.test_batch_size, shuffle=False, drop_last=True)

N = len(train_loader.dataset)


#############################################################################

def save_checkpoint(state, process_id, MODEL_OR_OPTIM):
    """
    Save the training model
    """
    if MODEL_OR_OPTIM == 'MODEL':
        filename = str(args.renhd_model) + "_process_" + str(process_id) + \
                   "_permutation_" + str(args.permutation) + ".th"
    elif MODEL_OR_OPTIM == 'OPTIM':
        filename = str(args.renhd_model) + "_optimizer_" + str(process_id) + \
                   "_permutation_" + str(args.permutation) + ".th"
    else:
        filename = ''
        print('Wrong MODEL_OR_OPTIM')
    torch.save(state, os.path.join(args.check_point_path, filename))
    # print("Successfully save the "+str(filename))


def load_checkpoint(process_id, MODEL_OR_OPTIM):
    """
        load the trained model
        """
    if MODEL_OR_OPTIM == 'MODEL':
        filename = str(args.renhd_model) + "_process_" + str(process_id) + \
                   "_permutation_" + str(args.permutation) + ".th"
    elif MODEL_OR_OPTIM == 'OPTIM':
        filename = str(args.renhd_model) + "_optimizer_" + str(process_id) + \
                   "_permutation_" + str(args.permutation) + ".th"
    else:
        filename = ''
        print('Wrong MODEL_OR_OPTIM')
    output_model_file = os.path.join(args.check_point_path, filename)
    # if os.path.exists(output_model_file):
    try:
        state_dict = torch.load(output_model_file)
        # print("Successfully load: " + str(output_model_file))
        return state_dict
    except:
        print('Cannot find: ' + str(output_model_file))
        # return None


#############################################################################

def train(model, process_id, xi, shared_dict, sampler, exchange_event, burnin=False):
    if process_id == 0:
        print(args)
        print(model)

    cuda_id = process_id % num_processor
    # cuda_id = args.cuda_id
    if cuda_availability:
        model.cuda(cuda_id)
        sampler.cuda(cuda_id)

    print("Running the process_" + str(process_id) +
          " on the GPU_" + str(cuda_id) +
          ' with temperature ' + str(xi))

    num_labels = model.outputdim
    estimator = FullyBayesian((len(test_loader.dataset), num_labels), \
                              model, \
                              test_loader, \
                              cuda_availability, \
                              cuda_id)
    nIter = 0
    num_valid_sample = 1
    total_loss = 0
    batch_loss_list = []

    sampler.resample_momenta(xi)

    for epoch in range(1, 1 + args.num_epochs):
        if process_id == 0:
            print("#################################################################")
            print("This is the epoch {}; Iterated: {}".format(epoch, nIter))

        for x, y in train_loader:
            nIter += 1
            batch_size = x.data.size(0)
            if args.permutation > 0.0:
                y = y.clone()
                y.data[:int(args.permutation * batch_size)] = torch.LongTensor(
                    np.random.choice(num_labels, int(args.permutation * batch_size)))
            if cuda_availability:
                x, y = x.cuda(device=cuda_id), y.cuda(device=cuda_id)

            model.zero_grad()
            yhat = model(x)
            loss = F.cross_entropy(yhat, y)
            for param in model.parameters():
                loss += args.prior_precision * torch.sum(param ** 2)
            loss.backward()
            batch_loss_list.append(loss.data.item())
            total_loss += loss.data

            '''update params'''
            sampler.update(xi)
            # sampler.update()

            if nIter % args.evaluation_interval == 0:
                if process_id == 0:
                    print(
                        'process:{}; loss:{:6.4f}; thermostats_param:{:6.4f}'.format(
                        process_id,
                        loss.data.item(),
                        sampler.get_z_theta()))

                if not burnin:
                    num_valid_sample += 1
                    if process_id == 0:
                        ## evaluation
                        acc = estimator.evaluation()
                        loss_std = statistics.stdev(batch_loss_list)
                        print('This is the accuracy: %{:6.2f} of process_{}; the standard devivation is {}'
                              .format(acc, process_id, loss_std))
                        model.train()
                        shared_dict['loss_std'] = loss_std

                    if num_valid_sample % args.exchange_per_evaluation == 0:
                        # exchange
                        # '''send energy to the shared memory'''
                        U = total_loss / args.exchange_per_evaluation / args.evaluation_interval # 1 / xi *
                        shared_dict[process_id] = U.item()
                        total_loss = 0

                        shared_dict['ready_exchange'] += 1
                        print('process {} waiting for replica exchange ...'.format(process_id))
                        exchange_event.wait()
                        exchange_event.clear()
                        print('Continue training in process {}'.format(process_id))

                sampler.resample_momenta(xi)
        if burnin and nIter >= args.num_burn_in:
            break

    # Save a trained model
    save_checkpoint(model.state_dict(), process_id, MODEL_OR_OPTIM='MODEL')
    save_checkpoint(sampler.state_dict(), process_id, MODEL_OR_OPTIM='OPTIM')
    if not burnin:
        shared_dict['finished_training'] += 1
    print('Finished the process {}'.format(process_id))

def exchange_process(xi_ladder,shared_dict,exchange_event):
    odd_exchange = False
    suc_exchange = False
    first_replica_xi = []  # record the temperature of the first replica

    while True:
        if (shared_dict['ready_exchange'] == num_process):
            print('conducting repleica exchange here ...')

            print('Temperatures before exchange: {}'.format(xi_ladder))
            odd_exchange, suc_exchange, xi_ladder, first_replica_xi = \
                exchange_replica(shared_dict, odd_exchange, suc_exchange,
                                 num_process, first_replica_xi, xi_ladder)
            print('Temperatures after exchange: {}'.format(xi_ladder))

            shared_dict['ready_exchange'] = 0
            exchange_event.set()
            # print('exchange event status is set: {}'.format(exchange_event.is_set()))

        if shared_dict['finished_training'] == num_process:
            break

    f_name = str(args.renhd_model) + str(args.permutation) + "_permutation.npy"
    output_f_name = os.path.join(args.check_point_path, f_name)
    np.save(output_f_name, first_replica_xi)
    print("Saving temperature of the standard replica: "+str(f_name))

    # print('conducting repleica exchange here ...')
    # time.sleep(10)
    # print('exchange event status is set: {}'.format(exchange_event.is_set()))
    # exchange_event.set()
    # print('exchange event status is set: {}'.format(exchange_event.is_set()))
    # time.sleep(10)
    # exchange_event.clear()
    # print('exchange event status is set: {}'.format(exchange_event.is_set()))
    # print("##########$$$$$$$$$$$$$$$$$$$$$$$$$$$########"
    #       "###########$$$$$$$$$$$$$$$$$$$$$$$$$$$#######")

#############################################################################


if __name__ == '__main__':
    # to use CUDA with multiprocessing, you must use 'spawn'
    ctx = torch_mp.get_context("spawn") # not compatible with mp.event()
    shared_dict = torch_mp.Manager().dict()
    exchange_event = ctx.Event()

    # mp.set_start_method('spawn')
    # exchange_event = mp.Event()
    # p = mp.Process()...

    print("{} processs will be running simultaneously!".format(num_process))
    xi_ladder = [args.xi_base ** j for j in range(num_process)]

    model = resnet.resnet20()
    sampler = RENHD(model, N, args.eta_theta, args.c_theta, args.mu)
    # sampler = SGNHT(model, N, args.eta_theta, args.c_theta)
    # print(model)
    ################################################################################
    ## burnin
    ################################################################################
    print("#############################################"
          "############################################")
    pretrained = 0# num_process
    # wrapper = torch.nn.DataParallel(resnet.__dict__[args.arch]()).to(args.curr_device)
    # for process_id in range(num_process):
    #     filename = str(args.renhd_model) + str(process_id) + "_process.pt"
    #     # filename = "resnet20.th"
    #     output_model_file = os.path.join(args.check_point_path, filename)
    #     if os.path.exists(output_model_file):
    #         checkpoint = torch.load(output_model_file,map_location=torch.device(args.curr_device))
    #         wrapper.load_state_dict(checkpoint['state_dict'])
    #         model = wrapper.module
    #         pretrained += 1
    #         print("Successfully load the model {} and skipping burn-in".format(str(filename)))

    if pretrained >= num_process:
        print("Successfully load the pretrained model and skipping burn-in")
    elif pretrained < num_process:
        print("burn-in ...")
        processes = []
        burnin = True

        for process_id in range(num_process):
            p = ctx.Process(target=train, args=(model, process_id,
                                                xi_ladder[process_id],
                                                shared_dict, sampler,
                                                exchange_event,burnin))
            p.start()
            processes.append(p)

        for p in processes:
            # To wait until a process has completed its work and exited
            p.join()

    ################################################################################
    ## load checkpoints and start new processs
    ################################################################################
    processes = []
    print("#############################################"
          "############################################")
    print('training ...')
    shared_dict['ready_exchange'] = 0
    shared_dict['finished_training'] = 0
    for process_id in range(num_process+1):
        if process_id < num_process -1:
            key = 'deltaE_'+str(process_id)+str(process_id+1)
            shared_dict[key] = []
        if process_id < num_process:
            ## Load a trained model that you have fine-tuned
            m_checkpoint = load_checkpoint(process_id, MODEL_OR_OPTIM='MODEL')
            s_checkpoint = load_checkpoint(process_id, MODEL_OR_OPTIM='OPTIM')

            if m_checkpoint is not None and s_checkpoint is not None:
                model.load_state_dict(m_checkpoint)
                sampler.load_state_dict(s_checkpoint)
            # p = mp.Process()
            p = ctx.Process(target=train, args=(model, process_id, \
                                                xi_ladder[process_id], shared_dict, sampler,
                                                exchange_event))
        else:
            p = ctx.Process(target=exchange_process, args=(xi_ladder,shared_dict, exchange_event))
        p.start()
        processes.append(p)

    for p in processes:
        # To wait until a process has completed its work and exited
        p.join()

    print("Finished Training")
