Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions deepmd_utils/model_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def deserialize(cls, data: dict) -> "NativeLayer":
"""
return cls(
w=data["@variables"]["w"],
b=data["@variables"]["b"],
idt=data.get("idt", None),
b=data["@variables"].get("b", None),
idt=data["@variables"].get("idt", None),
activation_function=data["activation_function"],
resnet=data.get("resnet", False),
)
Expand Down Expand Up @@ -241,7 +241,7 @@ def call(self, x: np.ndarray) -> np.ndarray:
np.ndarray
The output.
"""
if self.w is None or self.b is None or self.activation_function is None:
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
if self.activation_function == "tanh":
fn = np.tanh
Expand All @@ -251,7 +251,12 @@ def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
y = fn(np.matmul(x, self.w) + self.b)
y = (
np.matmul(x, self.w) + self.b
if self.b is not None
else np.matmul(x, self.w)
)
y = fn(y)
if self.idt is not None:
y *= self.idt
if self.resnet and self.w.shape[1] == self.w.shape[0]:
Expand Down
21 changes: 20 additions & 1 deletion source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
import os
import unittest
from copy import (
Expand All @@ -8,16 +9,34 @@
import numpy as np

from deepmd_utils.model_format import (
NativeLayer,
NativeNet,
load_dp_model,
save_dp_model,
)


class TestNativeLayer(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((2, 3), 3.0)
self.b = np.full((3,), 4.0)
self.idt = np.full((3,), 5.0)

def test_serialize_deserize(self):
for ww, bb, idt, activation_function, resnet in itertools.product(
[self.w], [self.b, None], [self.idt, None], ["tanh", "none"], [True, False]
):
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp = np.arange(self.w.shape[0])
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))


class TestNativeNet(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
self.w = np.full((2, 3), 3.0)
self.b = np.full((3,), 4.0)
self.idt = np.full((3,), 5.0)

def test_serialize(self):
network = NativeNet()
Expand Down