Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.restart_training = restart_model is not None
model_params = config["model"]
training_params = config["training"]
optimizer_params = config.get("optimizer", {})
Comment thread
OutisLi marked this conversation as resolved.
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
self.finetune_update_stat = False
Expand Down Expand Up @@ -157,14 +158,17 @@ def __init__(
self.lcurve_should_print_header = True

def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
opt_type = params.get("opt_type", "Adam")
opt_param = {
"kf_blocksize": params.get("kf_blocksize", 5120),
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
}
"""
Extract optimizer parameters.

Note: Default values are already filled by argcheck.normalize()
before this function is called.
"""
opt_type = params.get("type", "Adam")
if opt_type != "Adam":
raise ValueError(f"Not supported optimizer type '{opt_type}'")
Comment thread
OutisLi marked this conversation as resolved.
opt_param = dict(params)
opt_param.pop("type", None)
return opt_type, opt_param

def get_data_loader(
Expand Down Expand Up @@ -256,22 +260,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
return lr_schedule

# Optimizer
if self.multi_task and training_params.get("optim_dict", None) is not None:
self.optim_dict = training_params.get("optim_dict")
missing_keys = [
key for key in self.model_keys if key not in self.optim_dict
]
assert not missing_keys, (
f"These keys are not in optim_dict: {missing_keys}!"
)
self.opt_type = {}
self.opt_param = {}
for model_key in self.model_keys:
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
self.optim_dict[model_key]
)
else:
self.opt_type, self.opt_param = get_opt_param(training_params)
self.opt_type, self.opt_param = get_opt_param(optimizer_params)

# loss_param_tmp for Hessian activation
loss_param_tmp = None
Expand Down Expand Up @@ -677,7 +666,11 @@ def single_model_finetune(
),
)
self.optimizer = paddle.optimizer.Adam(
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
learning_rate=self.scheduler,
parameters=self.wrapper.parameters(),
beta1=float(self.opt_param["adam_beta1"]),
beta2=float(self.opt_param["adam_beta2"]),
weight_decay=float(self.opt_param["weight_decay"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.set_state_dict(optimizer_state_dict)
Expand Down
146 changes: 49 additions & 97 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
self.restart_training = restart_model is not None
model_params = config["model"]
training_params = config["training"]
optimizer_params = config.get("optimizer", {})
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
self.finetune_update_stat = False
Expand Down Expand Up @@ -185,26 +186,17 @@ def __init__(
self.lcurve_should_print_header = True

def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
opt_type = params.get("opt_type", "Adam")
opt_param = {
# LKF parameters
"kf_blocksize": params.get("kf_blocksize", 5120),
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
# Common parameters
"weight_decay": params.get("weight_decay", 0.001),
# Muon/AdaMuon parameters
"momentum": params.get("momentum", 0.95),
"adam_beta1": params.get("adam_beta1", 0.9),
"adam_beta2": params.get("adam_beta2", 0.95),
"lr_adjust": params.get("lr_adjust", 10.0),
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
"muon_2d_only": params.get("muon_2d_only", True),
"min_2d_dim": params.get("min_2d_dim", 1),
"flash_muon": params.get("flash_muon", True),
}
"""
Extract optimizer parameters.

Note: Default values are already filled by argcheck.normalize()
before this function is called.
"""
opt_type = params.get("type", "Adam")
if opt_type not in ("Adam", "AdamW", "LKF", "AdaMuon", "HybridMuon"):
raise ValueError(f"Not supported optimizer type '{opt_type}'")
opt_param = dict(params)
opt_param.pop("type", None)
return opt_type, opt_param
Comment thread
OutisLi marked this conversation as resolved.

def cycle_iterator(iterable: Iterable) -> Generator[Any, None, None]:
Expand Down Expand Up @@ -313,22 +305,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
return lr_schedule

# Optimizer
if self.multi_task and training_params.get("optim_dict", None) is not None:
self.optim_dict = training_params.get("optim_dict")
missing_keys = [
key for key in self.model_keys if key not in self.optim_dict
]
assert not missing_keys, (
f"These keys are not in optim_dict: {missing_keys}!"
)
self.opt_type = {}
self.opt_param = {}
for model_key in self.model_keys:
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
self.optim_dict[model_key]
)
else:
self.opt_type, self.opt_param = get_opt_param(training_params)
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
if self.zero_stage > 0 and self.multi_task:
raise ValueError(
"training.zero_stage is currently only supported in single-task training."
Expand Down Expand Up @@ -792,71 +769,48 @@ def single_model_finetune(
# TODO add optimizers for multitask
# author: iProzd
initial_lr = self.lr_schedule.value(self.start_step)
if self.opt_type in ["Adam", "AdamW"]:
# Initialize optimizer with the actual learning rate at start_step
# to ensure warmup is applied from the first step
if self.opt_type == "Adam":
self.optimizer = self._create_optimizer(
torch.optim.Adam,
lr=initial_lr,
fused=DEVICE.type != "cpu",
)
else:
self.optimizer = self._create_optimizer(
torch.optim.AdamW,
lr=initial_lr,
weight_decay=float(self.opt_param["weight_decay"]),
fused=DEVICE.type != "cpu",
)
self._load_optimizer_state(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: (
self.lr_schedule.value(step + self.start_step) / initial_lr
),
last_epoch=self.start_step - 1,
)
elif self.opt_type == "LKF":
if self.opt_type == "LKF":
self.optimizer = LKFOptimizer(
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
)
elif self.opt_type == "AdaMuon":
self.optimizer = self._create_optimizer(
AdaMuonOptimizer,
lr=initial_lr,
momentum=float(self.opt_param["momentum"]),
weight_decay=float(self.opt_param["weight_decay"]),
adam_betas=(
float(self.opt_param["adam_beta1"]),
float(self.opt_param["adam_beta2"]),
),
lr_adjust=float(self.opt_param["lr_adjust"]),
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: (
self.lr_schedule.value(step + self.start_step) / initial_lr
),
last_epoch=self.start_step - 1,
else:
# === Common path for gradient-based optimizers ===
adam_betas = (
float(self.opt_param["adam_beta1"]),
float(self.opt_param["adam_beta2"]),
)
elif self.opt_type == "HybridMuon":
weight_decay = float(self.opt_param["weight_decay"])

if self.opt_type in ("Adam", "AdamW"):
cls = torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
extra = {"betas": adam_betas, "fused": DEVICE.type != "cpu"}
elif self.opt_type == "AdaMuon":
cls = AdaMuonOptimizer
extra = {
"adam_betas": adam_betas,
"momentum": float(self.opt_param["momentum"]),
"lr_adjust": float(self.opt_param["lr_adjust"]),
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
}
elif self.opt_type == "HybridMuon":
cls = HybridMuonOptimizer
extra = {
"adam_betas": adam_betas,
"momentum": float(self.opt_param["momentum"]),
"lr_adjust": float(self.opt_param["lr_adjust"]),
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
"muon_2d_only": bool(self.opt_param["muon_2d_only"]),
"min_2d_dim": int(self.opt_param["min_2d_dim"]),
"flash_muon": bool(self.opt_param["flash_muon"]),
}
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

self.optimizer = self._create_optimizer(
HybridMuonOptimizer,
cls,
lr=initial_lr,
momentum=float(self.opt_param["momentum"]),
weight_decay=float(self.opt_param["weight_decay"]),
adam_betas=(
float(self.opt_param["adam_beta1"]),
float(self.opt_param["adam_beta2"]),
),
lr_adjust=float(self.opt_param["lr_adjust"]),
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
min_2d_dim=int(self.opt_param["min_2d_dim"]),
flash_muon=bool(self.opt_param["flash_muon"]),
weight_decay=weight_decay,
**extra,
)
self._load_optimizer_state(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
Expand All @@ -866,8 +820,6 @@ def single_model_finetune(
),
last_epoch=self.start_step - 1,
)
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

if self.zero_stage > 0 and self.rank == 0:
if self.zero_stage == 1:
Expand Down
1 change: 0 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def train(
jdata["model"] = json.loads(t_training_script)["model"]

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)

if not is_compress and not skip_neighbor_stat:
Expand Down
36 changes: 33 additions & 3 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,22 @@ def get_lr_and_coef(
# learning rate
lr_param = jdata["learning_rate"]
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
# optimizer
# Note: Default values are already filled by argcheck.normalize()
optimizer_param = jdata.get("optimizer", {})
self.optimizer_type = optimizer_param.get("type", "Adam")
self.optimizer_beta1 = float(optimizer_param.get("adam_beta1"))
self.optimizer_beta2 = float(optimizer_param.get("adam_beta2"))
self.optimizer_weight_decay = float(optimizer_param.get("weight_decay"))
Comment thread
OutisLi marked this conversation as resolved.
if self.optimizer_type != "Adam":
raise RuntimeError(
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
)
if self.optimizer_weight_decay != 0.0:
raise RuntimeError(
"TensorFlow Adam optimizer does not support weight_decay. "
"Set optimizer/weight_decay to 0."
)
# loss
# infer loss type by fitting_type
loss_param = jdata.get("loss", {})
Expand Down Expand Up @@ -328,17 +344,31 @@ def _build_network(self, data: DeepmdDataSystem, suffix: str = "") -> None:
log.info("built network")

def _build_optimizer(self) -> Any:
if self.optimizer_type != "Adam":
raise RuntimeError(
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
)
if self.run_opt.is_distrib:
if self.scale_lr_coef > 1.0:
log.info("Scale learning rate by coef: %f", self.scale_lr_coef)
optimizer = tf.train.AdamOptimizer(
self.learning_rate * self.scale_lr_coef
self.learning_rate * self.scale_lr_coef,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)
else:
optimizer = tf.train.AdamOptimizer(self.learning_rate)
optimizer = tf.train.AdamOptimizer(
self.learning_rate,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
optimizer = tf.train.AdamOptimizer(
learning_rate=self.learning_rate,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)

if self.mixed_prec is not None:
_TF_VERSION = Version(TF_VERSION)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from deepmd.utils.compat import (
convert_input_v0_v1,
convert_input_v1_v2,
convert_optimizer_v31_to_v32,
deprecate_numb_test,
update_deepmd_input,
)

__all__ = [
"convert_input_v0_v1",
"convert_input_v1_v2",
"convert_optimizer_v31_to_v32",
"deprecate_numb_test",
"update_deepmd_input",
]
Loading