import os
import cv2
import time
import numpy as np
from termcolor import cprint
from time import gmtime, strftime

import torch

def resume_model(resume, net):
    if os.path.isfile(resume):
        c_print("=> loading checkpoint '{}'".format(resume), color='blue', attrs=['bold'])
        checkpoint = torch.load(resume)
        net.load_state_dict(checkpoint['model_state_dict'])
        c_print("Resume Epoch log: {}".format(checkpoint['log']), color='blue', attrs=['bold'])
        c_print("=> loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']),
                  color='blue', attrs=['bold'])
    else:
        print("=> no checkpoint found at '{}'".format(resume))
        raise ValueError('End')
    return checkpoint['epoch']
    
def time2str():
    return strftime("%a, %d %b %Y %H:%M:%S", gmtime())

def c_print(text, color=None, on_color=None, attrs=None):
    if cprint is not None:
        cprint(text, color=color, on_color=on_color, attrs=attrs)
    else:
        print(text)

class Timer(object):
    """A simple timer."""
    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff

def set_trainable(model, requires_grad):
    for param in model.parameters():
        param.requires_grad = requires_grad


def weights_normal_init(model, dev=0.01):
    if isinstance(model, list):
        for m in model:
            weights_normal_init(m, dev)
    else:
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, dev)
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0.0, dev)

def clip_gradient(model, clip_norm):
    """Computes a gradient clipping coefficient based on gradient norm."""
    totalnorm = 0
    for p in model.parameters():
        if p.requires_grad:
            modulenorm = p.grad.data.norm()
            totalnorm += modulenorm ** 2
    totalnorm = np.sqrt(totalnorm)

    norm = clip_norm / max(totalnorm, clip_norm)
    for p in model.parameters():
        if p.requires_grad:
            p.grad.mul_(norm)

class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta)

        if patience == 0:
            self.is_better = lambda a, b: True

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if mode == 'min':
            self.is_better = lambda a, best: a < best - min_delta
        if mode == 'max':
            self.is_better = lambda a, best: a > best + min_delta

def get_part_half_box(n, m, Mask_1, Mask_2, obj_box):
    boxes = [[] for i in range(2)]
    parts = np.zeros((2, 2))
    max_idx = np.argmax(Mask_1)
    parts[0, 1] = int(max_idx / 28)
    parts[0, 0] = max_idx % 28
    max_idx = np.argmax(Mask_2)
    parts[1, 1] = int(max_idx / 28)
    parts[1, 0] = max_idx % 28

    if m > n:
        islong = 1
    else:
        islong = 0
    for i in range(2):
        if islong == 0:

            parts[i][0] = parts[i][0] * m / 28 + 8
            parts[i][1] = parts[i][1] * m / 28 + 8 + (n - m) / 2
            width  = (obj_box[2]-obj_box[0])/4
            length = (obj_box[3] - obj_box[1]) / 4
        else:
            parts[i][0] = parts[i][0] * n / 28 + 8 + (m - n) / 2
            parts[i][1] = parts[i][1] * n / 28 + 8
            width  = (obj_box[2]-obj_box[0])/4
            length = (obj_box[3] - obj_box[1]) / 4
        box = (np.maximum(0, np.int(parts[i][0] - width)), np.maximum(0, np.int(parts[i][1] - length)),
               np.minimum(m, np.int(parts[i][0] + width)), np.minimum(n, np.int(parts[i][1] + length)))
        box = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
        boxes[i] = box
    return boxes


def get_box(n, m, mask):
    mask = cv2.resize(mask, dsize=(448, 448), interpolation=cv2.INTER_LINEAR)
    mask_max = np.max(mask.flat)
    t = mask_max * 0.1
    t1 = np.max(mask, axis=0)
    for j in range(448):
        if t1[j] > t:
            left = j
            break
    for j in range(447, -1, -1):
        if t1[j] > t:
            right = j
            break
    t2 = np.max(mask, axis=1)
    for j in range(448):
        if t2[j] > t:
            up = j
            break
    for j in range(447, -1, -1):
        if t2[j] > t:
            down = j
            break
    x = (left + right) / 2
    y = (up + down) / 2
    l = np.maximum(right - left, down - up) / 2

    if m > n:
        islong = 1
    else:
        islong = 0
    if islong == 0:
        x = x * m / 448
        y = y * m / 448 + (n - m) / 2
        l = l * m / 448
    else:
        x = x * n / 448 + (m - n) / 2
        y = y * n / 448
        l = l * n / 448
    box = (np.maximum(0, np.int(x - l)), np.maximum(0, np.int(y - l)),
           np.minimum(m, np.int(x + l)), np.minimum(n, np.int(y + l)))
    box = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
    return box

def get_part(n, m, Mask_1, Mask_2):
    boxes = [[] for i in range(2)]
    parts = np.zeros((2, 2))
    max_idx = np.argmax(Mask_1)
    parts[0, 1] = int(max_idx / 28)
    parts[0, 0] = max_idx % 28
    max_idx = np.argmax(Mask_2)
    parts[1, 1] = int(max_idx / 28)
    parts[1, 0] = max_idx % 28

    if m > n:
        islong = 1
    else:
        islong = 0
    for i in range(2):
        if islong == 0:

            parts[i][0] = parts[i][0] * m / 28 + 8
            parts[i][1] = parts[i][1] * m / 28 + 8 + (n - m) / 2
            l = 64 * m / 448
        else:
            parts[i][0] = parts[i][0] * n / 28 + 8 + (m - n) / 2
            parts[i][1] = parts[i][1] * n / 28 + 8
            l = 64 * n / 448
        box = (np.maximum(0, np.int(parts[i][0] - l)), np.maximum(0, np.int(parts[i][1] - l)),
               np.minimum(m, np.int(parts[i][0] + l)), np.minimum(n, np.int(parts[i][1] + l)))
        box = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
        boxes[i] = box
    return boxes
