# -*- coding: utf-8 -*-
path = ""

import numpy as np
import matplotlib.pyplot as plt
import torch
import csv
import numpy as np
import random
import torch
from PIL import Image
from torch.utils.data import DataLoader,TensorDataset
#from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine,RandomResizedCrop,CenterCrop
#from torchvision.transforms import ToTensor, Normalize,transforms
from PIL.Image import BICUBIC
import scipy.io as sio

batch_size = 128
num_workers=4

import torch
import torch.nn as nn
import torch.nn.functional as F

arr = sio.loadmat('saved_test_feature.mat')
print("load .npy done")
test_features = arr['features']
test_label = arr['label']
test_g = arr['g']
arr = sio.loadmat('saved_train_feature.mat')
print("load .npy done")
train_features = arr['features']
train_label = arr['label']
train_g = arr['g']
'''
arr = sio.loadmat('120_preds-on_validation.mat')
print("load .npy done")
validation_features = arr['features']
validation_label = arr['label']
validation_g = arr['g']
'''
arr = sio.loadmat('saved_validation_feature.mat')
#print(arr[0])
print("load .npy done")
validation_test_features = arr['features']
validation_test_label = arr['label']
validation_test_g = arr['g']

train_features = np.vstack((train_features,validation_features))
train_label = np.vstack((train_label,validation_label))
train_g = np.vstack((train_g,validation_g))

train_tensor_x = torch.Tensor(train_features)
train_tensor_y = torch.Tensor(train_label)
train_tensor_g = torch.Tensor(train_g)
val_tensor_x = torch.Tensor(validation_features)
val_tensor_y = torch.Tensor(validation_label)
val_tensor_g = torch.Tensor(validation_g)
val_test_tensor_x = torch.Tensor(validation_test_features)
val_test_tensor_y = torch.Tensor(validation_test_label)
val_test_tensor_g = torch.Tensor(validation_test_g)
test_tensor_x = torch.Tensor(test_features)
test_tensor_y = torch.Tensor(test_label)
test_tensor_g = torch.Tensor(test_g)

train_my_dataset = TensorDataset(train_tensor_x,train_tensor_y,train_tensor_g)
train_my_dataloader = DataLoader(train_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
test_my_dataset = TensorDataset(test_tensor_x,test_tensor_y,test_tensor_g)
test_my_dataloader = DataLoader(test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_my_dataset = TensorDataset(val_tensor_x,val_tensor_y,val_tensor_g)
val_my_dataloader = DataLoader(val_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_test_my_dataset = TensorDataset(val_test_tensor_x,val_test_tensor_y,val_test_tensor_g)
val_test_my_dataloader = DataLoader(val_test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)

D_in = np.shape(test_features)[1]

num_classes = 2
num_epochs = 150

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, num_classes),
)

class ba_xent_loss(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self,logits,label,group,la):
    pi_list = [0.73,0.038,0.012,0.22]
    tau = 1
    pi = torch.FloatTensor(pi_list).cuda()
    class_val = label*2+group
    one_hot = F.one_hot(class_val,num_classes=4)
    one_hot = one_hot.type(torch.float32)
    one_hot_group = F.one_hot(group,num_classes=2).type(torch.float32)
    pi_g = torch.max(torch.sum(torch.reshape(pi,(2,2)),axis=1)*one_hot_group,axis = 1)[0]
    pi_yg = torch.max(pi*one_hot,axis = 1)[0]#right
    '''
    class_val_wrong = y[:,0]*2+1-y[:,1]
    one_hot_wrong = F.one_hot(class_val_wrong)
    one_hot_wrong = one_hot_wrong.type(torch.float32)
    pi_yg_wrong = torch.max(pi*one_hot_wrong,axis = 1)[0]#wrong
    '''
    one_hot_group = one_hot_group.T
    pi_yyg = torch.mm(torch.FloatTensor(np.reshape(pi_list,(2,2))).cuda(),one_hot_group).T

    base_probs = torch.div(pi_yyg.T,pi_yg.T).T
    logits_ad = logits + torch.log((base_probs**tau + 1e-12).type(torch.float32))
    #logits_ad = logits + torch.log((base_probs**tau + 1e-12).type(torch.float32))
    #loss = F.cross_entropy(F.softmax(logits_ad,dim=1),y[:,0],reduction='none')

    loss = F.cross_entropy(logits,label,reduction='none')
    weighted_loss = (1/pi_yg)*loss
    
    loss = torch.mean(weighted_loss)
    return la*loss+(1-la)*F.cross_entropy(logits,label)

class DEO_loss(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self,logits,label,group,la):
    pi_list = [0.73,0.038,0.012,0.22]
    tau = 1
    #pi = torch.FloatTensor(pi_list).cuda()
    pi = torch.FloatTensor(pi_list)
    class_val = label*2+group
    one_hot = F.one_hot(class_val,num_classes=4)
    one_hot = one_hot.type(torch.float32)
    pi_yg = torch.max(pi*one_hot,axis = 1)[0]#right

    loss = F.cross_entropy(logits,label,reduction='none')
    #print(loss)
    weighted_loss = (1/pi_yg)*loss
    weighted_loss_group = (one_hot.T*weighted_loss.T).T
    weighted_loss_group_sum = torch.sum(weighted_loss_group,axis=0)
    loss_DEO = abs(weighted_loss_group_sum[0]-weighted_loss_group_sum[1])+abs(weighted_loss_group_sum[2]-weighted_loss_group_sum[3])
    
    #print(weighted_loss)    
    #print(weighted_loss_group)
    #print(weighted_loss_group_sum)
    
    return la*loss_DEO+(1-la)*F.cross_entropy(logits,label)

class LA_loss(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self,logits,label,group,la):
    #print(label)
    #print(group)
    pi_list = [0.73,0.038,0.012,0.22]
    tau = 1
    #pi = torch.FloatTensor(pi_list).cuda()
    pi = torch.FloatTensor(pi_list)
    class_val = label*2+group
    one_hot = F.one_hot(class_val,num_classes=4)
    one_hot = one_hot.type(torch.float32)
    one_hot_group = F.one_hot(group,num_classes=2).type(torch.float32)
    pi_g = torch.max(torch.sum(torch.reshape(pi,(2,2)),axis=0)*one_hot_group,axis = 1)[0]
    #print(pi_g)
    pi_yg = torch.max(pi*one_hot,axis = 1)[0]#right
    one_hot_group = one_hot_group.T
    #pi_yyg = torch.mm(torch.FloatTensor(np.reshape(pi_list,(2,2))).cuda(),one_hot_group).T
    pi_yyg = torch.mm(torch.FloatTensor(np.reshape(pi_list,(2,2))),one_hot_group).T
    #print(pi_yyg)
    base_probs = torch.div(pi_yyg.T,pi_yg.T).T
    logits_ad = logits + torch.log((base_probs**tau + 1e-12).type(torch.float32))
    #logits_ad = logits + torch.log((base_probs**tau + 1e-12).type(torch.float32))
    #loss = F.cross_entropy(F.softmax(logits_ad,dim=1),y[:,0],reduction='none')

    loss = F.cross_entropy(logits_ad,label,reduction='none')
    weighted_loss = (1/pi_g)*loss
    loss_google = torch.mean(weighted_loss)
    return la*loss_google+(1-la)*F.cross_entropy(logits,label)

def eval_per_class(data_loader, model, text,flag=0):
    model.eval()
    correct=0.
    total=0.
    loss=0.
    class_group_correct = list(0. for i in range(4))
    class_group_total = list(0. for i in range(4))
    accuracy_4 = []
    classes = ('00', '01', '10', '11')
    for cur_iter, (data, label,group) in enumerate(data_loader):
        label = label.long()
        label = label[:,0]
        group = group[:,0]
        #data, label,group = data.cuda(), label[:,0].cuda(non_blocking=True),group[:,0].cuda(non_blocking=True)
        logits = model(data)
        preds = logits.data.max(1)[1]
        c = (label == preds).squeeze()
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        if mb_size>=1:
          for i in range(int(mb_size)):
            #label = preds[i].item()
            label_i = label[i].item()
            group_i = group[i].item()
            if label_i == 0 and group_i == 0:
              class_4 = 0
            if label_i == 0 and group_i == 1:
              class_4 = 1              
            if label_i == 1 and group_i == 0:
              class_4 = 2
            if label_i == 1 and group_i == 1:
              class_4 = 3             
            class_group_correct[class_4] += c.cpu().numpy()[i]
            class_group_total[class_4] += 1
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
    if flag ==0:
      for i in range(4):
        if class_group_total[i] != 0:
          print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')
    else:
      for i in range(4):
        if class_group_total[i] != 0:
          #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')      
    print(f'{text}:ACC = {correct/total*100.}')
    print(f'{text}:balance ACC = {np.mean(accuracy_4)}')

    return correct/total*100.,np.mean(accuracy_4),accuracy_4

    #return f'{text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}',loss/total,correct/total*100.

val_acc_all=[]
test_acc_all=[]

la = 1
  seed = 9
  torch.manual_seed(seed)
  model = torch.nn.Linear(D_in, num_classes)
  print(model.weight)
  print(model.bias)
  #model.to('cuda')

  lr = 0.005
  criterion = LA_loss()
  #criterion = DEO_loss()
  #criterion = ba_xent_loss()

  optimizer = torch.optim.Adam(params=model.parameters(),lr=lr,weight_decay=1e-4)
  train_lr_scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[60,100,120],gamma=0.5)
  losses = []
  acc_val_all = []
  acc_test_all = []
  acc_val_test_all = []
  val_balanced_acc_all = []
  test_balanced_acc_all = []
  acc_train_all = []
  train_balanced_acc_all = []
  val_test_balanced_acc_all = []
  loss_all = 0
  print(la)
  for epoch in range(num_epochs):
    i = 0
    loss_all = 0
    for cur_iter, (data, label,group) in enumerate(train_my_dataloader):
      label = label.long()
      group = group.long()
      label = label[:,0]
      group = group[:,0] 
      #data, label,group = data.cuda(non_blocking=True), label[:,0].cuda(non_blocking=True), group[:,0].cuda(non_blocking=True)
      logist = model(data)
      loss = criterion(logist,label,group,la)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_all = loss_all+loss.item()
      i = i+1
    if (epoch+5) % 1 == 0:
      print(i)
      print ("Epoch [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs,loss_all/i))
      losses.append(loss_all/i)
      print('train')
      train_acc,train_balanced_acc,_=eval_per_class(train_my_dataloader,model,' train_dataset',flag=1)
      print('val')
      val_acc,val_balanced_acc,_=eval_per_class(val_my_dataloader,model,' val_dataset',flag=1)
      print('val(test)')
      val_test_acc,val_test_balanced_acc,_=eval_per_class(val_test_my_dataloader,model,' val_test_dataset',flag=1)
      print('test')
      test_acc,test_balanced_acc,_=eval_per_class(test_my_dataloader,model,' test_dataset',flag=1)
      acc_val_all.append(val_acc)
      acc_val_test_all.append(val_test_acc)
      acc_test_all.append(test_acc)
      acc_train_all.append(train_acc)
      val_balanced_acc_all.append(val_balanced_acc)
      val_test_balanced_acc_all.append(val_test_balanced_acc)   
      train_balanced_acc_all.append(train_balanced_acc)
    train_lr_scheduler.step()
  plt.figure()
  plt.plot(losses,label='losses')
  plt.show()
  plt.figure()
  plt.plot(acc_val_all,label='acc_train_all')
  plt.plot(acc_val_all,label='acc_val_all')
  plt.plot(acc_val_test_all,label='acc_val_all')
  plt.plot(acc_test_all,label = 'acc_test_all')
  plt.plot(train_balanced_acc_all,label = 'train_balanced_acc_all')
  plt.plot(val_balanced_acc_all,label = 'val_balanced_acc_all')
  plt.plot(val_test_balanced_acc_all,label = 'val_balanced_acc_all')
  #plt.plot(test_balanced_acc_all,label = 'test_balanced_acc_all')
  plt.legend()
  plt.show()
  val_acc,val_balanced_acc,val_acc=eval_per_class(val_test_my_dataloader,model,' val_dataset')
  test_acc,test_balanced_acc,test_acc=eval_per_class(test_my_dataloader,model,' test_dataset')