import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

NETS = ['ResNet', 'BNN', 'IM']

def get_network(args):
    if args.net == 'BNN' or args.net =='IM':
        model = StochNet()
    elif args.net == 'ResNet':
        model = Wide_ResNet(28, 10, args.droprate, 10)
    elif args.net =='SNN':
        model = SNN(args.layer, args.stoch_varianz)
    else:
        raise ValueError('[+] Model: not supported')
    return model


class SNN(nn.Module):

    def __init__(self, layer, var):
        super(SNN, self).__init__()
        self.layer = layer
        self.var = torch.tensor(math.sqrt(var)).cuda()
        self.fc_xh1 = nn.Linear(784, 128)
        self.fc_h1h2 = nn.Linear(128, 128)
        self.fc_h2y = nn.Linear(128, 10)

    def forward(self, x):
        'returns a single prediction'
        if self.layer == 0:
            noise = torch.randn(x.shape, device='cuda')*self.var
            x = x+noise
        h1 = self.fc_xh1(x)
        h1 = F.relu(h1)
        if self.layer == 1:
            noise = torch.randn(h1.shape, device='cuda')*self.var
            h1 = h1+noise
        h2= self.fc_h1h2(h1)
        h2 = F.relu(h2)
        if self.layer == 2:
            noise = torch.randn(h2.shape, device='cuda')*self.var
            h2 = h2+noise
        y = self.fc_h2y(h2)
        return y 


class VMGLinear(nn.Module):
    ''' Layer for a Bayesian Neural Net'''

    def __init__(self, in_dim, out_dim, prior_var):
        super(VMGLinear, self).__init__()
        in_dim += 1
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.prior_var = prior_var
        self.mu = nn.Parameter(nn.init.kaiming_normal_(torch.zeros(in_dim, out_dim)))
        self.logvar_in = nn.Parameter(nn.init.normal_(torch.zeros(in_dim), -9, 1e-6))
        self.logvar_out = nn.Parameter(nn.init.normal_(torch.zeros(out_dim), -9, 1e-6))

    def forward(self, x, pr=False):
        m = x.size(0)
        W, D_KL = self.sample(m)
        x = torch.cat([x, torch.ones(m, 1, device='cuda')], 1)
        h = torch.bmm(x.unsqueeze(1), W).squeeze(1)
        if pr:
            print(W.data[0, 0, 0])
        return h, D_KL

    def sample(self, m):
        r, c = self.in_dim, self.out_dim
        M = self.mu
        logvar_r = self.logvar_in
        logvar_c = self.logvar_out
        var_r = torch.exp(logvar_r)
        var_c = torch.exp(logvar_c)
        E = torch.randn(m, *M.shape, device='cuda')
        # Reparametrization trick
        W = M + torch.sqrt(var_r).view(1, r, 1) * E * torch.sqrt(var_c).view(1, 1, c)
        # KL divergence to prior MVN(0, I, V)
        V = self.prior_var
        logvar_v1 = logvar_c
        logvar_v1 = logvar_v1.new_full(logvar_v1.shape, np.log(V))
        var_v1 = logvar_v1.new_full(logvar_v1.shape, V)
        var_inv_v1 = 1 / var_v1
        var_v1_matrix = torch.diag(var_inv_v1)
        D_KL = 1 / 2 * (torch.sum(var_r) * torch.sum(torch.dot(var_inv_v1, var_c))
                        + torch.trace(torch.matmul(torch.matmul(torch.transpose(M, 1, 0), M), var_v1_matrix))
                        - r * c + r * torch.sum(logvar_v1) - c * torch.sum(logvar_r) - r * torch.sum(logvar_c)
                        )

        return W, D_KL


class StochNet(nn.Module):
    def __init__(self):
        super(StochNet, self).__init__()
        self.fc_xh1 = VMGLinear(784, 128, 1)
        self.fc_h1h2 = VMGLinear(128, 128, 1)
        self.fc_h2y = VMGLinear(128, 10, 1)

    def forward(self, x, returnKL=False):
        'returns a single prediction'
        h1, D_KL1 = self.fc_xh1(x)
        h1 = F.relu(h1)
        h2, D_KL2 = self.fc_h1h2(h1)
        h2 = F.relu(h2)
        y, D_KL3 = self.fc_h2y(h2)
        return (y, D_KL1 + D_KL2 + D_KL3) if returnKL else y


def conv3x3(in_planes, out_planes, stride=1):
    '''Code from https://github.com/meliketoy/wide-resnet.pytorch
    MIT License

    Copyright (c) 2018 Bumsoo Kim

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.
    '''
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)


class wide_basic(nn.Module):
    '''Code from https://github.com/meliketoy/wide-resnet.pytorch
    MIT License

    Copyright (c) 2018 Bumsoo Kim

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.
    '''
    def __init__(self, in_planes, planes, dropout_rate, stride=1):
        super(wide_basic, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)
        return out


class Wide_ResNet(nn.Module):
    '''Code from https://github.com/meliketoy/wide-resnet.pytorch
    MIT License

    Copyright (c) 2018 Bumsoo Kim

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.
    '''
    def __init__(self, depth, widen_factor, dropout_rate, num_classes):
        super(Wide_ResNet, self).__init__()
        self.in_planes = 16
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) / 6
        k = widen_factor
        print('| Wide-Resnet %dx%d' % (depth, k))
        nStages = [16, 16 * k, 32 * k, 64 * k]
        self.conv1 = conv3x3(3, nStages[0])
        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        strides = [stride] + [1] * (int(num_blocks) - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out