import numpy as np
import math
from math import pi
import matplotlib.pyplot as plt
from Environment import Environment


def K_best(m):
    Lis = np.argsort(m)[-K:]
    Lis = np.flipud(Lis)
    reward = 1 - np.prod(1 - m[Lis])
    return Lis, reward


def alpha(arm):
    ret = 0
    for j in List:
        if TiP[j] == 1:
            TiP[j] += 1
        tem = S0[j] / TiP[j] - delta0 - 2 * math.sqrt(math.log(pi * pi * L * (TiP[j] - 1) * (TiP[j] - 1) / (3 * delta)) / (2 * (TiP[j] - 1)))
        tem = max(tem, 0)
        rtemp = S0[arm] / TiP[arm] * (TiP[arm] - 1) + 1 - Attack[arm] - tem * TiP[arm]
        if rtemp <= 1:
            Tstars[arm, j] = tem
        else:
            rtemp = S0[arm] / TiP[arm] * (TiP[arm] - 1) + 1 - Attack[arm] - Tstars[arm, j] * TiP[arm]
        if rtemp > ret:
            ret = rtemp
    return ret


L, K, T = 100, 10, 100000
delta0, delta = 0.1, 1.1
Repeat = 20
M = 11
List = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Cost1, CostT, CostK = np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T)
Time1, TimeT, TimeK = np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T)
A = np.zeros(T)
Target = np.zeros(T)


Position = np.array([0.6, 0.6, 0.6, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1])


E = Environment(100, 4, False, True, 'ml_1000user_1000item.npy')
Mu = E.means
optK, optR = K_best(Mu)
print("Optimal Arms:\t", end='')
for item in optK:
    print(Mu[item], end=" ")
print()
for item in Mu:
    print(item)


file = open('mu_position.txt', 'w')
file.write(str(Mu))
file.write('\n')
file.write(str(Position))
file.write('\n')
file.write(str(M))
file.write('\n')
file.write(str(List))
file.close()


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    S = np.zeros(L)
    S0 = np.zeros(L)
    Ti0 = np.ones(L)
    TiP = np.ones(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(delta * Ti0 * np.log(t + 1) / (2 * TiP * TiP))
        UCB = np.clip(S / TiP + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        clock = 0
        for item in Rec:
            Ti0[item] += 1
            TiP[item] += Position[clock]
            temp = ((np.random.rand() < Mu[item]) * (np.random.rand() < Position[clock])).astype(int)
            S0[item] += temp
            if item not in List:
                if temp == 1:
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
            S[item] += temp
            clock += 1

        Target[t] = (Target[t] * counter + Ntimes[M]) / (counter + 1)
        A[t] = (A[t] * counter + sum(Attack)) / (counter + 1)
        Cost1[t][counter] = sum(Attack)
        Time1[t][counter] = Ntimes[M] / (t + 1)


Ratio = np.zeros(T)
for t in range(T):
    Ratio[t] = Target[t] / (t + 1)


AT = np.zeros(T)
TargetT = np.zeros(T)


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    S = np.zeros(L)
    S0 = np.zeros(L)
    Ti0 = np.ones(L)
    TiP = np.ones(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(delta * Ti0 * np.log(t + 1) / (2 * TiP * TiP))
        UCB = np.clip(S / TiP + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        clock = 0
        for item in Rec:
            Ti0[item] += 1
            TiP[item] += Position[clock]
            temp = ((np.random.rand() < Mu[item]) * (np.random.rand() < Position[clock])).astype(int)
            S0[item] += temp
            if (temp == 1) and (item is not M):
                temp = 0
                Attack[item] += 1
            S[item] += temp
            clock += 1

        TargetT[t] = (TargetT[t] * counter + Ntimes[M]) / (counter + 1)
        AT[t] = (AT[t] * counter + sum(Attack)) / (counter + 1)
        CostT[t][counter] = sum(Attack)
        TimeT[t][counter] = Ntimes[M] / (t + 1)


RatioT = np.zeros(T)
for t in range(T):
    RatioT[t] = TargetT[t] / (t + 1)


AK = np.zeros(T)
TargetK = np.zeros(T)


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    S = np.zeros(L)
    S0 = np.zeros(L)
    Ti0 = np.ones(L)
    TiP = np.ones(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(delta * Ti0 * np.log(t + 1) / (2 * TiP * TiP))
        UCB = np.clip(S / TiP + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        clock = 0
        for item in Rec:
            Ti0[item] += 1
            TiP[item] += Position[clock]
            temp = ((np.random.rand() < Mu[item]) * (np.random.rand() < Position[clock])).astype(int)
            S0[item] += temp
            if (temp == 1) and (item not in List):
                temp = 0
                Attack[item] += 1
            S[item] += temp
            clock += 1

        TargetK[t] = (TargetK[t] * counter + Ntimes[M]) / (counter + 1)
        AK[t] = (AK[t] * counter + sum(Attack)) / (counter + 1)
        CostK[t][counter] = sum(Attack)
        TimeK[t][counter] = Ntimes[M] / (t + 1)


RatioK = np.zeros(T)
for t in range(T):
    RatioK[t] = TargetK[t] / (t + 1)


VarCost1, VarCostT, VarCostK = np.zeros(T), np.zeros(T), np.zeros(T)
VarTime1, VarTimeT, VarTimeK = np.zeros(T), np.zeros(T), np.zeros(T)
for index in range(T):
    VarCost1[index] = np.std(Cost1[index])
    VarCostT[index] = np.std(CostT[index])
    VarCostK[index] = np.std(CostK[index])
    VarTime1[index] = np.std(A[index])
    VarTimeT[index] = np.std(AT[index])
    VarTimeK[index] = np.std(AK[index])


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
x = np.linspace(1, T, T)
ymin = min(min(Ratio), min(RatioK), min(RatioT))
plt.ylim(ymin - 0.1, 1.1)
plt.plot(x, Ratio, label="Our attack", color='red', lw=2.4)
plt.plot(x, Ratio + VarTime1, color='pink', lw=0.8)
plt.plot(x, Ratio - VarTime1, color='pink', lw=0.8)
plt.fill_between(x, Ratio + VarTime1, Ratio - VarTime1, alpha=0.25, color='pink')

plt.plot(x, RatioT, label="Trivial₁", color='blue', lw=2.4)
plt.plot(x, RatioT + VarTimeT, color='skyblue', lw=0.8)
plt.plot(x, RatioT - VarTimeT, color='skyblue', lw=0.8)
plt.fill_between(x, RatioT + VarTimeT, RatioT - VarTimeT, alpha=0.25, color='skyblue')

plt.plot(x, RatioK, label="Trivialₖ₋₁", color='green', lw=2.4)
plt.plot(x, RatioK + VarTimeK, color='lightgreen', lw=0.8)
plt.plot(x, RatioK - VarTimeK, color='lightgreen', lw=0.8)
plt.fill_between(x, RatioK + VarTimeK, RatioK - VarTimeK, alpha=0.25, color='lightgreen')

plt.xlabel("t", fontsize=28)
plt.ylabel("Chosen ratio", fontsize=28)
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.savefig("Ratio_PBMnew1.png", dpi=600, bbox_inches='tight')
# plt.show()


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
x = np.linspace(1, T, T)
ymax = max(max(A), max(AK))
plt.ylim(0, ymax + 100)

plt.plot(x, A, label="Our attack", color='red', lw=2.4)
plt.plot(x, A + VarCost1, color='pink', lw=0.8)
plt.plot(x, A - VarCost1, color='pink', lw=0.8)
plt.fill_between(x, A + VarCost1, A - VarCost1, alpha=0.25, color='pink')

plt.plot(x, AT, label="Trivial₁", color='blue', lw=2.4)
plt.plot(x, AT + VarCostT, color='skyblue', lw=0.8)
plt.plot(x, AT - VarCostT, color='skyblue', lw=0.8)
plt.fill_between(x, AT + VarCostT, AT - VarCostT, alpha=0.25, color='skyblue')

plt.plot(x, AK, label="Trivialₖ₋₁", color='green', lw=2.4)
plt.plot(x, AK + VarCostK, color='lightgreen', lw=0.8)
plt.plot(x, AK - VarCostK, color='lightgreen', lw=0.8)
plt.fill_between(x, AK + VarCostK, AK - VarCostK, alpha=0.25, color='lightgreen')

plt.xlabel("t", fontsize=28)
plt.ylabel("Cost", fontsize=28)
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.savefig("Cost_PBMnew1.png", dpi=600, bbox_inches='tight')
# plt.show()
