Skip to content

what should be the convention of variable dtype in deepmd-kit modules. #4234

@wanghan-iapcm

Description

@wanghan-iapcm

Summary

how to define variable dtype should be made more clear for developers. see details.

DeePMD-kit Version

devel-c2944eb

Backend and its version

pytorch, python

Python Version, CUDA Version, GCC Version, LAMMPS Version, etc

No response

Details

There are at least two cases:

  1. model attributes: in this case it seems that the dtype of the variable is defined by the precision parameter of the constructor, but I find some exceptions like
    self.scale = torch.tensor(
    self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
    ).view(ntypes, 1)
    self.shift_diag = shift_diag
    self.constant_matrix = torch.zeros(
    ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
    )

    where the dtype is defined globally.
  2. model output: in this case it seems that the dtype should be either
    a. consistent with the attributes' dtype or
    b. consistent with the input dtype.
  3. but I find exception like
    outs = torch.zeros(
    (nf, nloc, net_dim_out),
    dtype=env.GLOBAL_PT_FLOAT_PRECISION,
    device=descriptor.device,
    ) # jit assertion

I would suggest we made a decision for the two cases, and all developers should follow the same convention.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions