Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def main_parser() -> argparse.ArgumentParser:
parser_train.add_argument(
"--use-pretrain-script",
action="store_true",
help="Use model parameters from the script of the pretrained model instead of user input when doing finetuning. Note: This behavior is default and unchangeable in TensorFlow.",
help="When performing fine-tuning or init-model, "
"utilize the model parameters provided by the script of the pretrained model rather than relying on user input. "
"It is important to note that in TensorFlow, this behavior is the default and cannot be modified for fine-tuning. ",
)
parser_train.add_argument(
"-o",
Expand Down
15 changes: 15 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,21 @@ def train(FLAGS):
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
)
# update init_model or init_frz_model config if necessary
if (
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
else:
config["model"] = json.loads(
torch.jit.load(
FLAGS.init_frz_model, map_location=DEVICE
).get_model_def_script()
)

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
Expand Down
39 changes: 39 additions & 0 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def train(
is_compress: bool = False,
skip_neighbor_stat: bool = False,
finetune: Optional[str] = None,
use_pretrain_script: bool = False,
**kwargs,
):
"""Run DeePMD model training.
Expand Down Expand Up @@ -93,6 +94,9 @@ def train(
skip checking neighbor statistics
finetune : Optional[str]
path to pretrained model or None
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
**kwargs
additional arguments

Expand Down Expand Up @@ -123,6 +127,41 @@ def train(
jdata, run_opt.finetune
)

if (
run_opt.init_model is not None or run_opt.init_frz_model is not None
) and use_pretrain_script:
from deepmd.tf.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (
get_tensor_by_name,
get_tensor_by_name_from_graph,
)

err_msg = (
f"The input model: {run_opt.init_model if run_opt.init_model is not None else run_opt.init_frz_model} has no training script, "
f"Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit."
)
if run_opt.init_model is not None:
with tf.Graph().as_default() as graph:
tf.train.import_meta_graph(
f"{run_opt.init_model}.meta", clear_devices=True
)
try:
t_training_script = get_tensor_by_name_from_graph(
graph, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e
else:
try:
t_training_script = get_tensor_by_name(
run_opt.init_frz_model, "train_attr/training_script"
)
except GraphWithoutTensorError as e:
raise RuntimeError(err_msg) from e
jdata["model"] = json.loads(t_training_script)["model"]

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
Expand Down
56 changes: 49 additions & 7 deletions source/tests/pt/test_init_frz_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from argparse import (
Namespace,
Expand All @@ -21,12 +24,17 @@
DeepPot,
)

from .common import (
run_dp,
)


class TestInitFrzModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
Expand All @@ -38,15 +46,30 @@ def setUp(self):
]

self.models = []
for imodel in range(2):
if imodel == 1:
config["training"]["numb_steps"] = 0
trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1])
for imodel in range(3):
frozen_model = f"frozen_model{imodel}.pth"
if imodel == 0:
temp_config = deepcopy(config)
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
trainer = get_trainer(temp_config, init_frz_model=self.models[-1])
else:
trainer = get_trainer(deepcopy(config))
trainer.run()
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
run_dp(
f"dp --pt train {tmp_input.name} --init-frz-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
trainer = None

frozen_model = f"frozen_model{imodel}.pth"
if imodel in [0, 1]:
trainer.run()
ns = Namespace(
model="model.pt",
output=frozen_model,
Expand All @@ -58,6 +81,7 @@ def setUp(self):
def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
Expand Down Expand Up @@ -96,8 +120,26 @@ def test_dp_test(self):
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("frozen_model") and f.endswith(".pth"):
os.remove(f)
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)
136 changes: 136 additions & 0 deletions source/tests/pt/test_init_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

import numpy as np

from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.infer.deep_eval import (
DeepPot,
)

from .common import (
run_dp,
)


class TestInitModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["model"]["descriptor"]["smooth_type_embedding"] = True
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
config["training"]["training_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]
config["training"]["validation_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]

self.models = []
for imodel in range(3):
ckpt_model = f"model{imodel}.ckpt"
if imodel == 0:
temp_config = deepcopy(config)
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config)
elif imodel == 1:
temp_config = deepcopy(config)
temp_config["training"]["numb_steps"] = 0
temp_config["training"]["save_ckpt"] = ckpt_model
trainer = get_trainer(temp_config, init_model=self.models[-1])
else:
empty_config = deepcopy(config)
empty_config["model"]["descriptor"] = {}
empty_config["model"]["fitting_net"] = {}
empty_config["training"]["numb_steps"] = 0
empty_config["training"]["save_ckpt"] = ckpt_model
tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
with open(tmp_input.name, "w") as f:
json.dump(empty_config, f, indent=4)
run_dp(
f"dp --pt train {tmp_input.name} --init-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat"
)
trainer = None

if imodel in [0, 1]:
trainer.run()
self.models.append(ckpt_model + ".pt")

def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
dp3 = DeepPot(str(self.models[2]))
cell = np.array(
[
5.122106549439247480e00,
4.016537340154059388e-01,
6.951654033828678081e-01,
4.016537340154059388e-01,
6.112136112297989143e00,
8.178091365465004481e-01,
6.951654033828678081e-01,
8.178091365465004481e-01,
6.159552512682983760e00,
]
).reshape(1, 3, 3)
coord = np.array(
[
2.978060152121375648e00,
3.588469695887098077e00,
2.792459820604495491e00,
3.895592322591093115e00,
2.712091020667753760e00,
1.366836847133650501e00,
9.955616170888935690e-01,
4.121324820711413039e00,
1.817239061889086571e00,
3.553661462345699906e00,
5.313046969500791583e00,
6.635182659098815883e00,
6.088601018589653080e00,
6.575011420004332585e00,
6.825240650611076099e00,
]
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

ret1 = dp1.eval(coord, cell, atype, atomic=True)
e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4]
ret2 = dp2.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4]
ret3 = dp3.eval(coord, cell, atype, atomic=True)
e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4]
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)