import torch
import numpy as np
import matplotlib.pyplot as plt

def plot_att_score(s_pred, y_true, T_true, bag_idx):
    """
    Args:
        s_pred: (num_inst,) attention score of the instances
        y_true: (num_inst,) label of the instances
        T_true: (num_bags,) label of the bags
        bag_idx: (num_inst,) maps each instance to its bag    
    """

    print('Plotting attention score distribution')

    # Keep instances belonging to positive bags

    pos_bags_idx = np.where(T_true == 1)[0] # positive bags

    idx_keep = np.isin(bag_idx, pos_bags_idx)

    y_true = y_true[idx_keep]
    s_pred = s_pred[idx_keep]

    pos_idx = np.where(y_true == 1)[0]
    neg_idx = np.where(y_true == 0)[0]

    pos_inst = s_pred[pos_idx]
    neg_inst = s_pred[neg_idx]
    
    fig, ax = plt.subplots()
    counts, bins = np.histogram(neg_inst, bins=20, density=True)
    ax.hist(bins[:-1], bins, weights=counts, label='Negative instances', edgecolor='black', alpha=0.8)
    counts, bins = np.histogram(pos_inst, bins=20, density=True)
    ax.hist(bins[:-1], bins, weights=counts, label='Positive instances', edgecolor='black', alpha=0.8)
    ax.set_xlabel('Attention score')
    ax.set_ylabel('Frequency')
    ax.legend()

    return fig

def plot_att_val(f_pred, y_true, T_true, bag_idx):
    """
    Args:
        s_pred: (num_inst,) attention score of the instances
        y_true: (num_inst,) label of the instances
        T_true: (num_bags,) label of the bags
        bag_idx: (num_inst,) maps each instance to its bag    
    """

    print('Plotting attention values distribution')

    pos_bags_idx = np.where(T_true == 1)[0] # positive bags

    idx_keep = np.isin(bag_idx, pos_bags_idx)

    y_true = y_true[idx_keep]
    f_pred = f_pred[idx_keep]
    
    pos_idx = np.where(y_true == 1)[0]
    neg_idx = np.where(y_true == 0)[0]

    pos_inst = f_pred[pos_idx]
    neg_inst = f_pred[neg_idx]
    
    fig, ax = plt.subplots()
    counts, bins = np.histogram(neg_inst, bins=20, density=True)
    ax.hist(bins[:-1], bins, weights=counts, label='Negative instances', edgecolor='black', alpha=0.8)
    counts, bins = np.histogram(pos_inst, bins=20, density=True)
    ax.hist(bins[:-1], bins, weights=counts, label='Positive instances', edgecolor='black', alpha=0.8)
    ax.set_xlabel('Attention value')
    ax.set_ylabel('Frequency')
    ax.legend()

    return fig