diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py index f9837952fe..c4baa2cd8d 100644 --- a/deepmd/pt/model/network/utils.py +++ b/deepmd/pt/model/network/utils.py @@ -30,17 +30,22 @@ def aggregate( ------- output: [num_owner, feature_dim] """ - bin_count = torch.bincount(owners) - bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) - - if (num_owner is not None) and (bin_count.shape[0] != num_owner): - difference = num_owner - bin_count.shape[0] - bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) - - # make sure this operation is done on the same device of data and owners - output = data.new_zeros([bin_count.shape[0], data.shape[1]]) + if num_owner is None or average: + # requires bincount + bin_count = torch.bincount(owners) + bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) + if (num_owner is not None) and (bin_count.shape[0] != num_owner): + difference = num_owner - bin_count.shape[0] + bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) + else: + num_owner = bin_count.shape[0] + else: + bin_count = None + + output = data.new_zeros([num_owner, data.shape[1]]) output = output.index_add_(0, owners, data) if average: + assert bin_count is not None output = (output.T / bin_count).T return output