@@ -118,15 +118,17 @@ struct random_indices {
118118
119119// Weighted sample without replacement.
120120// Use perturbed Gumbel variates as keys.
121+ template <typename IType>
121122struct generate_keys {
122- MSHADOW_XINLINE static void Map (index_t i, float *uniforms, float *weights) {
123+ MSHADOW_XINLINE static void Map (index_t i, float *uniforms, IType *weights) {
123124 uniforms[i] = -logf (-logf (uniforms[i])) + logf (weights[i]);
124125 }
125126};
126127
127128// Weighted sample with replacement.
129+ template <typename IType>
128130struct categorical_sampling {
129- MSHADOW_XINLINE static void Map (index_t i, float *weights, size_t length,
131+ MSHADOW_XINLINE static void Map (index_t i, IType *weights, size_t length,
130132 float *uniforms, int64_t *outs) {
131133 outs[i] = 0 ;
132134 float acc = 0.0 ;
@@ -179,15 +181,19 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
179181 prnd->SampleUniform (&random_numbers, 0 , 1 );
180182 workspace_ptr += ((random_tensor_size * sizeof (float ) / 7 + 1 ) * 8 );
181183 if (replace) {
182- Kernel<categorical_sampling, xpu>::Launch (
183- s, output_size, inputs[weight_index].dptr <float >(), input_size,
184- random_numbers.dptr_ , outputs[0 ].dptr <int64_t >());
184+ MSHADOW_REAL_TYPE_SWITCH (inputs[weight_index].type_flag_ , IType, {
185+ Kernel<categorical_sampling<IType>, xpu>::Launch (
186+ s, output_size, inputs[weight_index].dptr <IType>(), input_size,
187+ random_numbers.dptr_ , outputs[0 ].dptr <int64_t >());
188+ });
185189 } else {
186190 Tensor<xpu, 1 , int64_t > indices = Tensor<xpu, 1 , int64_t >(
187191 reinterpret_cast <int64_t *>(workspace_ptr), Shape1 (indices_size), s);
188192 indices = expr::range ((int64_t )0 , input_size);
189- Kernel<generate_keys, xpu>::Launch (s, input_size, random_numbers.dptr_ ,
190- inputs[weight_index].dptr <float >());
193+ MSHADOW_REAL_TYPE_SWITCH (inputs[weight_index].type_flag_ , IType, {
194+ Kernel<generate_keys<IType>, xpu>::Launch (s, input_size, random_numbers.dptr_ ,
195+ inputs[weight_index].dptr <IType>());
196+ });
191197 _sort<xpu>(random_numbers.dptr_ , indices.dptr_ , input_size);
192198 Copy (outputs[0 ].FlatTo1D <xpu, int64_t >(s), indices.Slice (0 , output_size), s);
193199 }
0 commit comments