import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parents[1]))
print(sys.executable)
from functools import partial, wraps

import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn

from arom.fourier_features import MultiScaleNeRFEncoding, NeRFEncoding


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class DiagonalGaussianDistribution(object):
    def __init__(self, mean, logvar, deterministic=False):
        self.mean = mean
        self.logvar = logvar
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(
                device=self.mean.device
            )

    def sample(self, K=1):
        if K == 1:
            x = self.mean + self.std * torch.randn(self.mean.shape).to(
                device=self.mean.device
            )
            return x
        else:
            x = self.mean[None, ...].repeat([K, 1, 1, 1]) + self.std[None, ...].repeat(
                K, 1, 1, 1
            ) * torch.randn([K, *self.mean.shape]).to(device=self.mean.device)
            return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            if other is None:
                return 0.5 * torch.mean(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2]
                )
            else:
                return 0.5 * torch.mean(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2],
                )

    def nll(self, sample, dims=[1, 2]):
        if self.deterministic:
            return torch.Tensor([0.0])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    def mode(self):
        return self.mean


# helpers


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def cache_fn(f):
    cache = None

    @wraps(f)
    def cached_fn(*args, _cache=True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache

    return cached_fn


# structured dropout, more effective than traditional attention dropouts


def dropout_seq(images, coordinates, mask=None, dropout=0.25):
    b, n, *_, device = *images.shape, images.device
    logits = torch.randn(b, n, device=device)

    keep_prob = 1.0 - dropout
    num_keep = max(1, int(keep_prob * n))
    keep_indices = logits.topk(num_keep, dim=1).indices

    batch_indices = torch.arange(b, device=device)
    batch_indices = rearrange(batch_indices, "b -> b 1")

    if mask is None:
        images = images[batch_indices, keep_indices]
        coordinates = coordinates[batch_indices, keep_indices]

        return images, coordinates
    
    else:
        images = images[batch_indices, keep_indices]
        coordinates = coordinates[batch_indices, keep_indices]
        mask = mask[batch_indices, keep_indices]

        return images, coordinates, mask

# helper classes


class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs["context"]
            normed_context = self.norm_context(context)
            kwargs.update(context=normed_context)

        return self.fn(x, **kwargs)


class PreNormCross(nn.Module):
    def __init__(self, dim, fn, k_dim=None, v_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(k_dim) if exists(k_dim) else None
        self.norm_v = nn.LayerNorm(v_dim) if exists(v_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_v):
            k = kwargs["k"]
            v = kwargs["v"]
            normed_k = self.norm_k(k)
            normed_v = self.norm_v(v)
            kwargs.update(k=normed_k, v=normed_v)

        return self.fn(x, **kwargs)


class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim),
            # nn.Dropout(0.1, inplace=True)
        )

    def forward(self, x):
        return self.net(x)


class CrossAttention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(key_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(value_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

        self.attn_drop = nn.Dropout(dropout, inplace=False)
        self.resid_drop = nn.Dropout(dropout)

    def forward(self, x, k, v, mask=None, pos=None):
        h = self.heads

        q = self.to_q(x)
        # context = default(context, x)
        k = self.to_k(k)
        v = self.to_v(v)

        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))

        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        if exists(mask):
            mask = mask.bool()
            mask = rearrange(mask, "b ... -> b (...)")
            max_neg_value = -torch.finfo(sim.dtype).max
            #mask = repeat(mask, "b j -> (b h) () j", h=h)
            mask = repeat(mask, "b j -> (b h) (n) j", h=h, n=x.shape[1])
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.resid_drop(self.to_out(out))


class MultiScaleAttention(nn.Module):
    def __init__(
        self, query_dim, context_dim=None, out_dim=None, heads=8, dim_head=64, dropout=0
    ):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        out_dim = default(out_dim, query_dim)
        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, out_dim)

        self.attn_drop = nn.Dropout(dropout, inplace=False)
        self.resid_drop = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, pos=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k, v))
        q = rearrange(q, "b n s (h d) -> (b h) s n d", h=h)

        sim = einsum("b s i d, b j d -> b s i j", q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, "b ... -> b (...)")
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, "b j -> (b h) () j", h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = einsum("b s i j, b j d -> b s i d", attn, v)
        out = rearrange(out, "(b h) s n d -> b n s (h d)", h=h)
        return self.resid_drop(self.to_out(out))


class Attention(nn.Module):
    def __init__(
        self, query_dim, context_dim=None, out_dim=None, heads=8, dim_head=64, dropout=0
    ):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        out_dim = default(out_dim, query_dim)
        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, out_dim)

        self.attn_drop = nn.Dropout(dropout, inplace=False)
        self.resid_drop = nn.Dropout(dropout)

    def forward(self, x, context=None, mask=None, pos=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))

        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, "b ... -> b (...)")
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, "b j -> (b h) () j", h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.resid_drop(self.to_out(out))


# main class


class PerceiverIO(nn.Module):
    def __init__(
        self,
        *,
        depth,
        logits_dim=None,
        num_channels=1,
        num_latents=512,
        hidden_dim=64,
        latent_dim=16,
        cross_heads=8,
        latent_heads=8,
        cross_dim_head=64,
        latent_dim_head=64,
        weight_tie_layers=False,
        decoder_ff=False,
        seq_dropout_prob=0.0,
        use_query_residual=False,
        input_dim=2,
        max_freq=4,
        num_freq=12,
        embed_dim=16,
        scales=[3, 4, 5],
        bottleneck_index=0,
        use_norm=False,
        encode_geo=False,
    ):
        super().__init__()
        self.depth = depth
        self.bottleneck_index = bottleneck_index
        self.encode_geo = encode_geo

        self.seq_dropout_prob = seq_dropout_prob
        self.use_query_residual = use_query_residual

        self.pos_encoding = NeRFEncoding(
            num_freq=num_freq,
            max_freq_log2=max_freq,
            input_dim=input_dim,
            base_freq=2,
            log_sampling=True,
            include_input=False,
            use_pi=True,
        )

        self.pos_query = MultiScaleNeRFEncoding(
            num_freq,
            log_sampling=True,
            include_input=False,
            input_dim=input_dim,
            base_freq=2,
            scales=scales,
            use_pi=True,
            disjoint=True,
        ) 

        self.latents = nn.Parameter(torch.randn(num_latents, hidden_dim))
        value_dim = hidden_dim
        key_dim = self.pos_encoding.out_dim
        queries_dim = self.pos_query.out_dim_per_scale

        self.lift_values = nn.Linear(num_channels, hidden_dim)

        self.cross_attend_blocks = nn.ModuleList(
            [
                PreNormCross(
                    hidden_dim,
                    CrossAttention(
                        hidden_dim,
                        key_dim,
                        value_dim,
                        heads=cross_heads,
                        dim_head=cross_dim_head,
                    ),
                    k_dim=key_dim,
                    v_dim=value_dim,
                ),
                PreNorm(hidden_dim, FeedForward(hidden_dim)),
            ]
        )

        if self.encode_geo:
            self.cross_attend_geo = nn.ModuleList(
            [
                PreNormCross(
                    hidden_dim,
                    CrossAttention(
                        hidden_dim,
                        key_dim,
                        key_dim,
                        heads=cross_heads,
                        dim_head=cross_dim_head,
                    ),
                    k_dim=key_dim,
                    v_dim=key_dim,
                ),
                PreNorm(hidden_dim, FeedForward(hidden_dim)),
            ]
        )

        get_latent_attn = lambda: PreNorm(
            hidden_dim,
            Attention(hidden_dim, heads=latent_heads, dim_head=latent_dim_head),
        )
        get_latent_ff = lambda: PreNorm(hidden_dim, FeedForward(hidden_dim))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {"_cache": weight_tie_layers}

        for i in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
                )
            )

        self.decoder_cross_attn = PreNorm(
            queries_dim,
            MultiScaleAttention(
                queries_dim,
                hidden_dim,
                embed_dim,
                heads=cross_heads,
                dim_head=cross_dim_head,
            ),
            context_dim=hidden_dim,
        )
        self.decoder_ff = (
            PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        )

        if use_norm:
            self.norm = nn.LayerNorm(hidden_dim)
            self.act = nn.SiLU()
        self.use_norm = use_norm

        self.mean_fc = nn.Linear(hidden_dim, latent_dim)
        self.logvar_fc = nn.Linear(hidden_dim, latent_dim)
        self.lift_z = nn.Linear(latent_dim, hidden_dim)

        self.to_logits = (
            nn.Linear(embed_dim, logits_dim) if exists(logits_dim) else nn.Identity()
        )

    def forward(
        self,
        images,
        coords,
        mask=None,
        target_coords=None,
        sample_posterior=True,
        return_stats=False,
    ):
        b, *_, device = *images.shape, images.device

        if target_coords is None:
            queries = self.pos_query(coords)
        else:
            queries = self.pos_query(target_coords)

        x = repeat(self.latents, "n d -> b n d", b=b)

        # structured dropout (as done in perceiver AR https://arxiv.org/abs/2202.07765)

        if self.training and self.seq_dropout_prob > 0.0:
            images, mask = dropout_seq(images, mask, dropout=self.seq_dropout_prob)
            coords, mask = dropout_seq(coords, mask, dropout=self.seq_dropout_prob)

        k = self.pos_encoding(coords)
        v = self.lift_values(images)

        # if encode_geo
        if self.encode_geo:
            cross_attn, cross_ff = self.cross_attend_geo
            x = cross_attn(x, k=k, v=k, mask=mask) + x
            x = cross_ff(x) + x

        # cross attention only happens once for Perceiver IO
        cross_attn, cross_ff = self.cross_attend_blocks

        x = cross_attn(x, k=k, v=v, mask=mask) + x
        x = cross_ff(x) + x

        # layers

        for index, (self_attn, self_ff) in enumerate(self.layers):
            if index == self.bottleneck_index:
                if self.use_norm:
                    x = self.norm(x)
                    x = self.act(x)
                mu = self.mean_fc(x)
                logvar = self.logvar_fc(x)
                posterior = DiagonalGaussianDistribution(mu, logvar)

                if sample_posterior:
                    z = posterior.sample()
                else:
                    z = posterior.mode()

                x = self.lift_z(z)

            x = self_attn(x) + x
            x = self_ff(x) + x


        # cross attend from decoder queries to latents

        latents = self.decoder_cross_attn(queries, context=x)

        if self.use_query_residual:
            latents = latents + queries

        # optional decoder feedforward

        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)

        # final linear out

        kl_loss = posterior.kl()
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

        if return_stats:
            return self.to_logits(latents), kl_loss, mu, logvar

        return self.to_logits(latents), kl_loss

    def get_features(
        self,
        images,
        coords,
        mask=None
    ):
        b, *_, device = *images.shape, images.device

        c = coords.clone()

        x = repeat(self.latents, "n d -> b n d", b=b)

        k = self.pos_encoding(c)
        v = self.lift_values(images)

        # if encode_geo
        if self.encode_geo:
            cross_attn, cross_ff = self.cross_attend_geo
            x = cross_attn(x, k=k, v=k, mask=mask) + x
            x = cross_ff(x) + x

        cross_attn, cross_ff = self.cross_attend_blocks

        # cross attention only happens once for Perceiver IO

        x = cross_attn(x, k=k, v=v, mask=mask) + x
        x = cross_ff(x) + x

        # layers

        for index, (self_attn, self_ff) in enumerate(self.layers):
            if index == self.bottleneck_index:
                if self.use_norm:
                    x = self.norm(x)
                    x = self.act(x)
                mu = self.mean_fc(x)
                logvar = self.logvar_fc(x)

                return mu, logvar

            x = self_attn(x) + x
            x = self_ff(x) + x

    def process(self, features, coords):
        queries = self.pos_query(coords)
        x = features

        # cross attend from decoder queries to latents
        for index, (self_attn, self_ff) in enumerate(self.layers):
            if self.bottleneck_index == index:
                x = self.lift_z(features)
                x = self_attn(x) + x
                x = self_ff(x) + x
            elif self.bottleneck_index > index:
                pass

            else:
                x = self_attn(x) + x
                x = self_ff(x) + x

        latents = self.decoder_cross_attn(queries, context=x)

        if self.use_query_residual:
            latents = latents + queries

        # optional decoder feedforward

        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)

        # final linear out

        return self.to_logits(latents)

    def process_from_stats(self, mean, logvar, queries):
        posterior = DiagonalGaussianDistribution(mean, logvar)
        z = posterior.sample()
        x = self.lift_z(z)

        # layers

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        if not exists(queries):
            return x

        # make sure queries contains batch dimension

        if queries.ndim == 2:
            queries = repeat(queries, "n d -> b n d", b=b)

        # cross attend from decoder queries to latents

        latents = self.decoder_cross_attn(queries, context=x)

        if self.use_query_residual:
            latents = latents + queries

        # optional decoder feedforward

        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)

        # final linear out

        return self.to_logits(latents), z

    def process_from_codes(self, codes, queries):
        z = codes
        x = self.lift_z(z)

        # layers

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        if not exists(queries):
            return x

        # make sure queries contains batch dimension

        if queries.ndim == 2:
            queries = repeat(queries, "n d -> b n d", b=b)

        # cross attend from decoder queries to latents

        latents = self.decoder_cross_attn(queries, context=x)

        if self.use_query_residual:
            latents = latents + queries

        # optional decoder feedforward

        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)

        # final linear out

        return self.to_logits(latents), z

    def encode(self, data, mask=None):
        b, *_, device = *data.shape, data.device

        x = repeat(self.latents, "n d -> b n d", b=b)

        cross_attn, cross_ff = self.cross_attend_blocks

        # structured dropout (as done in perceiver AR https://arxiv.org/abs/2202.07765)

        if self.training and self.seq_dropout_prob > 0.0:
            data, mask = dropout_seq(data, mask, self.seq_dropout_prob)

        # cross attention only happens once for Perceiver IO

        x = cross_attn(x, context=data, mask=mask, pos=data) + x
        x = cross_ff(x) + x

        for index, (self_attn, self_ff) in enumerate(self.pre_layers):
            x = self_attn(x) + x
            x = self_ff(x) + x

        # bottleneck-layer here

        mu = self.mean_fc(x)
        logvar = self.logvar_fc(x)

        return mu, logvar


# Perceiver LM example


class LocalityAwareINRDecoder(nn.Module):
    def __init__(
        self, output_dim=1, embed_dim=16, num_scales=3, dim=128, depth=3
    ):  # dim=32
        super().__init__()
        self.dim = dim
        # Define Fourier transformation, linear layers, and other components
        self.depth = depth
        self.modulation_layers = []
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim * num_scales, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, output_dim),
        )

    def forward(self, localized_latents):
        localized_latents = einops.rearrange(localized_latents, "b n k c -> b n (k c)")
        return self.mlp(localized_latents)


class FourierPositionalEmbedding(nn.Module):
    def __init__(
        self,
        hidden_dim=128,
        num_freq=32,
        max_freq_log2=5,
        input_dim=2,
        base_freq=2,
        use_relu=True,
    ):
        super().__init__()

        self.nerf_embedder = NeRFEncoding(
            num_freq=num_freq,
            max_freq_log2=max_freq_log2,
            input_dim=input_dim,
            base_freq=base_freq,
            log_sampling=False,
            include_input=False,
        )

        self.linear = nn.Linear(self.nerf_embedder.out_dim, hidden_dim)
        self.use_relu = use_relu

    def forward(self, coords):
        # Ensure coords are in [-1, 1]
        x = self.nerf_embedder(coords)
        if self.use_relu:
            x = torch.relu(self.linear(x))
        else:
            x = self.linear(x)  

        return x


class AROMAEncoderDecoderKL(nn.Module):
    def __init__(
        self,
        input_dim=2,
        num_channels=1,
        hidden_dim=64,  # 256
        dim=256,
        num_self_attentions=1,
        num_latents=16,
        latent_dim=8,  # latent_dim=8
        latent_heads=12,
        latent_dim_head=64,
        cross_heads=8,
        cross_dim_head=64,
        scales=[3, 4, 5],
        dropout_seq=0,
        embed_dim=16,
        depth_inr=3,
        bottleneck_index=1,
        encode_geo=False,
        max_encoding_freq=4,
        num_freq=12
    ):
        super().__init__()

        self.encoder = PerceiverIO(  # dimension of sequence to be encoded# dimension of decoder queries
            logits_dim=None,
            num_channels=num_channels,  # dimension of final logits
            depth=num_self_attentions,  # depth of net
            num_latents=num_latents,
            hidden_dim=hidden_dim,  # number of latents, or induced set points, or centroids. different papers giving it different names
            latent_dim=latent_dim,  # latent dimension
            cross_heads=cross_heads,  # number of heads for cross attention. paper said 1
            latent_heads=latent_heads,  # number of heads for latent self attention, 8
            cross_dim_head=cross_dim_head,  # number of dimensions per cross attention head
            latent_dim_head=latent_dim_head,  # number of dimensions per latent self attention head
            weight_tie_layers=False,  # whether to weight tie layers (optional, as indicated in the diagram)
            seq_dropout_prob=dropout_seq,  # fraction of the tokens from the input sequence to dropout (structured dropout, for saving compute and regularizing effects)
            input_dim=input_dim,
            max_freq=max_encoding_freq,  # 4
            scales=scales,
            embed_dim=embed_dim,
            bottleneck_index=bottleneck_index,
            encode_geo=encode_geo,
            num_freq=num_freq
        )

        self.decoder = LocalityAwareINRDecoder(
            output_dim=num_channels,
            embed_dim=embed_dim,
            num_scales=len(scales),
            dim=dim,
            depth=depth_inr,
        )

    def forward(self, images, coords, mask=None, target_coords=None,return_stats=False):
        if return_stats:
            localized_latents, kl_loss, mean, logvar = self.encoder(
                images, coords, mask, target_coords, return_stats=return_stats
            )
        else:
            localized_latents, kl_loss = self.encoder(
                images, coords, mask, target_coords, return_stats=return_stats
            )

        output_features = self.decoder(localized_latents)

        if return_stats:
            return output_features, kl_loss, mean, logvar

        return output_features, kl_loss

    def decode_from_stats(self, mean, logvar, coords):
        queries = self.query_network(coords)
        localized_latents, z = self.encoder.process_from_stats(mean, logvar, queries)
        output_features = self.decoder(coords, localized_latents)

        return output_features

    def decode_from_codes(self, codes, coords):
        queries = self.query_network(coords)
        localized_latents, z = self.encoder.process_from_codes(codes, queries)
        output_features = self.decoder(coords, localized_latents)

        return output_features

    def encode(self, images, coords, mask=None):
        # sequence = self.lift_values(torch.cat([images, self.positional_encoding(coords)], axis=-1))
        mu, logvar = self.encoder.get_features(images, coords, mask)

        return mu, logvar 

    def decode(self, features, coords):
        # sequence = self.lift_values(torch.cat([images, self.positional_encoding(coords)], axis=-1))
        localized_latents = self.encoder.process(features, coords)
        output_features = self.decoder(localized_latents)

        return output_features


def linear_scheduler(start, end, num_steps):
    delta = (end - start) / num_steps
    return [start + i * delta for i in range(num_steps)]
