diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index b6ea80cf32..47e129d2b4 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -64,6 +64,7 @@ jobs: CUDA_VISIBLE_DEVICES: 0 # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html XLA_PYTHON_CLIENT_PREALLOCATE: false + XLA_PYTHON_CLIENT_ALLOCATOR: platform - name: Convert models run: source/tests/infer/convert-models.sh - name: Download libtorch diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e5c0557851..4d6db2521f 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -2,6 +2,7 @@ """Utilities for the array API.""" import array_api_compat +import numpy as np from packaging.version import ( Version, ) @@ -73,3 +74,21 @@ def xp_take_along_axis(arr, indices, axis): out = xp.take(arr, indices) out = xp.reshape(out, shape) return xp_swapaxes(out, axis, -1) + + +def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: + """Reduces all values from the src tensor to the indices specified in the index tensor.""" + # jax only + if array_api_compat.is_jax_array(input): + from deepmd.jax.common import ( + scatter_sum, + ) + + return scatter_sum( + input, + dim, + index, + src, + ) + else: + raise NotImplementedError("Only JAX arrays are supported.") diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index e4233eb397..88e65a849a 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -1,10 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + from deepmd.dpmodel.atomic_model import ( DPEnergyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) from .dp_model import ( DPModelCommon, @@ -25,3 +32,15 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) + self._enable_hessian = False + self.hess_fitting_def = None + + def enable_hessian(self): + self.hess_fitting_def = deepcopy(self.atomic_output_def()) + self.hess_fitting_def["energy"].r_hessian = True + self._enable_hessian = True + + def atomic_output_def(self) -> FittingOutputDef: + if self._enable_hessian: + return self.hess_fitting_def + return super().atomic_output_def() diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index af1429ce25..9d7873f081 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -3,6 +3,9 @@ import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_scatter_sum, +) from deepmd.dpmodel.common import ( GLOBAL_ENER_FLOAT_PRECISION, ) @@ -11,6 +14,7 @@ ModelOutputDef, OutputVariableDef, get_deriv_name, + get_hessian_name, get_reduce_name, ) @@ -81,6 +85,7 @@ def communicate_extended_output( """ xp = array_api_compat.get_namespace(mapping) + mapping_ = mapping new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] @@ -98,24 +103,96 @@ def communicate_extended_output( mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) - # jax only - if array_api_compat.is_jax_array(force): - from deepmd.jax.common import ( - scatter_sum, - ) - - force = scatter_sum( - force, - 1, - mapping, - model_ret[kk_derv_r], - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") + force = xp_scatter_sum( + force, + 1, + mapping, + model_ret[kk_derv_r], + ) new_ret[kk_derv_r] = force else: # name holders new_ret[kk_derv_r] = None + if vdef.r_hessian: + kk_hess = get_hessian_name(kk) + if model_ret[kk_hess] is not None: + # [nf, *def, nall, 3, nall, 3] + hess_ = model_ret[kk_hess] + def_ndim = len(vdef.shape) + # [nf, nall1, nall2, *def, 3(1), 3(2)] + hess_1 = xp.permute_dims( + hess_, + ( + 0, + def_ndim + 1, + def_ndim + 3, + *range(1, def_ndim + 1), + def_ndim + 2, + def_ndim + 4, + ), + ) + nall = hess_1.shape[1] + # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] + hessian1 = xp.zeros( + [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nall, *vdef.shape, 3, 3], + ) + hessian1 = xp_scatter_sum( + hessian1, + 1, + mapping_hess, + hess_1, + ) + # [nf, nall2, nloc1, *def, 3(1), 3(2)] + hessian1 = xp.permute_dims( + hessian1, + (0, 2, 1, *range(3, def_ndim + 5)), + ) + nloc = hessian1.shape[2] + # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] + hessian = xp.zeros( + [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nloc, *vdef.shape, 3, 3], + ) + hessian = xp_scatter_sum( + hessian, + 1, + mapping_hess, + hessian1, + ) + # -> [nf, *def, nloc1, 3(1), nloc2, 3(2)] + hessian = xp.permute_dims( + hessian, + ( + 0, + *range(3, def_ndim + 3), + 2, + def_ndim + 3, + 1, + def_ndim + 4, + ), + ) + # -> [nf, *def nloc1 * 3, nloc2 * 3] + hessian = xp.reshape( + hessian, + (hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3), + ) + + new_ret[kk_hess] = hessian + else: + new_ret[kk_hess] = None if vdef.c_differentiable: assert vdef.r_differentiable if model_ret[kk_derv_c] is not None: diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 5ca372c86a..7c97ff692f 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -8,6 +8,7 @@ ) from deepmd.dpmodel.output_def import ( get_deriv_name, + get_hessian_name, get_reduce_name, ) from deepmd.jax.env import ( @@ -87,6 +88,18 @@ def eval_output( ) model_predict[kk_derv_r] = extended_force + if vdef.r_hessian: + # [nf, *def, nall, 3, nall, 3] + hessian = jax.vmap(jax.hessian(eval_output, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + kk_hessian = get_hessian_name(kk) + model_predict[kk_hessian] = hessian if vdef.c_differentiable: assert vdef.r_differentiable # avr: [nf, *def, nall, 3, 3] diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py new file mode 100644 index 0000000000..798b893651 --- /dev/null +++ b/source/tests/jax/test_dp_hessian_model.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) + +if sys.version_info >= (3, 10): + from deepmd.jax.common import ( + to_jax_array, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, + ) + from deepmd.jax.model.ener_model import ( + EnergyModel, + ) + + dtype = jnp.float64 + + +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) +class TestCaseSingleFrameWithoutNlist: + def setUp(self) -> None: + # nloc == 3, nall == 4 + self.nloc = 3 + self.nf, self.nt = 1, 2 + self.coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([1, self.nloc * 3]) + self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) + self.cell = 2.0 * np.eye(3).reshape([1, 9]) + # sel = [5, 2] + self.sel = [16, 8] + self.sel_mix = [24] + self.natoms = [3, 3, 2, 1] + self.rcut = 2.2 + self.rcut_smth = 0.4 + self.atol = 1e-12 + + +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) +class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_self_consistency(self): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()) + md0.enable_hessian() + md1.enable_hessian() + args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.call(*args) + ret1 = md1.call(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret1["energy_derv_c_redu"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r_derv_r"]), + to_numpy_array(ret1["energy_derv_r_derv_r"]), + atol=self.atol, + ) + ret0 = md0.call(*args, do_atomic_virial=True) + ret1 = md1.call(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + atol=self.atol, + ) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py new file mode 100644 index 0000000000..185660e2be --- /dev/null +++ b/source/tests/jax/test_make_hessian_model.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) + +if sys.version_info >= (3, 10): + from deepmd.jax.common import ( + to_jax_array, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.env import ( + jax, + jnp, + ) + from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, + ) + from deepmd.jax.model import ( + EnergyModel, + ) + + from ..seed import ( + GLOBAL_SEED, + ) + + dtype = jnp.float64 + + +def finite_hessian(f, x, delta=1e-6): + in_shape = x.shape + assert len(in_shape) == 1 + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape + in_shape) + for iidx in np.ndindex(*in_shape): + for jidx in np.ndindex(*in_shape): + i0 = np.zeros(in_shape) + i1 = np.zeros(in_shape) + i2 = np.zeros(in_shape) + i3 = np.zeros(in_shape) + i0[iidx] += delta + i2[iidx] += delta + i1[iidx] -= delta + i3[iidx] -= delta + i0[jidx] += delta + i1[jidx] += delta + i2[jidx] -= delta + i3[jidx] -= delta + y0 = f(x + i0) + y1 = f(x + i1) + y2 = f(x + i2) + y3 = f(x + i3) + res[(Ellipsis, *iidx, *jidx)] = (y0 + y3 - y1 - y2) / (4 * delta**2.0) + return res + + +class HessianTest: + def test( + self, + ) -> None: + # setup test case + places = 5 + delta = 1e-3 + natoms = self.nloc + nf = self.nf + nv = self.nv + generator = jax.random.key(GLOBAL_SEED) + cell0 = jax.random.uniform(generator, [3, 3], dtype=dtype) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * jnp.eye(3) + cell1 = jax.random.uniform(generator, [3, 3], dtype=dtype) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * jnp.eye(3) + cell = jnp.stack([cell0, cell1]) + coord = jax.random.uniform(generator, [nf, natoms, 3], dtype=dtype) + coord = jnp.matmul(coord, cell) + cell = cell.reshape([nf, 9]) + coord = coord.reshape([nf, natoms * 3]) + atype = jnp.stack( + [ + jnp.asarray([0, 0, 1], dtype=jnp.int64), + jnp.asarray([1, 0, 1], dtype=jnp.int64), + ] + ).reshape([nf, natoms]) + nfp, nap = 2, 3 + fparam = jax.random.uniform(generator, [nf, nfp], dtype=dtype) + aparam = jax.random.uniform(generator, [nf, natoms * nap], dtype=dtype) + # forward hess and value models + ret_dict0 = self.model_hess( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + ret_dict1 = self.model_valu( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + # compare hess and value models + np.testing.assert_allclose(ret_dict0["energy"], ret_dict1["energy"]) + ana_hess = ret_dict0["energy_derv_r_derv_r"] + + # compute finite difference + fnt_hess = [] + for ii in range(nf): + + def np_infer( + xx, + ): + ret = self.model_valu( + to_jax_array(xx)[None, ...], + atype[ii][None, ...], + box=cell[ii][None, ...], + fparam=fparam[ii][None, ...], + aparam=aparam[ii][None, ...], + ) + # detach + ret = {kk: to_numpy_array(ret[kk]) for kk in ret} + return ret + + def ff(xx): + return np_infer(xx)["energy_redu"] + + xx = to_numpy_array(coord[ii]) + fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) + + # compare finite difference with autodiff + fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) + np.testing.assert_almost_equal( + fnt_hess, to_numpy_array(ana_hess), decimal=places + ) + + +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) +class TestDPModel(unittest.TestCase, HessianTest): + def setUp(self) -> None: + jax.random.key(2) + self.nf = 2 + self.nloc = 3 + self.rcut = 4.0 + self.rcut_smth = 3.0 + self.sel = [10, 10] + self.nt = 2 + self.nv = 1 + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft0 = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + # self.nv, + mixed_types=ds.mixed_types(), + neuron=[4, 4, 4], + ) + type_map = ["foo", "bar"] + self.model_hess = EnergyModel(ds, ft0, type_map=type_map) + self.model_hess.enable_hessian() + self.model_valu = EnergyModel.deserialize(self.model_hess.serialize()) + + def test_output_def(self) -> None: + self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) + self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) + self.assertEqual( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, + OutputVariableCategory.DERV_R_DERV_R, + )