import os
import time
import pprint
import torch
import numpy as np
import torch.nn as nn
from learn2learn.vision.models.resnet12 import BasicBlock
import learn2learn as l2l


def hinge(a, b):
    res = a - b
    res[res < 0] = 0
    return res

def pertubated_hinge(a, b, eps):
    res = a - b + eps
    res[res < 0] = 0
    return res
    

def square_hinge(l, gamma):
    return hinge(l, gamma) * hinge(l, gamma)
        
def btk_eps_surrogate(k, n_tasks, losses, gamma, eps):
    return (1/k) \
        * (torch.sum( 
                    (losses - pertubated_hinge(losses, gamma, eps))  \
                    - (1/n_tasks) * (n_tasks - k) * gamma 
                    ))
        
def log(log_file_path, string):
    '''
    Write one line of log into screen and file.
        log_file_path: Path of log file.
        string:        String to write in log file.
    '''
    with open(log_file_path, 'a+') as f:
        f.write(string + '\n')
        f.flush()
    print(string)

# for older version of learn2learn library
# backbone of resnet12
class ResNet12Backbone(nn.Module):

    """
    [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/models/resnet12.py)

    **Description**

    The 12-layer residual network from Mishra et al, 2017.

    The code is adapted from [Lee et al, 2019](https://github.com/kjunelee/MetaOptNet/)
    who share it under the Apache 2 license.

    List of changes:

    * Rename ResNet to ResNet12.
    * Small API modifications.
    * Fix code style to be compatible with PEP8.
    * Support multiple devices in DropBlock

    **References**

    1. Mishra et al. 2017. “A Simple Neural Attentive Meta-Learner.” ICLR 18.
    2. Lee et al. 2019. “Meta-Learning with Differentiable Convex Optimization.” CVPR 19.
    3. Lee et al's code: [https://github.com/kjunelee/MetaOptNet/](https://github.com/kjunelee/MetaOptNet/)
    4. Oreshkin et al. 2018. “TADAM: Task Dependent Adaptive Metric for Improved Few-Shot Learning.” NeurIPS 18.

    **Arguments**

    * **output_size** (int) - The dimensionality of the output.
    * **hidden_size** (list, *optional*, default=640) - Size of the embedding once features are extracted.
        (640 is for mini-ImageNet; used for the classifier layer)
    * **keep_prob** (float, *optional*, default=1.0) - Dropout rate on the embedding layer.
    * **avg_pool** (bool, *optional*, default=True) - Set to False for the 16k-dim embeddings of Lee et al, 2019.
    * **drop_rate** (float, *optional*, default=0.1) - Dropout rate for the residual layers.
    * **dropblock_size** (int, *optional*, default=5) - Size of drop blocks.

    **Example**
    ~~~python
    model = ResNet12(output_size=ways, hidden_size=1600, avg_pool=False)
    ~~~
    """

    def __init__(
        self,
        output_size,
        hidden_size=640,  # mini-ImageNet images, used for the classifier
        keep_prob=1.0,  # dropout for embedding
        avg_pool=True,  # Set to False for 16000-dim embeddings
        drop_rate=0.1,  # dropout for residual layers
        dropblock_size=5,
    ):
        super(ResNet12Backbone, self).__init__()
        self.inplanes = 3
        self.output_size = output_size
        block = BasicBlock

        self.layer1 = self._make_layer(
            block,
            64,
            stride=2,
            drop_rate=drop_rate,
        )
        self.layer2 = self._make_layer(
            block,
            160,
            stride=2,
            drop_rate=drop_rate,
        )
        self.layer3 = self._make_layer(
            block,
            320,
            stride=2,
            drop_rate=drop_rate,
            drop_block=True,
            block_size=dropblock_size,
        )
        self.layer4 = self._make_layer(
            block,
            640,
            stride=2,
            drop_rate=drop_rate,
            drop_block=True,
            block_size=dropblock_size,
        )
        if avg_pool:
            self.avgpool = nn.AvgPool2d(5, stride=1)
        else:
            self.avgpool = l2l.nn.Lambda(lambda x: x)
        self.keep_prob = keep_prob
        self.keep_avg_pool = avg_pool
        self.dropout = nn.Dropout(p=1.0 - self.keep_prob, inplace=False)
        self.drop_rate = drop_rate

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight,
                    mode='fan_out',
                    nonlinearity='leaky_relu',
                )
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        self.features = torch.nn.Sequential(
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.avgpool,
            l2l.nn.Flatten(),
            self.dropout,
        )

    def _make_layer(
        self,
        block,
        planes,
        stride=1,
        drop_rate=0.0,
        drop_block=False,
        block_size=1,
    ):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(
            self.inplanes,
            planes,
            stride,
            downsample,
            drop_rate,
            drop_block,
            block_size)
        )
        self.inplanes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        return x
    
def flip_label(y, ratio,  pattern='sym', one_hot=False,n_class=5):
    y = np.array(y)
    if one_hot:
        y = np.argmax(y,axis=1)
    
    for i in range(len(y)):
        if pattern=='sym':
            p1 = ratio/(n_class-1)*np.ones(n_class)
            p1[y[i]] = 1-ratio
            y[i] = np.random.choice(n_class,p=p1)
        elif pattern=='asym':
            y[i] = np.random.choice([y[i],(y[i]+1)%n_class],p=[1-ratio,ratio])            
            
    if one_hot:
        y = np.eye(n_class)[y]
    y = tuple(y)
    return y