import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models

from basic_model import BasicModel
from actfuns import actfun_name2factory


class identity(nn.Module):
    def __init__(self):
        super(identity, self).__init__()
    def forward(self, x):
        return x

class SRAN(BasicModel):
    def __init__(self, args):
        super(SRAN, self).__init__(args)

        self.conv1 = models.resnet18(pretrained=False)
        self.conv1.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1.fc = identity()

        self.conv2 = models.resnet18(pretrained=False)
        self.conv2.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv2.fc = identity()

        self.conv3 = models.resnet18(pretrained=False)
        self.conv3.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv3.fc = identity()

        self.dataset = args.dataset

        if "RAVEN" in self.dataset:
            self.meta_target_length = 9
        elif self.dataset == 'PGM':
            self.meta_target_length = 12
        else:
            raise ValueError("Unsupported dataset: {}".format(self.dataset))

        self.img_size = args.img_size

        actfun_factory = actfun_name2factory(args.actfun)
        _actfun_example = actfun_factory()
        k = getattr(_actfun_example, "k", 1)
        divisor = getattr(_actfun_example, "divisor", 1)
        feature_factor = getattr(_actfun_example, "feature_factor", 1)

        actfun_args = {}
        if hasattr(_actfun_example, "feature_factor"):
            actfun_args["dim"] = -1

        n_out1_1 = int(int(round(3 * 512 * args.feature_width / divisor)) * divisor)
        n_out1_2 = int(round(512 / feature_factor))
        gate_function_1 = [
            nn.Linear(3 * 512, n_out1_1, bias=False),
            nn.LayerNorm(n_out1_1),
            actfun_factory(**actfun_args),
            nn.Linear(int(round(n_out1_1 * feature_factor)), n_out1_2, bias=False),
            nn.LayerNorm(n_out1_2),
            actfun_factory(**actfun_args),
        ]
        n_out2_1 = int(int(round(4 * 512 * args.feature_width / divisor)) * divisor)
        n_out2_2 = int(int(round(1024 * args.feature_width / divisor)) * divisor)
        gate_function_2 = [
            nn.LayerNorm(4 * 512),
            nn.Linear(4 * 512, n_out2_1, bias=False),
            nn.LayerNorm(n_out2_1),
            actfun_factory(**actfun_args),
            nn.Linear(int(round(n_out2_1 * feature_factor)), n_out2_2, bias=False),
            nn.LayerNorm(n_out2_2),
            actfun_factory(**actfun_args),
            nn.Linear(int(round(n_out2_2 * feature_factor)), 512, bias=False),
        ]
        n_out3_1 = int(int(round(1024 * args.feature_width / divisor)) * divisor)
        n_out3_2 = int(int(round(512 * args.feature_width / feature_factor / divisor)) * divisor)
        n_out3_3 = int(int(round(512 * args.feature_width / feature_factor / divisor)) * divisor)
        gate_function_3 = [
            nn.LayerNorm(1024),
            nn.Linear(1024, n_out3_1, bias=False),
            nn.LayerNorm(n_out3_1),
            actfun_factory(**actfun_args),
            nn.Linear(int(round(n_out3_1 * feature_factor)), n_out3_2, bias=False),
            nn.LayerNorm(n_out3_2),
            actfun_factory(**actfun_args),
            nn.Linear(int(round(n_out3_2 * feature_factor)), n_out3_3, bias=False),
            nn.LayerNorm(n_out3_3),
            actfun_factory(**actfun_args),
            nn.LayerNorm(int(round(n_out3_3 * feature_factor))),
            nn.Dropout(0.5),
            nn.Linear(int(round(n_out3_3 * feature_factor)), 512 + self.meta_target_length, bias=False),
            nn.LayerNorm(512 + self.meta_target_length),
        ]

        self.h1 = nn.Sequential(*gate_function_1)
        self.h2 = nn.Sequential(*gate_function_2)
        self.h3 = nn.Sequential(*gate_function_3)

        self.optimizer = optim.AdamW(
            self.parameters(),
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            eps=args.epsilon,
            weight_decay=args.weight_decay,
        )
        self.meta_beta = args.meta_beta
        self.row_to_column_idx = [0,3,6,1,4,7,2,5]

    def compute_loss(self, output, target, meta_target):
        pred, meta_target_pred, = output[0], output[1]
        target_loss = F.cross_entropy(pred, target)
        BCE_loss =  torch.nn.BCEWithLogitsLoss()
        meta_target_loss = BCE_loss(meta_target_pred, meta_target)
        #print (target_loss.item(), meta_target_loss.item())
        loss = target_loss + self.meta_beta*meta_target_loss
        return loss

    def get_pair(self, row1, row2):
        x = torch.cat((row1, row2), dim = 1)
        y = torch.cat((row2, row1), dim = 1)
        z = torch.stack((x, y), dim = 1)
        return z

    def get_row_rules(self, x, panel_features):

        row3 = x[:,6:8,:,:].unsqueeze(1)
        #B×1×2×224×224
        row3_candidates = torch.cat((row3.repeat(1,8,1,1,1), x[:,8:16,:,:].unsqueeze(2)), dim = 2)
        #B×8×3×224×224
        row_all = torch.cat((x[:,0:3,:,:].unsqueeze(1), x[:,3:6,:,:].unsqueeze(1), row3_candidates), dim = 1)
        #B×10×3×224×224
        intra_row_relations = self.conv2(row_all.view(-1,3,self.img_size,self.img_size))
        #(10B)×512


        choice_rows =  torch.cat((x[:,6:8,:,:].unsqueeze(1).repeat(1,8,1,1,1), x[:,8:16,:,:].unsqueeze(2)), dim = 2)
        #B×8×3×224×224
        conv_row_list = [self.get_pair(x[:,0:3,:,:], x[:,3:6,:,:])] + [self.get_pair(x[:,0:3,:,:], choice_rows[:,i,:,:,:]) for i in range(8)] + \
                        [self.get_pair(x[:,3:6,:,:], choice_rows[:,i,:,:,:]) for i in range(8)]
        conv_rows = torch.stack(conv_row_list, dim = 1)
        #B×17×2×6×224×224
        inter_row_relations = self.conv3(conv_rows.view(-1,6,self.img_size,self.img_size))
        #(34B)×512
        inter_row_relations = torch.sum(inter_row_relations.view(-1,17,2,512), dim = 2)
        #B×17×512


        row3_12features = panel_features[:,6:8,:].unsqueeze(1).repeat(1,8,1,1)
        #B×8×2×512
        candidate_features = panel_features[:,8:16,:].unsqueeze(2)
        #B×8×1×512
        row3_features = torch.cat((row3_12features,candidate_features), dim = 2)
        #B×8×3×512
        row_features = [panel_features[:,0:3,:].unsqueeze(1), panel_features[:,3:6,:].unsqueeze(1), row3_features]
        row_features = torch.cat(row_features, dim = 1)
        #B×10×3×512
        row_relations = self.h1(row_features.view(-1,1536))
        #(10B)×512
        row_relations = torch.cat((row_relations, intra_row_relations), dim = 1)
        #(10B)×1024
        row_relations = row_relations.view(-1,10,1024)
        #B×10×1024
        row_list = [self.get_pair(row_relations[:,0,:], row_relations[:,1,:])] + [self.get_pair(row_relations[:,0,:], row_relations[:,i,:]) for i in range(2,10)] + \
                   [self.get_pair(row_relations[:,1,:], row_relations[:,i,:]) for i in range(2,10)]
        row_relations = torch.stack(row_list, dim = 1)
        #B×17×2×2048
        rules = self.h2(row_relations.view(-1,2048))
        #(34B)×512
        rules = torch.sum(rules.view(-1,17,2,512), dim = 2)
        #B×17×512
        rules = torch.cat((rules, inter_row_relations), dim = 2)
        #B×17×1024
        rules = self.h3(rules)
        #B×17×(512+L)

        return rules[:,:,:512], rules[:,:,512:]

    def row_to_column(self, x, panel_features):
        context_image = x[:,self.row_to_column_idx]
        image = torch.cat((context_image, x[:,8:,:,:]), dim = 1)
        context_features = panel_features[:,self.row_to_column_idx]
        features = torch.cat((context_features, panel_features[:,8:,:]), dim = 1)
        return image, features

    def forward(self, x):
        B = x.size(0)
        panel_features = self.conv1(x.view(-1,1,self.img_size,self.img_size))
        #(16B)×512
        panel_features = panel_features.view(-1,16,512)
        #B×16×512

        row_output = self.get_row_rules(x, panel_features)
        row_rules, meta_target_row_pred = row_output[0], row_output[1]
        #B×17×512
        #B×17×L

        if "RAVEN" in self.dataset:
            column_rules = torch.zeros(B,17,512).cuda()
            meta_target_column_pred = torch.zeros(B,17,self.meta_target_length).cuda()
        elif self.dataset == 'PGM':
            x_c, panel_features_c = self.row_to_column(x, panel_features)
            column_output = self.get_row_rules(x_c, panel_features_c)
            column_rules, meta_target_column_pred = column_output[0], column_output[1]
            #B×17×512
            #B×17×L
        else:
            raise ValueError("Unsupported dataset: {}".format(self.dataset))

        rules = torch.cat((row_rules, column_rules), dim = 2)
        #B×17×1024
        meta_target_pred = meta_target_row_pred[:,0,:] + meta_target_column_pred[:,0,:]
        #B×L

        dominant_rule = rules[:,0,:].unsqueeze(1)
        pseudo_rules = rules[:,1:,:]
        similarity = torch.bmm(dominant_rule, torch.transpose(pseudo_rules, 1, 2)).squeeze(1)
        #B×16
        similarity = similarity[:,:8] + similarity[:,8:]
        #B×8

        return similarity, meta_target_pred
