diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index feaded0b01..100a0c13b6 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -785,7 +785,115 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": return EN -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +class EmbeddingNet(NativeNet): + """The embedding network. + + Parameters + ---------- + in_dim + Input dimension. + neuron + The number of neurons in each layer. The output dimension + is the same as the dimension of the last layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + seed : int, optional + Random seed. + bias : bool, Optional + Whether to use bias in the embedding layer. + trainable : bool or list[bool], Optional + Whether the weights are trainable. If a list, each element + corresponds to a layer. + """ + + def __init__( + self, + in_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + seed: int | list[int] | None = None, + bias: bool = True, + trainable: bool | list[bool] = True, + ) -> None: + layers = [] + i_in = in_dim + if isinstance(trainable, bool): + trainable = [trainable] * len(neuron) + for idx, ii in enumerate(neuron): + i_ot = ii + layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias, + use_timestep=resnet_dt, + activation_function=activation_function, + resnet=True, + precision=precision, + seed=child_seed(seed, idx), + trainable=trainable[idx], + ).serialize() + ) + i_in = i_ot + super().__init__(layers) + self.in_dim = in_dim + self.neuron = neuron + self.activation_function = activation_function + self.resnet_dt = resnet_dt + self.precision = precision + self.bias = bias + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "EmbeddingNetwork", + "@version": 2, + "in_dim": self.in_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "bias": self.bias, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbeddingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Reinitialize layers from serialized data, using the same layer type + # that __init__ created (respects subclass overrides via MRO). + if obj.layers: + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + else: + obj.layers = type(obj.layers)([]) + return obj def make_fitting_network( diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 9f453c895c..226948adba 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -34,7 +34,7 @@ # only linux ncpus = len(os.sched_getaffinity(0)) except AttributeError: - ncpus = os.cpu_count() + ncpus = os.cpu_count() or 1 NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) if multiprocessing.get_start_method() != "fork": # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index 56cec25d49..0f4d38ba84 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -34,7 +34,7 @@ # only linux ncpus = len(os.sched_getaffinity(0)) except AttributeError: - ncpus = os.cpu_count() + ncpus = os.cpu_count() or 1 NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) if multiprocessing.get_start_method() != "fork": # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 84d0024a85..b115214056 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -10,11 +10,11 @@ from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_embedding_network, make_fitting_network, make_multilayer_network, ) @@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): - pass +class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EmbeddingNetDP.__init__(self, *args, **kwargs) + # EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances. + # Convert to pt_expt NativeLayer and wrap in ModuleList. + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + EmbeddingNetDP, + lambda v: EmbeddingNet.deserialize(v.serialize()), +) class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 1ea5b1fdf9..a63d4f356a 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -180,6 +180,114 @@ def test_embedding_net(self) -> None: inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify EmbeddingNet is a concrete class, not factory-generated.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Check it's the actual EmbeddingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "EmbeddingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(len(net.layers), len(neuron)) + + def test_forward_pass(self) -> None: + """Test EmbeddingNet forward pass produces correct shapes.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + rng = np.random.default_rng() + x = rng.standard_normal((5, in_dim)) + out = net.call(x) + self.assertEqual(out.shape, (5, neuron[-1])) + self.assertEqual(out.dtype, np.float64) + + def test_trainable_parameter_variants(self) -> None: + """Test EmbeddingNet with different trainable configurations.""" + in_dim = 4 + neuron = [8, 16] + + # All trainable + net_trainable = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_trainable.layers: + self.assertTrue(layer.trainable) + + # All frozen + net_frozen = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_frozen.layers: + self.assertFalse(layer.trainable) + + # Mixed trainable + net_mixed = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=[True, False], + ) + self.assertTrue(net_mixed.layers[0].trainable) + self.assertFalse(net_mixed.layers[1].trainable) + + def test_empty_layers_round_trip(self) -> None: + """Test EmbeddingNet with empty neuron list (edge case for deserialize). + + This tests the fix for IndexError when neuron=[] results in empty layers. + The deserialize method should handle this case without trying to access + layers[0] when the list is empty. + """ + in_dim = 4 + neuron = [] # Empty neuron list + + # Create network with empty layers + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + + # Verify it has no layers + self.assertEqual(len(net.layers), 0) + + # Serialize and deserialize + serialized = net.serialize() + net_restored = EmbeddingNet.deserialize(serialized) + + # Verify restored network also has no layers + self.assertEqual(len(net_restored.layers), 0) + self.assertEqual(net_restored.in_dim, in_dim) + self.assertEqual(net_restored.neuron, neuron) + + # Verify forward pass works (should return input unchanged) + rng = np.random.default_rng() + x = rng.standard_normal((5, in_dim)) + out = net_restored.call(x) + # With no layers, output should equal input + np.testing.assert_allclose(out, x) + class TestFittingNet(unittest.TestCase): def test_fitting_net(self) -> None: diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index ad7c2a7e3d..24d61c5fd5 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -1,9 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +import numpy as np +import torch + +from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet +from deepmd.pt_expt.utils import ( + env, +) from deepmd.pt_expt.utils.network import ( + EmbeddingNet, NativeLayer, ) +from ...seed import ( + GLOBAL_SEED, +) + def test_native_layer_clears_parameter_on_none() -> None: layer = NativeLayer(2, 3, trainable=True) @@ -19,3 +32,252 @@ def test_native_layer_clears_buffer_on_none() -> None: layer.w = None assert layer.w is None assert layer._buffers.get("w") is None + + +class TestEmbeddingNetRefactor(unittest.TestCase): + """Tests for the refactored EmbeddingNet pt_expt wrapper and integration.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.neuron = [8, 16, 32] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_embedding_net_wraps_dpmodel(self) -> None: + """Verify pt_expt EmbeddingNet correctly wraps dpmodel.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check it's also a DPEmbeddingNet + self.assertIsInstance(net, DPEmbeddingNet) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_embedding_net_forward(self) -> None: + """Test pt_expt EmbeddingNet forward pass returns torch.Tensor.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.neuron[-1])) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt EmbeddingNet serialization/deserialization.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, NativeLayer) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_deserialize_preserves_layer_type(self) -> None: + """Test that deserialize uses type(obj.layers[0]) to preserve subclass layers. + + This is the key fix: dpmodel's deserialize no longer hardcodes + super(EmbeddingNet, obj).__init__(layers), which would overwrite + pt_expt's converted layers. Instead it uses type(obj.layers[0]) + to respect the subclass's layer type. + """ + # Create pt_expt EmbeddingNet + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Verify layers are pt_expt NativeLayer (torch modules) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + + # Deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify deserialized layers are STILL pt_expt NativeLayer, not dpmodel + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + # This would fail if deserialize used hardcoded dpmodel layers + self.assertIsInstance(layer, NativeLayer) + + def test_cross_backend_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt EmbeddingNet.""" + # Create both with same seed + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + pt_net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Test forward pass + rng = np.random.default_rng() + x_np = rng.standard_normal((5, self.in_dim)) + x_torch = torch.from_numpy(x_np).to(env.DEVICE) + + out_dp = dp_net.call(x_np) + out_pt = pt_net(x_torch).detach().cpu().numpy() + + np.testing.assert_allclose(out_dp, out_pt, rtol=1e-10, atol=1e-10) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that the registry auto-converts dpmodel EmbeddingNet to pt_expt.""" + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt EmbeddingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, EmbeddingNet) + + # Verify layers are pt_expt NativeLayer + for layer in converted.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_auto_conversion_in_setattr(self) -> None: + """Test that dpmodel_setattr auto-converts EmbeddingNet attributes.""" + from deepmd.pt_expt.common import ( + dpmodel_setattr, + ) + + # Create a simple torch module + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dummy = None + + obj = TestModule() + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Use dpmodel_setattr to set it + handled, value = dpmodel_setattr(obj, "embedding_net", dp_net) + + # Should not be handled (returns converted value for caller to set) + self.assertFalse(handled) + # Value should be converted to pt_expt EmbeddingNet + self.assertIsInstance(value, torch.nn.Module) + self.assertIsInstance(value, EmbeddingNet) + + def test_trainable_parameter_handling(self) -> None: + """Test that trainable parameters work correctly in pt_expt.""" + # Test with trainable=True + net_trainable = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=True, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters + param_count = sum( + p.numel() for p in net_trainable.parameters() if p.requires_grad + ) + self.assertGreater(param_count, 0) + + # Check all layer parameters are trainable + for layer in net_trainable.layers: + if layer.w is not None: + self.assertTrue(layer.w.requires_grad) + if layer.b is not None: + self.assertTrue(layer.b.requires_grad) + + # Test with trainable=False + net_frozen = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=False, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters (should be 0) + param_count_frozen = sum( + p.numel() for p in net_frozen.parameters() if p.requires_grad + ) + self.assertEqual(param_count_frozen, 0) + + # Check all layer weights are buffers, not parameters + for layer in net_frozen.layers: + if layer.w is not None: + self.assertFalse(layer.w.requires_grad)