From af9b98028a5a881578156d85e3eff256971838ca Mon Sep 17 00:00:00 2001 From: caic99 Date: Wed, 28 May 2025 07:51:38 +0000 Subject: [PATCH 1/2] perf: skip bincount if unnecessary --- deepmd/pt/model/network/utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py index f9837952fe..9cbfcd500f 100644 --- a/deepmd/pt/model/network/utils.py +++ b/deepmd/pt/model/network/utils.py @@ -30,15 +30,17 @@ def aggregate( ------- output: [num_owner, feature_dim] """ - bin_count = torch.bincount(owners) - bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) - - if (num_owner is not None) and (bin_count.shape[0] != num_owner): - difference = num_owner - bin_count.shape[0] - bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) - - # make sure this operation is done on the same device of data and owners - output = data.new_zeros([bin_count.shape[0], data.shape[1]]) + if num_owner is None or average: + # requires bincount + bin_count = torch.bincount(owners) + bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) + if (num_owner is not None) and (bin_count.shape[0] != num_owner): + difference = num_owner - bin_count.shape[0] + bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) + else: + num_owner = bin_count.shape[0] + + output = data.new_zeros([num_owner, data.shape[1]]) output = output.index_add_(0, owners, data) if average: output = (output.T / bin_count).T From aefd3b560df26189f230f561ce08105a1dbc11fd Mon Sep 17 00:00:00 2001 From: caic99 Date: Wed, 28 May 2025 08:02:41 +0000 Subject: [PATCH 2/2] fix ut --- deepmd/pt/model/network/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py index 9cbfcd500f..c4baa2cd8d 100644 --- a/deepmd/pt/model/network/utils.py +++ b/deepmd/pt/model/network/utils.py @@ -39,10 +39,13 @@ def aggregate( bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) else: num_owner = bin_count.shape[0] + else: + bin_count = None output = data.new_zeros([num_owner, data.shape[1]]) output = output.index_add_(0, owners, data) if average: + assert bin_count is not None output = (output.T / bin_count).T return output