import numpy as np
import torch
from skimage.transform import resize
import xarray as xr


def get_measurement(R, noise_std=0.04, ood=False):
    # Open dataset
    dset = xr.open_dataset(
        "/XXXX-2/XXXX-1/scratch/uq_diffusion/era5/data/era5_2m_temperature_2009-2017_01.grib"
    )
    dset = dset["t2m"].values

    # Normalize dataset
    dset = (dset - dset.min()) / (dset.max() - dset.min())
    y = dset[-1]
    # Resize to 256x256
    y = resize(y, (R, R), mode="reflect", anti_aliasing=True, preserve_range=True)

    if ood:
        # Mask the image
        # mask = create_circular_mask(R, R, radius=4, center=(90, 100))
        mask = create_circular_mask(R, R, radius=4, center=(100, 20))
        y[mask] = 1

    # Add noise
    y += np.random.randn(*y.shape) * noise_std
    y = (y - y.min()) / (y.max() - y.min())
    y = y * 2 - 1
    y = torch.from_numpy(y).float()
    return y


# compute psnr
def PSNR(x, y):
    mse = np.mean((x - y) ** 2)
    return 20 * np.log10(1.0 / np.sqrt(mse))


def create_circular_mask(h, w, center=None, radius=None):
    if center is None:  # use the middle of the image
        center = (int(w / 2), int(h / 2))
    if radius is None:  # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask
