We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent af9b980 commit aefd3b5Copy full SHA for aefd3b5
1 file changed
deepmd/pt/model/network/utils.py
@@ -39,10 +39,13 @@ def aggregate(
39
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
40
else:
41
num_owner = bin_count.shape[0]
42
+ else:
43
+ bin_count = None
44
45
output = data.new_zeros([num_owner, data.shape[1]])
46
output = output.index_add_(0, owners, data)
47
if average:
48
+ assert bin_count is not None
49
output = (output.T / bin_count).T
50
return output
51
0 commit comments