Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Training Molecular Dynamics Potentials in JAX

[**Documentation**](https://chemtrain.readthedocs.io/en/latest/) | [**Preprint**](https://web3.arxiv.org/abs/2408.15852) | [**Getting Started**](#getting-started) | [**Installation**](#installation) | [**Contents**](#contents) | [**Contact**](#contact)
[**Documentation**](https://chemtrain.readthedocs.io/en/latest/) | [**Getting Started**](#getting-started) | [**Installation**](#installation) | [**Contents**](#contents) | [**Contact**](#contact)

[![PyPI version](https://badge.fury.io/py/chemtrain.svg)](https://badge.fury.io/py/chemtrain)
[![Documentation Status](https://readthedocs.org/projects/chemtrain/badge/?version=latest)](https://chemtrain.readthedocs.io/en/latest/?badge=latest)
[![Test](https://github.com/tummfm/chemtrain/actions/workflows/test.yml/badge.svg)](https://github.com/tummfm/chemtrain/actions/workflows/test.yml)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Paper](https://img.shields.io/badge/Paper-chemtrain-yellow)](https://doi.org/10.1016/j.cpc.2025.109512)
[![Paper](https://img.shields.io/badge/Paper-chemtrain--deploy-yellow)
](https://doi.org/10.1021/acs.jctc.5c00996)


Neural Networks are promising models for enhancing the accuracy of classical molecular
simulations. However, the training of accurate models is challenging.
Expand Down Expand Up @@ -130,9 +134,9 @@ Within the repository, we provide the following directories:

## Citation

If you use chemtrain, please cite the following [paper](https://www.sciencedirect.com/science/article/pii/S0010465525000153):
If you use chemtrain or chemtrain-deploy, please cite the following [paper](https://www.sciencedirect.com/science/article/pii/S0010465525000153):

```
```bibtex
@article{fuchs2025chemtrain,
title = {chemtrain: Learning deep potential models via automatic differentiation and statistical physics},
journal = {Computer Physics Communications},
Expand All @@ -146,6 +150,19 @@ If you use chemtrain, please cite the following [paper](https://www.sciencedirec
}
```

```bibtex
@article{fuchsChemtrainDeploy2025,
title = {Chemtrain-{{Deploy}}: {{A Parallel}} and {{Scalable Framework}} for {{Machine Learning Potentials}} in {{Million-Atom MD Simulations}}},
author = {Fuchs, Paul and Chen, Weilong and Thaler, Stephan and Zavadlav, Julija},
year = {2025},
month = jul,
journal = {Journal of Chemical Theory and Computation},
publisher = {American Chemical Society},
issn = {1549-9618},
doi = {10.1021/acs.jctc.5c00996}
}
```

## Contributing
Contributions are always welcome! Please open a pull request to discuss the code
additions.
Expand Down
87 changes: 43 additions & 44 deletions chemtrain/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,18 @@ def map_dataset(position_dataset,
shift_fn,
c_map,
d_map=None,
force_dataset = None):
"""Maps fine-scaled positions and forces to a coarser scale.
force_dataset=None):
"""
Maps fine-scaled positions and forces to a coarser scale.

Uses the linear mapping from [Noid2008]_ to map fine-scaled positions and
forces to coarse grained positions and forces via the relations:

.. math::

\\mathbf R_I = \\sum_{i \\in \\mathcal I_I} c_{Ii} \\mathbf r_i,\\quad \\text{and}

\\mathbf{F}_I = \\sum_{i \\in \\mathcal I_I} \\frac{d_{Ii}}{c_{Ii}} \\mathbf f_i.


Args:
position_dataset: Dataset of fine-scaled positions.
displacement_fn: Function to compute the displacement between two
Expand All @@ -232,49 +231,49 @@ def map_dataset(position_dataset,
*The multiscale coarse-graining method. I. A rigorous bridge between
atomistic and coarse-grained models*. J. Chem. Phys. 28 June 2008;
128 (24): 244114. https://doi-org.eaccess.tum.edu/10.1063/1.2938860


"""
# Compute the mapping via displacements to take care of periodic
# boundary conditions

disp_fn = jax.vmap(displacement_fn, in_axes=(None, 0))

ref_positions = jnp.zeros_like(position_dataset[0, 0, :])
displacements = lax.map(
functools.partial(disp_fn, ref_positions),
position_dataset
)

c_map /= jnp.sum(c_map, axis=1, keepdims=True)

cg_dislacements = lax.map(
functools.partial(jnp.einsum, 'Ii..., id->Id', c_map),
-displacements
)

cg_positions = lax.map(
functools.partial(jax.vmap(shift_fn, in_axes=(None, 0)), ref_positions),
cg_dislacements
)

# Map forces if provided
# Normalise mapping weights
c_norm = c_map / jnp.sum(c_map, axis=1, keepdims=True)
if d_map is not None:
d_norm = d_map / jnp.sum(d_map, axis=1, keepdims=True)
else:
d_norm = None

def _map_single(ipt, shift_fn, displacement_fn, c_norm, d_norm):
pos, forc = ipt

# Choose reference for each CG bead
ref_idx = jnp.argmax(c_map, axis=1)
ref_positions = pos[ref_idx, :]

# Compute displacements for each reference position and map
disp = jax.vmap(
lambda r: jax.vmap(lambda p: displacement_fn(p, r))(pos)
)(ref_positions)
cg_disp = jnp.einsum('Ii,Iid->Id', c_map, disp)
cg_positions = jax.vmap(shift_fn)(ref_positions, cg_disp)


if (forc is not None) and (d_norm is not None):
mask = (c_norm > 0.0)
safe_c = jnp.where(mask, c_norm, 1.0)
cg_forces = jnp.einsum('Ii, id->Id', mask * d_norm / safe_c, forc)
else:
cg_forces = None

return cg_positions, cg_forces

_map_single = functools.partial(_map_single,
shift_fn=shift_fn,
displacement_fn=displacement_fn,
c_norm=c_norm,
d_norm=d_norm)

if force_dataset is None:
return cg_positions

d_map /= jnp.sum(d_map, axis=1, keepdims=True)

# Avoid division by zero.
mask = (c_map > 0.0)
safe_c = jnp.where(mask, c_map, 1.0)

cg_forces = lax.map(
functools.partial(jnp.einsum, 'Ii..., id->Id', mask * d_map / safe_c),
force_dataset
)

return cg_positions, cg_forces
# map positions only
return lax.map(lambda pos: _map_single((pos, None))[0], position_dataset)
else:
return lax.map(_map_single, (position_dataset, force_dataset))


def allocate_neighborlist(dataset,
Expand Down
6 changes: 1 addition & 5 deletions chemtrain/learn/difftre.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import numpy as onp

import jax
from jax import jit, numpy as jnp, lax, tree_map
from jax import jit, numpy as jnp, lax
from jax.typing import ArrayLike

from jax_md_mod import custom_quantity
Expand Down Expand Up @@ -182,10 +182,6 @@ def difftre_weights_fn(params, traj_state, reduction="min"):
def difftre_loss_fn(params, traj_state, state_dict, targets):
partial_loss = functools.partial(_difftre_loss, params)

# print(f"Trajstate shapes are {tree_map(jnp.shape, traj_state)}")
# print(f"Statedict shapes are {tree_map(jnp.shape, state_dict)}")
# print(f"Target shapes are {tree_map(jnp.shape, targets)}")


if not batched:
return partial_loss(traj_state, state_dict, targets)
Expand Down
3 changes: 2 additions & 1 deletion chemtrain/learn/max_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from functools import partial

import jax
from jax import (lax, vmap, pmap, value_and_grad, tree_map, device_count,
from jax import (lax, vmap, value_and_grad, device_count,
numpy as jnp, device_put, jit)
from jax.tree_util import tree_map
from jax.sharding import Mesh, PartitionSpec, NamedSharding, SingleDeviceSharding
from jax.experimental.shard_map import shard_map
from jax_sgmc import data
Expand Down
3 changes: 2 additions & 1 deletion chemtrain/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
import jax
import numpy as onp
from jax import (
tree_map, numpy as jnp, random, device_count, jit, device_get,
numpy as jnp, random, device_count, jit, device_get,
tree_util
)
from jax.tree_util import tree_map
from jax_sgmc import data

from chemtrain import util
Expand Down
3 changes: 2 additions & 1 deletion chemtrain/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import cloudpickle as pickle
# import h5py
import jax
from jax import tree_map, tree_util, device_count, numpy as jnp, tree_unflatten, lax
from jax import tree_util, device_count, numpy as jnp, lax
from jax.tree_util import tree_map

import jax_md_mod
from jax_md import simulate, partition
Expand Down
2 changes: 1 addition & 1 deletion chemtrain/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.2.0"
__version__ = "0.2.1"
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
Expand All @@ -24,10 +25,10 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
'jax <= 0.4.37',
'jaxlib <= 0.4.37',
'jax <= 0.7.0',
'jaxlib <= 0.7.0',
'scipy < 1.13', # Removed scipy.linal.tril, etc.
'jax-md',
'jax-md @ git+https://github.com/jax-md/jax-md@jax-md-v0.2.25',
'jax-sgmc',
'optax',
'dm-haiku',
Expand Down
3 changes: 3 additions & 0 deletions tests/data/test_utils/test_mapping/CG_weights_ala2.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/test_utils/test_mapping/forces_ala2_AT.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/test_utils/test_mapping/forces_ala2_CG.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/test_utils/test_mapping/positions_ala2_AT.npy
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/test_utils/test_mapping/positions_ala2_CG.npy
Git LFS file not shown
91 changes: 91 additions & 0 deletions tests/test_utils/test_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License and limitations under the License.

import jax.numpy as jnp
from jax_md import space
from chemtrain.data import preprocessing
from pathlib import Path
import pytest

class TestMappingAla2:
@pytest.fixture
def setup_problem(self, datafiles):
# Directory containing test mapping data for alanine dipeptide
data_dir = Path(datafiles)

# Load all-atom and coarse-grained data
Ala2_AT_F = jnp.load(str(data_dir / 'forces_ala2_AT.npy'))
Ala2_AT_R = jnp.load(str(data_dir / 'positions_ala2_AT.npy'))
Ala2_CG_F = jnp.load(str(data_dir / 'forces_ala2_CG.npy'))
Ala2_CG_R = jnp.load(str(data_dir / 'positions_ala2_CG.npy'))
weights = jnp.load(str(data_dir / 'CG_weights_ala2.npy'))

# Define periodic box
box = jnp.identity(3) * 6
displacement_fn, shift_fn = space.periodic_general(
box=box, fractional_coordinates=False
)

return {
'AT_F': Ala2_AT_F,
'AT_R': Ala2_AT_R,
'CG_F': Ala2_CG_F,
'CG_R': Ala2_CG_R,
'weights': weights,
'displacement_fn': displacement_fn,
'shift_fn': shift_fn
}

@pytest.mark.test_mapping_combined
def test_map_ala2(self, setup_problem):
# Get data from setup
data = setup_problem

# Map dataset: positions and forces
mapped_R, mapped_F = preprocessing.map_dataset(
data['AT_R'],
data['displacement_fn'],
data['shift_fn'],
data['weights'],
data['weights'],
data['AT_F'],
)

assert mapped_R.shape == data['CG_R'].shape, \
f"Mapped positions shape {mapped_R.shape} does not match expected {data['CG_R'].shape}"
assert mapped_F.shape == data['CG_F'].shape, \
f"Mapped forces shape {mapped_F.shape} does not match expected {data['CG_F'].shape}"

assert jnp.allclose(mapped_R, data['CG_R'], rtol=1e-3, atol=1e-3), \
"Mapped positions not close to expected CG positions"
assert jnp.allclose(mapped_F, data['CG_F'], rtol=1e-3, atol=1e-3), \
"Mapped forces not close to expected CG forces"

@pytest.mark.test_mapping_positions
def test_map_ala2_positions(self, setup_problem):
# Get data from setup
data = setup_problem

# Map dataset: positions only
mapped_R = preprocessing.map_dataset(
data['AT_R'],
data['displacement_fn'],
data['shift_fn'],
data['weights'],
)

assert mapped_R.shape == data['CG_R'].shape, \
f"Mapped positions shape {mapped_R.shape} does not match expected {data['CG_R'].shape}"

assert jnp.allclose(mapped_R, data['CG_R'], rtol=1e-3, atol=1e-3), \
"Mapped positions not close to expected CG positions"