Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
to_numpy_array,
)
from deepmd.utils.data_system import (
Expand Down Expand Up @@ -306,6 +310,7 @@ def init_subclass_params(sub_data, sub_class):
# set trainable
for param in self.parameters():
param.requires_grad = trainable
self.compress = False

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -859,3 +864,85 @@ def update_sel(
)
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
return local_jdata_cpy, min_nbor_dist

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
# do some checks before the mocel compression process
if self.compress:
raise ValueError("Compression is already enabled.")
assert (
not self.repinit.resnet_dt
), "Model compression error: repinit resnet_dt must be false!"
for tt in self.repinit.exclude_types:
if (tt[0] not in range(self.repinit.ntypes)) or (
tt[1] not in range(self.repinit.ntypes)
):
raise RuntimeError(
"Repinit exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.repinit.ntypes)
+ "!"
)
if (
self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types)
== 0
):
raise RuntimeError(
"Repinit empty embedding-nets are not supported in model compression!"
)

if self.repinit.attn_layer != 0:
raise RuntimeError(
"Cannot compress model when repinit attention layer is not 0."
)

if self.repinit.tebd_input_mode != "strip":
raise RuntimeError(
"Cannot compress model when repinit tebd_input_mode == 'concat'"
)
Comment thread
njzjz marked this conversation as resolved.

# repinit doesn't have a serialize method
data = self.serialize()
self.table = DPTabulate(
self,
data["repinit_args"]["neuron"],
data["repinit_args"]["type_one_side"],
data["exclude_types"],
ActivationFn(data["repinit_args"]["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)

self.repinit.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True
19 changes: 14 additions & 5 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def __init__(
raise RuntimeError("Unknown activation function type!")

self.activation_fn = activation_fn
self.davg = self.descrpt.serialize()["@variables"]["davg"]
self.dstd = self.descrpt.serialize()["@variables"]["dstd"]
self.ntypes = self.descrpt.get_ntypes()
serialized = self.descrpt.serialize()
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2):
serialized = serialized["repinit_variable"]
self.davg = serialized["@variables"]["davg"]
self.dstd = serialized["@variables"]["dstd"]
self.embedding_net_nodes = serialized["embeddings"]["networks"]
Comment thread
njzjz marked this conversation as resolved.

self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"]
self.ntypes = self.descrpt.get_ntypes()

self.layer_size = self._get_layer_size()
self.table_size = self._get_table_size()
Expand Down Expand Up @@ -291,7 +294,13 @@ def _layer_1(self, x, w, b):
return t, self.activation_fn(torch.matmul(x, w) + b) + t

def _get_descrpt_type(self):
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1):
if isinstance(
self.descrpt,
(
deepmd.pt.model.descriptor.DescrptDPA1,
deepmd.pt.model.descriptor.DescrptDPA2,
),
):
return "Atten"
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA):
return "A"
Expand Down
149 changes: 149 additions & 0 deletions source/tests/pt/model/test_compressed_descriptor_dpa2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
Any,
)

import numpy as np
import torch

from deepmd.dpmodel.descriptor.dpa2 import (
RepformerArgs,
RepinitArgs,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.pt.model.descriptor.dpa2 import (
DescrptDPA2,
)
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt
from deepmd.pt.utils.nlist import (
extend_coord_with_ghosts as extend_coord_with_ghosts_pt,
)

from ...consistent.common import (
parameterized,
)


def eval_pt_descriptor(
pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False
) -> Any:
ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt(
torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3),
torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1),
torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3),
pt_obj.get_rcut(),
)
nlist = build_neighbor_list_pt(
ext_coords,
ext_atype,
natoms[0],
pt_obj.get_rcut(),
pt_obj.get_sel(),
distinguish_types=(not mixed_types),
)
result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping)
return result


@parameterized(("float32", "float64"), (True, False))
class TestDescriptorDPA2(unittest.TestCase):
def setUp(self):
(self.dtype, self.type_one_side) = self.param
if self.dtype == "float32":
self.skipTest("FP32 has bugs:")
# ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward
# torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
# E RuntimeError: expected scalar type Float but found Double
if self.dtype == "float32":
Comment thread
njzjz marked this conversation as resolved.
self.atol = 1e-5
elif self.dtype == "float64":
self.atol = 1e-10
self.seed = 21
self.sel = [10]
self.rcut_smth = 5.80
self.rcut = 6.00
self.neuron = [6, 12, 24]
self.axis_neuron = 3
self.ntypes = 2
self.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
self.box = np.array(
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)

repinit = RepinitArgs(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
nsel=10,
tebd_input_mode="strip",
type_one_side=self.type_one_side,
)
repformer = RepformerArgs(
rcut=self.rcut - 1,
rcut_smth=self.rcut_smth - 1,
nsel=9,
)

self.descriptor = DescrptDPA2(
ntypes=self.ntypes,
repinit=repinit,
repformer=repformer,
precision=self.dtype,
)

def test_compressed_forward(self):
result_pt = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)
self.descriptor.enable_compression(0.5)
result_pt_compressed = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)

self.assertEqual(result_pt.shape, result_pt_compressed.shape)
torch.testing.assert_close(
result_pt,
result_pt_compressed,
atol=self.atol,
rtol=self.atol,
)


if __name__ == "__main__":
unittest.main()
Comment thread
njzjz marked this conversation as resolved.