Source code for erbs.bias.state

import dataclasses
from typing import Optional

import jax.numpy as jnp
import numpy as np
from flax.struct import dataclass
from jax import Array

from erbs.bias.kernel import compress, global_mc_normalisation, incremental_compress


[docs] @dataclass class BiasState: std: float compression_threshold: float g: Optional[Array] = None height: Optional[Array] = None cov: Optional[Array] = None normalisation: Optional[Array] = None def replace(self, **kwargs) -> "BiasState": return dataclasses.replace(self, **kwargs) def initialize(self): cv_dim = self.g.shape[-1] # TODO use determinant here somewhere # k0 = self.std ** (cv_dim * 2) k0 = 1 / (np.sqrt(2.0 * np.pi * self.std**2) ** cv_dim) height = jnp.full((self.g.shape[0], 1), k0) cov = jnp.full(self.g.shape, self.std) Zn = global_mc_normalisation( np.asarray(self.g), np.asarray(height), np.asarray(cov), ) return self.replace(height=height, cov=cov, normalisation=Zn) def compress(self): g, cov, height = compress( self.g, self.cov, self.height, thresh=self.compression_threshold ) return self.replace(g=g, cov=cov, height=height) def incremental_compress(self, gnew): cv_dim = self.g.shape[-1] covnew = self.std ** (cv_dim * 2) # move to bias state hnew = 1 / (np.sqrt(2.0 * np.pi * self.std**2) ** cv_dim) g, cov, height = incremental_compress( self.g, self.cov, self.height, gnew, covnew, hnew, thresh=self.compression_threshold, ) return self.replace(g=g, cov=cov, height=height) def add_configuration(self, gnew): cv_dim = self.g.shape[-1] covnew = np.full((cv_dim), self.std) # move to bias state hnew = 1 / (np.sqrt(2.0 * np.pi * self.std**2) ** cv_dim) hnew = np.array([hnew]) g, cov, height = incremental_compress( np.array(self.g), np.array(self.cov), np.array(self.height), np.array(gnew), covnew, hnew, thresh=self.compression_threshold, ) # for Zn: when compressing kernels, just compute the uncompressed contribution # to the normalisation # Zn = incremental_mc_normalisation() Zn = global_mc_normalisation( g, height, cov, ) new_state = self.replace( g=jnp.array(g), cov=jnp.array(cov), height=jnp.array(height), normalisation=Zn ) return new_state