diff --git a/deepmd/main.py b/deepmd/main.py index 4560df9e57..b43f4f8fd5 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -258,7 +258,9 @@ def main_parser() -> argparse.ArgumentParser: parser_train.add_argument( "--use-pretrain-script", action="store_true", - help="Use model parameters from the script of the pretrained model instead of user input when doing finetuning. Note: This behavior is default and unchangeable in TensorFlow.", + help="When performing fine-tuning or init-model, " + "utilize the model parameters provided by the script of the pretrained model rather than relying on user input. " + "It is important to note that in TensorFlow, this behavior is the default and cannot be modified for fine-tuning. ", ) parser_train.add_argument( "-o", diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 1342f928c5..f125c7f1ad 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -243,6 +243,21 @@ def train(FLAGS): model_branch=FLAGS.model_branch, change_model_params=FLAGS.use_pretrain_script, ) + # update init_model or init_frz_model config if necessary + if ( + FLAGS.init_model is not None or FLAGS.init_frz_model is not None + ) and FLAGS.use_pretrain_script: + if FLAGS.init_model is not None: + init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE) + if "model" in init_state_dict: + init_state_dict = init_state_dict["model"] + config["model"] = init_state_dict["_extra_state"]["model_params"] + else: + config["model"] = json.loads( + torch.jit.load( + FLAGS.init_frz_model, map_location=DEVICE + ).get_model_def_script() + ) # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") diff --git a/deepmd/tf/entrypoints/train.py b/deepmd/tf/entrypoints/train.py index d394773cf2..12a3c59d70 100755 --- a/deepmd/tf/entrypoints/train.py +++ b/deepmd/tf/entrypoints/train.py @@ -65,6 +65,7 @@ def train( is_compress: bool = False, skip_neighbor_stat: bool = False, finetune: Optional[str] = None, + use_pretrain_script: bool = False, **kwargs, ): """Run DeePMD model training. @@ -93,6 +94,9 @@ def train( skip checking neighbor statistics finetune : Optional[str] path to pretrained model or None + use_pretrain_script : bool + Whether to use model script in pretrained model when doing init-model or init-frz-model. + Note that this option is true and unchangeable for fine-tuning. **kwargs additional arguments @@ -123,6 +127,41 @@ def train( jdata, run_opt.finetune ) + if ( + run_opt.init_model is not None or run_opt.init_frz_model is not None + ) and use_pretrain_script: + from deepmd.tf.utils.errors import ( + GraphWithoutTensorError, + ) + from deepmd.tf.utils.graph import ( + get_tensor_by_name, + get_tensor_by_name_from_graph, + ) + + err_msg = ( + f"The input model: {run_opt.init_model if run_opt.init_model is not None else run_opt.init_frz_model} has no training script, " + f"Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit." + ) + if run_opt.init_model is not None: + with tf.Graph().as_default() as graph: + tf.train.import_meta_graph( + f"{run_opt.init_model}.meta", clear_devices=True + ) + try: + t_training_script = get_tensor_by_name_from_graph( + graph, "train_attr/training_script" + ) + except GraphWithoutTensorError as e: + raise RuntimeError(err_msg) from e + else: + try: + t_training_script = get_tensor_by_name( + run_opt.init_frz_model, "train_attr/training_script" + ) + except GraphWithoutTensorError as e: + raise RuntimeError(err_msg) from e + jdata["model"] = json.loads(t_training_script)["model"] + jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") jdata = normalize(jdata) diff --git a/source/tests/pt/test_init_frz_model.py b/source/tests/pt/test_init_frz_model.py index 223b28515d..1cbc1b29b6 100644 --- a/source/tests/pt/test_init_frz_model.py +++ b/source/tests/pt/test_init_frz_model.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +import os +import shutil +import tempfile import unittest from argparse import ( Namespace, @@ -21,12 +24,17 @@ DeepPot, ) +from .common import ( + run_dp, +) + class TestInitFrzModel(unittest.TestCase): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: config = json.load(f) + config["model"]["descriptor"]["smooth_type_embedding"] = True config["training"]["numb_steps"] = 1 config["training"]["save_freq"] = 1 config["learning_rate"]["start_lr"] = 1.0 @@ -38,15 +46,30 @@ def setUp(self): ] self.models = [] - for imodel in range(2): - if imodel == 1: - config["training"]["numb_steps"] = 0 - trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1]) + for imodel in range(3): + frozen_model = f"frozen_model{imodel}.pth" + if imodel == 0: + temp_config = deepcopy(config) + trainer = get_trainer(temp_config) + elif imodel == 1: + temp_config = deepcopy(config) + temp_config["training"]["numb_steps"] = 0 + trainer = get_trainer(temp_config, init_frz_model=self.models[-1]) else: - trainer = get_trainer(deepcopy(config)) - trainer.run() + empty_config = deepcopy(config) + empty_config["model"]["descriptor"] = {} + empty_config["model"]["fitting_net"] = {} + empty_config["training"]["numb_steps"] = 0 + tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json") + with open(tmp_input.name, "w") as f: + json.dump(empty_config, f, indent=4) + run_dp( + f"dp --pt train {tmp_input.name} --init-frz-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat" + ) + trainer = None - frozen_model = f"frozen_model{imodel}.pth" + if imodel in [0, 1]: + trainer.run() ns = Namespace( model="model.pt", output=frozen_model, @@ -58,6 +81,7 @@ def setUp(self): def test_dp_test(self): dp1 = DeepPot(str(self.models[0])) dp2 = DeepPot(str(self.models[1])) + dp3 = DeepPot(str(self.models[2])) cell = np.array( [ 5.122106549439247480e00, @@ -96,8 +120,26 @@ def test_dp_test(self): e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4] ret2 = dp2.eval(coord, cell, atype, atomic=True) e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4] + ret3 = dp3.eval(coord, cell, atype, atomic=True) + e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4] np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("frozen_model") and f.endswith(".pth"): + os.remove(f) + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) diff --git a/source/tests/pt/test_init_model.py b/source/tests/pt/test_init_model.py new file mode 100644 index 0000000000..dd264fbe89 --- /dev/null +++ b/source/tests/pt/test_init_model.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.infer.deep_eval import ( + DeepPot, +) + +from .common import ( + run_dp, +) + + +class TestInitModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + config = json.load(f) + config["model"]["descriptor"]["smooth_type_embedding"] = True + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + + self.models = [] + for imodel in range(3): + ckpt_model = f"model{imodel}.ckpt" + if imodel == 0: + temp_config = deepcopy(config) + temp_config["training"]["save_ckpt"] = ckpt_model + trainer = get_trainer(temp_config) + elif imodel == 1: + temp_config = deepcopy(config) + temp_config["training"]["numb_steps"] = 0 + temp_config["training"]["save_ckpt"] = ckpt_model + trainer = get_trainer(temp_config, init_model=self.models[-1]) + else: + empty_config = deepcopy(config) + empty_config["model"]["descriptor"] = {} + empty_config["model"]["fitting_net"] = {} + empty_config["training"]["numb_steps"] = 0 + empty_config["training"]["save_ckpt"] = ckpt_model + tmp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".json") + with open(tmp_input.name, "w") as f: + json.dump(empty_config, f, indent=4) + run_dp( + f"dp --pt train {tmp_input.name} --init-model {self.models[-1]} --use-pretrain-script --skip-neighbor-stat" + ) + trainer = None + + if imodel in [0, 1]: + trainer.run() + self.models.append(ckpt_model + ".pt") + + def test_dp_test(self): + dp1 = DeepPot(str(self.models[0])) + dp2 = DeepPot(str(self.models[1])) + dp3 = DeepPot(str(self.models[2])) + cell = np.array( + [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + ).reshape(1, 3, 3) + coord = np.array( + [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + ).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + ret1 = dp1.eval(coord, cell, atype, atomic=True) + e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4] + ret2 = dp2.eval(coord, cell, atype, atomic=True) + e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4] + ret3 = dp3.eval(coord, cell, atype, atomic=True) + e3, f3, v3, ae3, av3 = ret3[0], ret3[1], ret3[2], ret3[3], ret3[4] + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(e1, e3, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(f1, f3, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(v1, v3, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(ae1, ae3, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(av1, av3, rtol=1e-10, atol=1e-10) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f)