Source code for erbs.bias.potential

from functools import partial
from typing import Optional, Union

import jax
import jax.numpy as jnp
import numpy as np
from apax.config.train_config import Config
from apax.data.input_pipeline import CachedInMemoryDataset
from apax.layers.descriptor import GaussianMomentDescriptor
from apax.layers.descriptor.basis_functions import (
    GaussianBasis,
    RadialFunction,
)
from apax.nn.models import FeatureModel
from apax.train.checkpoints import restore_parameters
from apax.utils.jax_md_reduced import partition, space
from ase.calculators.calculator import Calculator, all_changes
from matplotlib.path import Path
from tqdm import trange

from erbs.bias.energy_function_factory import OPESExploreFactory
from erbs.bias.state import BiasState
from erbs.dim_reduction.elementwise_pca import DimReduction


def build_feature_neighbor_fns(
    atoms,
    n_basis,
    r_max,
    dr_threshold,
    feature_fn: Optional[callable] = None,
    config: Optional[Config] = None,
    params=None,
    batched=False,
):
    box = np.asarray(atoms.get_cell().lengths(), dtype=jnp.float32)

    if batched:
        displacement_fn = None
        neighbor_fn = None
    else:
        if np.all(box < 1e-6):
            displacement_fn, _ = space.free()
            frac_coords = False
        else:
            displacement_fn, _ = space.periodic_general(box, fractional_coordinates=True)
            frac_coords = True
        neighbor_fn = partition.neighbor_list(
            displacement_fn,
            box,
            r_max,
            dr_threshold,
            fractional_coordinates=frac_coords,
            disable_cell_list=True,
            format=partition.Sparse,
        )

    if config and params:
        n_species = 119  # int(np.max(Z) + 1)
        Builder = config.model.get_builder()
        builder = Builder(config.model.get_dict(), n_species=n_species)

        feature_model = builder.build_ll_feature_model(
            apply_mask=True, init_box=np.array(box), inference_disp_fn=displacement_fn
        )
        feature_fn = partial(feature_model.apply, params)
    else:
        descriptor = GaussianMomentDescriptor(
            radial_fn=RadialFunction(
                n_basis,
                basis_fn=GaussianBasis(
                    n_basis=n_basis,
                    r_min=1.5,
                    r_max=r_max,
                ),
                emb_init=None,
            ),
            n_contr=8,
        )
        feature_model = FeatureModel(
            descriptor,
            readout=None,
            should_average=True,
            init_box=box,
            inference_disp_fn=displacement_fn,
        )
        feature_fn = partial(feature_model.apply, {})
    return feature_fn, neighbor_fn


[docs] class ERBS(Calculator): implemented_properties = ["energy", "forces", "stress"] def __init__( self, base_calc: Calculator, dim_reduction_factory: DimReduction, energy_fn_factory: OPESExploreFactory, feature_fn: Optional[callable] = None, model_dir: Optional[Union[Path, list[Path]]] = None, n_basis=5, r_max=6.0, dr_threshold=0.5, interval=10_000, update_iterations=np.inf, **kwargs, ): Calculator.__init__(self, **kwargs) # if not isinstance(base_calc, Calculator): # raise ValueError( # "All the calculators should be inherited from" # "the ase's Calculator class" # ) self.base_calc = base_calc self.n_basis = n_basis self.model_config = None self.params = None self.feature_fn = feature_fn if model_dir: self.model_config, self.params = restore_parameters(model_dir) self.r_max = r_max self.dr_threshold = dr_threshold self.update_iterations = update_iterations self.cv_fn = None self.dim_reduction_factory = dim_reduction_factory self.dim_red_fn = None self.energy_fn_factory = energy_fn_factory self.energy_fn = None self.body_fn = None self.auxilliary_cvs = [] # used for dimensionality reduction self.ref_cvs = [] self.bias_state = None self.neighbors = None self.neighbor_fn = None self.interval = interval self._step_counter = 0 self.accumulate = True self.bias_results = None def _initialize_nl(self, atoms): self.cv_fn, self.neighbor_fn = build_feature_neighbor_fns( atoms, self.n_basis, self.r_max, self.dr_threshold, feature_fn=self.feature_fn, config=self.model_config, params=self.params, ) self.cv_fn = jax.jit(self.cv_fn) def update_with_new_dimred(self, g_new): self.ref_cvs.append(g_new) reduced_ref_cvs = self.dim_reduction_factory.fit_transform( np.array(self.ref_cvs + self.auxilliary_cvs) ) self.dim_red_fn = self.dim_reduction_factory.create_dim_reduction_fn() self.dim_red_fn = jax.jit(self.dim_red_fn) # create energy fn with new dim_reduction_fn self.energy_fn = self.energy_fn_factory.create( self.cv_fn, self.dim_red_fn, ) threshold = self.energy_fn_factory.compression_threshold self.bias_state = BiasState( std=self.energy_fn_factory.std, g=reduced_ref_cvs, compression_threshold=threshold, ) self.bias_state = self.bias_state.initialize() if len(reduced_ref_cvs) > 2: self.bias_state = self.bias_state.compress() def update_with_fixed_dimred(self, g_new): self.ref_cvs.append(g_new) g_new_red = self.dim_red_fn(g_new) if self.bias_state is None: raise ValueError("Bias state has not yet been initialized") self.bias_state.add_configuration(g_new_red) def update_neighbors(self, position, box, is_pbc): if self.neighbors is None: if is_pbc: self.neighbors = self.neighbor_fn.allocate(position, box=box) else: self.neighbors = self.neighbor_fn.allocate(position) else: if is_pbc: self.neighbors = self.neighbors.update(position, box=box) else: self.neighbors = self.neighbors.update(position) def update_bias(self, atoms): position = jnp.array(atoms.positions, dtype=jnp.float64) numbers = jnp.array(atoms.numbers, dtype=jnp.int32) box = jnp.asarray(atoms.cell.array) is_pbc = np.any(atoms.get_cell().lengths() > 1e-6) if is_pbc: box = box.T inv_box = jnp.linalg.inv(box) position = space.transform(inv_box, position) self.update_neighbors(position, box, is_pbc) if self.neighbors.did_buffer_overflow: print("neighbor list overflowed, reallocating.") if is_pbc: self.neighbors = self.neighbor_fn.allocate(position, box=box) else: self.neighbors = self.neighbor_fn.allocate(position) offsets = jnp.zeros((self.neighbors.idx.shape[1], 3)) g_new = self.cv_fn(position, numbers, self.neighbors.idx, box, offsets) should_reinit = self._step_counter < self.update_iterations if self.bias_state is None or should_reinit: self.update_with_new_dimred(g_new) else: self.update_with_fixed_dimred(g_new) @jax.jit def body_fn(positions, neighbor, box, bias_state): if np.any(atoms.get_cell().lengths() > 1e-6): box = box.T inv_box = jnp.linalg.inv(box) positions = space.transform(inv_box, positions) neighbor = neighbor.update(positions, box=box) else: neighbor = neighbor.update(positions) offsets = jnp.full([neighbor.idx.shape[1], 3], 0) ef_function = jax.value_and_grad(self.energy_fn) energy, neg_forces = ef_function( positions, numbers, neighbor, box, offsets, bias_state ) forces = -neg_forces results = {"energy": energy, "forces": forces} return results, neighbor self.body_fn = body_fn
[docs] def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): Calculator.calculate(self, atoms, properties, system_changes) self.base_calc.calculate(atoms, properties, system_changes) self.results = self.base_calc.results positions = jnp.asarray(atoms.positions, dtype=jnp.float64) box = jnp.asarray(atoms.cell.array, dtype=jnp.float64) if self._step_counter == 0: self._initialize_nl(atoms) should_update_bias = self._step_counter % self.interval == 0 if should_update_bias and self.accumulate: self.update_bias(atoms) bias_results, self.neighbors = self.body_fn( positions, self.neighbors, box, self.bias_state ) if self.neighbors.did_buffer_overflow: print("neighbor list overflowed, reallocating.") self.neighbors = self.neighbor_fn.allocate(positions, box=box) bias_results, self.neighbors = self.body_fn( positions, self.neighbors, box, self.bias_state ) self.bias_results = { k: np.array(v, dtype=np.float64) for k, v in bias_results.items() } self.results["energy"] = self.results["energy"] + self.bias_results["energy"] self.results["forces"] = self.results["forces"] + self.bias_results["forces"] self.results["energy_bias"] = self.bias_results["energy"] self.results["forces_bias"] = self.bias_results["forces"] self._step_counter += 1
def compute_cvs(self, atoms_list, batch_size=4): dataset = CachedInMemoryDataset( atoms_list, self.r_max, batch_size, n_epochs=1, ignore_labels=True, ) n_data = dataset.n_data ds = dataset.batch() self.cv_fn, _ = build_feature_neighbor_fns( atoms_list[0], self.n_basis, self.r_max, dr_threshold=self.dr_threshold, batched=True, ) def calc_descriptor(positions, Z, neighbors, box, offsets): g = self.cv_fn(positions, Z, neighbors, box, offsets) return g calc_descriptor = jax.vmap(calc_descriptor, in_axes=(0, 0, 0, 0, 0)) calc_descriptor = jax.jit(calc_descriptor) descriptors = [] pbar = trange(n_data, desc="Evaluating data", ncols=100, leave=False) for i, inputs in enumerate(ds): g = calc_descriptor( inputs["positions"], inputs["numbers"], inputs["idx"], inputs["box"], inputs["offsets"], ) num_strucutres_in_batch = g.shape[0] for j in range(num_strucutres_in_batch): g_cpu = np.asarray(g[j]) descriptors.append(g_cpu) pbar.update(batch_size) pbar.close() dataset.cleanup() return descriptors def add_configs(self, atoms_list, batch_size=4, for_dimred_only=True): descriptors = self.compute_cvs(atoms_list, batch_size) if for_dimred_only: self.auxilliary_cvs.extend(descriptors) else: self.ref_cvs.extend(descriptors) def add_descriptors(self, path, for_dimred_only=True): data = np.load(path) descriptors = data["g"] descriptors = list(descriptors) if for_dimred_only: self.auxilliary_cvs.extend(descriptors) else: self.ref_cvs.extend(descriptors) def save_descriptors(self, path): data = {"g": np.array(self.ref_cvs)} if len(self.auxilliary_cvs) > 0: data["g_aux"] = np.array(self.auxilliary_cvs) np.savez(path, **data)