import os
import arff
import random
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class Linear_Cls(nn.Module):
	def __init__(self, input_dim, num_label):
		super(Linear_Cls, self).__init__()
		self.fc1 = nn.Linear(input_dim, num_label)

	def forward(self, x):
		return torch.sigmoid(self.fc1(x))

class MLP_Cls(nn.Module):
	def __init__(self, input_dim, dim, num_label):
		super(MLP_Cls, self).__init__()
		self.fc1 = nn.Linear(input_dim, dim)
		self.fc2 = nn.Linear(dim, num_label)
		self.relu = nn.ReLU(True)

	def forward(self, x):
		return torch.sigmoid(self.fc2(self.relu(self.fc1(x))))


def train_onestep_OGD(model, X, arm, y, loss_type='square', lr=0.005, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()
	k = y.shape[1]

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

	if (loss_type == 'square'):
		loss_fn = nn.MSELoss().to(device)
	else:
		loss_fn = nn.BCELoss().to(device)
	num = len(dataloader)

	batch_loss = 0.0
	for x, arms, y in dataloader:
		x, y = x.to(device), y.to(device)
		pred = model(x)
		pred_c, pred_l = [], []
		bs = x.shape[0]
		for j in range(bs):
			for l in range(k):
				if arms[j, l].item() == 1:
					pred_c.append(pred[j, l].view(-1))
					pred_l.append(y[j, l].view(-1))

		pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
		loss = loss_fn(pred_c, pred_l)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		batch_loss += loss.item()
		
		# if batch_loss / num <= 1e-3:
		# 	return batch_loss / num

	# return batch_loss / num
	return model

def train_cls_batch(model, X, arm, y, num_epochs=20, lr=0.001, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()
	k = y.shape[1]
	# print(X.shape, arm.shape, y.shape)

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	loss_fn = nn.BCELoss().to(device)
	num = len(dataloader)

	for i in range(num_epochs):
		batch_loss = 0.0
		for x, arms, y in dataloader:
			x, y = x.to(device), y.to(device)
			pred = model(x)
			pred_c, pred_l = [], []
			bs = x.shape[0]
			for j in range(bs):
				for l in range(k):
					if arms[j, l].item() == 1:
						pred_c.append(pred[j, l].view(-1))
						pred_l.append(y[j, l].view(-1))

			pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
			# print(pred_c.shape, pred_l.shape)
			loss = loss_fn(pred_c, pred_l)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			batch_loss += loss.item()
		if batch_loss / num <= 1e-3:
			return batch_loss / num

	return batch_loss / num

def train_cls_MC_batch(model, X, arm, y, num_epochs=20, lr=0.001, batch_size=64, device='cpu'):
	model.train()
	X = torch.cat(X).float()
	arm = torch.cat(arm).int()
	y = torch.cat(y).float()
	k = y.shape[1]

	optimizer = optim.Adam(model.parameters(), lr=lr)
	dataset = TensorDataset(X, arm, y)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	loss_fn = nn.MSELoss().to(device)
	num = len(dataloader)

	for i in range(num_epochs):
		batch_loss = 0.0
		for x, arms, y in dataloader:
			x, y = x.to(device), y.to(device)
			pred = model(x)
			pred_c, pred_l = [], []
			bs = x.shape[0]
			for j in range(bs):
				for l in range(k):
					if arms[j, l].item() == 1:
						pred_c.append(pred[j, l].view(-1))
						pred_l.append(y[j, l].view(-1))

			pred_c, pred_l = torch.cat(pred_c), torch.cat(pred_l)
			loss = loss_fn(pred_c, pred_l)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			batch_loss += loss.item()
		if batch_loss / num <= 1e-3:
			return batch_loss / num

	return batch_loss / num