import chex
import optax
from clu import metrics as clu_metrics
from flax import struct as flax_struct
from flax.training import train_state
from typing_extensions import Self


class MetricTrainState(train_state.TrainState):
    metrics: clu_metrics.Collection = flax_struct.field(
        default_factory=clu_metrics.Collection.create_collection
    )

    def apply_gradients(
        self,
        /,
        grads: optax.Updates,
        metrics: clu_metrics.Collection | None = None,
        **kwargs,
    ) -> Self:
        state = super().apply_gradients(grads=grads)
        new_metrics = self.metrics.merge(metrics) if metrics else self.metrics
        return state.replace(metrics=new_metrics, **kwargs)

    @classmethod
    def create(
        cls,
        /,
        *,
        params: chex.ArrayTree,
        **kwargs,
    ) -> Self:
        return super().create(
            params=params,
            **kwargs,
        )


class WeightedParticleState(MetricTrainState):
    support_map: train_state.TrainState = flax_struct.field(
        default_factory=train_state.TrainState.create
    )
