import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global nurse_num
global day_num
global shift_num
global day_shift_num
global day_work_shift_num
global x_num
global sigma_num
global gamma_num
global decision_num
global t_decision_num
global penaltyTerm
global extra_serve_patient_num
global extra_payment
global minimum_relax_day
global maximum_relax_day
global relax_val

nurse_num = 10
total_day_num = 7
day_num = total_day_num
shift_num = 4
day_shift_num = day_num * shift_num
day_work_shift_num = day_num * (shift_num - 1)
x_num = nurse_num * day_shift_num
sigma_num = day_shift_num
gamma_num = nurse_num * day_shift_num
decision_num = x_num + sigma_num
t_decision_num = x_num + sigma_num + gamma_num
penaltyTerm = 0.01
extra_serve_patient_num = 1
extra_payment = 0
relax_val = 1e-5
if day_num == 7:
    minimum_relax_day = 1
    maximum_relax_day = 2
elif day_num == 2:
    minimum_relax_day = 0
    maximum_relax_day = 1


def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)


def set_extra_payment(ep):
    global extra_payment
    extra_payment = ep


def change_day_num(cur_day_num):
    global day_num
    global day_shift_num
    global day_work_shift_num
    global x_num
    global sigma_num
    global gamma_num
    global t_decision_num
    day_num = cur_day_num
    day_shift_num = cur_day_num * shift_num
    day_work_shift_num = cur_day_num * (shift_num - 1)
    x_num = nurse_num * day_shift_num
    sigma_num = day_shift_num
    gamma_num = nurse_num * day_shift_num
    t_decision_num = x_num + sigma_num + gamma_num
    

def reset_day_num():
    global day_num
    global day_shift_num
    global day_work_shift_num
    global x_num
    global sigma_num
    global gamma_num
    global t_decision_num
    day_num = total_day_num
    day_shift_num = total_day_num * shift_num
    day_work_shift_num = total_day_num * (shift_num - 1)
    x_num = nurse_num * day_shift_num
    sigma_num = day_shift_num
    gamma_num = nurse_num * day_shift_num
    t_decision_num = x_num + sigma_num + gamma_num


def gen_matrix(nurse_num, day_num, shift_num, serve_patient_num, decision_num, day_shift_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    h2 = np.ones(nurse_num*(day_num-1))
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, decision_num))
    G4 = np.zeros((nurse_num, decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    G = np.concatenate([G1, G2, G3, G4, G5], axis=0)
    
    return A,b,G,h2,h3,h4,h5


def gen_t_matrix(t, prev_sol, real_h1, pre_h1, serve_patient_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, t_decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num
    h1 = np.zeros(day_shift_num)
    for i in range(t*shift_num):
        h1[i] = real_h1[i]
    for i in range(t*shift_num, day_shift_num):
        h1[i] = pre_h1[i]

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, t_decision_num))
    G4 = np.zeros((nurse_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h2 = np.ones(nurse_num*(day_num-1))
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day
        h4[i] = maximum_relax_day
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,t_decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    # x - gamma <= x_(t-1)
    G6 = np.zeros((nurse_num*day_shift_num, t_decision_num))
    h6 = prev_sol[:nurse_num*day_shift_num]
    for i in range(nurse_num):
        for j in range(day_shift_num):
            G6[i*day_shift_num+j][i*day_shift_num+j] = 1
            G6[i*day_shift_num+j][decision_num+i*day_shift_num+j] = -1
#    print(h6)
#    np.savetxt('G6.txt', G6, fmt="%.0f")

    # gamma range
    G7 = np.zeros((nurse_num*day_shift_num,t_decision_num))
    for i in range(nurse_num*day_shift_num):
        G7[i][decision_num+i] = 1
    h7 = np.ones(nurse_num*day_shift_num)
    
    # fix hard commitments
    A2 = np.zeros((nurse_num*t*shift_num+(t-1)*shift_num,t_decision_num))
    b2 = np.zeros(nurse_num*t*shift_num+(t-1)*shift_num)
    cnt = 0
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#                print("fix cons")
                A2[cnt][i*day_shift_num+j*shift_num+k] = 1
                cnt = cnt + 1
    for j in range(t-1):
        for k in range(shift_num):
            A2[cnt][nurse_num*day_shift_num+j*shift_num+k] = 1
            cnt = cnt + 1
#        print(G6)
    cnt = 0
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#                print(i, j, k)
#                print(i*day_shift_num+j*day_shift_num+k)
                b2[cnt] = prev_sol[i*day_shift_num+j*shift_num+k]
                cnt = cnt + 1
    for j in range(t-1):
        for k in range(shift_num):
            b2[cnt] = prev_sol[nurse_num*day_shift_num+j*shift_num+k]
            cnt = cnt + 1
#        print(h6)
    G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
    h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    A = np.concatenate([A, A2], axis=0)
    b = np.concatenate([b, b2], axis=0)
        
#        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape, h6.shape, h7.shape, h8.shape)
#        (8,) (15,) (15,) (15,) (120,) torch.Size([120]) (120,) (16,)
#
#    else:
#        G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
#        h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    
    return A,b,G,h
    
# cost, penalty: full version
def gen_obj(t, cost, penalty=None):
    if penalty is None:
        c = np.zeros(decision_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c[i*day_shift_num+j] = 0
        for i in range(nurse_num*day_shift_num, decision_num):
            c[i] = extra_payment
    else:
        c_for_x = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_shift_num):
                if j % shift_num != 3:
                    c_for_x[i*day_shift_num+j] = cost[i]
                elif j % shift_num == 3:
                    c_for_x[i*day_shift_num+j] = 0
        
        c_for_sigma = np.zeros(day_shift_num)
        for i in range(day_shift_num):
            c_for_sigma[i] = extra_payment
        
        c_for_gamma = np.zeros(nurse_num*day_shift_num)
        for i in range(nurse_num):
            for j in range(day_num):
                for k in range(shift_num):
                    c_for_gamma[i*day_shift_num+j*shift_num+k] = (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c_for_x[i*day_shift_num+j*shift_num+k]
                
        c = np.concatenate([c_for_x, c_for_sigma, c_for_gamma], axis=0)
        
    return c
    

def actual_obj(c, A, b, G, real_patient_num, h2, h3, h4, h5, n_instance):
    obj_list = []
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
#    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    for num in range(n_instance):
        h1 = np.zeros(day_shift_num)
        cnt = num * day_work_shift_num
        for i in range(day_shift_num):
            if i % shift_num != 3:
                h1[i] = -real_patient_num[cnt]
                cnt = cnt + 1
            elif i % shift_num == 3:    # relax shift
                h1[i] = 0
        h = np.concatenate([h1, h2, h3, h4, h5], axis=0)
#        print(h1)
        h = h.tolist()

        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#        sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
        
        OBJ = x.prod(c)
#        for i in range(day_shift_num):
#            OBJ = OBJ + extra_payment * sigma[i]
        m.setObjective(OBJ, GRB.MINIMIZE)
        for i in range(rowSizeA):
            m.addConstr(x.prod(A[i]) == b[i])
        for j in range(rowSizeG):
            m.addConstr(x.prod(G[j]) <= h[j])
#        np.savetxt('c.txt', c, fmt="%.2f")
#        np.savetxt('A.txt', A, fmt="%.2f")
#        np.savetxt('b.txt', b, fmt="%.2f")
#        np.savetxt('G.txt', G, fmt="%.2f")
#        np.savetxt('h.txt', h, fmt="%.2f")
        
#            m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= h[j])
#        for j in range(day_shift_num, rowSizeG):
#            m.addConstr(x.prod(G[j]) <= h[j])

        m.optimize()
        sol = []
        sigmaSol = []
        try:
            for i in range(decision_num):
                sol.append(x[i].x)
#            for i in range(day_shift_num):
#                sigmaSol.append(sigma[i].x)
            objective = m.objVal
        except:
            for i in range(decision_num):
                sol.append(0)
#            for i in range(day_shift_num):
#                sigmaSol.append(0)
            objective = 0

        obj_list.append(objective)
#        print(objective)
##        print("True Opt Sol: ",sol)
#        print("True Opt Schedule: ")
##        print("Day 1 2 3 4 5 6 7")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#        for i in range(nurse_num*day_shift_num):
#            if (i!=0 and i%day_shift_num == 0):
#                print("")
##                print("N", math.ceil(i/day_shift_num), end=" ")
#            if sol[i] == 1:
#                if i % shift_num == 0:
#                    print("M", end=" ")
#                elif i % shift_num == 1:
#                    print("E", end=" ")
#                elif i % shift_num == 2:
#                    print("N", end=" ")
#                else:
#                    print("-", end=" ")
#        print("\n")
##        print(sigmaSol)
#        print("Extra hired nurses: ")
#        print("Day 1 2 3 4 5 6 7")
#        for j in range(shift_num-1):
#            if j == 0:
#                print(" M ", end=" ")
#            elif j == 1:
#                print(" E ", end=" ")
#            elif j == 2:
#                print(" N ", end=" ")
#            for i in range(day_num):
#                print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#            print("")
#        print("\n")

    return np.array(obj_list)


def get_init_plan(c, A, b, G, real_patient_num, pre_patient_num, h2, h3, h4, h5):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
#    c = c[:x_sol_size]
#    A = A[:, :x_sol_size]
#    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient_num[cnt]
            real_h1[i] = -real_patient_num[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
    pre_h = pre_h.tolist()
    real_h = real_h.tolist()
#    print(pre_h1, real_h1)

    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(decision_num, vtype=GRB.INTEGER, name='x')
#    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')

    OBJ = x.prod(c)
#    for i in range(day_shift_num):
#        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)
    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= pre_h[j])
#    for j in range(day_shift_num, rowSizeG):
#        m.addConstr(x.prod(G[j]) <= pre_h[j])
    m.optimize()
    
    try:
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        for i in range(nurse_num*day_shift_num):
            predSol[i] = round(x[i].x)
        for i in range(nurse_num*day_shift_num, decision_num):
            sigmaSol[i-nurse_num*day_shift_num] = x[i].x
        objective = m.objVal
    except:
        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
        predSol = np.zeros(nurse_num*day_shift_num)
        sigmaSol = np.zeros(day_shift_num)
        objective = 0
    
##    print(objective)
#    print("init plan: ")
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(predSol[i], end=" ")
#    print("\n")
#    print("Init Opt Sol: ")
##        for i in range(decision_num):
##            if (i!=0 and i%day_shift_num == 0):
##                print("")
##            print(sol[i], end=" ")
##        print("\n")
#
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if predSol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(sigmaSol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
    
    return predSol, sigmaSol


def get_t_updated_plan(t, pre_sol, pre_sigmaSol, c, A, b, G, real_patient_num, pre_patient_num, h2, h3, h4, h5, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]
    x_sol_size = nurse_num * day_shift_num
    c = c[:x_sol_size]
    A = A[:, :x_sol_size]
    G = G[:, :x_sol_size]
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient_num[cnt]
            real_h1[i] = -real_patient_num[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    pre_h = np.concatenate([pre_h1, h2, h3, h4, h5], axis=0)
    real_h = np.concatenate([real_h1, h2, h3, h4, h5], axis=0)
    pre_h = pre_h.tolist()
    real_h = real_h.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(x_sol_size, vtype=GRB.BINARY, name='x')
    gamma = m.addVars(x_sol_size, vtype=GRB.BINARY, name='gamma')
    sigma = m.addVars(day_shift_num, vtype=GRB.INTEGER, name='sigma')
    

    OBJ = x.prod(c)
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                OBJ = OBJ + (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * gamma[i*day_shift_num+j*shift_num+k]
    for i in range(day_shift_num):
        OBJ = OBJ + extra_payment * sigma[i]
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(t*shift_num):
        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#    for j in range(t):
#        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= real_h[j])
#        print("real_h[", j, "]", real_h[j])
    for j in range(t*shift_num, day_shift_num):
        m.addConstr(x.prod(G[j]) - extra_serve_patient_num * sigma[j] <= pre_h[j])
    for j in range(day_shift_num, rowSizeG):
        m.addConstr(x.prod(G[j]) <= pre_h[j])
#        print("pre_h[", j, "]", pre_h[j])
    for i in range(x_sol_size):
        m.addConstr(gamma[i] >= x[i] - pre_sol[i])
#    for i in range(day_shift_num):
#        m.addConstr(tau[i] >= sigma[i] - pre_sigmaSol[i])
#        m.addConstr(phi[i] >= pre_sigmaSol[i] - sigma[i])
    for i in range(nurse_num):
        for j in range(t):
            for k in range(shift_num):
#            print(i*day_shift_num+j)
                m.addConstr(x[i*day_shift_num+j*shift_num+k] == pre_sol[i*day_shift_num+j*shift_num+k])
    for j in range(t-1):
        for k in range(shift_num):
            m.addConstr(sigma[j*shift_num+k] == pre_sigmaSol[j*shift_num+k])

    m.optimize()
    t_updated_sol = np.zeros(x_sol_size)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(x_sol_size)

    try:
        for i in range(x_sol_size):
            t_updated_sol[i] = round(x[i].x)
            t_gamma[i] = gamma[i].x
        for i in range(day_shift_num):
            t_sigma_sol[i] = sigma[i].x
        objective = m.objVal
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    t_incur_penalty = 0
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                t_incur_penalty = t_incur_penalty + (day_num - j + t) * penalty[i*day_shift_num+j*shift_num+k] * c[i*day_shift_num+j*shift_num+k] * t_gamma[i*day_shift_num+j*shift_num+k]
    
#    print(t, "objective: ", objective, " updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")

#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(x_sol_size):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def correction_single_obj(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty):
    init_plan, init_sigma = get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
    t_updated_sol = init_plan
    t_sigma_sol = init_sigma
    total_penalty = 0
    for t in range(1, day_num+1):
        t_updated_sol, t_sigma_sol, t_incur_penalty = get_t_updated_plan(t, t_updated_sol, t_sigma_sol, c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5, penalty)
        total_penalty = total_penalty + t_incur_penalty
#    print(t_updated_sol, t_sigma_sol)
    total_cost = np.dot(t_updated_sol, c[:nurse_num*day_shift_num]) + np.sum(t_sigma_sol * extra_payment) + total_penalty
#    print("EOV: ", total_cost)
    return total_cost


def check_IP_t_updated(t, c, A, b, G, h, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]

    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()

    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(t_decision_num, vtype=GRB.INTEGER, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])

    m.optimize()
    t_updated_sol = np.zeros(nurse_num*day_shift_num)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(nurse_num*day_shift_num)

    try:
        for i in range(nurse_num*day_shift_num):
            t_updated_sol[i] = round(x[i].x)
        for i in range(nurse_num*day_shift_num, decision_num):
            t_sigma_sol[i-nurse_num*day_shift_num] = round(x[i].x)
        for i in range(decision_num, t_decision_num):
            t_gamma[i-decision_num] = round(x[i].x)
        objective = m.objVal
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
    t_incur_penalty = 0
    for i in range(nurse_num*day_shift_num):
        t_incur_penalty = t_incur_penalty + penalty[i] * c[i] * t_gamma[i]

    
#    print(t, "updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")
#
#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def check_intOpt(cost, serve_patient_num, real_patient, pre_patient, penalty):
    c = gen_obj(0, cost)
    A,b,G,h2,h3,h4,h5 = gen_matrix(nurse_num,day_num,shift_num,serve_patient_num,decision_num,day_shift_num)
    init_plan, init_sigma = get_init_plan(c, A, b, G, real_patient, pre_patient, h2, h3, h4, h5)
    
#    init_obj = np.dot(init_plan, c[:nurse_num*day_shift_num])
#    print("IP: ", init_obj)

    t_updated_x = init_plan
    t_sigma_sol = init_sigma
    t_updated_sol = np.concatenate([init_plan, init_sigma], axis=0)
    total_penalty = 0
    
    pre_h1 = np.zeros(day_shift_num)
    real_h1 = np.zeros(day_shift_num)
    cnt = 0
    for i in range(day_shift_num):
        if i % shift_num != 3:
            pre_h1[i] = -pre_patient[cnt]
            real_h1[i] = -real_patient[cnt]
            cnt = cnt + 1
        else:
            pre_h1[i] = 0
            real_h1[i] = 0
    
    for t in range(1, day_num+1):
#        if t % 4 != 0:
        c_t = gen_obj(t, cost, penalty)
        A_t, b_t, G_t, h_t = gen_t_matrix(t, t_updated_sol, real_h1, pre_h1, serve_patient_num)
        t_updated_x, t_sigma_sol, t_incur_penalty = check_IP_t_updated(t, c_t, A_t, b_t, G_t, h_t, penalty)
        t_updated_sol = np.concatenate([t_updated_x, t_sigma_sol], axis=0)
        total_penalty = total_penalty + t_incur_penalty
        c = c[:nurse_num*day_shift_num]
        total_cost = np.dot(t_updated_x, c) + np.sum(t_sigma_sol * extra_payment) + total_penalty
#    print("EOV: ", total_cost)
#    time.sleep(10)
    return total_cost


# remaining_schedule: (curr) nurse_num * cur_day_num * shift_num
# has_rested: nurse_num
# real_h1, pre_h1: (curr) cur_day_num * shift_num
# A, b, G, h: (curr) x_num + sigma_num + gamma_num
def gen_constraints_latter_days(t, remaining_schedule, has_rested, real_h1, pre_h1, serve_patient_num):
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, t_decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num
    h1 = np.zeros(day_shift_num)
    for i in range(shift_num):
        h1[i] = real_h1[i]
    for i in range(shift_num, day_shift_num):
        h1[i] = pre_h1[i]

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, t_decision_num))
    G4 = np.zeros((nurse_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h2 = np.ones(nurse_num*(day_num-1))
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day + has_rested[i]
        h4[i] = maximum_relax_day - has_rested[i]
#    print(t, has_rested, h4)
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,t_decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    # x - gamma <= x_(t-1)
    G6 = np.zeros((nurse_num*day_shift_num, t_decision_num))
    h6 = remaining_schedule[:nurse_num*day_shift_num]
    for i in range(nurse_num):
        for j in range(day_shift_num):
            G6[i*day_shift_num+j][i*day_shift_num+j] = 1
            G6[i*day_shift_num+j][x_num+sigma_num+i*day_shift_num+j] = -1
#    print(h6)
#    np.savetxt('G6.txt', G6, fmt="%.0f")

    # gamma range
    G7 = np.zeros((nurse_num*day_shift_num,t_decision_num))
    for i in range(nurse_num*day_shift_num):
        G7[i][x_num+sigma_num+i] = 1
    h7 = np.ones(nurse_num*day_shift_num)
    
    # fix hard commitments
    A2 = np.zeros((nurse_num*shift_num,t_decision_num))
    b2 = np.zeros(nurse_num*shift_num)
    cnt = 0
    for i in range(nurse_num):
        for k in range(shift_num):
#                print("fix cons")
            A2[cnt][i*day_shift_num+k] = 1
            cnt = cnt + 1
#        print(G6)
    cnt = 0
    for i in range(nurse_num):
        for k in range(shift_num):
#                print(i, j, k)
#                print(i*day_shift_num+j*day_shift_num+k)
            b2[cnt] = remaining_schedule[i*day_shift_num+k]
            cnt = cnt + 1
#    for j in range(t-1):
#        for k in range(shift_num):
#            b2[cnt] = prev_sol[nurse_num*day_shift_num+j*shift_num+k]
#            cnt = cnt + 1
#        print(h6)
#    print(G1.shape, G2.shape, G3.shape, G4.shape, G5.shape, G6.shape, G7.shape)
    G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
    h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    A = np.concatenate([A, A2], axis=0)
    b = np.concatenate([b, b2], axis=0)
        
#        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape, h6.shape, h7.shape, h8.shape)
#        (8,) (15,) (15,) (15,) (120,) torch.Size([120]) (120,) (16,)
#
#    else:
#        G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
#        h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    
    return A,b,G,h


def gen_obj_latter_days(t, cost, penalty):
    c_for_x = np.zeros(nurse_num*day_shift_num)
    for i in range(nurse_num):
        for j in range(day_shift_num):
            if j % shift_num != 3:
                c_for_x[i*day_shift_num+j] = cost[i]
            elif j % shift_num == 3:
                c_for_x[i*day_shift_num+j] = 0
    
    c_for_sigma = np.zeros(day_shift_num)
    for i in range(day_shift_num):
        c_for_sigma[i] = extra_payment
    
    c_for_gamma = np.zeros(nurse_num*day_shift_num)
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                c_for_gamma[i*day_shift_num+j*shift_num+k] = (total_day_num - j) * penalty[i*day_shift_num+j*shift_num+k] * c_for_x[i*day_shift_num+j*shift_num+k]
            
    c = np.concatenate([c_for_x, c_for_sigma, c_for_gamma], axis=0)
    
    return c


# prev_sol: (curr) x_num + sigma_num
# has_rested: nurse_num (full version)
# real_h1, pre_h1: curr version
# cost, penalty: full version
def get_updated_plan_for_each_day(t, c, A, b, G, h, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]

    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()

#    np.savetxt('c.txt', c, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(t_decision_num, vtype=GRB.INTEGER, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])
#    m.addConstr(x[x_num+sigma_num-1] == 0)
#    m.addConstr(x[x_num+5] == 2)

    m.optimize()
    t_updated_sol = np.zeros(nurse_num*day_shift_num)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(nurse_num*day_shift_num)

    try:
        for i in range(x_num):
            t_updated_sol[i] = round(x[i].x)
        for i in range(x_num, x_num+sigma_num):
            t_sigma_sol[i-x_num] = round(x[i].x)
        for i in range(x_num+sigma_num, t_decision_num):
            t_gamma[i-x_num-sigma_num] = round(x[i].x)
        objective = m.objVal
    except:
        print("Stage ", t, ", cannot solve")
        m.computeIIS()
        m.write('model.ilp')
        print(t)
        np.savetxt('A.txt', A, fmt="%.2f")
        np.savetxt('b.txt', b)
        np.savetxt('G.txt', G, fmt="%.2f")
        np.savetxt('h.txt', h, fmt="%.2f")
        time.sleep(100)
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
#    print(t_updated_sol, t_sigma_sol)
    t_incur_penalty = np.sum(c[x_num+sigma_num:]*t_gamma)
#    print(t_updated_sol)
#    print(t, "objective: ", objective, " updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")
#
#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty

# remaining_schedule: curr version
# curr_cost: cur_NN before cost
# has_rested: cur_NN before rested
# pre_h1, real_h1: full version
def correction_single_for_latter_days(cur_NN, pre_h1, real_h1, cost, penalty, serve_patient_num, remaining_schedule, curr_cost, curr_penalty, has_rested):
#    print(pre_h1, real_h1)
#    print(has_rested)
    has_rested_temp = np.zeros(nurse_num)
    for i in range(nurse_num):
        has_rested_temp[i] = has_rested[i]
    t_updated_x = remaining_schedule
    total_penalty = curr_penalty
    for t in range(cur_NN, total_day_num+1):
        # compute current states
        if t > cur_NN:
            nurse_payment = 0
            for i in range(nurse_num):
                for k in range(shift_num-1):
                    nurse_payment += cost[i] * t_updated_x[i*day_shift_num+k]
            extra_nurse_payment = 0
            for i in range(shift_num):
                extra_nurse_payment += t_updated_sigma[i] * extra_payment
            curr_cost += nurse_payment + extra_nurse_payment
            total_penalty += t_incur_penalty
            
            for i in range(nurse_num):
                if t_updated_x[i*day_shift_num+3] == 1:
                    has_rested_temp[i] += 1
        
        cur_day_num = total_day_num - t + 1
        change_day_num(cur_day_num)
        t_real_h1 = real_h1
        t_pre_h1 = pre_h1
        if t > cur_NN:
            t_real_h1 = real_h1[(t-cur_NN)*shift_num:]
            t_pre_h1 = pre_h1[(t-cur_NN)*shift_num:]
            # remove schedules that have happened
            remaining_schedule = np.zeros(nurse_num*day_shift_num)
            for i in range(nurse_num):
                for j in range(day_shift_num):
                    remaining_schedule[i*day_shift_num+j] = t_updated_x[i*(day_num+1)*shift_num+shift_num+j]
            
        A_t, b_t, G_t, h_t = gen_constraints_latter_days(t, remaining_schedule, has_rested_temp, t_real_h1, t_pre_h1, serve_patient_num)
        c_t = gen_obj_latter_days(t, cost, penalty)
        t_updated_x, t_updated_sigma, t_incur_penalty = get_updated_plan_for_each_day(t, c_t, A_t, b_t, G_t, h_t, penalty)
#        print(curr_cost, t_incur_penalty)
    
    nurse_payment = 0
    for i in range(nurse_num):
        for k in range(shift_num-1):
            nurse_payment += cost[i] * t_updated_x[i*day_shift_num+k]
    
    extra_nurse_payment = 0
    for i in range(shift_num):
        extra_nurse_payment += t_updated_sigma[i] * extra_payment
    final_cost = curr_cost + nurse_payment + extra_nurse_payment + total_penalty
    
    cur_day_num = total_day_num - cur_NN + 1
    change_day_num(cur_day_num)
#    print(has_rested)
    
    return final_cost

# t = cur_stage
# cur_day_num = total_day_num - cur_NN + 1
# real_h1, pre_h1: cur_day_num * shift_num
# remaining_schedule, x_prev_stage: nurse_num * cur_day_num * shift_num
def gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, remaining_schedule, has_rested, real_h1, pre_h1, serve_patient_num, x_prev_stage=None):

    cur_day_num = total_day_num - cur_NN + 1
    # Each nurse must be scheduled for exactly one shift per day
    A = np.zeros((nurse_num*day_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num):
            for q in range(shift_num):
                A[i*day_num+j][i*day_shift_num+shift_num*j+q] = 1
    #print(A)
    b = np.ones(nurse_num*day_num)

    # Each schedule must satisfy the patients' need (include relax shifts)
    G1 = np.zeros((day_shift_num, t_decision_num))
    for j in range(day_shift_num):
        for i in range(nurse_num):
            G1[j][i*day_shift_num+j] = -serve_patient_num[i]
            G1[j][nurse_num * day_num * shift_num + j] = -extra_serve_patient_num
    h1 = np.zeros(day_shift_num)
    for i in range((cur_stage-cur_NN+1)*shift_num):
        h1[i] = real_h1[i]
    for i in range((cur_stage-cur_NN+1)*shift_num, day_shift_num):
        h1[i] = pre_h1[i]

    # No nurse may be scheduled to work a night shift followed immendiately by a morning shift
    G2 = np.zeros((nurse_num*(day_num-1), t_decision_num))
    for i in range(nurse_num):
        for j in range(day_num-1):
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)-2] = 1
            G2[i*(day_num-1)+j][i*day_shift_num+shift_num*(j+1)] = 1
    #print(G2)
    
    # Each nurse gets one or two day-off shift per week
    G3 = np.zeros((nurse_num, t_decision_num))
    G4 = np.zeros((nurse_num, t_decision_num))
    for i in range(nurse_num):
        for j in range(1, day_num+1):
            G3[i][i*day_shift_num+shift_num*j-1] = -1
            G4[i][i*day_shift_num+shift_num*j-1] = 1
#    print(G4)
    h2 = np.ones(nurse_num*(day_num-1))
    h3 = np.empty(nurse_num, dtype=int)
    h4 = np.empty(nurse_num, dtype=int)
#    print(has_rested)
    for i in range(nurse_num):
        h3[i] = -minimum_relax_day + has_rested[i]
        h4[i] = maximum_relax_day - has_rested[i]
    
    # x range
    G5 = np.zeros((nurse_num * day_num * shift_num,t_decision_num))
    for i in range(nurse_num * day_num * shift_num):
        G5[i][i] = 1
    h5 = np.ones(nurse_num * day_num * shift_num)
    
    # x - gamma <= x_(t-1)
    G6 = np.zeros((nurse_num*day_shift_num, t_decision_num))
    h6 = remaining_schedule[:nurse_num*day_shift_num]
    for i in range(nurse_num):
        for j in range(day_shift_num):
            G6[i*day_shift_num+j][i*day_shift_num+j] = 1
            G6[i*day_shift_num+j][x_num+sigma_num+i*day_shift_num+j] = -1
#    print(h6)
#    np.savetxt('G6.txt', G6, fmt="%.0f")

    # gamma range
    G7 = np.zeros((nurse_num*day_shift_num,t_decision_num))
    for i in range(nurse_num*day_shift_num):
        G7[i][x_num+sigma_num+i] = 1
    h7 = np.ones(nurse_num*day_shift_num)
    
    # fix hard commitments
    A2 = np.zeros((nurse_num*(cur_stage-cur_NN+1)*shift_num,t_decision_num))
    b2 = np.zeros(nurse_num*(cur_stage-cur_NN+1)*shift_num)
    cnt = 0
    for i in range(nurse_num):
        for j in range(cur_stage-cur_NN+1):
            for k in range(shift_num):
#                print("fix cons")
                A2[cnt][i*day_shift_num+j*shift_num+k] = 1
                cnt = cnt + 1
#        print(G6)
    if cur_stage == cur_NN:
        cnt = 0
        for i in range(nurse_num):
            for k in range(shift_num):
    #                print(i, j, k)
    #                print(i*day_shift_num+j*day_shift_num+k)
                b2[cnt] = remaining_schedule[i*day_shift_num+k]
                cnt = cnt + 1
    else:
        cnt = 0
        for i in range(nurse_num):
            for j in range(cur_stage-cur_NN+1):
                for k in range(shift_num):
    #                print("fix cons")
                    b2[cnt] = x_prev_stage[i*day_shift_num+j*shift_num+k]
                    cnt = cnt + 1
#    for j in range(t-1):
#        for k in range(shift_num):
#            b2[cnt] = prev_sol[nurse_num*day_shift_num+j*shift_num+k]
#            cnt = cnt + 1
#        print(h6)
#    print(G1.shape, G2.shape, G3.shape, G4.shape, G5.shape, G6.shape, G7.shape)
    G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
    h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    A = np.concatenate([A, A2], axis=0)
    b = np.concatenate([b, b2], axis=0)
        
#        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape, h6.shape, h7.shape, h8.shape)
#        (8,) (15,) (15,) (15,) (120,) torch.Size([120]) (120,) (16,)
#
#    else:
#        G = np.concatenate([G1, G2, G3, G4, G5, G6, G7], axis=0)
#        h = np.concatenate([h1, h2, h3, h4, h5, h6, h7], axis=0)
    
    return A,b,G,h

# t = cur_stage
def gen_obj_latter_days_full_intOpt(cur_NN, cur_stage, cost, penalty):
    c_for_x = np.zeros(nurse_num*day_shift_num)
    for i in range(nurse_num):
        for j in range(day_shift_num):
            if j % shift_num != 3:
                c_for_x[i*day_shift_num+j] = cost[i]
            elif j % shift_num == 3:
                c_for_x[i*day_shift_num+j] = 0
    
    c_for_sigma = np.zeros(day_shift_num)
    for i in range(day_shift_num):
        c_for_sigma[i] = extra_payment
    
    c_for_gamma = np.zeros(nurse_num*day_shift_num)
    for i in range(nurse_num):
        for j in range(day_num):
            for k in range(shift_num):
                c_for_gamma[i*day_shift_num+j*shift_num+k] = (total_day_num - (j + cur_NN) + cur_stage) * penalty[i*day_shift_num+j*shift_num+k] * c_for_x[i*day_shift_num+j*shift_num+k]
            
    c = np.concatenate([c_for_x, c_for_sigma, c_for_gamma], axis=0)
    
    return c

# prev_sol: (curr) x_num + sigma_num
# has_rested: nurse_num (full version)
# real_h1, pre_h1: curr version
# cost, penalty: full version
def get_updated_plan_for_each_day_full_intOpt(t, c, A, b, G, h, penalty):
    rowSizeA = A.shape[0]
    rowSizeG = G.shape[0]

    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()

#    np.savetxt('c.txt', c, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(t_decision_num, vtype=GRB.INTEGER, name='x')

    OBJ = x.prod(c)
    m.setObjective(OBJ, GRB.MINIMIZE)

    for i in range(rowSizeA):
        m.addConstr(x.prod(A[i]) == b[i])
    for j in range(rowSizeG):
        m.addConstr(x.prod(G[j]) <= h[j])
#    m.addConstr(x[x_num+sigma_num-1] == 0)
#    m.addConstr(x[x_num+5] == 2)

    m.optimize()
    t_updated_sol = np.zeros(nurse_num*day_shift_num)
    t_sigma_sol = np.zeros(day_shift_num)
    t_gamma = np.zeros(nurse_num*day_shift_num)

    try:
        for i in range(x_num):
            t_updated_sol[i] = round(x[i].x)
        for i in range(x_num, x_num+sigma_num):
            t_sigma_sol[i-x_num] = round(x[i].x)
        for i in range(x_num+sigma_num, t_decision_num):
            t_gamma[i-x_num-sigma_num] = round(x[i].x)
        objective = m.objVal
    except:
        print("cannot solve")
        
        m.computeIIS()
        m.write('model.ilp')
#        time.sleep(100)
#        np.savetxt('A.txt', A, fmt="%.0f")
#        np.savetxt('b.txt', b, fmt="%.0f")
#        np.savetxt('G.txt', G, fmt="%.0f")
#        np.savetxt('h.txt', h, fmt="%.0f")
#        for i in range(decision_num):
#            t_updated_sol[i] = 0
        objective = 0
    
#    print(t_updated_sol, t_sigma_sol)
    t_incur_penalty = np.sum(c[x_num+sigma_num:]*t_gamma)
#    print(t_updated_sol)
#    print(t, "objective: ", objective, " updated_sol: ")
##    for i in range(decision_num):
##        if (i!=0 and i%day_shift_num == 0):
##            print("")
##        print(t_updated_sol[i], end=" ")
##    cnt = 0
#    for i in range(nurse_num*day_shift_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        if t_updated_sol[i] == 1:
#            if i % shift_num == 0:
#                print("M", end=" ")
#            elif i % shift_num == 1:
#                print("E", end=" ")
#            elif i % shift_num == 2:
#                print("N", end=" ")
#            else:
#                print("-", end=" ")
#    print("\n")
##    print(t_sigma_sol)
#    print("Extra hired nurses: ")
#    print("Day 1 2 3 4 5 6 7")
#    for j in range(shift_num-1):
#        if j == 0:
#            print(" M ", end=" ")
#        elif j == 1:
#            print(" E ", end=" ")
#        elif j == 2:
#            print(" N ", end=" ")
#        for i in range(day_num):
#            print(math.ceil(t_sigma_sol[i*shift_num+j]), end=" ")
#        print("")
#    print("\n")
#    print(t, "incur_penalty: ", t_incur_penalty)
#    print("-----------------------------------")
#
#    print("\n")
#    print(t, "change: ")
#    change = t_updated_sol - pre_sol
#    for i in range(decision_num):
#        if (i!=0 and i%day_shift_num == 0):
#            print("")
#        print(change[i], end=" ")
#    print("\n")
    
    return t_updated_sol, t_sigma_sol, t_incur_penalty


def check_grad_compute_used(cur_NN, pre_h1, real_h1, cost, penalty, serve_patient_num, remaining_schedule, curr_cost, curr_penalty, has_rested):
#    print(pre_h1, real_h1)
    t_updated_x = remaining_schedule
    total_penalty = curr_penalty
    for cur_stage in range(cur_NN, total_day_num+1):

        if cur_stage == cur_NN:
            A_t, b_t, G_t, h_t = gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, remaining_schedule, has_rested, real_h1, pre_h1, serve_patient_num)
            c_t = gen_obj_latter_days_full_intOpt(cur_NN, cur_stage, cost, penalty)
        else:
            A_t, b_t, G_t, h_t = gen_constraints_latter_days_full_intOpt(cur_NN, cur_stage, remaining_schedule, has_rested, real_h1, pre_h1, serve_patient_num, t_updated_x)
            c_t = gen_obj_latter_days_full_intOpt(cur_NN, cur_stage, cost, penalty)
            
        t_updated_x, t_updated_sigma, t_incur_penalty = get_updated_plan_for_each_day_full_intOpt(cur_stage, c_t, A_t, b_t, G_t, h_t, penalty)
        total_penalty += t_incur_penalty
        if cur_stage == total_day_num:
            final_cost = curr_cost + curr_penalty + np.sum(c_t[:x_num]*t_updated_x) + extra_payment * np.sum(t_updated_sigma)
#        print(curr_cost, t_incur_penalty)
    
    nurse_payment = 0
    for i in range(nurse_num):
        for k in range((day_num-1)*shift_num, day_num*shift_num):
            nurse_payment += cost[i] * t_updated_x[i*day_shift_num+k]
    
    return final_cost
