Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
f"training in {model_key}",
to_numpy_array(self.training_dataloader[model_key].sampler.weights),
)
if validation_data is not None:
if (
validation_data is not None
and validation_data[model_key] is not None
):
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
Expand Down Expand Up @@ -723,7 +726,7 @@ def log_loss_valid(_task_key="Default"):
)
if input_dict == {}:
# no validation data
return "", None
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
Expand All @@ -744,23 +747,24 @@ def log_loss_valid(_task_key="Default"):
if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
Expand All @@ -783,33 +787,35 @@ def log_loss_valid(_task_key="Default"):
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None and valid_results[_key]:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)

current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
if self.rank == 0:
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
)
)
)

if fout:
if self.lcurve_should_print_header:
Expand Down