From 9cad115c715b3f9df7813410153c2fc192a8240c Mon Sep 17 00:00:00 2001 From: Franz Goerlich <102743232+frgoe003@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:24:55 +0200 Subject: [PATCH 1/3] Fix CG mapping (#8) Fixed mapping under periodic boundary conditions --- chemtrain/data/preprocessing.py | 87 +++++++++--------- .../test_mapping/CG_weights_ala2.npy | 3 + .../test_mapping/forces_ala2_AT.npy | 3 + .../test_mapping/forces_ala2_CG.npy | 3 + .../test_mapping/positions_ala2_AT.npy | 3 + .../test_mapping/positions_ala2_CG.npy | 3 + tests/test_utils/test_mapping.py | 91 +++++++++++++++++++ 7 files changed, 149 insertions(+), 44 deletions(-) create mode 100644 tests/data/test_utils/test_mapping/CG_weights_ala2.npy create mode 100644 tests/data/test_utils/test_mapping/forces_ala2_AT.npy create mode 100644 tests/data/test_utils/test_mapping/forces_ala2_CG.npy create mode 100644 tests/data/test_utils/test_mapping/positions_ala2_AT.npy create mode 100644 tests/data/test_utils/test_mapping/positions_ala2_CG.npy create mode 100644 tests/test_utils/test_mapping.py diff --git a/chemtrain/data/preprocessing.py b/chemtrain/data/preprocessing.py index aa3d0b03..55b61899 100644 --- a/chemtrain/data/preprocessing.py +++ b/chemtrain/data/preprocessing.py @@ -198,19 +198,18 @@ def map_dataset(position_dataset, shift_fn, c_map, d_map=None, - force_dataset = None): - """Maps fine-scaled positions and forces to a coarser scale. + force_dataset=None): + """ + Maps fine-scaled positions and forces to a coarser scale. Uses the linear mapping from [Noid2008]_ to map fine-scaled positions and forces to coarse grained positions and forces via the relations: .. math:: - \\mathbf R_I = \\sum_{i \\in \\mathcal I_I} c_{Ii} \\mathbf r_i,\\quad \\text{and} \\mathbf{F}_I = \\sum_{i \\in \\mathcal I_I} \\frac{d_{Ii}}{c_{Ii}} \\mathbf f_i. - Args: position_dataset: Dataset of fine-scaled positions. displacement_fn: Function to compute the displacement between two @@ -232,49 +231,49 @@ def map_dataset(position_dataset, *The multiscale coarse-graining method. I. A rigorous bridge between atomistic and coarse-grained models*. J. Chem. Phys. 28 June 2008; 128 (24): 244114. https://doi-org.eaccess.tum.edu/10.1063/1.2938860 - - """ - # Compute the mapping via displacements to take care of periodic - # boundary conditions - - disp_fn = jax.vmap(displacement_fn, in_axes=(None, 0)) - - ref_positions = jnp.zeros_like(position_dataset[0, 0, :]) - displacements = lax.map( - functools.partial(disp_fn, ref_positions), - position_dataset - ) - - c_map /= jnp.sum(c_map, axis=1, keepdims=True) - - cg_dislacements = lax.map( - functools.partial(jnp.einsum, 'Ii..., id->Id', c_map), - -displacements - ) - - cg_positions = lax.map( - functools.partial(jax.vmap(shift_fn, in_axes=(None, 0)), ref_positions), - cg_dislacements - ) - - # Map forces if provided + # Normalise mapping weights + c_norm = c_map / jnp.sum(c_map, axis=1, keepdims=True) + if d_map is not None: + d_norm = d_map / jnp.sum(d_map, axis=1, keepdims=True) + else: + d_norm = None + + def _map_single(ipt, shift_fn, displacement_fn, c_norm, d_norm): + pos, forc = ipt + + # Choose reference for each CG bead + ref_idx = jnp.argmax(c_map, axis=1) + ref_positions = pos[ref_idx, :] + + # Compute displacements for each reference position and map + disp = jax.vmap( + lambda r: jax.vmap(lambda p: displacement_fn(p, r))(pos) + )(ref_positions) + cg_disp = jnp.einsum('Ii,Iid->Id', c_map, disp) + cg_positions = jax.vmap(shift_fn)(ref_positions, cg_disp) + + + if (forc is not None) and (d_norm is not None): + mask = (c_norm > 0.0) + safe_c = jnp.where(mask, c_norm, 1.0) + cg_forces = jnp.einsum('Ii, id->Id', mask * d_norm / safe_c, forc) + else: + cg_forces = None + + return cg_positions, cg_forces + + _map_single = functools.partial(_map_single, + shift_fn=shift_fn, + displacement_fn=displacement_fn, + c_norm=c_norm, + d_norm=d_norm) if force_dataset is None: - return cg_positions - - d_map /= jnp.sum(d_map, axis=1, keepdims=True) - - # Avoid division by zero. - mask = (c_map > 0.0) - safe_c = jnp.where(mask, c_map, 1.0) - - cg_forces = lax.map( - functools.partial(jnp.einsum, 'Ii..., id->Id', mask * d_map / safe_c), - force_dataset - ) - - return cg_positions, cg_forces + # map positions only + return lax.map(lambda pos: _map_single((pos, None))[0], position_dataset) + else: + return lax.map(_map_single, (position_dataset, force_dataset)) def allocate_neighborlist(dataset, diff --git a/tests/data/test_utils/test_mapping/CG_weights_ala2.npy b/tests/data/test_utils/test_mapping/CG_weights_ala2.npy new file mode 100644 index 00000000..8297d331 --- /dev/null +++ b/tests/data/test_utils/test_mapping/CG_weights_ala2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7e92b6ccee1e4f926cdc6aae6a3625fa87a25b2b724d888c7693dea30fbb800 +size 1008 diff --git a/tests/data/test_utils/test_mapping/forces_ala2_AT.npy b/tests/data/test_utils/test_mapping/forces_ala2_AT.npy new file mode 100644 index 00000000..cb44aa9d --- /dev/null +++ b/tests/data/test_utils/test_mapping/forces_ala2_AT.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b05bbd540fa21dcc881e80d71f073b11e566b82fbc49899fcc18849f4dfaa41 +size 264128 diff --git a/tests/data/test_utils/test_mapping/forces_ala2_CG.npy b/tests/data/test_utils/test_mapping/forces_ala2_CG.npy new file mode 100644 index 00000000..648c3b14 --- /dev/null +++ b/tests/data/test_utils/test_mapping/forces_ala2_CG.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f8c9228e3fb16ae458dbbea2af8e47fc20db8c5ccb4b01c3132f34e9aabd049 +size 120128 diff --git a/tests/data/test_utils/test_mapping/positions_ala2_AT.npy b/tests/data/test_utils/test_mapping/positions_ala2_AT.npy new file mode 100644 index 00000000..26c150c7 --- /dev/null +++ b/tests/data/test_utils/test_mapping/positions_ala2_AT.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95bce0ba937718604ba9d0b109d1e56acf7330c7ff71532ec843ed22ae31fdf9 +size 264128 diff --git a/tests/data/test_utils/test_mapping/positions_ala2_CG.npy b/tests/data/test_utils/test_mapping/positions_ala2_CG.npy new file mode 100644 index 00000000..b607b1d7 --- /dev/null +++ b/tests/data/test_utils/test_mapping/positions_ala2_CG.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3e48051ddaf68112b1880a088627e3285df0f74a2103504f086db0b99e4e015 +size 120128 diff --git a/tests/test_utils/test_mapping.py b/tests/test_utils/test_mapping.py new file mode 100644 index 00000000..36a79610 --- /dev/null +++ b/tests/test_utils/test_mapping.py @@ -0,0 +1,91 @@ +# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License and limitations under the License. + +import jax.numpy as jnp +from jax_md import space +from chemtrain.data import preprocessing +from pathlib import Path +import pytest + +class TestMappingAla2: + @pytest.fixture + def setup_problem(self, datafiles): + # Directory containing test mapping data for alanine dipeptide + data_dir = Path(datafiles) + + # Load all-atom and coarse-grained data + Ala2_AT_F = jnp.load(str(data_dir / 'forces_ala2_AT.npy')) + Ala2_AT_R = jnp.load(str(data_dir / 'positions_ala2_AT.npy')) + Ala2_CG_F = jnp.load(str(data_dir / 'forces_ala2_CG.npy')) + Ala2_CG_R = jnp.load(str(data_dir / 'positions_ala2_CG.npy')) + weights = jnp.load(str(data_dir / 'CG_weights_ala2.npy')) + + # Define periodic box + box = jnp.identity(3) * 6 + displacement_fn, shift_fn = space.periodic_general( + box=box, fractional_coordinates=False + ) + + return { + 'AT_F': Ala2_AT_F, + 'AT_R': Ala2_AT_R, + 'CG_F': Ala2_CG_F, + 'CG_R': Ala2_CG_R, + 'weights': weights, + 'displacement_fn': displacement_fn, + 'shift_fn': shift_fn + } + + @pytest.mark.test_mapping_combined + def test_map_ala2(self, setup_problem): + # Get data from setup + data = setup_problem + + # Map dataset: positions and forces + mapped_R, mapped_F = preprocessing.map_dataset( + data['AT_R'], + data['displacement_fn'], + data['shift_fn'], + data['weights'], + data['weights'], + data['AT_F'], + ) + + assert mapped_R.shape == data['CG_R'].shape, \ + f"Mapped positions shape {mapped_R.shape} does not match expected {data['CG_R'].shape}" + assert mapped_F.shape == data['CG_F'].shape, \ + f"Mapped forces shape {mapped_F.shape} does not match expected {data['CG_F'].shape}" + + assert jnp.allclose(mapped_R, data['CG_R'], rtol=1e-3, atol=1e-3), \ + "Mapped positions not close to expected CG positions" + assert jnp.allclose(mapped_F, data['CG_F'], rtol=1e-3, atol=1e-3), \ + "Mapped forces not close to expected CG forces" + + @pytest.mark.test_mapping_positions + def test_map_ala2_positions(self, setup_problem): + # Get data from setup + data = setup_problem + + # Map dataset: positions only + mapped_R = preprocessing.map_dataset( + data['AT_R'], + data['displacement_fn'], + data['shift_fn'], + data['weights'], + ) + + assert mapped_R.shape == data['CG_R'].shape, \ + f"Mapped positions shape {mapped_R.shape} does not match expected {data['CG_R'].shape}" + + assert jnp.allclose(mapped_R, data['CG_R'], rtol=1e-3, atol=1e-3), \ + "Mapped positions not close to expected CG positions" + + \ No newline at end of file From b1d34a5c6b8824f43c39dfcdd0359d2e1f5508c2 Mon Sep 17 00:00:00 2001 From: Paul Fuchs Date: Mon, 28 Jul 2025 14:42:00 +0200 Subject: [PATCH 2/3] Update compatibility with newest jax version --- chemtrain/learn/difftre.py | 6 +----- chemtrain/learn/max_likelihood.py | 3 ++- chemtrain/trainers/base.py | 3 ++- chemtrain/util.py | 3 ++- chemtrain/version.py | 2 +- pyproject.toml | 7 ++++--- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/chemtrain/learn/difftre.py b/chemtrain/learn/difftre.py index c1630f15..c53d1d14 100644 --- a/chemtrain/learn/difftre.py +++ b/chemtrain/learn/difftre.py @@ -36,7 +36,7 @@ import numpy as onp import jax -from jax import jit, numpy as jnp, lax, tree_map +from jax import jit, numpy as jnp, lax from jax.typing import ArrayLike from jax_md_mod import custom_quantity @@ -182,10 +182,6 @@ def difftre_weights_fn(params, traj_state, reduction="min"): def difftre_loss_fn(params, traj_state, state_dict, targets): partial_loss = functools.partial(_difftre_loss, params) - # print(f"Trajstate shapes are {tree_map(jnp.shape, traj_state)}") - # print(f"Statedict shapes are {tree_map(jnp.shape, state_dict)}") - # print(f"Target shapes are {tree_map(jnp.shape, targets)}") - if not batched: return partial_loss(traj_state, state_dict, targets) diff --git a/chemtrain/learn/max_likelihood.py b/chemtrain/learn/max_likelihood.py index ce0f8c00..2b75476c 100644 --- a/chemtrain/learn/max_likelihood.py +++ b/chemtrain/learn/max_likelihood.py @@ -18,8 +18,9 @@ from functools import partial import jax -from jax import (lax, vmap, pmap, value_and_grad, tree_map, device_count, +from jax import (lax, vmap, value_and_grad, device_count, numpy as jnp, device_put, jit) +from jax.tree_util import tree_map from jax.sharding import Mesh, PartitionSpec, NamedSharding, SingleDeviceSharding from jax.experimental.shard_map import shard_map from jax_sgmc import data diff --git a/chemtrain/trainers/base.py b/chemtrain/trainers/base.py index 9e33aa69..ea14dff8 100644 --- a/chemtrain/trainers/base.py +++ b/chemtrain/trainers/base.py @@ -31,9 +31,10 @@ import jax import numpy as onp from jax import ( - tree_map, numpy as jnp, random, device_count, jit, device_get, + numpy as jnp, random, device_count, jit, device_get, tree_util ) +from jax.tree_util import tree_map from jax_sgmc import data from chemtrain import util diff --git a/chemtrain/util.py b/chemtrain/util.py index c8544d28..4a180d8a 100644 --- a/chemtrain/util.py +++ b/chemtrain/util.py @@ -20,7 +20,8 @@ import cloudpickle as pickle # import h5py import jax -from jax import tree_map, tree_util, device_count, numpy as jnp, tree_unflatten, lax +from jax import tree_util, device_count, numpy as jnp, lax +from jax.tree_util import tree_map import jax_md_mod from jax_md import simulate, partition diff --git a/chemtrain/version.py b/chemtrain/version.py index ba047b20..345886c6 100644 --- a/chemtrain/version.py +++ b/chemtrain/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/pyproject.toml b/pyproject.toml index d2d39692..a774ed4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", @@ -24,10 +25,10 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - 'jax <= 0.4.37', - 'jaxlib <= 0.4.37', + 'jax <= 0.7.0', + 'jaxlib <= 0.7.0', 'scipy < 1.13', # Removed scipy.linal.tril, etc. - 'jax-md', + 'jax-md @ git+https://github.com/jax-md/jax-md@jax-md-v0.2.25', 'jax-sgmc', 'optax', 'dm-haiku', From 854c2766b79c04eb2fbef843310e079ba6ab9210 Mon Sep 17 00:00:00 2001 From: Paul Fuchs Date: Mon, 28 Jul 2025 16:01:54 +0200 Subject: [PATCH 3/3] Updated readme --- README.md | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9a10eb58..19d258cb 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,15 @@ # Training Molecular Dynamics Potentials in JAX -[**Documentation**](https://chemtrain.readthedocs.io/en/latest/) | [**Preprint**](https://web3.arxiv.org/abs/2408.15852) | [**Getting Started**](#getting-started) | [**Installation**](#installation) | [**Contents**](#contents) | [**Contact**](#contact) +[**Documentation**](https://chemtrain.readthedocs.io/en/latest/) | [**Getting Started**](#getting-started) | [**Installation**](#installation) | [**Contents**](#contents) | [**Contact**](#contact) [![PyPI version](https://badge.fury.io/py/chemtrain.svg)](https://badge.fury.io/py/chemtrain) [![Documentation Status](https://readthedocs.org/projects/chemtrain/badge/?version=latest)](https://chemtrain.readthedocs.io/en/latest/?badge=latest) [![Test](https://github.com/tummfm/chemtrain/actions/workflows/test.yml/badge.svg)](https://github.com/tummfm/chemtrain/actions/workflows/test.yml) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Paper](https://img.shields.io/badge/Paper-chemtrain-yellow)](https://doi.org/10.1016/j.cpc.2025.109512) +[![Paper](https://img.shields.io/badge/Paper-chemtrain--deploy-yellow) +](https://doi.org/10.1021/acs.jctc.5c00996) + Neural Networks are promising models for enhancing the accuracy of classical molecular simulations. However, the training of accurate models is challenging. @@ -130,9 +134,9 @@ Within the repository, we provide the following directories: ## Citation -If you use chemtrain, please cite the following [paper](https://www.sciencedirect.com/science/article/pii/S0010465525000153): +If you use chemtrain or chemtrain-deploy, please cite the following [paper](https://www.sciencedirect.com/science/article/pii/S0010465525000153): -``` +```bibtex @article{fuchs2025chemtrain, title = {chemtrain: Learning deep potential models via automatic differentiation and statistical physics}, journal = {Computer Physics Communications}, @@ -146,6 +150,19 @@ If you use chemtrain, please cite the following [paper](https://www.sciencedirec } ``` +```bibtex +@article{fuchsChemtrainDeploy2025, + title = {Chemtrain-{{Deploy}}: {{A Parallel}} and {{Scalable Framework}} for {{Machine Learning Potentials}} in {{Million-Atom MD Simulations}}}, + author = {Fuchs, Paul and Chen, Weilong and Thaler, Stephan and Zavadlav, Julija}, + year = {2025}, + month = jul, + journal = {Journal of Chemical Theory and Computation}, + publisher = {American Chemical Society}, + issn = {1549-9618}, + doi = {10.1021/acs.jctc.5c00996} +} +``` + ## Contributing Contributions are always welcome! Please open a pull request to discuss the code additions.