diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test.sh index c33fd38d135..9422c508872 100755 --- a/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test.sh +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test.sh @@ -24,21 +24,21 @@ test_files="['$aime2024_test_path', '$aime2025_test_path']" # resume config export resume_mode=${resume_mode:-auto} export resume_from_path=${resume_from_path:-null} -export model_path=${model_path:-/afs/chatrl/public/models/Qwen2.5-32B} +export model_path=/afs/chatrl/public/models/DeepSeek-R1-Distill-Qwen-7B export model_name=$(basename "$model_path") # project config -export project_name=${project_name:-verl_dapo_req_sched_v0626} +export project_name=${project_name:-verl_qwen_7b_dapo_req_sched_v0626} # train params export total_epochs=${total_epochs:-50} -export vllm_tp=${vllm_tp:-4} +export vllm_tp=${vllm_tp:-2} -export train_prompt_batch_size=${train_prompt_batch_size:-512} +export train_prompt_batch_size=${train_prompt_batch_size:-32} export grpo_rollout_n=${grpo_rollout_n:-16} # model params export max_response_length=${max_response_length:-20000} -export prompt_key=${prompt_key:-prompt} +export prompt_key=${prompt_key:-messages} export resume_type=${resume_type:-no_resume} # env config export nnode=${WORLD_SIZE:-1} @@ -69,7 +69,7 @@ infer_micro_batch_size=null max_prompt_length=$((1024 * 2)) -enable_overlong_buffer=True +enable_overlong_buffer=False overlong_buffer_len=$((1024 * 4)) overlong_penalty_factor=1.0 @@ -117,7 +117,7 @@ echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = sleep 1 export base_model_suffix=${base_model_suffix:-Base} -export experiment_name=${model_name}-${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_all_dapo_trick_${resume_type}_${TIMESTAMP} +export experiment_name=${model_name}-${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_all_dapo_trick_${resume_type} rm -rf /workspace/tmp_tensorboard/* export TENSORBOARD_DIR=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0806.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0806.sh new file mode 100755 index 00000000000..7f36b9706b2 --- /dev/null +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0806.sh @@ -0,0 +1,234 @@ +pip install langdetect +pip install math-verify sympy +set -x + +# export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/DAPO-Math-17k/data/dapo-math-17k_dedup.parquet} +# export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/kzl/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8.parquet} +# export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_dapo-math-17k_by_acc_0.2_0.7.parquet} +# export deepmath_train_path=${deepmath_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_deepmath_by_acc_0.2_0.7.parquet} +# export math7d5k_train_path=${math7d5k_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_math_train_by_acc_0_0.7.parquet} + +export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8_no_prompt.parquet} +export aime2025_test_path=${aime2025_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2025/dapo_aime2025_sample8_no_prompt.parquet} + +# export aime2024_test_path_from_lyy=${aime2024_test_path_from_lyy:-/afs/chatrl/users/kzl/data/eval/aime2024_dapo_sample32_new.parquet} +# export aime2025_test_path_from_lyy=${aime2025_test_path_from_lyy:-/afs/chatrl/users/lyy/data/eval/aime2025_dapo_sample32.parquet} + +export dapo_train_path=/afs/chatrl/users/hxh/data/math_data/dapo-math/rule_based_rl/dapo-math-17k_dedup_no_prompt_sft_0614_acc_0d1-0d7.parquet +export math7d5k_train_path=/afs/chatrl/users/hxh/data/math_data/MATH_train/rule_based_rl/train_7d5k_math_verify_sft_0614_acc_0-0d7.parquet + +export math_zh=/afs/chatrl/users/kzl/code/awesome_scripts/math_zh_test/processed_data/0730-250627-0710_science_label_question_combined-12121_single_no_repeat_30.parquet + +export math_dapo_zh=/afs/chatrl/users/kzl/data/math_data/dapo-math/prompts/dapo-math-17k_zh.parquet +export math_dapo_en=/afs/chatrl/users/kzl/data/math_data/dapo-math/prompts/dapo-math-17k_en.parquet +# export train_files="['$math7d5k_train_path', '$dapo_train_path']" + +# export train_files="['$math7d5k_train_path', '$dapo_train_path']" +export train_files="['$math_dapo_zh','$math_dapo_en']" + +# train_files="['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']" +# export train_files=${train_files:-"['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']"} + +# test_files="['$aime2024_test_path_from_lyy', '$aime2025_test_path_from_lyy']" +test_files="['$aime2024_test_path', '$aime2025_test_path']" + +# resume config +export resume_mode=${resume_mode:-auto} +export resume_from_path=${resume_from_path:-null} +export model_path=/afs/chatrl/public/models/DeepSeek-R1-Distill-Qwen-7B +export model_name=$(basename "$model_path") + + +# project config +export project_name=${project_name:-verl_qwen_7b_dapo_req_sched_v0626} +# train params +export total_epochs=${total_epochs:-50} +export vllm_tp=${vllm_tp:-2} + +export train_prompt_batch_size=${train_prompt_batch_size:-32} +export grpo_rollout_n=${grpo_rollout_n:-16} +# model params +export max_response_length=${max_response_length:-20000} +export prompt_key=${prompt_key:-messages} +export resume_type=${resume_type:-no_resume} +# env config +export nnode=${WORLD_SIZE:-1} + +export ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size:-1} + +export filter_score_high=${filter_score_high:-1.1} +export filter_score_low=${filter_score_low:--1} +# export filter_score_high=${filter_score_high:-0.7} +# export filter_score_low=${filter_score_low:-0.2} + +export save_freq=${save_freq:-20} +export test_freq=${test_freq:-20} + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 + + +use_dynamic_bsz=True +infer_micro_batch_size=null + +max_prompt_length=$((1024 * 2)) + +export val_before_train=${val_before_train:-True} +export trust_remote_code=${trust_remote_code:-False} + +export enable_overlong_buffer=${enable_overlong_buffer:-False} +export overlong_buffer_len=${overlong_buffer_len:-$((1024 * 4))} +overlong_penalty_factor=1.0 + +export gen_prompt_bsz=${gen_prompt_bsz:-$((train_prompt_batch_size * 1))} + + +real_train_batch_size=$((train_prompt_batch_size * grpo_rollout_n)) +ppo_mini_batch_size=32 + + +lr=1e-6 + +# Algorithm +export temperature=${temperature:-1.0} +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +shuffle=False + +offload=False +max_tokens=$((max_prompt_length + max_response_length)) +gen_max_tokens=$((max_tokens * 2)) +log_prob_max_tokens=$((max_tokens * 2)) + + +export seq_dir=${seq_dir:-/afs/chatrl/users/kzl/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/init} +export log_dir=${log_dir:-/afs/chatrl/users/kzl/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/log} + +cap_dataset_size=$((1024 * 80000)) +filter_overlong_prompts=False + +#req_algo="long_short" +# req_algo="even_prompt" +# req_algo="even_token" +# agg="max" # sum / max + +export req_algo=${req_algo:-even_token} +export agg=${agg:-max} + +export entropy_coeff=${entropy_coeff:-0} +export entropy_max=${entropy_max:-null} + +percentile=90 +export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") + + +echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = $train_prompt_batch_size, nnode = $nnode, offload = $offload, max_tokens = $max_tokens, model = $model_name, vllm_tp = $vllm_tp, vllm_mem = $vllm_mem, seq_dir = $seq_dir, log_dir = $log_dir, cap_dataset_size = $cap_dataset_size, filter_overlong_prompts = $filter_overlong_prompts, max_prompt_length = $max_prompt_length, max_response_length = $max_response_length, req_algo = $req_algo, percentile = $percentile, agg = $agg" + +sleep 1 +export base_model_suffix=${base_model_suffix:-Base} +export experiment_name=${model_name}-${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_all_dapo_trick_${resume_type}_dapo_mix_zh_en + +rm -rf /workspace/tmp_tensorboard/* +export TENSORBOARD_DIR=/afs/chatrl/users/kzl/models/verl_rl_models/${project_name}/${experiment_name} + +#data.max_batch_size=${train_prompt_batch_size} \ +#python3 -u -m verl.trainer.main_ppo \ +# python3 -u -m verl.trainer.main_ppo_with_time \ +python3 -u -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name='dapo_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=${prompt_key} \ + data.train_batch_size=${train_prompt_batch_size} \ + actor_rollout_ref.rollout.n=${grpo_rollout_n} \ + data.shuffle=True \ + data.filter_overlong_prompts=${filter_overlong_prompts} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + req_scheduler.seq_dir="$seq_dir" \ + req_scheduler.log_dir="$log_dir" \ + req_scheduler.agg="$agg" \ + req_scheduler.algo="$req_algo" \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.truncation='left' \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.profiler.all_ranks=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + data.trust_remote_code=${trust_remote_code} \ + actor_rollout_ref.model.trust_remote_code=${trust_remote_code} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.filter_score_low=${filter_score_low} \ + algorithm.filter_groups.filter_score_high=${filter_score_high} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_tokens} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.model.path=${model_path} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=${lr} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${vllm_tp} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${gen_max_tokens} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.resume_mode=${resume_mode} \ + trainer.resume_from_path=${resume_from_path} \ + trainer.logger=['tensorboard'] \ + trainer.default_local_dir=/afs/chatrl/users/kzl/models/verl_rl_models/${project_name}/${experiment_name} \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=${val_before_train} \ + trainer.nnodes=${nnode} \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} 2>&1 | tee /afs/chatrl/users/kzl/code/verlmix/logs_sensecore/$experiment_name.log \ No newline at end of file diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0911.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0911.sh new file mode 100755 index 00000000000..e1dfb9f0814 --- /dev/null +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen-dapo-req-sched-dapo-trick-test0911.sh @@ -0,0 +1,234 @@ +pip install langdetect +pip install math-verify sympy +set -x + +# export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/DAPO-Math-17k/data/dapo-math-17k_dedup.parquet} +# export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/kzl/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8.parquet} +# export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_dapo-math-17k_by_acc_0.2_0.7.parquet} +# export deepmath_train_path=${deepmath_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_deepmath_by_acc_0.2_0.7.parquet} +# export math7d5k_train_path=${math7d5k_train_path:-/afs/chatrl/users/kzl/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_math_train_by_acc_0_0.7.parquet} + +export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8_no_prompt.parquet} +export aime2025_test_path=${aime2025_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2025/dapo_aime2025_sample8_no_prompt.parquet} + +# export aime2024_test_path_from_lyy=${aime2024_test_path_from_lyy:-/afs/chatrl/users/kzl/data/eval/aime2024_dapo_sample32_new.parquet} +# export aime2025_test_path_from_lyy=${aime2025_test_path_from_lyy:-/afs/chatrl/users/lyy/data/eval/aime2025_dapo_sample32.parquet} + +export dapo_train_path=/afs/chatrl/users/hxh/data/math_data/dapo-math/rule_based_rl/dapo-math-17k_dedup_no_prompt_sft_0614_acc_0d1-0d7.parquet +export math7d5k_train_path=/afs/chatrl/users/hxh/data/math_data/MATH_train/rule_based_rl/train_7d5k_math_verify_sft_0614_acc_0-0d7.parquet + +export math_zh=/afs/chatrl/users/kzl/code/awesome_scripts/math_zh_test/processed_data/0730-250627-0710_science_label_question_combined-12121_single_no_repeat_30.parquet + +export math_dapo_zh=/afs/chatrl/users/kzl/data/math_data/dapo-math/prompts/dapo-math-17k_zh.parquet +export math_dapo_en=/afs/chatrl/users/kzl/data/math_data/dapo-math/prompts/dapo-math-17k_en.parquet +# export train_files="['$math7d5k_train_path', '$dapo_train_path']" + +# export train_files="['$math7d5k_train_path', '$dapo_train_path']" +export train_files="['$math_dapo_zh','$math_dapo_en']" + +# train_files="['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']" +# export train_files=${train_files:-"['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']"} + +# test_files="['$aime2024_test_path_from_lyy', '$aime2025_test_path_from_lyy']" +test_files="['$aime2024_test_path', '$aime2025_test_path']" + +# resume config +export resume_mode=${resume_mode:-auto} +export resume_from_path=${resume_from_path:-null} +export model_path=/afs/chatrl/public/models/DeepSeek-R1-Distill-Qwen-7B +export model_name=$(basename "$model_path") + + +# project config +export project_name=${project_name:-verl_qwen_7b_dapo_req_sched_v0626} +# train params +export total_epochs=${total_epochs:-50} +export vllm_tp=${vllm_tp:-2} + +export train_prompt_batch_size=${train_prompt_batch_size:-32} +export grpo_rollout_n=${grpo_rollout_n:-16} +# model params +export max_response_length=${max_response_length:-20000} +export prompt_key=${prompt_key:-messages} +export resume_type=${resume_type:-no_resume} +# env config +export nnode=${WORLD_SIZE:-1} + +export ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size:-1} + +export filter_score_high=${filter_score_high:-1.1} +export filter_score_low=${filter_score_low:--1} +# export filter_score_high=${filter_score_high:-0.7} +# export filter_score_low=${filter_score_low:-0.2} + +export save_freq=${save_freq:-20} +export test_freq=${test_freq:-20} + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 + + +use_dynamic_bsz=True +infer_micro_batch_size=null + +max_prompt_length=$((1024 * 2)) + +export val_before_train=${val_before_train:-True} +export trust_remote_code=${trust_remote_code:-False} + +export enable_overlong_buffer=${enable_overlong_buffer:-False} +export overlong_buffer_len=${overlong_buffer_len:-$((1024 * 4))} +overlong_penalty_factor=1.0 + +export gen_prompt_bsz=${gen_prompt_bsz:-$((train_prompt_batch_size * 1))} + + +real_train_batch_size=$((train_prompt_batch_size * grpo_rollout_n)) +ppo_mini_batch_size=32 + + +lr=1e-6 + +# Algorithm +export temperature=${temperature:-1.0} +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +shuffle=False + +offload=False +max_tokens=$((max_prompt_length + max_response_length)) +gen_max_tokens=$((max_tokens * 2)) +log_prob_max_tokens=$((max_tokens * 2)) + + +export seq_dir=${seq_dir:-/afs/chatrl/users/kzl/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/init} +export log_dir=${log_dir:-/afs/chatrl/users/kzl/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/log} + +cap_dataset_size=$((1024 * 80000)) +filter_overlong_prompts=False + +#req_algo="long_short" +# req_algo="even_prompt" +# req_algo="even_token" +# agg="max" # sum / max + +export req_algo=${req_algo:-even_token} +export agg=${agg:-max} + +export entropy_coeff=${entropy_coeff:-0} +export entropy_max=${entropy_max:-null} + +percentile=90 +export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") + + +echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = $train_prompt_batch_size, nnode = $nnode, offload = $offload, max_tokens = $max_tokens, model = $model_name, vllm_tp = $vllm_tp, vllm_mem = $vllm_mem, seq_dir = $seq_dir, log_dir = $log_dir, cap_dataset_size = $cap_dataset_size, filter_overlong_prompts = $filter_overlong_prompts, max_prompt_length = $max_prompt_length, max_response_length = $max_response_length, req_algo = $req_algo, percentile = $percentile, agg = $agg" + +sleep 1 +export base_model_suffix=${base_model_suffix:-Base} +export experiment_name=${model_name}-${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_all_dapo_trick_${resume_type}_dapo_zh_en_0929 + +rm -rf /workspace/tmp_tensorboard/* +export TENSORBOARD_DIR=/afs/chatrl/users/kzl/models/verl_rl_models/${project_name}/${experiment_name} + +#data.max_batch_size=${train_prompt_batch_size} \ +#python3 -u -m verl.trainer.main_ppo \ +# python3 -u -m verl.trainer.main_ppo_with_time \ +python3 -u -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name='dapo_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=${prompt_key} \ + data.train_batch_size=${train_prompt_batch_size} \ + actor_rollout_ref.rollout.n=${grpo_rollout_n} \ + data.shuffle=True \ + data.filter_overlong_prompts=${filter_overlong_prompts} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + req_scheduler.seq_dir="$seq_dir" \ + req_scheduler.log_dir="$log_dir" \ + req_scheduler.agg="$agg" \ + req_scheduler.algo="$req_algo" \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.truncation='left' \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.profiler.all_ranks=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + data.trust_remote_code=${trust_remote_code} \ + actor_rollout_ref.model.trust_remote_code=${trust_remote_code} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.filter_score_low=${filter_score_low} \ + algorithm.filter_groups.filter_score_high=${filter_score_high} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_tokens} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.model.path=${model_path} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=${lr} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${vllm_tp} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${gen_max_tokens} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.resume_mode=${resume_mode} \ + trainer.resume_from_path=${resume_from_path} \ + trainer.logger=['tensorboard'] \ + trainer.default_local_dir=/afs/chatrl/users/kzl/models/verl_rl_models/${project_name}/${experiment_name} \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=${val_before_train} \ + trainer.nnodes=${nnode} \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} 2>&1 | tee /afs/chatrl/users/kzl/code/verlmix0923/logs_sensecore/$experiment_name.log \ No newline at end of file diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick-judge-model.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick-judge-model.sh new file mode 100755 index 00000000000..d99a8b8089f --- /dev/null +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick-judge-model.sh @@ -0,0 +1,224 @@ +set -x + +# export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/hxh/data/rule_based_rl/DAPO-Math-17k/data/dapo-math-17k_dedup.parquet} +# export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8.parquet} +export dapo_train_path=${dapo_train_path:-/afs/chatrl/users/hxh/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_dapo-math-17k_by_acc_0.2_0.7.parquet} +export deepmath_train_path=${deepmath_train_path:-/afs/chatrl/users/hxh/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_deepmath_by_acc_0.2_0.7.parquet} +export math7d5k_train_path=${math7d5k_train_path:-/afs/chatrl/users/hxh/data/rule_based_rl/filter_by_32b_cold_start_20250614/filtered_math_train_by_acc_0_0.7.parquet} + +export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8_no_prompt.parquet} +export aime2025_test_path=${aime2025_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2025/dapo_aime2025_sample8_no_prompt.parquet} + +# train_files="['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']" + + +export train_files=${train_files:-"['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']"} + +# test_files="['$aime2024_test_path', '$aime2025_test_path']" +export test_files=${test_files:-"['$aime2024_test_path', '$aime2025_test_path']"} + + +# resume config +export resume_mode=${resume_mode:-auto} +export resume_from_path=${resume_from_path:-null} +export model_path=${model_path:-/afs/chatrl/public/models/Qwen2.5-32B} +export model_name=$(basename "$model_path") + +# project config +export project_name=${project_name:-verl_dapo_math_grpo_dapo_req_sched} +# train params +export total_epochs=${total_epochs:-50} +export vllm_tp=${vllm_tp:-4} + +export train_prompt_batch_size=${train_prompt_batch_size:-512} +export grpo_rollout_n=${grpo_rollout_n:-16} +# model params +export max_response_length=${max_response_length:-20000} +export prompt_key=${prompt_key:-prompt} +export resume_type=${resume_type:-no_resume} +# env config +export nnode=${WORLD_SIZE:-1} + +export ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size:-1} + +export filter_score_high=${filter_score_high:-null} +export filter_score_low=${filter_score_low:-null} + + +export save_freq=${save_freq:-20} +export test_freq=${test_freq:-20} + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 + + +use_dynamic_bsz=True +infer_micro_batch_size=null + +max_prompt_length=$((1024 * 2)) + +export val_before_train=${val_before_train:-True} + +export trust_remote_code=${trust_remote_code:-False} + +export enable_overlong_buffer=${enable_overlong_buffer:-True} +export overlong_buffer_len=${overlong_buffer_len:-$((1024 * 4))} +overlong_penalty_factor=1.0 + +export gen_prompt_bsz=${gen_prompt_bsz:-$((train_prompt_batch_size * 1))} + + +real_train_batch_size=$((train_prompt_batch_size * grpo_rollout_n)) +ppo_mini_batch_size=32 + + +export lr=${lr:-1e-6} + +# Algorithm +export temperature=${temperature:-1.0} +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +shuffle=False + +offload=False +max_tokens=$((max_prompt_length + max_response_length)) +gen_max_tokens=$((max_tokens * 2)) +log_prob_max_tokens=$((max_tokens * 2)) + + +export seq_dir=${seq_dir:-/afs/chatrl/users/hxh/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/init} +export log_dir=${log_dir:-/afs/chatrl/users/hxh/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/log} + +cap_dataset_size=$((1024 * 80000)) +filter_overlong_prompts=False + +#req_algo="long_short" +# req_algo="even_prompt" +# req_algo="even_token" +# agg="max" # sum / max + +export req_algo=${req_algo:-even_token} +export agg=${agg:-max} + + +export entropy_coeff=${entropy_coeff:-0} +export entropy_max=${entropy_max:-null} + +percentile=90 +export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") + + +echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = $train_prompt_batch_size, nnode = $nnode, offload = $offload, max_tokens = $max_tokens, model = $model, vllm_tp = $vllm_tp, vllm_mem = $vllm_mem, seq_dir = $seq_dir, log_dir = $log_dir, cap_dataset_size = $cap_dataset_size, filter_overlong_prompts = $filter_overlong_prompts, max_prompt_length = $max_prompt_length, max_response_length = $max_response_length, req_algo = $req_algo, percentile = $percentile, agg = $agg" + +sleep 1 +export base_model_suffix=${base_model_suffix:-Base} +export experiment_name=${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_temp${temperature}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_entropy_coeff_${entropy_coeff}${resume_type} + +rm -rf /workspace/tmp_tensorboard/* +export TENSORBOARD_DIR=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} +export save_judge_path=/afs/chatrl/users/hxh/code/verl/logs/remote-reward/${project_name}-${experiment_name}.log + +#data.max_batch_size=${train_prompt_batch_size} \ +#python3 -u -m verl.trainer.main_ppo \ +# python3 -u -m verl.trainer.main_ppo_with_time \ +python3 -u -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name='dapo_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=${prompt_key} \ + data.train_batch_size=${train_prompt_batch_size} \ + actor_rollout_ref.rollout.n=${grpo_rollout_n} \ + data.shuffle=True \ + data.filter_overlong_prompts=${filter_overlong_prompts} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + req_scheduler.seq_dir="$seq_dir" \ + req_scheduler.log_dir="$log_dir" \ + req_scheduler.agg="$agg" \ + req_scheduler.algo="$req_algo" \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.truncation='left' \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.profiler.all_ranks=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + data.trust_remote_code=${trust_remote_code} \ + actor_rollout_ref.model.trust_remote_code=${trust_remote_code} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.filter_score_low=${filter_score_low} \ + algorithm.filter_groups.filter_score_high=${filter_score_high} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_tokens} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.model.path=${model_path} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=${lr} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${vllm_tp} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${gen_max_tokens} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.resume_mode=${resume_mode} \ + trainer.resume_from_path=${resume_from_path} \ + trainer.logger=['tensorboard'] \ + trainer.default_local_dir=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=${val_before_train} \ + trainer.nnodes=${nnode} \ + remote_reward.base_url=http://111.31.225.52:16669/v1 \ + remote_reward.save_judge_path=${save_judge_path} \ + remote_reward.api_key=EMPTY \ + remote_reward.model_name="Qwen3-30B-A3B" \ + reward_model.reward_manager=remote_batch \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} 2>&1 | tee /afs/chatrl/users/hxh/code/verl/logs_sensecore/$experiment_name.log \ No newline at end of file diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick.sh index 42c876ead2b..8ed3754b53a 100755 --- a/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick.sh +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick.sh @@ -14,7 +14,9 @@ export aime2025_test_path=${aime2025_test_path:-/afs/chatrl/users/hxh/data/rule_ export train_files=${train_files:-"['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']"} -test_files="['$aime2024_test_path', '$aime2025_test_path']" +# test_files="['$aime2024_test_path', '$aime2025_test_path']" +export test_files=${test_files:-"['$aime2024_test_path', '$aime2025_test_path']"} + # resume config export resume_mode=${resume_mode:-auto} @@ -39,9 +41,12 @@ export nnode=${WORLD_SIZE:-1} export ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size:-1} -export filter_score_high=${filter_score_high:-0.7} -export filter_score_low=${filter_score_low:-0.2} +export filter_score_high=${filter_score_high:-null} +export filter_score_low=${filter_score_low:-null} + +export save_freq=${save_freq:-20} +export test_freq=${test_freq:-20} use_kl_in_reward=False kl_coef=0.0 @@ -63,8 +68,12 @@ infer_micro_batch_size=null max_prompt_length=$((1024 * 2)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) +export val_before_train=${val_before_train:-True} + +export trust_remote_code=${trust_remote_code:-False} + +export enable_overlong_buffer=${enable_overlong_buffer:-True} +export overlong_buffer_len=${overlong_buffer_len:-$((1024 * 4))} overlong_penalty_factor=1.0 export gen_prompt_bsz=${gen_prompt_bsz:-$((train_prompt_batch_size * 1))} @@ -74,10 +83,10 @@ real_train_batch_size=$((train_prompt_batch_size * grpo_rollout_n)) ppo_mini_batch_size=32 -lr=1e-6 +export lr=${lr:-1e-6} # Algorithm -temperature=1.0 +export temperature=${temperature:-1.0} top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout @@ -103,6 +112,10 @@ filter_overlong_prompts=False export req_algo=${req_algo:-even_token} export agg=${agg:-max} + +export entropy_coeff=${entropy_coeff:-0} +export entropy_max=${entropy_max:-null} + percentile=90 export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") @@ -111,7 +124,7 @@ echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = sleep 1 export base_model_suffix=${base_model_suffix:-Base} -export experiment_name=Qwen25-32B-${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_all_dapo_trick_${resume_type}_${TIMESTAMP} +export experiment_name=${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_temp${temperature}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_entropy_coeff_${entropy_coeff}${resume_type} rm -rf /workspace/tmp_tensorboard/* export TENSORBOARD_DIR=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} @@ -146,6 +159,8 @@ python3 -u -m recipe.dapo.main_dapo \ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ + data.trust_remote_code=${trust_remote_code} \ + actor_rollout_ref.model.trust_remote_code=${trust_remote_code} \ algorithm.filter_groups.enable=${enable_filter_groups} \ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ algorithm.filter_groups.metric=${filter_groups_metric} \ @@ -168,7 +183,7 @@ python3 -u -m recipe.dapo.main_dapo \ actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ actor_rollout_ref.rollout.name=vllm \ @@ -197,7 +212,8 @@ python3 -u -m recipe.dapo.main_dapo \ trainer.project_name=${project_name} \ trainer.experiment_name=${experiment_name} \ trainer.n_gpus_per_node=8 \ + trainer.val_before_train=${val_before_train} \ trainer.nnodes=${nnode} \ - trainer.save_freq=10 \ - trainer.test_freq=20 \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ trainer.total_epochs=${total_epochs} 2>&1 | tee /afs/chatrl/users/hxh/code/verl/logs_sensecore/$experiment_name.log \ No newline at end of file diff --git a/examples_sensecore/grpo_scripts_verl_0626/qwen7b-dapo-test.sh b/examples_sensecore/grpo_scripts_verl_0626/qwen7b-dapo-test.sh new file mode 100755 index 00000000000..ffec4c55ed8 --- /dev/null +++ b/examples_sensecore/grpo_scripts_verl_0626/qwen7b-dapo-test.sh @@ -0,0 +1,218 @@ +set -x + +export math7d5k_train_path=${math7d5k_train_path:-/afs/chatrl/users/hxh/data/math_data/MATH_train/rule_based_rl/train_7d5k_with_refined_answers_math_verify_telechat3_base_rl_onplicy_step400_acc_0-0d7.parquet} + +export aime2024_test_path=${aime2024_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2024/dapo_aime2024_sample8_no_prompt.parquet} +export aime2025_test_path=${aime2025_test_path:-/afs/chatrl/users/hxh/data/rule_based_rl/AIME-2025/dapo_aime2025_sample8_no_prompt.parquet} + +# train_files="['$math7d5k_train_path', '$dapo_train_path', '$deepmath_train_path']" + + +export train_files=${train_files:-"['$math7d5k_train_path']"} + +# test_files="['$aime2024_test_path', '$aime2025_test_path']" +export test_files=${test_files:-"['$aime2024_test_path', '$aime2025_test_path']"} + + +# resume config +export resume_mode=${resume_mode:-auto} +export resume_from_path=${resume_from_path:-null} +export model_path=${model_path:-/afs/chatrl/public/models/DeepSeek-R1-Distill-Qwen-7B} +export model_name=$(basename "$model_path") + +# project config +export project_name=${project_name:-verl_dapo_math_grpo_test} +# train params +export total_epochs=${total_epochs:-50} +export vllm_tp=${vllm_tp:-4} + +export train_prompt_batch_size=${train_prompt_batch_size:-32} +export grpo_rollout_n=${grpo_rollout_n:-16} +# model params +export max_response_length=${max_response_length:-8000} +export prompt_key=${prompt_key:-messages} +export resume_type=${resume_type:-no_resume} +# env config +export nnode=${WORLD_SIZE:-1} + +export ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size:-1} + +export filter_score_high=${filter_score_high:-null} +export filter_score_low=${filter_score_low:-null} + + +export save_freq=${save_freq:-20} +export test_freq=${test_freq:-20} + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 + + +use_dynamic_bsz=True +infer_micro_batch_size=null + +max_prompt_length=$((1024 * 2)) + +export val_before_train=${val_before_train:-True} + +export trust_remote_code=${trust_remote_code:-True} + +export enable_overlong_buffer=${enable_overlong_buffer:-True} +export overlong_buffer_len=${overlong_buffer_len:-$((1024 * 4))} +overlong_penalty_factor=1.0 + +export gen_prompt_bsz=${gen_prompt_bsz:-$((train_prompt_batch_size * 1))} + + +real_train_batch_size=$((train_prompt_batch_size * grpo_rollout_n)) +ppo_mini_batch_size=32 + + +lr=1e-6 + +# Algorithm +export temperature=${temperature:-1.0} +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +shuffle=False + +offload=False +max_tokens=$((max_prompt_length + max_response_length)) +gen_max_tokens=$((max_tokens * 2)) +log_prob_max_tokens=$((max_tokens * 2)) + + +export seq_dir=${seq_dir:-/afs/chatrl/users/hxh/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/init} +export log_dir=${log_dir:-/afs/chatrl/users/hxh/data/req_sched_seq_dir/filter_by_32b_cold_start_20250614/log} + +cap_dataset_size=$((1024 * 80000)) +filter_overlong_prompts=False + +#req_algo="long_short" +# req_algo="even_prompt" +# req_algo="even_token" +# agg="max" # sum / max + +export req_algo=${req_algo:-even_token} +export agg=${agg:-max} + + +export entropy_coeff=${entropy_coeff:-0} +export entropy_max=${entropy_max:-null} + +percentile=90 +export TIMESTAMP=$(date +"%Y-%m-%d_%H-%M-%S") + + +echo "real_train_batch_size = $real_train_batch_size, train_prompt_batch_size = $train_prompt_batch_size, nnode = $nnode, offload = $offload, max_tokens = $max_tokens, model = $model, vllm_tp = $vllm_tp, vllm_mem = $vllm_mem, seq_dir = $seq_dir, log_dir = $log_dir, cap_dataset_size = $cap_dataset_size, filter_overlong_prompts = $filter_overlong_prompts, max_prompt_length = $max_prompt_length, max_response_length = $max_response_length, req_algo = $req_algo, percentile = $percentile, agg = $agg" + +sleep 1 +export base_model_suffix=${base_model_suffix:-Base} +export experiment_name=${base_model_suffix}_dapo-${req_algo}-${agg}_${nnode}node_rollout${grpo_rollout_n}_temp${temperature}_bs${train_prompt_batch_size}_minibatch${ppo_mini_batch_size}_lr${lr}_sp${ulysses_sequence_parallel_size}_tp${vllm_tp}_maxlen${max_response_length}_overlong_punish_${enable_overlong_buffer}_entropy_coeff_${entropy_coeff}${resume_type} + +rm -rf /workspace/tmp_tensorboard/* +export TENSORBOARD_DIR=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} + +#data.max_batch_size=${train_prompt_batch_size} \ +#python3 -u -m verl.trainer.main_ppo \ +# python3 -u -m verl.trainer.main_ppo_with_time \ +python3 -u -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name='dapo_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=${prompt_key} \ + data.train_batch_size=${train_prompt_batch_size} \ + actor_rollout_ref.rollout.n=${grpo_rollout_n} \ + data.shuffle=True \ + data.filter_overlong_prompts=${filter_overlong_prompts} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + req_scheduler.seq_dir="$seq_dir" \ + req_scheduler.log_dir="$log_dir" \ + req_scheduler.agg="$agg" \ + req_scheduler.algo="$req_algo" \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.truncation='left' \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.profiler.all_ranks=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + data.trust_remote_code=${trust_remote_code} \ + actor_rollout_ref.model.trust_remote_code=${trust_remote_code} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.filter_score_low=${filter_score_low} \ + algorithm.filter_groups.filter_score_high=${filter_score_high} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_tokens} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${log_prob_max_tokens} \ + actor_rollout_ref.model.path=${model_path} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=${lr} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${ulysses_sequence_parallel_size} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${vllm_tp} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${gen_max_tokens} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.resume_mode=${resume_mode} \ + trainer.resume_from_path=${resume_from_path} \ + trainer.logger=['tensorboard'] \ + trainer.default_local_dir=/afs/chatrl/users/hxh/models/verl_rl_models/${project_name}/${experiment_name} \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=${val_before_train} \ + trainer.nnodes=${nnode} \ + remote_reward.base_url=http://111.31.225.52:16669/v1 \ + remote_reward.api_key=EMPTY \ + remote_reward.model_name="Qwen3-30B-A3B" \ + reward_model.reward_manager=remote_batch \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} 2>&1 | tee /afs/chatrl/users/hxh/code/verl/logs_sensecore/$experiment_name.log \ No newline at end of file diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 10862796658..bdea9ccbb2c 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -34,6 +34,9 @@ compute_throughout_metrics, compute_timing_metrics, reduce_metrics, + compute_mix_metrics, + compute_acc_metrics, + compute_mix_language_metrics, ) from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask from verl.utils.debug import marked_timer @@ -459,6 +462,9 @@ def fit(self): # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_mix_metrics(batch=batch)) + metrics.update(compute_acc_metrics(batch=batch)) + metrics.update(compute_mix_language_metrics(batch=batch, tokenizer_name_or_path="/afs/chatrl/public/models/DeepSeek-R1-Distill-Qwen-7B")) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: implement actual tflpo and theoretical tflpo n_gpus = self.resource_pool_manager.get_n_gpus() diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index d16500d104b..acb91541a48 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -33,7 +33,6 @@ def main(config): def run_ppo(config) -> None: - print(f"> config = {config}") if not ray.is_initialized(): # this is for local ray cluster ray.init( diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 8e9ee290c71..e6f7261211a 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -50,6 +50,11 @@ required=False, help="test correctness of hf_model, , with hf_model in checkpoint.contents", ) +parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether to trust remote code when loading the model config", +) args = parser.parse_args() os.makedirs(args.target_dir, exist_ok=True) if args.test: @@ -181,7 +186,7 @@ def process_one_shard(rank, model_state_dict_lst): print("Writing to local disk") hf_path = os.path.join(local_dir, "huggingface") if args.target_dir is None else args.target_dir - config = AutoConfig.from_pretrained(args.hf_model_path) + config = AutoConfig.from_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code) if "ForTokenClassification" in config.architectures[0]: auto_model = AutoModelForTokenClassification @@ -193,7 +198,7 @@ def process_one_shard(rank, model_state_dict_lst): raise NotImplementedError(f"Unknown architecture {config['architectures']}") with torch.device("meta"): - model = auto_model.from_config(config, torch_dtype=torch.bfloat16) + model = auto_model.from_config(config, torch_dtype=torch.bfloat16, trust_remote_code=args.trust_remote_code) model.to_empty(device="cpu") print(f"Saving model to {hf_path}") @@ -259,7 +264,7 @@ def process_one_shard(shard_dir, model_state_dict_lst): process_one_shard(sharded_dir, model_state_dict_lst) state_dict = {} - config = AutoConfig.from_pretrained(args.hf_model_path) + config = AutoConfig.from_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code) if args.test: ref_state_dict = load_file(os.path.join(args.test_hf_dir, "model.safetensors")) @@ -398,7 +403,7 @@ def merge_across_tp(key, tp_data): raise NotImplementedError(f"Unknown architecture {config['architectures']}") with torch.device("meta"): - model = auto_model.from_config(config, torch_dtype=torch.bfloat16) + model = auto_model.from_config(config, torch_dtype=torch.bfloat16, trust_remote_code=args.trust_remote_code) model.to_empty(device="cpu") print(f"Saving model to {hf_path}") diff --git a/scripts/sensecore/grpo_req_sched_0626/start_run_qwen2-32b-req-sched-dapo-judge.sh b/scripts/sensecore/grpo_req_sched_0626/start_run_qwen2-32b-req-sched-dapo-judge.sh new file mode 100755 index 00000000000..cb7a5fb6d7d --- /dev/null +++ b/scripts/sensecore/grpo_req_sched_0626/start_run_qwen2-32b-req-sched-dapo-judge.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# 打印环境变量信息 +env +echo "-----------------------" +echo "MASTER_ADDR = ${MASTER_ADDR}" +echo "MASTER_PORT = ${MASTER_PORT}" +echo "RANK = ${RANK}" + +# 进入项目目录 +cd /afs/chatrl/users/hxh/code/verl + +pip install re -i https://mirrors.aliyun.com/pypi/simple/ +pip install math_verify -i https://mirrors.aliyun.com/pypi/simple/ +pip install sympy -i https://mirrors.aliyun.com/pypi/simple/ +# 安装当前目录下的Python包 +pip install -v -e . -i https://mirrors.aliyun.com/pypi/simple/ + +# 根据环境变量RANK执行不同的任务 +if [ "$RANK" = "0" ]; then + # 主节点启动Ray集群 + echo "Starting Ray head node..." + ray start --head --port=$MASTER_PORT & + + # 等待Ray启动完成 + while true; do + echo "Checking Ray cluster status..." + ray status + NODE_COUNT=$(ray status | grep -c '^ 1 node_') + EXPECTED_NODE_COUNT=${WORLD_SIZE} # 获得任务的节点数 + + echo "Current alive nodes: ${NODE_COUNT}/${EXPECTED_NODE_COUNT}" + + if [ "$NODE_COUNT" -eq "$EXPECTED_NODE_COUNT" ]; then + echo "All Ray nodes are ready." + break + else + echo "Waiting for all Ray nodes to be ready..." + sleep 10 + fi + done + + # 检查Ray状态 + echo "Checking Ray status:" + ray status + + # 执行训练脚本,并记录日志 + echo "Running training script..." + ./examples_sensecore/grpo_scripts_verl_0626/qwen32b-dapo-req-sched-dapo-trick-judge-model.sh +else + # 工作节点连接到Ray主节点 + echo "Starting Ray worker node..." + ray start --address=${MASTER_ADDR}:${MASTER_PORT} --block +fi diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 8970f5120e8..60d6a0e7d17 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -218,6 +218,7 @@ actor_rollout_ref: # Entropy regularization coefficient in PPO loss entropy_coeff: 0 + entropy_max: null # Whether to use KL loss instead of KL reward penalty. True for GRPO use_kl_loss: false @@ -842,10 +843,10 @@ reward_model: ranks: null remote_reward: - base_url: "http://127.0.0.1:6664/v1" - api_key: "EMPTY" - model_name: "Qwen2.5-32B-Instruct" - save_judge_path: "/afs/chatrl/users/hwq/code/verl-remote-reward/examples_sensecore/grpo_remote_reward/output_test_para.jsonl" + base_url: "http://111.31.225.52:16669/v1" + api_key: null + model_name: "Qwen3-30B-A3B" + save_judge_path: "/afs/chatrl/users/hxh/code/verl/logs/remote-reward/output_test_para.jsonl" # custom reward function definition custom_reward_function: diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index f246f47c354..61ed16a9d96 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -25,6 +25,9 @@ from verl import DataProto from verl.utils.import_utils import deprecated +from verl.utils.reward_score.language_detect import detect_language +from verl.utils.tokenizer import hf_tokenizer + @deprecated("verl.utils.metric.reduce_metrics") def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -424,3 +427,134 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str] data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) return data_src2var2metric2val + + +def compute_mix_metrics(batch: DataProto) -> dict: + """ + 计算mix相关指标 + """ + reward_scores = batch.batch["token_level_scores"].sum(-1) + + num_samples = len(reward_scores) + num_mix = 0 + num_non_mix = 0 + num_mix_correct = 0 + num_mix_wrong = 0 + num_non_mix_correct = 0 + num_non_mix_wrong = 0 + + for score in reward_scores: + if score == 0.2: + num_mix_correct += 1 + num_mix += 1 + elif score == -1.0: + num_mix_wrong += 1 + num_mix += 1 + elif score == 1.0: + num_non_mix_correct += 1 + num_non_mix += 1 + elif score == -0.8: + num_non_mix_wrong += 1 + num_non_mix += 1 + else: + num_non_mix += 1 + + mix_ratio = num_mix / num_samples if num_samples > 0 else 0 + mix_acc = num_mix_correct / num_mix if num_mix > 0 else 0 + non_mix_acc = num_non_mix_correct / num_non_mix if num_non_mix > 0 else 0 + + return { + "training/mix_sample_ratio": mix_ratio, + "training/mix_accuracy": mix_acc, + "training/non_mix_accuracy": non_mix_acc, + } + +def compute_acc_metrics(batch: DataProto) -> dict: + """ + 计算整体准确率指标 + """ + reward_scores = batch.batch["token_level_scores"].sum(-1) + + num_samples = len(reward_scores) + num_correct = 0 + num_incorrect = 0 + + for score in reward_scores: + if score == 0.2 or score == 1.0: + num_correct += 1 + elif score == -1.0 or score == -0.8: + num_incorrect += 1 + else: + # 如果有其他分数,可以根据需求决定是否计入统计 + pass + + accuracy = num_correct / (num_correct + num_incorrect) if (num_correct + num_incorrect) > 0 else 0 + + return { + "training/accuracy": accuracy, + } + +def compute_mix_language_metrics(batch: DataProto, tokenizer_name_or_path: str) -> Dict[str, float]: + """ + 计算batch中mix语言的比率,以及score>0和score<0中mix的比率。 + + Args: + batch: DataProto对象,包含token_level_scores和responses等 + tokenizer_name_or_path: 用于解码的tokenizer路径或名称 + + Returns: + dict,包含三个key: + - "mix_ratio": batch中mix语言样本占比 + - "mix_ratio_score_pos": score>0样本中mix占比 + - "mix_ratio_score_neg": score<0样本中mix占比 + """ + # 加载tokenizer + tokenizer = hf_tokenizer(tokenizer_name_or_path) + + # 获取batch中token_level_scores的序列分数(sum) + sequence_scores = batch.batch["token_level_scores"].sum(-1).tolist() # list[float] + + # 获取batch中responses的tokenid,形状(batch_size, seq_len) + responses = batch.batch["responses"].tolist() + + # 解码responses为文本 + texts = [tokenizer.decode(resp, skip_special_tokens=True).strip() for resp in responses] + + # 统计 + total = len(texts) + if total == 0: + return { + "training/mix_ratio": 0.0, + "training/mix_ratio_score_pos": 0.0, + "training/mix_ratio_score_neg": 0.0, + } + + mix_count = 0 + mix_score_pos_count = 0 + mix_score_neg_count = 0 + score_pos_count = 0 + score_neg_count = 0 + + for text, score in zip(texts, sequence_scores): + lang = detect_language(text) + is_mix = (lang == "mix") + if is_mix: + mix_count += 1 + if score > 0: + score_pos_count += 1 + if is_mix: + mix_score_pos_count += 1 + elif score < 0: + score_neg_count += 1 + if is_mix: + mix_score_neg_count += 1 + + mix_ratio = mix_count / total + mix_ratio_score_pos = mix_score_pos_count / score_pos_count if score_pos_count > 0 else 0.0 + mix_ratio_score_neg = mix_score_neg_count / score_neg_count if score_neg_count > 0 else 0.0 + + return { + "training/mix_ratio": mix_ratio, + "training/mix_ratio_score_pos": mix_ratio_score_pos, + "training/mix_ratio_score_neg": mix_ratio_score_neg, + } \ No newline at end of file diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 647656016fe..6898858d378 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -32,8 +32,9 @@ from typing import Optional, Type, Dict from codetiming import Timer from contextlib import contextmanager -import glob +from datetime import datetime +import glob import numpy as np import ray import torch @@ -365,7 +366,7 @@ def load_table(self): with open(json_file, 'r') as f: data = json.load(f) - print(f"[ReqScheduler] data keys = {data.keys()} in {filename}") + # print(f"[ReqScheduler] data keys = {data.keys()} in {filename}") # 按格式保存 ps = data['prompts'] ls = data['lengths'] @@ -373,7 +374,7 @@ def load_table(self): p = tuple(p) if p not in ans: ans[p] = l - print(f"[ReqScheduler] Processed {filename}, found {len(ans)} unique prompts") + # print(f"[ReqScheduler] Processed {filename}, found {len(ans)} unique prompts") except Exception as e: print(f"[ReqScheduler] Error processing {filename}: {str(e)}") raise e @@ -844,6 +845,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl collate_fn = default_collate_fn + print(f"> [debug] Size of train dataset: {len(self.train_dataset)}") self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), @@ -967,7 +969,8 @@ def _validate(self): # Store original inputs input_ids = test_batch.batch["input_ids"] # TODO: Can we keep special tokens except for padding tokens? - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + # @xiaohui: keep special tokens + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in input_ids] sample_inputs.extend(input_texts) batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] @@ -1016,18 +1019,22 @@ def _validate(self): # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + # @xiaohui: keep special tokens + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in output_ids] sample_outputs.extend(output_texts) test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function + print(f">> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, [ray_trainer.py] evaluate start...") result = self.val_reward_fn(test_batch, return_dict=True) reward_tensor = result["reward_tensor"] scores = reward_tensor.sum(-1).cpu().tolist() sample_scores.extend(scores) reward_extra_infos_dict["reward"].extend(scores) + print(f">> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, [ray_trainer.py] evaluate end !!!") + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") if "reward_extra_info" in result: for key, lst in result["reward_extra_info"].items(): @@ -1300,6 +1307,7 @@ def fit(self): last_val_metrics = None for epoch in range(self.config.trainer.total_epochs): + print(f"Epoch {epoch} / {self.config.trainer.total_epochs}") for batch_dict in self.train_dataloader: do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False if do_profile: @@ -1493,8 +1501,9 @@ def fit(self): if rollout_data_dir: with marked_timer("dump_rollout_generations", timing_raw, color="green"): print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + # @xiaohui: keep special tokens + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=False) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=False) scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() self._dump_generations( inputs=inputs, diff --git a/verl/trainer/ppo/ray_trainer_with_time.py b/verl/trainer/ppo/ray_trainer_with_time.py deleted file mode 100644 index 28502d83084..00000000000 --- a/verl/trainer/ppo/ray_trainer_with_time.py +++ /dev/null @@ -1,1779 +0,0 @@ -#opyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -# tp size for each worker -#MODEL_DEPLOYMENT = [1, 1, 1, 1, 1, 1, 1, 1] -#MODEL_DEPLOYMENT = [2, 1, 1, 1, 1, 1, 1] -#MODEL_DEPLOYMENT = [2,2,2,2] -MODEL_DEPLOYMENT = None - -import os -import sys -import contextlib -import io -import uuid -from collections import defaultdict - -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum -from pprint import pprint -from typing import Type, Dict -from copy import deepcopy -from time import time -import json -import heapq -import glob - - -import numpy as np -import pandas as pd -from codetiming import Timer -from omegaconf import OmegaConf, open_dict -from torch.utils.data import RandomSampler, SequentialSampler -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm import tqdm -from verl import DataProto -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, - process_validation_metrics, -) -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, -) -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo import core_algos -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn, process_image -from verl.utils.model import compute_position_id_with_mask -import verl.utils.torch_functional as verl_F - - -WorkerType = Type[Worker] - - -@contextlib.contextmanager -def suppress_stdout(): - old_stdout = sys.stdout - sys.stdout = io.StringIO() - try: - yield - finally: - sys.stdout = old_stdout - - -def unpad_responses(padded_tensor, pad_token_id): - if isinstance(pad_token_id, list): - # from all worker - pid = pad_token_id[0] - for worker_pad_token_id in pad_token_id: - if worker_pad_token_id != pid: - raise ValueError("pad_token_id is not the same across all workers") - pad_token_id = pid - - padded_tensor = padded_tensor.cpu() - # Convert tensor to list if it's a tensor - if isinstance(padded_tensor, torch.Tensor): - padded_list = padded_tensor.tolist() - else: - padded_list = padded_tensor - - # Reconstruct original responses by removing padding tokens - unpadded_responses = [] - for padded_response in padded_list: - # Find where padding starts (first occurrence of pad_token_id) - try: - pad_start_idx = padded_response.index(pad_token_id) - # Get only the tokens before padding - original_response = padded_response[:pad_start_idx] - except ValueError: - # No padding found, use the full response - original_response = padded_response - - unpadded_responses.append(original_response) - return unpadded_responses - -class RLHFDatasetFilter(RLHFDataset): - def __init__(self, - parquet_files, - tokenizer, - processor, - prompt_key='prompt', - image_key='images', - filter_prompts=True, - cache_dir='~/.cache/verl/rlhf', - chat_template_func=None, - return_raw_chat=False, - truncation='error', - # - filter_overlong_prompts=False, # NOTE: this will filter both prompt and responses - min_prompt_length=None, - max_prompt_length=1024, - min_response_length=None, - max_response_length=None, - cap_dataset_size=None, - req_scheduler=None, - ): - self.min_prompt_length = min_prompt_length - self.max_prompt_length = max_prompt_length - self.min_response_length = min_response_length - self.max_response_length = max_response_length - # - self.filter_overlong_prompts = filter_overlong_prompts - self.cap_dataset_size = cap_dataset_size - self.req_scheduler = req_scheduler - - super().__init__( - parquet_files=parquet_files, - tokenizer=tokenizer, - processor=processor, - prompt_key=prompt_key, - image_key=image_key, - max_prompt_length=max_prompt_length, - filter_prompts=filter_prompts, - cache_dir=cache_dir, - chat_template_func=chat_template_func, - return_raw_chat=return_raw_chat, - truncation=truncation, - filter_overlong_prompts=filter_overlong_prompts, - ) - - def _generate_cache_id(self): - import hashlib - """Generate a unique identifier for the current dataset configuration""" - # [RLHFDatasetFilter]: _generate_cache_id, self.parquet_files=['/afs/chatrl/users/hxh/data/rule_based_rl/DAPO-Math-17k/data/dapo-math-17k_dedup.parquet']. - print(f'[RLHFDatasetFilter]: _generate_cache_id, {self.parquet_files=}.') - # Create a string containing all parameters that would affect filtering - config_str = str(sorted(self.parquet_files)) + str(self.max_prompt_length) + str(self.min_prompt_length) - # Hash the configuration string to create a shorter identifier - return hashlib.md5(config_str.encode()).hexdigest()[:10] - - def _read_files_and_tokenize(self): - import pickle - # Create a cache identifier based on input files and filtering parameters - cache_id = self._generate_cache_id() - cache_dir = os.path.dirname(self.parquet_files[0]) - cache_file = os.path.join(cache_dir, f"filtered_data_{cache_id}.parquet") - cache_metadata = os.path.join(cache_dir, f"metadata_{cache_id}.pkl") - # [RLHFDatasetFilter]: cache_file='/afs/chatrl/users/hxh/data/rule_based_rl/DAPO-Math-17k/data/filtered_data_8aff3605b0.parquet'. - print(f'[RLHFDatasetFilter]: {cache_file=}.') - - # Check if cached filtered data exists - if os.path.exists(cache_file) and os.path.exists(cache_metadata): - try: - # Load metadata to verify cache is valid - with open(cache_metadata, 'rb') as f: - metadata = pickle.load(f) - - # Verify the cache is still valid for our current parameters - if (metadata['max_prompt_length'] == self.max_prompt_length and - metadata['min_prompt_length'] == self.min_prompt_length): - print(f"[RLHFDatasetFilter] Loading pre-filtered dataset from cache: {cache_file}") - self.dataframe = pd.read_parquet(cache_file) - print(f'[RLHFDatasetFilter]: {len(self.dataframe)=} {list(self.dataframe.columns)}') - return - - except Exception as e: - print(f"[RLHFDatasetFilter] Cache loading failed: {e}. Rebuilding cache.") - - # If cache doesn't exist or is invalid, process the data - # Read and concatenate all input files - print("[RLHFDatasetFilter] Processing data and building cache...") - dataframes = [] - for parquet_file in self.parquet_files: - dataframe = pd.read_parquet(parquet_file) - dataframes.append(dataframe) - - # XXX, for req scheduling, we need to build a table map from req_id to output len - # the dataset is too large, we just hacky downsmaple for now - self.dataframe = pd.concat(dataframes) - full_len = len(self.dataframe) - if self.cap_dataset_size is not None: - self.dataframe = self.dataframe.iloc[:self.cap_dataset_size] - print(f'[RLHFDatasetFilter]: {full_len=} {len(self.dataframe)=} {list(self.dataframe.columns)}') - - # apply filter - if self.filter_overlong_prompts: - tokenizer = self.tokenizer - prompt_key = self.prompt_key - - # add a new col for efficient filtering! - new_key = 'applied_chat_template_prompts' - self.dataframe[new_key] = self.dataframe[prompt_key].apply(lambda prompt: tokenizer.apply_chat_template(prompt, add_generation_prompt=True,)) - - # filter prompt - t1 = time() - def filter_long(doc): - return len(doc[new_key]) <= self.max_prompt_length - def filter_short(doc): - return len(doc[new_key]) >= self.min_prompt_length - if self.min_prompt_length is not None: - self.dataframe = self.dataframe[self.dataframe.apply(filter_short, axis=1)] - if self.max_prompt_length is not None: - self.dataframe = self.dataframe[self.dataframe.apply(filter_long, axis=1)] - t2 = time() - print(f'[RLHFDatasetFilter] filter prompt: {len(self.dataframe)=}, {self.min_prompt_length}:{self.max_prompt_length}, time cost: {t2-t1:.2f}s') - - # filter response - t1 = time() - def filter_long(doc): - outlen = self.req_scheduler.lookup_table(doc[new_key]) - if outlen is None: - return False - return outlen <= self.max_response_length - def filter_short(doc): - outlen = self.req_scheduler.lookup_table(doc[new_key]) - if outlen is None: - return False - return outlen >= self.min_response_length - # - if self.min_prompt_length is not None: - self.dataframe = self.dataframe[self.dataframe.apply(filter_short, axis=1)] - if self.max_prompt_length is not None: - self.dataframe = self.dataframe[self.dataframe.apply(filter_long, axis=1)] - t2 = time() - print(f'[RLHFDatasetFilter] filter response: {len(self.dataframe)=}, {self.min_response_length}:{self.max_response_length}, time cost: {t2-t1:.2f}s') - - def __getitem__(self, item): - # print(f'[RLHFDatasetFilter] items: {item} {type(item)}, end=') - row_dict: dict = self.dataframe.iloc[item].to_dict() - # 原数据格式 - # 输出结果:row_dict keys: dict_keys(['data_source', 'prompt', 'ability', 'reward_model', 'extra_info']) - # print(f'[RLHFDatasetFilter] row_dict keys: {row_dict.keys()}') - - chat = row_dict.pop(self.prompt_key) - # self.prompt_key='prompt' or message - # print(f'[RLHFDatasetFilter] {self.prompt_key=}') - - prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) - - assert self.image_key not in row_dict, 'multi-modal is not supported yet' - raw_prompt = prompt_with_chat_template - - input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template, - tokenizer=self.tokenizer, - max_length=self.max_prompt_length, - pad_token_id=self.tokenizer.pad_token_id, - left_pad=True, - truncation=self.truncation, - ) - position_ids = compute_position_id_with_mask(attention_mask) - - row_dict['input_ids'] = input_ids[0] - row_dict['attention_mask'] = attention_mask[0] - row_dict['position_ids'] = position_ids[0] - row_dict['raw_prompt_ids'] = self.tokenizer.encode(raw_prompt, add_special_tokens=False) - - # encode prompts without chat template - if self.return_raw_chat: - row_dict['raw_prompt'] = chat.tolist() - - # add index for each prompt - index = row_dict.get("extra_info", {}).get("index", 0) - row_dict["index"] = index - - return row_dict - - -class ReqScheduler: - def __init__(self, config): - self.config = config - - # prompt_ids -> len(reponse) - self.table: dict[tuple[int], int] = self.load_table() - - def load_table(self): - ''' 加载预存的 prompts 信息 - 预存的 table 数据格式 - { - "prompts": [ - [prompt_token_ids_1], - [prompt_token_ids_2], - ... - ], - "lengths": [ - [120, 88, 85, 92, 95, 100, 90, 110], // prompt 1 对应 sample n 个 response 长度 - [105, 90, 95, 92, 100, 94, 90, 88], // prompt 2 - ... - ], - "stats": [ // 初始预计算存储,仍可保留便于快速调用 - {"max": 120, "min": 85, "mean": 97.5, "std": 10.2, "sum": 780}, - {"max": 105, "min": 88, "mean": 94.3, "std": 5.6, "sum": 754}, - ... - ] - } - ''' - if self.config.seq_dir is None: - return {} - - # Find all JSON files in the directory - json_files = glob.glob(os.path.join(self.config.seq_dir, "*.json")) - print(f"[ReqScheduler] Found {len(json_files)} JSON files to process") - - # prompts -> list[responses] - ans = {} - for json_file in json_files: - filename = os.path.basename(json_file) - #if key not in filename: - # continue - try: - with open(json_file, 'r') as f: - data = json.load(f) - - # [ReqScheduler] data keys = dict_keys(['prompts', 'response', 'reqs_idx', 'outlens']) - print(f"[ReqScheduler] data keys = {data.keys()} in {filename}") - # 按格式保存 - ps = data['prompts'] - ls = data['lengths'] - for p, l in zip(ps, ls): - p = tuple(p) - if p not in ans: - ans[p] = l - print(f"[ReqScheduler] Processed {filename}, found {len(ans)} unique prompts") - except Exception as e: - print(f"[ReqScheduler] Error processing {filename}: {str(e)}") - raise e - - # Aggregate prompts -> responses - agg = self.config.get('agg', 'mean') - if agg == 'max': - ans = {k: max(v) for k, v in ans.items()} - elif agg == 'min': - ans = {k: min(v) for k, v in ans.items()} - elif agg == 'mean': - ans = {k: int(np.mean(v)) for k, v in ans.items()} - elif agg =='median': - ans = {k: int(np.median(v)) for k, v in ans.items()} - elif agg == 'sum': - ans = {k: sum(v) for k, v in ans.items()} - else: - raise ValueError(f"Unknown agg {agg}") - print(f'[ReqScheduler] Table-Size: {len(ans)=}') - return ans - - def lookup_table(self, prompt): - ''' 根据 table 预存的信息 查找 prompt 的相关信息 - ''' - if isinstance(prompt, list): - prompt = tuple(prompt) - assert isinstance(prompt, tuple), f"prompt type {type(prompt)} is not supported" - if prompt in self.table: - # print(f"[ReqScheduler] Found prompt {len(prompt)} in table with response length {self.table[prompt]}") - return self.table[prompt] - return None - - def update_table(self, raw_prompt_ids, responses): - new_table = {} - for p, r in zip(raw_prompt_ids, responses): - p = tuple(p) - r = tuple(r) - if p not in new_table: - new_table[p] = [] - new_table[p].append(len(r)) - - # Aggregate prompts -> responses - agg = self.config.get('agg', 'mean') - if agg == 'max': - new_table = {k: max(v) for k, v in new_table.items()} - elif agg == 'min': - new_table = {k: min(v) for k, v in new_table.items()} - elif agg == 'mean': - new_table = {k: int(np.mean(v)) for k, v in new_table.items()} - elif agg =='median': - new_table = {k: int(np.median(v)) for k, v in new_table.items()} - elif agg == 'sum': - new_table = {k: sum(v) for k, v in new_table.items()} - else: - raise ValueError(f"Unknown agg {agg}") - - # add or overwrite - for k, v in new_table.items(): - self.table[k] = v - print(f'[ReqScheduler] in update_table, Table-Size: {len(self.table)=}') - - def log_seqlen(self, raw_prompt_ids, responses, prefix): - print(f'[ReqScheduler] in log_seqlen, {type(raw_prompt_ids)}, {type(responses)}, {len(raw_prompt_ids)}, {len(responses)}') - assert len(raw_prompt_ids) == len(responses), f'{len(raw_prompt_ids)}, {len(responses)}' - prompts_dict = {} - prompts, response = [], [] - for p, r in zip(raw_prompt_ids, responses): - if tuple(p) not in prompts_dict: - prompts_dict[tuple(p)] = [] - prompts_dict[tuple(p)].append(len(r)) - - for pid in prompts_dict: - prompts.append(list(pid)) - response.append(prompts_dict[pid]) - - log_dir = self.config.log_dir - os.makedirs(log_dir, exist_ok=True) - data_files = glob.glob(f"{log_dir}/{prefix}_*.json") - file_num = len(data_files) + 1 - output_file = f"{log_dir}/{prefix}_{file_num}.json" - with open(output_file, 'w') as f: - json.dump({ - 'prompts': prompts, - 'lengths': response - }, f) - - def restore_order(self, - gen_batch_output: DataProto, - reqs_idx, - n_samples, - ): - # the output is permutated by req scheduler - # this step store the original orders - # - bs = len(gen_batch_output) - assert bs % n_samples == 0, f'bs {bs} must be divisible by n_samples {n_samples}' - assert bs//n_samples == len(reqs_idx), f'bs//n_samples {bs//n_samples} != len(reqs_idx) {len(reqs_idx)}' - print(f"[ReqScheduler] restore_order, {bs=}, {n_samples=}, {len(reqs_idx)=}") - - # e.g. [1, 0] -> [16, 17, ..., 31, 0, 1, ... , 15] - cnt = 0 - global_idx = [None for _ in range(bs)] - group_idx = 0 - max_id = max(reqs_idx) - while group_idx <= max_id: - for i, idx in enumerate(reqs_idx): - if idx == group_idx: - start_position = i * n_samples - end_position = start_position + n_samples - global_idx[start_position: end_position] = [j for j in range(cnt, cnt+n_samples)] - cnt += n_samples - group_idx += 1 - - assert len(global_idx) == bs, f'len(global_idx) {len(global_idx)} != bs {bs}' - - global_idx = torch.tensor(global_idx) - gen_batch_output.reorder(global_idx) - - def sched(self, batch_dict: dict, - world_size: int, - config, - ): - print(f"[ReqScheduler] sched, {world_size=}, {config=}") - # get OUT len - outlens = [] - for raw_prompt_ids in batch_dict['raw_prompt_ids']: - outlen = self.lookup_table(raw_prompt_ids) - outlens.append(outlen) - - # sched - tp_size = config.rollout.tensor_model_parallel_size - assert world_size % tp_size == 0, f'world_size {world_size} must be divisible by tp_size {tp_size}' - dp_size = world_size // tp_size - res = self._sched(outlens, dp_size, tp_size) - - # idx -> dp group idx: - batch_dict['reqs_idx'] = res - batch_dict['outlens'] = np.array(outlens, dtype=np.int32) - # len(batch_dict['outlens']) = train_prompt_bs - # print(f"[ReqScheduler] calculate reqs_idx, outlens = {len(batch_dict['outlens'])}") - - - def print_stats(self, outlens, res): - longest = max(outlens) - shortest = min(outlens) - avg = np.mean(outlens) - std = np.std(outlens) - print(f"[ReqScheduler] Stats: {longest=}, {shortest=}, avg: {avg:.2f}, std: {std:.2f}") - num_group = np.unique(res) - group = [0 for _ in range(len(num_group))] - for v in res: - group[v] += 1 - print(f"[ReqScheduler] Group: {group}") - - def _sched(self, outlens, dp_size, tp_size): - algo = self.config.algo - - # if has None, the prompt is not in table - # so we use even_prompt - has_none = False - for outlen in outlens: - if outlen is None: - has_none = True - break - - agg = self.config.get('agg', 'mean') - if has_none: - print(f"[ReqScheduler] has None, reset {algo} to even_prompt; {agg=}") - algo = 'even_prompt' - - # so that print stats will not fail - for i in range(len(outlens)): - outlens[i] = -1 - else: - print(f"[ReqScheduler] algo: {algo}, {agg=}") - - # get method - method = getattr(self, algo) - res = method(outlens, dp_size, tp_size, self.config) - self.print_stats(outlens, res) - return res - - def dummy(self, outlens, dp_size, tp_size, config): - res = [0] * (len(outlens) - 1) + [1] - res = np.array(res, dtype=np.int32) - return res - - def even_prompt(self, outlens: list[int], dp_size, tp_size, config): - per_dp = len(outlens) // dp_size - res = [] - cnt = 0 - for i in range(0, len(outlens), per_dp): - for j in range(per_dp): - res.append(cnt) - cnt += 1 - return np.array(res, dtype=np.int32) - - # def even_token(self, outlens, dp_size, tp_size, config): - # total_num_token = sum(outlens) - # per_dp = total_num_token // dp_size - # res = [] - # group_idx = 0 - # cnt = 0 - # for i in range(0, len(outlens)): - # cnt += outlens[i] - # if cnt > per_dp: - # group_idx += 1 - # cnt = 0 - # res.append(group_idx) - # return np.array(res, dtype=np.int32) - - def even_token(self, outlens, dp_size, tp_size, config): - prompt_indices = list(range(len(outlens))) - sorted_pairs = sorted(zip(outlens, prompt_indices), reverse=True) - heap = [(0, i) for i in range(dp_size)] - heapq.heapify(heap) - res = [None] * len(outlens) - for token_len, orig_idx in sorted_pairs: - total, group = heapq.heappop(heap) - res[orig_idx] = group - heapq.heappush(heap, (total + token_len, group)) - return np.array(res, dtype=np.int32) - - def long_short(self, outlens, dp_size, tp_size, config): - p = np.percentile(outlens, config.percentile) - long = set() - for i in range(len(outlens)): - if outlens[i] > p: - long.add(i) - - # TODO assume only 1 long workers, the rest is short worker - # n_long_worker = dp_size//2 - # n_short_worker = dp_size - n_long_worker - global MODEL_DEPLOYMENT - if MODEL_DEPLOYMENT is None: - n_short_worker = dp_size-1 - else: - #n_short_worker = sum(MODEL_DEPLOYMENT) - MODEL_DEPLOYMENT[0] + 1 - n_short_worker = len(MODEL_DEPLOYMENT)-1 - - # 1. even_prompt for the rest: - #short_worker_cnt = 1 - #res = [] - #for i in range(len(outlens)): - # if i in long: - # # only one long worker - # res.append(0) - # else: - # # round-robin the rest prompts - # res.append(short_worker_cnt) - # short_worker_cnt += 1 - # if short_worker_cnt > n_short_worker: - # short_worker_cnt = 1 - - # 2. even_token for the rest - res = [None for _ in range(len(outlens))] - total_num_token_for_short = 0 - for i in range(len(outlens)): - if i in long: - # only one long worker - res[i] = 0 - else: - total_num_token_for_short += outlens[i] - - per_dp = total_num_token_for_short // n_short_worker + 1 - group_idx = 1 - cnt = 0 - for i in range(len(outlens)): - if i not in long: - res[i] = group_idx - cnt += outlens[i] - if cnt >= per_dp: - group_idx += 1 - cnt = 0 - - print(f"[ReqScheduler] p: {p}, {res=}") - return np.array(res, dtype=np.int32) - - - -############################ ############################ -############################ ############################ -############################ ############################ -############################ ############################ -############################ ############################ -############################ ############################ -############################ ############################ - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 - - -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - GAE = 'gae' - GRPO = 'grpo' - REINFORCE_PLUS_PLUS = 'reinforce_plus_plus' - REMAX = 'remax' - RLOO = 'rloo' - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - Mapping - """ - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, - RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items( - ): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, - use_gpu=True, - max_colocate_count=1, - name_prefix=resource_pool_name) - self.resource_pool_dict[resource_pool_name] = resource_pool - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}") - - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster" - ) - - -import torch -from verl.utils.torch_functional import masked_mean - - -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - responses = data.batch["responses"] - response_length = responses.size(1) - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - - # compute kl between ref_policy and current policy - # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty(data.batch["old_log_probs"], data.batch["ref_log_prob"], - kl_penalty=kl_penalty) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch["token_level_rewards"] = token_level_rewards - - metrics = { - "actor/reward_kl_penalty": current_kl, - "actor/reward_kl_penalty_coeff": beta, - } - - return data, metrics - - -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - # Back-compatible with trainers that do not compute response mask in fit - if "response_mask" not in data.batch.keys(): - data.batch["response_mask"] = compute_response_mask(data) - # prepare response group - # TODO: add other ways to estimate advantages - if adv_estimator == AdvantageEstimator.GAE: - values = data.batch["values"] - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - response_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.GRPO: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - gamma=gamma, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - response_mask=data.batch["response_mask"], - ) - - data.batch["advantages"] = advantages - data.batch["returns"] = returns - elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError - return data - - -def _compute_response_info(batch): - response_length = batch.batch['responses'].shape[-1] - - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] - - prompt_length = prompt_mask.sum(-1).float() - response_length = response_mask.sum(-1).float() # (batch_size,) - - return dict( - response_mask=response_mask, - prompt_length=prompt_length, - response_length=response_length, - ) - - -@contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): - with Timer(name=name, logger=None) as timer: - yield - if name not in timing_raw: - timing_raw[name] = 0 - timing_raw[name] += timer.last - - -class RayPPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__(self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None): - - # assert torch.cuda.is_available(), 'cuda must be available on driver' - - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, 'Currently, only support hybrid engine' - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - - # define KL control - if self.use_reference_policy: - if config.algorithm.kl_ctrl.type == 'fixed': - self.kl_ctrl = core_algos.FixedKLController( - kl_coef=config.algorithm.kl_ctrl.kl_coef) - elif config.algorithm.kl_ctrl.type == 'adaptive': - assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}' - self.kl_ctrl = core_algos.AdaptiveKLController( - init_kl_coef=config.algorithm.kl_ctrl.kl_coef, - target_kl=config.algorithm.kl_ctrl.target_kl, - horizon=config.algorithm.kl_ctrl.horizon) - else: - raise NotImplementedError - else: - self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - self.use_critic = True - elif self.config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, - AdvantageEstimator.REINFORCE_PLUS_PLUS, - AdvantageEstimator.REMAX, AdvantageEstimator.RLOO - ]: - self.use_critic = False - else: - raise NotImplementedError - - # gh512 - init Req Scheduler - self.req_scheduler = ReqScheduler( - config=self.config.req_scheduler, - ) - - self._validate_config() - self._create_dataloader() - - def _validate_config(self): - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, \ - f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.micro_batch_size' or " - f"'{name}.micro_batch_size_per_gpu'.") - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.micro_batch_size' AND " - f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " - f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated)." - ) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor") - - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref") - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout. - log_prob_micro_batch_size_per_gpu, "actor_rollout_ref.rollout") - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive( - config.critic.ppo_micro_batch_size, - config.critic.ppo_micro_batch_size_per_gpu, "critic") - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, - config.reward_model.micro_batch_size_per_gpu, "reward_model") - - # Actor - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - sp_size = config.actor_rollout_ref.actor.get( - 'ulysses_sequence_parallel_size', 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - sp_size = config.critic.get('ulysses_sequence_parallel_size', 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == 'fsdp': - if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \ - config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.actor_rollout_ref.model.use_remove_padding, \ - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - - if self.use_critic and config.critic.strategy == 'fsdp': - if config.critic.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.critic.model.use_remove_padding, \ - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - - if config.data.get('val_batch_size', None) is not None: - print( - f"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves." - ) - - print( - "[validate_config] All configuration checks passed successfully!") - - def _create_dataloader(self): - # TODO: we have to make sure the batch size is divisible by the dp size - #self.train_dataset = RLHFDataset( - # parquet_files=self.config.data.train_files, - # tokenizer=self.tokenizer, - # processor=self.processor, - # prompt_key=self.config.data.prompt_key, - # image_key=self.config.data.get('image_key', 'images'), - # max_prompt_length=self.config.data.max_prompt_length, - # filter_prompts=True, - # return_raw_chat=self.config.data.get('return_raw_chat', False), - # truncation='error', - # ) - self.train_dataset = RLHFDatasetFilter( - parquet_files=self.config.data.train_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get('image_key', 'images'), - filter_prompts=True, - return_raw_chat=self.config.data.get('return_raw_chat', False), - truncation='left', - # gh512 - filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False), - min_prompt_length=self.config.data.min_prompt_length, - max_prompt_length=self.config.data.max_prompt_length, - min_response_length=self.config.data.min_response_length, - max_response_length=self.config.data.max_response_length, - cap_dataset_size=self.config.data.get('cap_dataset_size', None), - req_scheduler=self.req_scheduler, - ) - - # use sampler for better ckpt resume - if self.config.data.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed( - self.config.data.get('seed', 1)) - sampler = RandomSampler(data_source=self.train_dataset, - generator=train_dataloader_generator) - else: - sampler = SequentialSampler(data_source=self.train_dataset) - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.train_batch_size, - num_workers=8, - drop_last=True, - collate_fn=collate_fn, - sampler=sampler) - - #self.val_dataset = RLHFDataset( - # parquet_files=self.config.data.val_files, - # tokenizer=self.tokenizer, - # processor=self.processor, - # prompt_key=self.config.data.prompt_key, - # image_key=self.config.data.get('image_key', 'images'), - # max_prompt_length=self.config.data.max_prompt_length, - # filter_prompts=True, - # return_raw_chat=self.config.data.get('return_raw_chat', False), - # truncation='error', - # ) - self.val_dataset = RLHFDatasetFilter( - parquet_files=self.config.data.val_files, - tokenizer=self.tokenizer, - processor=self.processor, - prompt_key=self.config.data.prompt_key, - image_key=self.config.data.get('image_key', 'images'), - max_prompt_length=self.config.data.max_prompt_length, - filter_prompts=True, - return_raw_chat=self.config.data.get('return_raw_chat', False), - truncation=self.config.data.get("truncation", "left"), - # xiaohui - filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False), - min_prompt_length=0, - min_response_length=0, - max_response_length=self.config.data.max_response_length * 10, - cap_dataset_size=None, - req_scheduler=self.req_scheduler, - ) - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - # Validation datasets are sent to inference engines as a whole batch, - # which will schedule the memory themselves. - batch_size=len(self.val_dataset), - num_workers=8, - shuffle=False, - drop_last=False, - collate_fn=collate_fn) - - assert len(self.train_dataloader) >= 1 - assert len( - self.val_dataloader - ) == 1, "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves." - - print(f'[RayPPOTrainer] in _create_dataloader, Size of train dataloader: {len(self.train_dataloader)}') - - # inject total_training_steps to actor/critic optim_config. This is hacky. - total_training_steps = len( - self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f'[RayPPOTrainer] in _create_dataloader, Total training steps: {self.total_training_steps}') - - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - self.config.critic.optim.total_training_steps = total_training_steps - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.log_val_generations - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - - def _validate(self): - reward_tensor_lst = [] - data_source_lst = [] - reward_extra_infos_dict: dict[str, list] = defaultdict(list) - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_scores = [] - - for test_data in self.val_dataloader: - # xiaohui: we need to schedule the val requests - self.req_scheduler.sched( - test_data, self.actor_rollout_wg.world_size, self.config.actor_rollout_ref, - ) - test_batch = DataProto.from_single_dict(test_data) - - # repeat test batch - test_batch = test_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, - interleave=True, - ) - - # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[ - 0].non_tensor_batch['reward_model']['style'] == 'model': - return {} - - # Store original inputs - input_ids = test_batch.batch['input_ids'] - input_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) - for ids in input_ids - ] - sample_inputs.extend(input_texts) - - if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys(): - test_gen_batch = test_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - non_tensor_batch_keys=[ - 'raw_prompt_ids', 'multi_modal_data', - 'multi_modal_inputs' - ], - ) - else: - test_gen_batch = test_batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - # 添加 resp 长度相关的信息 - non_tensor_batch_keys=['raw_prompt_ids', 'reqs_idx', 'outlens'], - ) - test_reqs_idx = test_gen_batch.non_tensor_batch['reqs_idx'] - test_gen_batch.meta_info = { - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, - 'recompute_log_prob': False, - 'do_sample': self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, - 'validate': True, - } - - print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") - # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( - test_gen_batch, self.actor_rollout_wg.world_size) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences( - test_gen_batch_padded) - self.req_scheduler.restore_order( - test_output_gen_batch_padded, - test_reqs_idx, - n_samples=self.config.actor_rollout_ref.rollout.val_kwargs.n - ) - - # unpad - test_output_gen_batch = unpad_dataproto( - test_output_gen_batch_padded, pad_size=pad_size) - - # Store generated outputs - output_ids = test_output_gen_batch.batch['responses'] - output_texts = [ - self.tokenizer.decode(ids, skip_special_tokens=True) - for ids in output_ids - ] - sample_outputs.extend(output_texts) - test_batch = test_batch.union(test_output_gen_batch) - # evaluate using reward_function - result = self.val_reward_fn(test_batch, return_dict=True) - reward_tensor = result["reward_tensor"] - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_extra_infos_dict["reward"].extend(scores) - if "reward_extra_info" in result: - for key, lst in result["reward_extra_info"].items(): - reward_extra_infos_dict[key].extend(lst) - - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) - - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - - for key_info, lst in reward_extra_infos_dict.items(): - assert len(lst) == 0 or len(lst) == len(sample_scores), (f"{key_info}: {len(lst)=}, {len(sample_scores)=}") - - data_sources = np.concatenate(data_source_lst, axis=0) - - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) - metric_dict = {} - for data_source, var2metric2val in data_src2var2metric2val.items(): - core_var = "acc" if "acc" in var2metric2val else "reward" - for var_name, metric2val in var2metric2val.items(): - n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) - for metric_name, metric_val in metric2val.items(): - if ((var_name == core_var) and - any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and - (f"@{n_max}" in metric_name)): - metric_sec = "val-core" - else: - metric_sec = "val-aux" - pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" - metric_dict[pfx] = metric_val - - return metric_dict - - def init_workers(self): - """Init resource pool and worker group""" - self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = { - pool: {} - for pool in self.resource_pool_manager.resource_pool_dict.values() - } - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool( - Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role='actor_rollout', - model_deployment=MODEL_DEPLOYMENT, - ) - self.resource_pool_to_cls[resource_pool][ - 'actor_rollout'] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool( - Role.Critic) - critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], - config=self.config.critic) - self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool( - Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool( - Role.RewardModel) - rm_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RewardModel], - config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - self.wg_dicts = [] - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - print(f'{resource_pool=} | {class_dict=}') - - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls( - class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) - - if self.use_critic: - self.critic_wg = all_wg['critic'] - self.critic_wg.init_model() - print("Critic model initialized.") - print("=" * 100) - if self.use_reference_policy: - self.ref_policy_wg = all_wg['ref'] - self.ref_policy_wg.init_model() - print("Reference policy initialized.") - print("=" * 100) - if self.use_rm: - self.rm_wg = all_wg['rm'] - self.rm_wg.init_model() - print("Reward model initialized.") - print("=" * 100) - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg['actor_rollout'] - self.actor_rollout_wg.init_model() - print("Actor rollout initialized.") - print("=" * 100) - - def _save_checkpoint(self): - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, - f'global_step_{self.global_steps}') - actor_local_path = os.path.join(local_global_step_folder, 'actor') - - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, - f'global_step_{self.global_steps}', 'actor') - self.actor_rollout_wg.save_checkpoint( - actor_local_path, - actor_remote_path, - self.global_steps, - #remove_previous_ckpt=self.config.trainer. - #remove_previous_ckpt_in_save) - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, - 'critic') - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, - f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint( - critic_local_path, - critic_remote_path, - self.global_steps, - #remove_previous_ckpt=self.config.trainer. - #remove_previous_ckpt_in_save) - ) - - # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, - 'data.pt') - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, - 'latest_checkpointed_iteration.txt') - with open(local_latest_checkpointed_iteration, 'w') as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == 'disable': - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - NotImplementedError('load from hdfs is not implemented yet') - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, - checkpoint_folder) - global_step_folder = find_latest_ckpt_path( - checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == 'auto': - if global_step_folder is None: - print('Training from scratch') - return 0 - else: - if not (self.config.trainer.resume_from_path - and global_step_folder is not None): - assert isinstance(self.config.trainer.resume_mode, - str), "resume ckpt must be str type" - assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps" - global_step_folder = self.config.trainer.resume_mode - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, - global_step_folder) - print(f'Load from checkpoint folder: {global_step_folder}') - # set global step - self.global_steps = int(global_step_folder.split('global_step_')[-1]) - - print(f'Setting global step to {self.global_steps}') - print(f'Resuming from {global_step_folder}') - - actor_path = os.path.join(global_step_folder, 'actor') - critic_path = os.path.join(global_step_folder, 'critic') - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, - del_local_after_load=self.config.trainer.del_local_ckpt_after_load) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, - del_local_after_load=self.config. - trainer.del_local_ckpt_after_load) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, 'data.pt') - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print( - f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch" - ) - - def _balance_batch(self, - batch: DataProto, - metrics, - logging_prefix='global_seqlen'): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch['attention_mask'] - batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch['attention_mask'].view( - batch_size, -1).sum(-1).tolist() # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor( - [j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, - partitions=global_partition_lst, - prefix=logging_prefix) - metrics.update(global_balance_stats) - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from verl.utils.tracking import Tracking - from omegaconf import OmegaConf - - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, - resolve=True)) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get( - 'val_before_train', True): - # XXX gh512 disable for now - val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get('val_only', False): - return - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - timings = [] - - for epoch in range(self.config.trainer.total_epochs): - print('*'*100) - print('='*100) - print(f"Epoch {epoch}: ") - - for bs_idx, batch_dict in enumerate(self.train_dataloader): - # gh512; add sched results - self.req_scheduler.sched(batch_dict, - self.actor_rollout_wg.world_size, - self.config.actor_rollout_ref, - ) - - metrics = {} - timing_raw = {} - # print(f'[BATCH] {len(batch_dict)} {batch_dict.keys()} {len(batch_dict["input_ids"])}') - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - if 'multi_modal_inputs' in batch.non_tensor_batch.keys(): - gen_batch = batch.pop( - batch_keys=[ - 'input_ids', 'attention_mask', 'position_ids' - ], - non_tensor_batch_keys=[ - 'raw_prompt_ids', 'multi_modal_data', - 'multi_modal_inputs' - ], - ) - else: - gen_batch = batch.pop( - batch_keys=['input_ids', 'attention_mask', 'position_ids'], - # 添加 resp 长度相关的信息 - non_tensor_batch_keys=['raw_prompt_ids', 'reqs_idx', 'outlens'], - ) - - is_last_step = self.global_steps >= self.total_training_steps - - # gh512: data examine - idx = gen_batch.batch['input_ids'] # (bs, prompt_length) - attention_mask = gen_batch.batch['attention_mask'] - position_ids = gen_batch.batch['position_ids'] - raw_prompt_ids = gen_batch.non_tensor_batch['raw_prompt_ids'] # (bs, varlen) - - # NOTE: we put raw_prompt_ids back to batch for repeated-interleave purpose and log seq len - # raw_prompt_ids 存储的是 prompts 的原始 token ids - batch.non_tensor_batch['raw_prompt_ids'] = raw_prompt_ids - reqs_idx = gen_batch.non_tensor_batch['reqs_idx'] - outlens = gen_batch.non_tensor_batch['outlens'] - # print(f'[BATCH INPUT]: reqs_idx = {reqs_idx[0]}, outlens = {len(outlens)}') - print( - f'[BATCH INPUT]: {idx.shape}, {attention_mask.shape}, {position_ids.shape}, {gen_batch.non_tensor_batch.keys()} {type(raw_prompt_ids)}' - ) - - with _timer('step', timing_raw): - # generate a batch - # 这里传入的 batch 是所有的数据,到具体 rank 上再做分配 - with _timer('gen', timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences( - gen_batch) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with _timer('gen_max', timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info['do_sample'] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch['reward_baselines'] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - with _timer('post_processing', timing_raw): - self.req_scheduler.restore_order(gen_batch_output, - reqs_idx, - self.config.actor_rollout_ref.rollout.n, - ) - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], - dtype=object, - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - - ##################################### - # gh512: data examine2 (union-ed) - seq = batch.batch['input_ids'] - response = batch.batch['responses'] - raw_prompt_ids = batch.non_tensor_batch['raw_prompt_ids'] - print(f'[BATCH OUTPUT]: {seq.shape}, {response.shape} {len(batch)} {batch.batch.keys()} {batch.non_tensor_batch.keys()}') - # gh512: log - pad_ids = self.actor_rollout_wg.get_tokenizer_pad_id() - model = self.config.actor_rollout_ref.model.path.split('/')[-1] - dataset = self.config.data.train_files[0].split('/')[-1] - prefix = f'{dataset}_{model}_E{epoch}B{bs_idx}_data' - unpadded = unpad_responses(response, pad_ids) - self.req_scheduler.log_seqlen( - raw_prompt_ids, - unpadded, - prefix, - ) - self.req_scheduler.update_table( - raw_prompt_ids, - unpadded, - ) - # gh512: data examine2 (union-ed) - ##################################### - - batch.batch["response_mask"] = compute_response_mask(batch) - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - - # recompute old_log_probs - with _timer('old_log_prob', timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with _timer('ref', timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with _timer('values', timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer('adv', timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - with suppress_stdout(): - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - else: - batch.batch['token_level_rewards'] = batch.batch[ - 'token_level_scores'] - - # compute advantages, executed on the driver process - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n) - - # update critic - if self.use_critic: - with _timer('update_critic', timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics( - critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer('update_actor', timing_raw): - actor_output = self.actor_rollout_wg.update_actor( - batch) - actor_output_metrics = reduce_metrics( - actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # validate - # XXX gh512 disable - if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ - (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): - with _timer('testing', timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and (is_last_step or - self.global_steps % self.config.trainer.save_freq == 0): - with _timer('save_checkpoint', timing_raw): - self._save_checkpoint() - - with _timer('collecting', timing_raw): - # collect metrics - metrics.update( - compute_data_metrics(batch=batch, - use_critic=self.use_critic)) - metrics.update( - compute_timing_metrics(batch=batch, - timing_raw=timing_raw)) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - timing_raw = defaultdict(float) # clear timing - - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f'Final validation metrics: {last_val_metrics}') - progress_bar.close() - return - progress_bar.update(1) - self.global_steps += 1 - # gh512 - # print _timer - print(f'{epoch=}: {bs_idx=}') - print(timing_raw) - print('*' * 100) - timings.append(timing_raw) - - - # print time - keys = timings[0].keys() - stats = {key: [] for key in keys} - for timing in timings: - for key in keys: - stats[key].append(timing[key]) - - print(f'timing: {len(timings)}') - for key in keys: - print(f'{key}: ') - print(f'{np.mean(stats[key])} - {np.std(stats[key])}') diff --git a/verl/utils/reward_score/__init__ copy.py b/verl/utils/reward_score/__init__ copy.py new file mode 100644 index 00000000000..ae86a665034 --- /dev/null +++ b/verl/utils/reward_score/__init__ copy.py @@ -0,0 +1,117 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from . import gsm8k, math, prime_math, prime_code + +from verl.utils.import_utils import deprecated +import numpy as np + +def default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None, memory_limit_mb=None): + """Compute the score for a given solution based on the data source. + + Args: + data_source (str): The source dataset identifier which determines the scoring method. + solution_str (str): The solution string to be evaluated. + ground_truth (str): The ground truth answer for comparison. + extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None. + + Returns: + float: The computed score as a floating point number. If the result is a dictionary, + it returns the dictionary instead. + + Raises: + NotImplementedError: If the reward function is not implemented for the given data source. + """ + if isinstance(data_source, (list, tuple, np.ndarray)): + if all(isinstance(ds, str) and (ds.startswith("math_judge") or ds.startswith("expert")) for ds in data_source): + from . import remote_reward_batch + return remote_reward_batch.compute_score_batched_from_response(data_source, solution_str, ground_truth, extra_info) + # return remote_reward_batch.compute_score_batched_from_response(data_source, solution_str, ground_truth, extra_info) + elif all(isinstance(ds, str) and (ds.startswith("dapo") or ds.startswith("train-math-numinamath")) for ds in data_source): + from . import math_dapo + print("use math_dapo") + # 批量调用 math_dapo + return [math_dapo.compute_score(s, g) for s, g in zip(solution_str, ground_truth)] + # from . import remote_reward_batch + # return remote_reward_batch.compute_score_batched(data_source, solution_str, ground_truth, extra_info) + elif all(isinstance(ds, str) and (ds.startswith("boxed")) for ds in + data_source): + from . import math_verify_boxed + # 批量调用 math_dapo + return [math_verify_boxed.compute_score(s, g) for s, g in zip(solution_str, ground_truth)] + else: + from . import remote_reward_batch + return remote_reward_batch.compute_score_batched(data_source, solution_str, ground_truth, extra_info) + + elif data_source.startswith("math_verify"): + from . import math_verify + res = math_verify.compute_score(solution_str, ground_truth) + elif data_source == "openai/gsm8k": + from . import gsm8k + res = gsm8k.compute_score(solution_str, ground_truth) + elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: + from . import math + + res = math.compute_score(solution_str, ground_truth) + # [Optional] Math-Verify Integration + # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). + # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`. + # To use it, override the `compute_score` function with the following implementation: + + # from . import math_verify + # res = math_verify.compute_score(solution_str, ground_truth) + elif data_source in [ + 'train-math-numinamath1.5_aops_forum', 'DeepScaleR_no_system', 'dapo_aime2025_s32_no_system', 'dapo_aime2024_s32_no_system', + 'train-math-numinamath1.5_aops_forum_int', 'train-math-numinamath1.5_aops_forum_total', + 'train-math-numinamath1.5_olympiads_int', 'train-math-numinamath1.5_olympiads_total', + 'gpqa_diamond' + ] or data_source.startswith("aime") or data_source.startswith("dapo"): + from . import math_dapo + res = math_dapo.compute_score(solution_str, ground_truth) + elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: + from . import prime_code + res = prime_code.compute_score(solution_str, ground_truth, continuous=True) + elif data_source in ['hiyouga/geometry3k']: + from . import geo3k + res = geo3k.compute_score(solution_str, ground_truth) + elif data_source in ['/nvfile-heatstorage/chatrl/users/hxh/data/rule_based_rl/math_train/reinforce_step150_wrong_answer/train_sample20_less_than_0d8.jsonl']: + from . import self_developed + res = self_developed.compute_score(solution_str, ground_truth) + elif data_source in ['kk_logic']: + from . import knight_and_knave + res = knight_and_knave.compute_score(solution_str, ground_truth) + elif data_source in ['count_down']: + from . import count_down + res = count_down.compute_score(solution_str, ground_truth) + else: + raise NotImplementedError(f"Reward function is not implemented for {data_source=}") + + if isinstance(res, dict): + return res + elif isinstance(res, (list, tuple)): + return res + elif isinstance(res, (int, float, bool)): + return float(res) + else: + return float(res[0]) + + +@deprecated("verl.utils.reward_score.default_compute_score") +def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None, memory_limit_mb=None): + """ + Legacy function API to be deprecated. Please use `default_compute_score` instead. + """ + return default_compute_score(data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb) + + +__all__ = ["default_compute_score"] diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 4248229697e..6418fa24e8c 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -14,7 +14,7 @@ # from . import gsm8k, math, prime_math, prime_code from verl.utils.import_utils import deprecated -import numpy as np + def default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None, memory_limit_mb=None): """Compute the score for a given solution based on the data source. @@ -32,29 +32,9 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No Raises: NotImplementedError: If the reward function is not implemented for the given data source. """ - if isinstance(data_source, (list, tuple, np.ndarray)): - if all(isinstance(ds, str) and (ds.startswith("math_judge") or ds.startswith("expert")) for ds in data_source): - from . import remote_reward_batch - return remote_reward_batch.compute_score_batched_from_response(data_source, solution_str, ground_truth, extra_info) - # return remote_reward_batch.compute_score_batched_from_response(data_source, solution_str, ground_truth, extra_info) - elif all(isinstance(ds, str) and (ds.startswith("dapo") or ds.startswith("train-math-numinamath")) for ds in data_source): - from . import math_dapo - print("use math_dapo") - # 批量调用 math_dapo - return [math_dapo.compute_score(s, g) for s, g in zip(solution_str, ground_truth)] - # from . import remote_reward_batch - # return remote_reward_batch.compute_score_batched(data_source, solution_str, ground_truth, extra_info) - elif all(isinstance(ds, str) and (ds.startswith("boxed")) for ds in - data_source): - from . import math_verify_boxed - # 批量调用 math_dapo - return [math_verify_boxed.compute_score(s, g) for s, g in zip(solution_str, ground_truth)] - else: - from . import remote_reward_batch - return remote_reward_batch.compute_score_batched(data_source, solution_str, ground_truth, extra_info) - - elif data_source == "openai/gsm8k": + if data_source == "openai/gsm8k": from . import gsm8k + res = gsm8k.compute_score(solution_str, ground_truth) elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: from . import math @@ -67,14 +47,20 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No # from . import math_verify # res = math_verify.compute_score(solution_str, ground_truth) - elif data_source in [ - 'train-math-numinamath1.5_aops_forum', 'DeepScaleR_no_system', 'dapo_aime2025_s32_no_system', 'dapo_aime2024_s32_no_system', - 'train-math-numinamath1.5_aops_forum_int', 'train-math-numinamath1.5_aops_forum_total', - 'train-math-numinamath1.5_olympiads_int', 'train-math-numinamath1.5_olympiads_total', - 'gpqa_diamond' - ] or data_source.startswith("aime") or data_source.startswith("dapo"): + elif data_source == 'deepmath_103k' or data_source.startswith("math_verify"): + # print("Using math_verify for scoring") + from . import math_verify + res = math_verify.compute_score(solution_str, ground_truth) + print(f"verl/verl/utils/reward_score/__init__ default_compute_score Math-Verify score: {res}") + elif data_source == 'math_dapo' or data_source.startswith("aime") or data_source.startswith("dapo_"): from . import math_dapo res = math_dapo.compute_score(solution_str, ground_truth) + elif data_source in [ + 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', + 'numina_olympiads' + ]: + from . import prime_math + res = prime_math.compute_score(solution_str, ground_truth) elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: from . import prime_code res = prime_code.compute_score(solution_str, ground_truth, continuous=True) @@ -95,8 +81,6 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No if isinstance(res, dict): return res - elif isinstance(res, (list, tuple)): - return res elif isinstance(res, (int, float, bool)): return float(res) else: @@ -111,4 +95,4 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N return default_compute_score(data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb) -__all__ = ["default_compute_score"] +__all__ = ["default_compute_score"] \ No newline at end of file diff --git a/verl/utils/reward_score/language_detect.py b/verl/utils/reward_score/language_detect.py new file mode 100644 index 00000000000..0cdfd309578 --- /dev/null +++ b/verl/utils/reward_score/language_detect.py @@ -0,0 +1,104 @@ +import re +from typing import Optional +from langdetect import detect_langs + +# 数学变量白名单 +MATH_WHITELIST = set(['x', 'y', 'z', 'a', 'b', 'c', 'd', 'n', 'm', 'k', 't', 's', 'r', 'p', 'q', 'f', 'g', 'h', 'i', 'j', 'l', 'u', 'v', 'w', + 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','(',')']) + +# 英文专有名词白名单(合并数学变量) +EN_WHITELIST = set([ + 'Answer' +]) | MATH_WHITELIST + +# 常见英文单词(去掉数学常用词) +COMMON_EN_WORDS = set([ + 'the', 'is', 'are', 'that', 'it', 'for', 'with', 'as', 'by', 'an', 'be', 'at' +]) + +def clean_text(text): + # 去除LaTeX公式 + text = re.sub(r'\$\$.*?\$\$|\$.*?\$', '', text, flags=re.DOTALL) + text = re.sub(r'\\\[.*?\\\]', '', text, flags=re.DOTALL) + text = re.sub(r'\\\(.*?\\\)', '', text, flags=re.DOTALL) + text = re.sub(r'\\begin\{([a-zA-Z*]+)\}.*?\\end{\1}', '', text, flags=re.DOTALL) + # 去除数字和常见数学符号 + text = re.sub(r'[0-9+\-*/^_=(){}\[\]<>≤≥≠∑∏∫√π∞±∈∩∪]', '', text) + # 去除数学变量和英文专有名词 + for w in EN_WHITELIST: + text = re.sub(r'\b{}\b'.format(re.escape(w)), '', text) + return text + +def is_english_sentence(text): + # 检查是否有完整英文句子(大写字母开头,句号结尾) + return bool(re.search(r'[A-Z][a-zA-Z ,;:\'\"\-()]+[\.!?]', text)) + +def count_ch_en(text): + ch_count = 0 + en_count = 0 + en_words = [] + for c in text: + if '\u4e00' <= c <= '\u9fff': + ch_count += 1 + elif 'a' <= c.lower() <= 'z': + en_count += 1 + # 统计英文单词 + words = re.findall(r'[a-zA-Z]+', text) + for w in words: + if w not in EN_WHITELIST: + en_words.append(w.lower()) + return ch_count, en_count, en_words + +def detect_language(text): + text_clean = clean_text(text) + if not text_clean.strip(): + return "unknown" + ch_count, en_count, en_words = count_ch_en(text_clean) + total = ch_count + en_count + en_word_set = set(en_words) + + score = 0 + + # 1. 英文字符比例 + if total > 0 and en_count / total > 0.5: + score += 1 + + # 2. 常见英文单词数量 + if len(en_word_set & COMMON_EN_WORDS) >= 8: + score += 1 + + # 3. 是否有完整英文句子 + if is_english_sentence(text_clean): + score += 1 + + # 4. langdetect辅助 + try: + lang_result = detect_langs(text_clean) + zh_prob = 0.0 + en_prob = 0.0 + for res in lang_result: + if 'zh' in res.lang: + zh_prob = res.prob + elif 'en' in res.lang: + en_prob = res.prob + if en_prob > 0.7: + score += 1 + elif zh_prob > 0.7: + score -= 1 + except: + pass + + # 根据分数判定初步语言 + if score >= 2: + # 初步判定为英文(包含纯英文和混合) + # 进一步区分纯英文和混合 + # 这里定义:如果中文字符占比小于10%,判为纯英文,否则混合 + if total == 0: + return "en" + ch_ratio = ch_count / total + if ch_ratio < 0.1: + return "en" + else: + return "mix" + else: + return 'zh' \ No newline at end of file diff --git a/verl/utils/reward_score/math_dapo.py b/verl/utils/reward_score/math_dapo.py index 56519c7e01a..04d5d02e8f3 100644 --- a/verl/utils/reward_score/math_dapo.py +++ b/verl/utils/reward_score/math_dapo.py @@ -15,6 +15,7 @@ import re from typing import Optional +from .language_detect import detect_language def last_boxed_only_string(string: str) -> Optional[str]: """Extract the last LaTeX boxed expression from a string. @@ -175,7 +176,18 @@ def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False """ # Extract answer from solution match = re.findall(answer_pattern, solution_str) - extracted_answer = match[-1] if match else "[INVALID]" + # extracted_answer = match[-1] if match else "[INVALID]" + + if match: + extracted_answer = match[-1] + else: + # 如果匹配不到,则尝试匹配 \boxed{...} + boxed_str = last_boxed_only_string(solution_str) + if boxed_str is not None: + extracted_answer = remove_boxed(boxed_str) + else: + extracted_answer = "[INVALID]" + pred = normalize_final_answer(extracted_answer) # Process ground truth @@ -249,19 +261,36 @@ def compute_score( Returns: Reward score (1.0 for correct, -1.0 for incorrect) """ + lang = detect_language(solution_str) # Limit solution length for efficiency solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters # Verify the solution correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + # if lang == 'mix': + # # 如果是英文,直接惩罚 + # return { + # "score": -1.0, + # "acc": correct, + # "pred": pred, + # } reward = 1.0 if correct else -1.0 - # acc = 1.0 if correct else 0.0 - if reward == 1.0: - acc = 1.0 - elif reward == -1.0: - acc = 0.0 - + acc = correct + # if lang == 'mix': + # if correct: + # reward = 0.2 + # else: + # reward = -1.0 + # else: + # if correct: + # reward = 1.0 + # else: + # reward = -0.8 + + acc = correct + return { "score": reward, "acc": acc, diff --git a/verl/utils/reward_score/math_dapo_before.py b/verl/utils/reward_score/math_dapo_before.py new file mode 100644 index 00000000000..56519c7e01a --- /dev/null +++ b/verl/utils/reward_score/math_dapo_before.py @@ -0,0 +1,269 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +from typing import Optional + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[: len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left) : -1] + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt), pred + + +def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100 :] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify(solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred + + +def compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, +) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else -1.0 + # acc = 1.0 if correct else 0.0 + if reward == 1.0: + acc = 1.0 + elif reward == -1.0: + acc = 0.0 + + return { + "score": reward, + "acc": acc, + "pred": pred, + } diff --git a/verl/utils/reward_score/math_verify.py b/verl/utils/reward_score/math_verify.py index f5c6f90f94b..e32b08edf8d 100644 --- a/verl/utils/reward_score/math_verify.py +++ b/verl/utils/reward_score/math_verify.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import Optional +from .language_detect import detect_language try: import sympy as sp @@ -23,6 +26,7 @@ print("To use Math-Verify, please install it first by running `pip install math-verify sympy`.") + def remove_unnecessary(s): '''去掉不必要的符号和框 ''' @@ -72,7 +76,49 @@ def parse_set(expr_str): parsed_elements.add(elem) return parsed_elements +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[: len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left) : -1] def compute_score(model_output: str, ground_truth: str) -> bool: # Limit solution length for efficiency @@ -83,6 +129,11 @@ def compute_score(model_output: str, ground_truth: str) -> bool: # dapo prompt answer_pattern=r"(?i)Answer\s*:\s*([^\n]+)" ) + # 如果匹配不到Answer,则尝试匹配最后一个\boxed{...} + if extracted_ans is None: + boxed_str = last_boxed_only_string(response_str) + if boxed_str is not None: + extracted_ans = remove_boxed(boxed_str) # 根据 box 里面的内容判断(如果有 boxed 的话) format_correct= -1.0 answer_correct = -1.0 @@ -112,7 +163,25 @@ def compute_score(model_output: str, ground_truth: str) -> bool: # print(f"[ground_truth] = \n{ground_truth}") # print(f"[format_correct] = {format_correct}, [answer_correct] = \n{answer_correct}") # print("--"*10) + + # lang = detect_language(model_output) + # if lang != 'mix' and answer_correct == 1.0: + # reward = 1.0 + # elif lang == 'mix' and answer_correct == 1.0: + # reward = 0.2 + # elif lang != 'mix' and answer_correct <= 0: + # reward = -0.8 + # else: # lang == 'mix' and answer_correct <= 0 + # reward = -1.0 + # correct = 0 if reward <= 0 else 1 + # acc = correct + # return { + # "score": reward, + # "acc": acc, + # "pred": "" if extracted_ans is None else extracted_ans, + # } + # correct 在 -1,1 之间 if format_correct < 0: reward = format_correct @@ -123,8 +192,16 @@ def compute_score(model_output: str, ground_truth: str) -> bool: # acc 在 0,1 之间 correct = 0 if reward <= 0 else reward acc = correct + # if lang == 'mix': + # # 如果是英文,直接惩罚 + # return { + # "score": -1.0, + # "acc": acc, + # "pred": "" if extracted_ans is None else extracted_ans, + # } + # print(reward, acc, extracted_ans) return { "score": reward, "acc": acc, "pred": "" if extracted_ans is None else extracted_ans, - } + } \ No newline at end of file diff --git a/verl/utils/reward_score/math_verify_before.py b/verl/utils/reward_score/math_verify_before.py new file mode 100644 index 00000000000..4faec38b331 --- /dev/null +++ b/verl/utils/reward_score/math_verify_before.py @@ -0,0 +1,136 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import sympy as sp + import re + from sympy.parsing.latex import parse_latex + from math_verify import parse, verify + from math_verify.metric import math_metric + from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig +except ImportError: + print("To use Math-Verify, please install it first by running `pip install math-verify sympy`.") + + +def remove_unnecessary(s): + '''去掉不必要的符号和框 + ''' + for pattern in [ + "^\\circ", + "\\$", "\$", "\\%", "\%", " ", + "tfrac", "dfrac", "^{\\circ}", + "\n", + "\\!" + ]: + s = s.replace(pattern, "") + for item in ["^\\text", "\\mbox{", "\\text{", "^{\\text{"]: + if len(s.split(item)) == 2: + s = s.split(item)[0] + return s + + +def match_answer_content(processed_str, answer_pattern = r'(.*?)'): + if answer_pattern is None: + return processed_str + + matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) + if not matches: + # print("verify not matches, return None") + # print("[Error] No valid answer tags found") + return None + final_answer = matches[-1].group(1).strip() + # print("verify matches, return final_answer") + # print(final_answer) + return final_answer + +def convert_to_standard_number(s): + try: + return str(float(s)) if '.' in s or 'e' in s.lower() else str(int(s)) + except ValueError: + return None + +def parse_set(expr_str): + elements = expr_str.split(',') + parsed_elements = set() + for elem in elements: + elem = elem.strip() + num = convert_to_standard_number(elem) + if num is not None: + parsed_elements.add(num) + else: + try: + parsed_elements.add(str(parse_latex(elem))) + except: + parsed_elements.add(elem) + return parsed_elements + + + +def compute_score(model_output: str, ground_truth: str) -> bool: + # Limit solution length for efficiency + # response_str = model_output[-300:] # The longest answer in MATH-500 has 159 characters + # last line + response_str = model_output.split("\n")[-1] + # 按照 pattern 抽取出答案的部分 + extracted_ans = match_answer_content( + response_str, + # dapo prompt + # answer_pattern=r"(?i)Answer\s*:\s*([^\n]+)" + answer_pattern=None + ) + # 根据 box 里面的内容判断(如果有 boxed 的话) + format_correct= -1.0 + answer_correct = -1.0 + if extracted_ans is not None: + format_correct = 1.0 + extracted_ans = remove_unnecessary(extracted_ans.strip()) + ground_truth = remove_unnecessary(ground_truth.strip()) + if len(extracted_ans) > 0: + try: + # math_verify判断 + answer_correct = 1.0 if verify(parse("$" + ground_truth + "$"), parse("$" + extracted_ans + "$")) else -1.0 + if answer_correct < 0: + ans_set = parse_set(extracted_ans) + gt_set = parse_set(ground_truth) + if ans_set == gt_set: + answer_correct = 1.0 + else: + ans_sympy = {sp.simplify(parse_latex(x)) for x in ans_set} + gt_sympy = {sp.simplify(parse_latex(x)) for x in gt_set} + if ans_sympy == gt_sympy: + answer_correct = 1.0 + except: + # print({"math_verify parse error": True, "extracted_ans": extracted_ans, "ground_truth": ground_truth}) + pass + + # print(f"[model_output] = \n{model_output}") + # print(f"[ground_truth] = \n{ground_truth}") + # print(f"[format_correct] = {format_correct}, [answer_correct] = \n{answer_correct}") + # print("--"*10) + + # correct 在 -1,1 之间 + if format_correct < 0: + reward = format_correct + else: + reward = answer_correct + + # correct 在 0,1 之间 + # acc 在 0,1 之间 + correct = 0 if reward <= 0 else reward + acc = correct + return { + "score": reward, + "acc": acc, + "pred": "" if extracted_ans is None else extracted_ans, + } diff --git a/verl/utils/reward_score/remote_reward_batch/__init__.py b/verl/utils/reward_score/remote_reward_batch/__init__.py index 8ebf36734f3..091aa01246a 100644 --- a/verl/utils/reward_score/remote_reward_batch/__init__.py +++ b/verl/utils/reward_score/remote_reward_batch/__init__.py @@ -47,6 +47,7 @@ def init_remote_reward(cfg: dict | None = None): print("----------------------------SAVE_JUDGE_PATH-----------------------------------") print(SAVE_JUDGE_PATH) + PROMPT_TEMPLATE_FROM_RESPONSE = '''I will now give you a question, an answer generated by an AI assistant, and the correct answer to the question. Please first extract the final answer stated by the AI assistant from its response, then compare it with the correct answer to decide whether they are consistent. Important: You do not need to compute or solve the problem yourself. Use the question text only as context if needed; base your decision on the AI assistant’s extracted final answer and the provided correct answer. @@ -82,25 +83,30 @@ def init_remote_reward(cfg: dict | None = None): Output: '''.strip() + IMPROVED_PROMPT_TEMPLATE = '''I will now give you a question, an answer generated by an AI assistant, and the correct answer to the question. Please compare the AI assistant's answer with the correct answer and determine whether they are consistent. Important: You do not need to compute or solve the problem yourself. Just judge whether the AI Assistant Answer matches the provided Ground Truth; use the question text only as context if needed. -- ---------------- [Question Start] ----------------- +----------------- [Question Start] ----------------- {problem} -- ---------------- [Question End] ----------------- -- ---------------- [AI Assistant Answer Start] ----------------- +----------------- [Question End] ----------------- + + +----------------- [AI Assistant Answer Start] ----------------- {pred} -- ---------------- [AI Assistant Answer End] ----------------- -- ---------------- [Ground Truth Start] ----------------- +----------------- [AI Assistant Answer End] ----------------- + + +----------------- [Correct Answer Start] ----------------- {ground_truth} -- ---------------- [Ground Truth End] ----------------- +----------------- [Correct Answer End] ----------------- Please structure your response in three lines: 1. Line 1: Analyze the AI assistant's answer @@ -143,10 +149,12 @@ def init_remote_reward(cfg: dict | None = None): -def save_to_jsonl(prompt, response): +def save_to_jsonl(prompt, response, prompt_tokens=-1, completion_tokens=-1): record = { "prompt": prompt, - "response": response + "response": response, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens } #print(record) with open(SAVE_JUDGE_PATH, 'a', encoding='utf-8') as f: @@ -244,10 +252,12 @@ def get_response_by_batch_generate(problems, preds, ground_truths): entry = resp_map.get(i) if entry and entry["responses"]: resp = entry["responses"][0] + prompt_tokens = entry.get("prompt_tokens", -1) + completion_tokens = entry.get("completion_tokens", -1) else: resp = None responses.append(resp) - save_to_jsonl(prompts[i], resp) + save_to_jsonl(prompts[i], resp, prompt_tokens, completion_tokens) return responses # except Exception as e: @@ -306,10 +316,13 @@ def get_response_by_batch_generate_from_response(problems, preds, ground_truths) entry = resp_map.get(i) if entry and entry["responses"]: resp = entry["responses"][0] + prompt_tokens = entry.get("prompt_tokens", -1) + completion_tokens = entry.get("completion_tokens", -1) else: resp = None responses.append(resp) - save_to_jsonl(prompts[i], resp) + save_to_jsonl(prompts[i], resp, prompt_tokens, completion_tokens) + return responses # except Exception as e: @@ -349,7 +362,8 @@ def compute_score(data_source, solution_str, ground_truth, extra_info): # problem = extra_info["question"] problem = extra_info - response_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + # response_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + response_str = solution_str # 按照 pattern 抽取出答案的部分 pred = match_answer_content( response_str, @@ -388,7 +402,8 @@ def compute_score_batched(data_sources, solution_strs, ground_truths, extra_info # from remote_reward import get_response_by_generate, compute_reward_result_tag problems = extra_infos # 提前截断 response,并抽取 pred - response_strs = [s[-300:] for s in solution_strs] + # response_strs = [s[-300:] for s in solution_strs] + response_strs = solution_strs # preds = [ # match_answer_content(resp, answer_pattern=r"(?i)Answer\s*:\s*([^\n]+)") # for resp in response_strs @@ -431,7 +446,8 @@ def compute_score_batched_from_response(data_sources, solution_strs, ground_trut # from remote_reward import get_response_by_generate, compute_reward_result_tag problems = extra_infos # 提前截断 response,并抽取 pred - response_strs = [s[-300:] for s in solution_strs] + # response_strs = [s[-300:] for s in solution_strs] + response_strs = solution_strs preds = response_strs diff --git a/verl/utils/reward_score/remote_reward_batch/api/base.py b/verl/utils/reward_score/remote_reward_batch/api/base.py index dc644285f9c..b5a1877492b 100644 --- a/verl/utils/reward_score/remote_reward_batch/api/base.py +++ b/verl/utils/reward_score/remote_reward_batch/api/base.py @@ -144,9 +144,13 @@ def batch_generate(self, batch_messages, **kwargs) -> List[List[str]]: xid, completion = result["id"], result["completions"] response_index = batch_messages[index].get("response_index", None) - responses[index] = {"id": xid, - "response_index": response_index, - "responses": [choice.message.content for choice in completion.choices]} # Store the response at the correct index + responses[index] = { + "id": xid, + "response_index": response_index, + "responses": [choice.message.content for choice in completion.choices], + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, + } # Store the response at the correct index # print(f"> result, responses[{index}] = ", responses[index]) except Exception as e: print(f">>> Exception while processing {model_name} API, completion: {completion}, error: {e}") diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index c75bb8e9fef..d6df785fbca 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -63,7 +63,7 @@ def _ensure_fresh_telechat_load(name_or_path, **kwargs): """Ensure a fresh loading of the telechat tokenizer to avoid caching issues.""" from transformers import AutoTokenizer - max_retries = 2 + max_retries = 5 last_error = None for attempt in range(max_retries): @@ -88,9 +88,15 @@ def _ensure_fresh_telechat_load(name_or_path, **kwargs): print("[VERL] All retries failed") break except Exception as e: + last_error = e # Other types of errors, do not retry print(f"[VERL] Load failed (non-retryable error): {e}") - raise + if attempt < max_retries - 1: + print("[VERL] Cleaning cache and retrying...") + continue + else: + print("[VERL] All retries failed") + break raise last_error diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index e2b887fd42b..2c207df595a 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -460,6 +460,7 @@ def update_policy(self, data: DataProto): clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff + entropy_max = self.config.get("entropy_max", None) loss_agg_mode = self.config.loss_agg_mode # all return: (bsz, response_length) @@ -490,7 +491,13 @@ def update_policy(self, data: DataProto): entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff + if entropy_coeff != 0: + entropy_value = entropy_loss.detach().item() + if entropy_max is not None and entropy_value > entropy_max: + logger.warning(f"Entropy loss {entropy_value} exceeds max {entropy_max}, clipping it.") + policy_loss = pg_loss + else: + policy_loss = pg_loss - entropy_loss * entropy_coeff else: policy_loss = pg_loss diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d9744e85af2..cb79e0c1fe1 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -22,6 +22,8 @@ from dataclasses import asdict from typing import Optional, Union import numpy as np +from datetime import datetime + import gc import psutil import torch @@ -232,7 +234,10 @@ def _build_model_optimizer( torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs - actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2") + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2" + ) # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": @@ -857,7 +862,7 @@ def generate_sequences(self, prompts: DataProto): actual_max = np.max(actual_outlen) actual_min = np.min(actual_outlen) - print(f"[GENTIME] {rank=}, {timing_generate['generate_sequences']:.2f}s; Sum: predict_totallens={predict_tsum}, pre_outlens={pre_osum}, insum={insum} ; Total: {predict_tlongest=}, {predict_tshortest=}, {predict_tavg=}, {predict_tstd=}; In: {inlongest=}, {inshortest=}, inavg={inavg:.0f}, instd={instd:.0f}; ACTUAL: {actual_sum=}, {actual_mean=}, {actual_max=}, {actual_min=}") + print(f"[GENTIME] {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, {rank=}, {timing_generate['generate_sequences']:.2f}s; Sum: predict_totallens={predict_tsum}, pre_outlens={pre_osum}, insum={insum} ; Total: {predict_tlongest=}, {predict_tshortest=}, {predict_tavg=}, {predict_tstd=}; In: {inlongest=}, {inshortest=}, inavg={inavg:.0f}, instd={instd:.0f}; ACTUAL: {actual_sum=}, {actual_mean=}, {actual_max=}, {actual_min=}") output = self.rollout_sharding_manager.postprocess_data(output) timing_generate.update(self.rollout_sharding_manager.timing) @@ -1613,6 +1618,7 @@ def _switch_chat_template(self, data: DataProto): valid_response_ids = response_ids[:valid_response_length] # decode + # @xiaohui: debug for decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos response = response.replace(src_tokenizer.eos_token, "") diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index e561c5c8768..3e7121bb297 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -80,8 +80,9 @@ def __call__(self, data: DataProto, return_dict: bool = False): valid_response_ids = response_ids[:valid_response_length] # decode - prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + # @xiaohui: keep special tokens + prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=False) + response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=False) eos_token = self.tokenizer.eos_token if response_str.endswith(eos_token): response_str = response_str[: -len(eos_token)] diff --git a/verl/workers/reward_manager/remote_batch.py b/verl/workers/reward_manager/remote_batch.py index be0a39dd208..6fd6e3d9cbd 100644 --- a/verl/workers/reward_manager/remote_batch.py +++ b/verl/workers/reward_manager/remote_batch.py @@ -24,10 +24,10 @@ from verl.workers.reward_manager import register import os import json - -save_num_examine_path_remote_batch = os.environ.get("SAVE_NUM_EXAMINE_PATH_REMOTE_BATCH", "/afs/chatrl/users/hwq/log/verl/logs_sensecore/save_num_examine_remote_batch_test_val.jsonl") # 允许默认路径 +from datetime import datetime +save_num_examine_path_remote_batch = os.environ.get("SAVE_NUM_EXAMINE_PATH_REMOTE_BATCH", "/afs/chatrl/users/hwq/log/verl/logs_sensecore/save_num_examine_remote_batch_test_val.jsonl") # 允许默认路径 @register("remote_batch") class REMOTEBatchRewardManager: @@ -71,12 +71,23 @@ def verify(self, data): prompt_len = prompt_ids.shape[-1] valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1) - responses_str = [] + # responses_str = [] + summary_str = [] for i in range(len(data)): valid_len = valid_response_lengths[i] valid_response_ids = response_ids[i][:valid_len] - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) - responses_str.append(response_str) + # response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=False) + # responses_str.append(response_str) + # @xiaohui: 提取 summary,只判断 summary + # 提取 summary 的逻辑:选择 valid_response_ids 里 [15, 16735, 21, 2298, 18] (<|start|>assistant<|channel|>final<|message|>) 之后的内容 + current_summary = "" + for i in range(len(valid_response_ids)): + if valid_response_ids[i] == 15 and i < len(valid_response_ids) - 4 and valid_response_ids[i + 1] == 16735 and valid_response_ids[i + 2] == 21 and valid_response_ids[i + 3] == 2298 and valid_response_ids[i + 4] == 18: + current_summary = self.tokenizer.decode(valid_response_ids[i + 5:], skip_special_tokens=True) + break + # print({"\t> valid_response_ids": valid_response_ids}) + # print({"\t> current_summary": current_summary}) + summary_str.append(current_summary) ground_truths = [item.non_tensor_batch["reward_model"].get("ground_truth", None) for item in data] data_sources = data.non_tensor_batch[self.reward_fn_key] @@ -86,13 +97,16 @@ def verify(self, data): prompt_key = self.remote_reward_cfg.get("prompt_key") prompts_into_extras = data.non_tensor_batch.get(prompt_key, [None] * len(data)) + print(f">>> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, [remote_batch.py] compute_score start...") scores = self.compute_score( data_source=data_sources, - solution_str=responses_str, + # solution_str=responses_str, + solution_str=summary_str, ground_truth=ground_truths, extra_info=prompts_into_extras, **self.reward_kwargs, ) + print(f">>> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, [remote_batch.py] compute_score end!!!") return scores @@ -132,7 +146,8 @@ def __call__(self, data: DataProto, return_dict=False): reward_tensor[i, length - 1] = reward data_source = data_sources[i] - response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=True) + # @xiaohui: keep special tokens + response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=False) prompt_str = self.tokenizer.decode(data.batch["prompts"][i], skip_special_tokens=True) ground_truth = data[i].non_tensor_batch["reward_model"].get("ground_truth", None) # 保存为 JSONL @@ -147,7 +162,6 @@ def __call__(self, data: DataProto, return_dict=False): fout.write(json.dumps(output_item, ensure_ascii=False) + "\n") if already_printed.get(data_source, 0) < self.num_examine: - # if already_printed.get(data_source, 0) < 500: # response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=True) # prompt_str = self.tokenizer.decode(data.batch["prompts"][i], skip_special_tokens=True) # ground_truth = data[i].non_tensor_batch["reward_model"].get("ground_truth", None)