#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 22 17:35:37 2020

@author: zw
"""

from torch.nn import Parameter
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
import torch.nn.functional as F
from torchvision import models

import numpy as np
import math
import matplotlib.pyplot as plt  

import cv2

from resnet_v2 import *
from resnet_model import BasicBlock

class Encoding(nn.Module):
    def __init__(self, trinum=1):
        super(Encoding, self).__init__()
        self.trinum = trinum
        self.resnet = resnet50()
        
        self.resb = []
        for i in range(trinum):
            self.resb.append(BasicBlock(2048, 2048))
    
    def forward(self, x):
        x = self.resnet.relu1(self.resnet.bn1(self.resnet.conv1(x)))
        if self.resnet.deep_base:
            x = self.resnet.relu2(self.resnet.bn2(self.resnet.conv2(x)))
            x = self.resnet.relu3(self.resnet.bn3(self.resnet.conv3(x)))
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)  # out = [88] 256
        x2 = self.resnet.layer2(x1)  # out = [44] 512
        x3 = self.resnet.layer3(x2)  # out = [22] 1024
        x4 = self.resnet.layer4(x3)  # out = [11] 2048
        
        for i in range(self.trinum):
            x4 = self.resb[i](x4)
        
        return [x1, x2, x3, x4]

class UpConvBlock(nn.Module):
    def __init__(self, inp, out):
        super(UpConvBlock, self).__init__()
        
        self.Up = nn.Conv2d(inp, out, 3, padding=1) 
        self.Up_bn = nn.BatchNorm2d(out)
        self.Up_relu = nn.ReLU(inplace=True)
        
        self.upscore = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, x_skip):
        
        x_up = self.upscore(x)
        new_x = torch.cat((x_up, x_skip), 1)
        new_x = self.Up_relu(self.Up_bn(self.Up(new_x)))
        
        return new_x

class Decoding(nn.Module):
    def __init__(self, ups=3):
        super(Decoding, self).__init__()
        self.ups = ups
        
        channels = [2048, 1024, 512, 256]
        self.layer = []
        for i in range(ups):
            self.layer.append(UpConvBlock(channels[i]+channels[i+1], channels[i+1]))
        
    def forward(self, Enc_list):
        out = []
        base_in = Enc_list[-1]
        
        out.append(base_in)
        Enc_list = list(reversed(Enc_list))
        Enc_list = Enc_list[1:]
        for i in range(self.ups):
            base_in = self.layer[i](base_in, Enc_list[i])
            out.append(base_in)
        return out

class OutConvBlock(nn.Module):
    def __init__(self, inp, upscale):
        super(OutConvBlock, self).__init__()
        
        self.outconv = nn.Conv2d(inp, 1, 3, padding=1)   
        self.upscore = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        
        x = self.outconv(x)            
        x = self.upscore(x)
        x = torch.sigmoid(x)
        
        return x

class Out(nn.Module):
    def __init__(self, outnum):
        super(Out, self).__init__()
        self.outnum = outnum
        scale = [32, 16, 8, 4]
        inp_channel = [2048, 1024, 512, 256]
        self.layer = []
        
        for i in range(outnum):
            self.layer.append(OutConvBlock(inp_channel[i], scale[i]))

    def forward(self, out_list):
        out = []
        
        for i in range(self.outnum):
            out.append(self.layer[i](out_list[i]))
        
        out = list(reversed(out))
        return out

class Priming(nn.Module):
    def __init__(self, diffusion=False):
        super(Priming, self).__init__()
        self.diffusion = diffusion
        
        self.resnet_f = models.resnet50(pretrained=True)
        for p in self.parameters():
            p.requires_grad = False
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        

    def forward(self, base, mask, img):
        #Priming out
        mask_img = F.interpolate(mask, size=(224, 224), mode='bilinear', align_corners=True)
        priming_x = F.interpolate(img, size=(224, 224), mode='bilinear', align_corners=True)
        priming_x = priming_x * mask_img
        
        priming_x = self.resnet_f.conv1(priming_x)
        priming_x = self.resnet_f.bn1(priming_x)
        priming_x = self.resnet_f.relu(priming_x)
        priming_x = self.resnet_f.maxpool(priming_x)

        priming_x = self.resnet_f.layer1(priming_x)  # out = 256
        priming_x = self.resnet_f.layer2(priming_x)  # out = 512
        priming_x = self.resnet_f.layer3(priming_x)  # out = 1024
        priming_x = self.resnet_f.layer4(priming_x)  # out = 2048
        
        priming = self.resnet_f.avgpool(priming_x)   # out = 2048
        
        #Main out
        mask_f = F.interpolate(mask, size=(11, 11), mode='bilinear', align_corners=True)
        mp = mask_f * base
        mp = self.avgpool(mp)
        mp = mp.view(mp.size(0), mp.size(1), 1)
        priming = priming.view(priming.size(0), 1, priming.size(1))
        mp_priming = torch.matmul(mp, priming)
        mp_priming = torch.matmul(mp_priming, mp)
        mp_priming = mp_priming.view(mp_priming.size(0), mp_priming.size(1), 1, 1)
        out = base * mp_priming
        
        return out

class BaseModel(nn.Module):
    def __init__(self, ups=3):
        super(BaseModel, self).__init__()
        self.encoding = Encoding()            
        self.decoding = Decoding(ups=ups)     #out = [x1, x2, x3]
        self.out = Out(outnum=ups+1)
        
        self.priming1 = Priming()
        self.decoding1 = Decoding(ups=ups)
        self.out1 = Out(outnum=ups+1)
        
        self.priming2 = Priming()
        self.decoding2 = Decoding(ups=ups)
        self.out2 = Out(outnum=ups+1)
        
    def forward(self, x):
        '''
        '1': [256 , 88, 88]
        '2': [512 , 44, 44]
        '3': [1024, 22, 22]
        '4': [2048, 11, 11]
        '''
        inp = x
        Ex = self.encoding(x)      #out = [x1, x2, x3, x4]
        Dx = self.decoding(Ex)     #out = [d4, d3, d2, d1]
        Ox = self.out(Dx)          #out = [o1, o2, o3, o4]
        
        Px1 = self.priming1(Ex[-1], Ox[0], inp)
        Ex1 = Ex[:-1]
        Ex1 = Ex1 + [Px1]
        Dx1 = self.decoding1(Ex1)     #out = [d4, d3, d2, d1]
        Ox1 = self.out(Dx1)          #out = [o1, o2, o3, o4]
        
        Px2 = self.priming2(Ex[-1], Ox1[0], inp)
        Ex2 = Ex[:-1]
        Ex2 = Ex2 + [Px2]
        Dx2 = self.decoding2(Ex2)     #out = [d4, d3, d2, d1]
        Ox2 = self.out(Dx2)          #out = [o1, o2, o3, o4]
        
        out = Ox2 + Ox1 + Ox
        return out

inp = torch.rand((2, 3, 352, 352))

net = BaseModel()
out = net(inp)

print(len(out))


