diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 70912fb336c..5e941a181e8 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -146,9 +146,7 @@ class PPOTrainer(BaseTrainer): def __init__( self, args: PPOConfig, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], model: nn.Module, ref_model: Optional[nn.Module], reward_model: nn.Module,