Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1a3a0d
feat: add dos training
anyangml Mar 18, 2024
d4a3965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
3e1f3c6
fix: precommit
anyangml Mar 18, 2024
c4769ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
429f03f
feat: add dos stat
anyangml Mar 19, 2024
91d2a8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
04c7477
fix: training test
anyangml Mar 19, 2024
4d95548
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
d71a117
Merge branch 'devel' into feat/dos-train
anyangml Mar 19, 2024
bc54b68
fix: precommit
anyangml Mar 19, 2024
850ea1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
004cdb2
fix: UTs
anyangml Mar 19, 2024
a116235
fix: UTs
anyangml Mar 19, 2024
dbd2d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
2ff0ae6
fix: stat
anyangml Mar 20, 2024
b7b69fd
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
366f6b4
fix: stat
anyangml Mar 20, 2024
915141c
fix: dp test
anyangml Mar 20, 2024
ed65e19
fix: test examples
anyangml Mar 20, 2024
a73d392
fix UTs
anyangml Mar 20, 2024
bf8fac2
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
3b5be19
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
1630800
fix: add to test_examples
Mar 20, 2024
be076af
fix: update loss
Mar 20, 2024
36d1674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
c156b02
fix: add numb_dos to jit model
anyangml Mar 20, 2024
148c196
chore: refactor
anyangml Mar 20, 2024
3ad66fb
fix: UTs
anyangml Mar 20, 2024
b751921
Merge branch 'devel' into feat/dos-train
anyangml Mar 22, 2024
2e14755
fix: loss
anyangml Mar 24, 2024
3c5e6f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_sel_type(self) -> List[int]:

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
return 0
return self.dp.model["Default"].get_numb_dos()

def get_has_efield(self):
"""Check if the model has efield."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .denoise import (
DenoiseLoss,
)
from .dos import (
DOSLoss,
)
from .ener import (
EnergyStdLoss,
)
Expand All @@ -21,4 +24,5 @@
"EnergySpinLoss",
"TensorLoss",
"TaskLoss",
"DOSLoss",
]
256 changes: 256 additions & 0 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class DOSLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float,
numb_dos: int,
start_pref_dos: float = 1.00,
limit_pref_dos: float = 1.00,
start_pref_cdf: float = 1000,
limit_pref_cdf: float = 1.00,
start_pref_ados: float = 0.0,
limit_pref_ados: float = 0.0,
start_pref_acdf: float = 0.0,
limit_pref_acdf: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.numb_dos = numb_dos
self.inference = inference

self.start_pref_dos = start_pref_dos
self.limit_pref_dos = limit_pref_dos
self.start_pref_cdf = start_pref_cdf
self.limit_pref_cdf = limit_pref_cdf

self.start_pref_ados = start_pref_ados
self.limit_pref_ados = limit_pref_ados
self.start_pref_acdf = start_pref_acdf
self.limit_pref_acdf = limit_pref_acdf

assert (
self.start_pref_dos >= 0.0
and self.limit_pref_dos >= 0.0
and self.start_pref_cdf >= 0.0
and self.limit_pref_cdf >= 0.0
and self.start_pref_ados >= 0.0
and self.limit_pref_ados >= 0.0
and self.start_pref_acdf >= 0.0
and self.limit_pref_acdf >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"

self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference
self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference
self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference
self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference

assert (
self.has_dos or self.has_cdf or self.has_ados or self.has_acdf
), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")

def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)

coef = learning_rate / self.starter_learning_rate
pref_dos = (
Comment thread Fixed
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
)
pref_cdf = (
Comment thread Fixed
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
)
pref_ados = (
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
)
pref_acdf = (
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_ados and "atom_dos" in model_pred and "atom_dos" in label:
find_local = label.get("find_atom_dos", 0.0)
pref_ados = pref_ados * find_local
local_tensor_pred_dos = model_pred["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
local_tensor_label_dos = label["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_dos_loss"] = self.display_if_exist(
l2_local_loss_dos.detach(), find_local
)
loss += pref_ados * l2_local_loss_dos
rmse_local_dos = l2_local_loss_dos.sqrt()
more_loss["rmse_local_dos"] = self.display_if_exist(
rmse_local_dos.detach(), find_local
)
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
find_local = label.get("find_atom_dos", 0.0)
pref_acdf = pref_acdf * find_local
local_tensor_pred_cdf = torch.cusum(
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_cdf_loss"] = self.display_if_exist(
l2_local_loss_cdf.detach(), find_local
)
loss += pref_acdf * l2_local_loss_cdf
rmse_local_cdf = l2_local_loss_cdf.sqrt()
more_loss["rmse_local_cdf"] = self.display_if_exist(
rmse_local_cdf.detach(), find_local
)
if self.has_dos and "dos" in model_pred and "dos" in label:
find_global = label.get("find_dos", 0.0)
pref_dos = pref_dos * find_global
global_tensor_pred_dos = model_pred["dos"].reshape([-1, self.numb_dos])
global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos])
diff = global_tensor_pred_dos - global_tensor_label_dos
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_dos = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_dos_loss"] = self.display_if_exist(
l2_global_loss_dos.detach(), find_global
)
loss += pref_dos * l2_global_loss_dos
rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num
more_loss["rmse_global_dos"] = self.display_if_exist(
rmse_global_dos.detach(), find_global
)
if self.has_cdf and "dos" in model_pred and "dos" in label:
find_global = label.get("find_dos", 0.0)
pref_cdf = pref_cdf * find_global
global_tensor_pred_cdf = torch.cusum(
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
)
global_tensor_label_cdf = torch.cusum(
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_cdf = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_cdf_loss"] = self.display_if_exist(
l2_global_loss_cdf.detach(), find_global
)
loss += pref_cdf * l2_global_loss_cdf
rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num
more_loss["rmse_global_cdf"] = self.display_if_exist(
rmse_global_dos.detach(), find_global
)
return model_pred, loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_ados or self.has_acdf:
label_requirement.append(
DataRequirementItem(
"atom_dos",
ndof=self.numb_dos,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_dos or self.has_cdf:
label_requirement.append(
DataRequirementItem(
"dos",
ndof=self.numb_dos,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def get_numb_dos(self) -> int:
"""Get the number of DOS for DOSFittingNet."""
return self.get_fitting_net().dim_out

@torch.jit.export
def forward_lower(
self,
Expand Down
66 changes: 66 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -28,6 +30,13 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -96,6 +105,63 @@ def output_def(self) -> FittingOutputDef:
]
)

def compute_output_stats(
Comment thread
wanghan-iapcm marked this conversation as resolved.
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

if "atom_dos" in sampled[sys]:
bias_dos = compute_stats_from_atomic(
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

bias_dos = compute_stats_from_redu(
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(bias_dos)
self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
Loading