Source code for erbs.descriptor.example_descriptor
from typing import Any, Callable
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from apax.utils.jax_md_reduced import space
[docs]
class RBFDescriptorFlax(nn.Module):
displacement_fn: Callable = space.free()[0]
n_basis: int = 5
r_min: float = 0.5
r_max: float = 6.0
dtype: Any = jnp.float32
[docs]
def setup(self):
self.betta = self.n_basis**2 / self.r_max**2
shifts = self.r_min + (self.r_max - self.r_min) / self.n_basis * np.arange(
self.n_basis
)
# shape: 1 x n_basis
shifts = einops.repeat(shifts, "n_basis -> 1 n_basis")
self.shifts = jnp.asarray(shifts, dtype=self.dtype)
self.metric = space.map_bond(
space.canonicalize_displacement_or_metric(self.displacement_fn)
)
def __call__(self, R, neighbor):
R = R.astype(self.dtype)
# R shape n_atoms x 3
n_atoms = R.shape[0]
# dr shape: neighbors
dr = self.metric(R[neighbor.idx[0]], R[neighbor.idx[1]])
dr = einops.repeat(dr, "neighbors -> neighbors 1")
# 1 x n_basis, neighbors x 1 -> neighbors x n_basis
distances = self.shifts - dr
# shape: neighbors x n_basis
radial_basis = jnp.exp(-self.betta * (distances**2))
descriptor = jax.ops.segment_sum(radial_basis, neighbor.idx[1], n_atoms)
return descriptor