Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 22e5ae3

Browse files
xidulumarcoabreu
authored andcommitted
add type switch to weight tensor (#16543)
1 parent d12e674 commit 22e5ae3

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/operator/numpy/random/np_choice_op.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,17 @@ struct random_indices {
118118

119119
// Weighted sample without replacement.
120120
// Use perturbed Gumbel variates as keys.
121+
template <typename IType>
121122
struct generate_keys {
122-
MSHADOW_XINLINE static void Map(index_t i, float *uniforms, float *weights) {
123+
MSHADOW_XINLINE static void Map(index_t i, float *uniforms, IType *weights) {
123124
uniforms[i] = -logf(-logf(uniforms[i])) + logf(weights[i]);
124125
}
125126
};
126127

127128
// Weighted sample with replacement.
129+
template <typename IType>
128130
struct categorical_sampling {
129-
MSHADOW_XINLINE static void Map(index_t i, float *weights, size_t length,
131+
MSHADOW_XINLINE static void Map(index_t i, IType *weights, size_t length,
130132
float *uniforms, int64_t *outs) {
131133
outs[i] = 0;
132134
float acc = 0.0;
@@ -179,15 +181,19 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
179181
prnd->SampleUniform(&random_numbers, 0, 1);
180182
workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1) * 8);
181183
if (replace) {
182-
Kernel<categorical_sampling, xpu>::Launch(
183-
s, output_size, inputs[weight_index].dptr<float>(), input_size,
184-
random_numbers.dptr_, outputs[0].dptr<int64_t>());
184+
MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, {
185+
Kernel<categorical_sampling<IType>, xpu>::Launch(
186+
s, output_size, inputs[weight_index].dptr<IType>(), input_size,
187+
random_numbers.dptr_, outputs[0].dptr<int64_t>());
188+
});
185189
} else {
186190
Tensor<xpu, 1, int64_t> indices = Tensor<xpu, 1, int64_t>(
187191
reinterpret_cast<int64_t *>(workspace_ptr), Shape1(indices_size), s);
188192
indices = expr::range((int64_t)0, input_size);
189-
Kernel<generate_keys, xpu>::Launch(s, input_size, random_numbers.dptr_,
190-
inputs[weight_index].dptr<float>());
193+
MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, {
194+
Kernel<generate_keys<IType>, xpu>::Launch(s, input_size, random_numbers.dptr_,
195+
inputs[weight_index].dptr<IType>());
196+
});
191197
_sort<xpu>(random_numbers.dptr_, indices.dptr_, input_size);
192198
Copy(outputs[0].FlatTo1D<xpu, int64_t>(s), indices.Slice(0, output_size), s);
193199
}

tests/python/unittest/test_numpy_op.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,16 +2490,17 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None):
24902490
# test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5, weight)
24912491

24922492
# Test hypridize mode:
2493-
for hybridize in [True, False]:
2494-
for replace in [True, False]:
2495-
test_choice = TestUniformChoice(num_classes // 2, replace)
2496-
test_choice_weighted = TestWeightedChoice(num_classes // 2, replace)
2497-
if hybridize:
2498-
test_choice.hybridize()
2499-
test_choice_weighted.hybridize()
2500-
weight = np.array(_np.random.dirichlet([1.0] * num_classes))
2501-
test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None)
2502-
test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight)
2493+
for wtype in ['float16', 'float32', 'float64']:
2494+
for hybridize in [True, False]:
2495+
for replace in [True, False]:
2496+
test_choice = TestUniformChoice(num_classes // 2, replace)
2497+
test_choice_weighted = TestWeightedChoice(num_classes // 2, replace)
2498+
if hybridize:
2499+
test_choice.hybridize()
2500+
test_choice_weighted.hybridize()
2501+
weight = np.array(_np.random.dirichlet([1.0] * num_classes)).astype(wtype)
2502+
test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None)
2503+
test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight)
25032504

25042505

25052506
@with_seed()

0 commit comments

Comments
 (0)