import numpy as np
import torch
import torch.nn.functional as F
import exp_utils as PQ
import matplotlib.pyplot as plt
import pickle


def softminus(x: torch.Tensor):
    return -F.softplus(-x)


def swish(x):
    return x * x.sigmoid()


def barrier(states: torch.Tensor):
    max_angle = +np.pi / 2
    min_angle = -np.pi / 2

    def interval_barrier(x, lb, rb):
        x = (x - lb) / (rb - lb)
        eps = 1e-6
        b = -((x + eps) * (1 - x + eps) / (0.5 + eps)**2).log()
        b_min, b_max = 0, -np.log(4 * eps)
        grad = 1. / eps - 1
        out = grad * torch.max(-x, x - 1)
        return torch.where(torch.as_tensor((0 < x) & (x < 1)), b, b_max + out)

    b1 = interval_barrier(states[..., 0], min_angle, max_angle)
    return b1
    # b2 = interval_barrier(states[..., 1], -1, 1)
    # return (b1 + b2) / 2


def plot_pendulum_set(fns, device, clouds, filename, title, max_thresh=3, *,
                      x_min=-np.pi / 2, x_max=np.pi / 2, y_min=-2, y_max=2,
                      encode=lambda x: x, decode=lambda x: (x[:, 0], x[:, 1]), xlabel="angle", ylabel="angvel"):

    xs = np.linspace(x_min, x_max, 201)
    ys = np.linspace(y_min, y_max, 201)

    X, Y = np.meshgrid(xs, ys)
    points = torch.tensor(encode(np.array([X, Y])), dtype=torch.float32, device=device).permute(1, 2, 0)
    values = {key: fn(points).cpu().detach().numpy() for key, fn in fns.items()}

    fig, axes = plt.subplots(nrows=1, ncols=len(values), figsize=(8 * len(values), 6))
    if not isinstance(axes, np.ndarray):
        axes = [axes]

    cmaps = {
        'hardD': plt.cm.RdBu,
        'softD': plt.cm.BrBG,
        'U': plt.cm.PRGn,
        'L': plt.cm.BrBG,
        'barrier': plt.cm.BrBG,
        'logBarrier': plt.cm.BrBG,
    }

    for ax, (key, value) in zip(axes, values.items()):
        if key in ['L', 'U']:
            vmin, vmax = 0, 2
        elif key == 'logBarrier':
            vmin, vmax = -3, 3
        else:
            thresh = max(min(np.max(value), -np.min(value), max_thresh), 1e-6)
            vmin, vmax = thresh, -thresh
        # thresh = max_thresh

        im = ax.imshow(value, cmap=cmaps[key], extent=[x_min, x_max, y_min, y_max], aspect='auto', origin='lower',
                       vmax=vmax, vmin=vmin)
        CS = plt.contour(X, Y, values['L'], levels=[1.], colors=['tab:orange'])
        fig.colorbar(im, ax=ax)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.clabel(CS)
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        if key == 'L':
            color = {'traj': 'C2', 's': 'C3'}
            for i, (name, cloud) in enumerate(clouds.items()):
                # if getattr(env, 'obs_type', 'state') == 'state':
                #     xs = cloud[:, 0]
                #     ys = cloud[:, 1]
                # else:
                #     xs = np.arctan2(cloud[:, 1], cloud[:, 0])
                #     ys = cloud[:, 2]
                xs, ys = decode(cloud)
                # alpha = 0.2 if len(cloud) >= 1000 else 1
                alpha = 1 / max(np.log(len(clouds) / 100.), 1.)
                ax.plot(xs, ys, label=name, markersize=1, ls=' ', marker='o', color=color[name], alpha=alpha)

            if len(clouds):
                ax.legend(loc=1)

        ax.set_title(title + f", fn = {key}")
        ax.grid()

    fig.tight_layout()
    fig.savefig(filename, dpi=150)
    plt.close(fig)


def find_max_barrier(states, model, policy, barrier, horizon):
    max_barrier = barrier(states)
    for T in range(horizon):
        if T % 10 == 0: print(T)
        actions = policy(states)
        next_states = model(states, actions)
        states = next_states
        max_barrier = max_barrier.max(barrier(states))
    return max_barrier

