We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c330f2d commit 596d3ebCopy full SHA for 596d3eb
1 file changed
nemo_rl/algorithms/distillation.py
@@ -153,13 +153,7 @@ def setup(
153
assert generation_config is not None, (
154
"A generation config in the PolicyConfig is required for distillation"
155
)
156
- assert ( # [TODO] we may support this for tp in the future
157
- not loss_config.get("zero_outside_topk", False)
158
- or (policy_config["dtensor_cfg"]["tensor_parallel_size"] == 1)
159
- ), (
160
- f"zero_outside_topk=True requires tensor_parallel_size=1, "
161
- f"but got tensor_parallel_size={policy_config['dtensor_cfg']['tensor_parallel_size']}. "
162
- )
+
163
# Set random seed
164
set_seed(distillation_config["seed"])
165
0 commit comments