grpo
GRPO (trl 库)
重要参数
- Num_generations: Number of generations to sample. The effective batch size (num_processes * per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value.
- generation_batch_size: Batch size to use for generation. If
None
, it defaults to the effective training batch size:per_device_train_batch_size * num_processes * steps_per_generation
. - steps_per_generation: Number of optimization steps per generation.
If
None
, it defaults to gradient_accumulation_steps. - Num_iterations: Number of iterations per batch (denoted as μ in the algorithm).
- Per_device_train_batch_size
- Num_processes (world_size)
trl 库的重要参数比较少。其中根据官方文档,generation_batch_size = `per_device_train_batch_size * num_processes * steps_per_generation Gradient_accumulation_steps 一般就是 steps_per_generation (对应 verl 中的 mini_batch_size / n_gpus / ppo_micro_batch_size_per_gpu),可以理解为 per_device_train_bs (对应 verl 中的 ppo_micro_batch_size_per_gpu) 是使用梯度累计后的 bs,乘 gpu 数,再乘梯度累计的 steps 就是总的 batch_size(对应 verl 中的 train_batch_size * rollout. N)。所以注意,总的 batch_size (generation_batch_size) 是已经 rollout 采样后的 bs,除以 num_generations 才是针对 prompts 的 bs(verl 中的 train_batch_size)。 下面是_get_train_sampler 方法的注释,对每一个 prompt 重复 num_generations 是该方法实现的。
if dataset is None:
= self.train_dataset
dataset return RepeatSampler(
=dataset,
data_source=self.num_generations, # 每个 prompt 生成 self.num_generations 个 completions
mini_repeat_count# 例如,如果 per_device_train_batch_size=8, num_generations=2, steps_per_generation=4,
# 则 generation_batch_size = 8 (per_device_train_batch_size) * 4 (steps_per_generation) = 32
# 这里的 batch_size = 32 / 2 = 16,表示一个 "generation block" 中有16个不同的prompt。
=self.args.generation_batch_size // self.num_generations,
batch_size# 每个 "generation block" (包含16个不同prompt,每个prompt有2个completion) 会被用于 num_iterations * steps_per_generation 次更新
# 例如 num_iterations=1, steps_per_generation=4, 则这个 block 会被重复 1*4=4 次,每次取出一个 per_device_train_batch_size 的数据进行训练
=self.num_iterations * self.args.steps_per_generation,
repeat_count=self.shuffle_dataset,
shuffle=self.args.seed,
seed )
结合下面的例子帮助理解,例子中梯度累计 steps 不等于
steps_per_generation
在 GRPO_trainer 中,最重要的方法是
_generate_and_score_completions
方法,输入为
input,输出为计算得到的优势值和 old_logp 用于计算
ratio。一些核心的部分和注释如下:
with unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
as unwrapped_model:
) with (
self.model_wrapped, recurse=False)
FSDP.summon_full_params(if self.is_fsdp_enabled
else nullcontext()
):# prompt_ids: (B_gen_local, P_max)
# prompt_mask: (B_gen_local, P_max)
# prompt_completion_ids: torch.Tensor (B_gen_local, P_max + C_new), C_new 是 HF generate 生成的新 token 数量 (最大为 max_completion_length)
= unwrapped_model.generate(
prompt_completion_ids =prompt_mask, generation_config=self.generation_config
prompt_ids, attention_mask
)
# Compute prompt length and extract completion ids
= prompt_ids.size(1) # P_max
prompt_length # prompt_ids 保持不变: (B_gen_local, P_max)
= prompt_completion_ids[:, :prompt_length]
prompt_ids # completion_ids: torch.Tensor (B_gen_local, C_new_hf)
= prompt_completion_ids[:, prompt_length:] completion_ids
上面为 generate 的过程,不过现在基本上使用 vllm 或者 sglang 加速推理。为了逻辑简单,这里展示了 HF generate 的过程。Trl 实现的时候,将一个 prompt 采样多次的逻辑实现在了 get_train_dataloader 方法中,即一开始就使用 get_train_sampler 方法对同一个 prompt repeat 了多次。因此这里不需要再进行 repeat。 之后得到补充部分的 mask:
# Mask everything after the first EOS token
# is_eos: torch.Tensor (B_gen_local, C_new), C_new 是 completion 的实际长度 (C_max_vllm 或 C_new_hf)
= completion_ids == self.processing_class.eos_token_id
is_eos # eos_idx: torch.Tensor (B_gen_local,), 存储每个 completion 中第一个 EOS token 的索引,如果没有EOS则为序列长度
= torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
eos_idx[is_eos.= torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
sequence_indices # completion_mask: torch.Tensor (B_gen_local, C_new), 标记有效 token (EOS之前及EOS本身)
= (sequence_indices <= eos_idx.unsqueeze(1)).int()
completion_mask
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
# completion_ids_list: list[list[int]], 长度 B_gen_local, 移除了 padding 和 EOS 之后的 token
= [
completion_ids_list id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
[ ]
然后根据 generate 的 ids 得到 old_logp:
with torch.no_grad():
# When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
# old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
# per_token_logps.detach() instead.
if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
# old_per_token_logps: torch.Tensor (B_gen_local, C_new), 代表生成这些 completion 时所用模型的 log probabilities
= self._get_per_token_logps(
old_per_token_logps self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
)else:
= None # 在特定条件下,可以用当前模型的 logprobs.detach() 代替,以节省计算 old_per_token_logps
在上一步 generate 的时候我们不是已经进行过完整 batch 的推理了么?为什么现在还要重复进行一次 forward 来计算 log_prob,而不是在 generate 的过程中就把 log_prob 保存下来?
因为 forward 的时候和 generate 的时候 logprob 由于推理引擎和训练引擎的优化目标不一样,会造成两者对不上,因此需要做两次。 Batch 算子的细微差异,都会造成这两个 log_prob 不完全一致。推理引擎要的是快速出 token id,训练引擎需要保证一定的 log_prob 精度。
注意这里的很关键的一点是如果符合分支条件将 old_logp 设置成了
None,那么后续计算 ratio 时就固定为 1(old_logps = logps. Detach
)。如果 num_iterations > 1,说明一批数据会被训练多次,ratio
就不固定为 1 了。所以要保存生成训练数据的那个模型对应的
logps。Steps_per_generation > ga_steps 也一样,因为
steps_per_generation 参数就代表一批数据训练多少次。Ga_steps 更新 actor
参数,这之后 ratio 就不为 1 了。一共要经历 steps_per_gen / ga_steps
参数更新。如果我们不设置 steps_per_generation 默认就是
ga_steps,这里还是看这个图就可以理解了:
然后根据定义的奖励函数计算 reward:
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):with profiling_context(self, reward_func_name):
if isinstance(reward_func, nn.Module):
# GRPO一般不需要这部分
else:
# 自定义奖励函数
= reward_func(
output_reward_func =prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
prompts
)# Convert None values to NaN
= [reward if reward is not None else torch.nan for reward in output_reward_func]
output_reward_func # rewards_per_func[:, i]: torch.Tensor (B_gen_local,)
= torch.tensor(output_reward_func, dtype=torch.float32, device=device) rewards_per_func[:, i]
得到了奖励,就可以计算组内优势了:
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
= gather(rewards_per_func) # (N_proc * B_gen_local, num_reward_funcs)
rewards_per_func
# Apply weights to each reward function's output and sum
= (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) # (N_proc * B_gen_local,)
rewards
# Compute grouped-wise rewards
= rewards.view(-1, self.num_generations).mean(dim=1) # (N_groups_total, G) -> (N_groups_total,)
mean_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) # (N_groups_total, G) -> (N_groups_total,)
std_grouped_rewards = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards)) # (N_groups_total,)
is_std_zero
# Normalize the rewards to compute the advantages
= mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (N_groups_total * G,),即(N_proc * B_gen_local)
mean_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (N_groups_total * G,),即(N_proc * B_gen_local)
std_grouped_rewards = rewards - mean_grouped_rewards advantages
这样 grpo 训练所需要的 experience 就生产好了。下面进入训练阶段,计算 kl 散度:
if self.beta != 0.0: # 仅当 beta 不为0时才需要计算 KL 散度
with torch.no_grad():
if self.ref_model is not None:
= self._get_per_token_logps(
ref_per_token_logps self.ref_model, input_ids, attention_mask, logits_to_keep
)else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
= self._get_per_token_logps(
ref_per_token_logps self.model, input_ids, attention_mask, logits_to_keep
)# per_token_kl: torch.Tensor (B_eff, C_new), 每个 token 的 KL 散度
# KL(P || Q) = sum P(x) log(P(x)/Q(x)) 的一种估计形式
# 使用 exp(log P - log Q) - (log P - log Q) - 1 来避免直接计算 P/Q 可能的数值不稳定
= (
per_token_kl - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
torch.exp(ref_per_token_logps # k3 )
核心计算部分:
# Compute the loss
= inputs["advantages"] # torch.Tensor (B_eff,)
advantages # old_per_token_logps: torch.Tensor (B_eff, C_new), 旧策略(生成数据时)的对数概率
= (
old_per_token_logps if inputs["old_per_token_logps"] is None else inputs["old_per_token_logps"]
per_token_logps.detach()
)# coef_1 (r_t(θ)): torch.Tensor (B_eff, C_new), 概率比率 exp(log_probs_new - log_probs_old)
= torch.exp(per_token_logps - old_per_token_logps)
coef_1 # coef_2 (clipped r_t(θ)): torch.Tensor (B_eff, C_new), 裁剪后的概率比率
= torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
coef_2
# Two-sided clipping
if self.args.delta is not None: # GRPO 论文中的 δ 参数,用于额外限制概率比率的上限
= torch.clamp(coef_1, max=self.args.delta)
coef_1
# per_token_loss1: torch.Tensor (B_eff, C_new), PPO 目标的第一项 r_t(θ) * A_t
= coef_1 * advantages.unsqueeze(1)
per_token_loss1 # per_token_loss2: torch.Tensor (B_eff, C_new), PPO 目标的第二项 clip(r_t(θ), 1-ε, 1+ε) * A_t
= coef_2 * advantages.unsqueeze(1)
per_token_loss2 # per_token_loss: torch.Tensor (B_eff, C_new), PPO 损失的 surrogate 部分 -min(loss1, loss2)
= -torch.min(per_token_loss1, per_token_loss2)
per_token_loss if self.beta != 0.0:
# 如果 beta 不为0,则加入 KL 散度惩罚项
= per_token_loss + self.beta * per_token_kl per_token_loss
这样,我们就得到了补全部分的每一个有效 token 的损失。这次还可以加入
entropy loss,指策略分布的熵
(Entropy):策略对选择下一个动作(在这里是下一个
token)的不确定性程度。熵越高,表示策略输出的概率分布越均匀,选择各个动作的概率越接近,策略的探索性越强;熵越低,表示策略越倾向于选择少数几个高概率的动作,确定性越强。
entropy_loss
指 entropy
的平均值,是一个标量,表示探索性高低。 得到 token_loss
后根据不同的方法计算 batch 损失:
if self.loss_type == "grpo":
# GRPO 论文中的标准损失:对每个序列的 token 损失求和后取平均,然后再对批次取平均
# (sum_t (L_t * mask_t) / sum_t mask_t).mean_batch()
= ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
loss elif self.loss_type == "bnpo":
# BNPO (Batch Normalized Policy Optimization) 损失:对所有 token 的损失求和后,除以所有有效 token 的总数
# sum_batch sum_t (L_t * mask_t) / sum_batch sum_t mask_t
= (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
loss elif self.loss_type == "dr_grpo":
# DR-GRPO (Dense Reward GRPO) 损失:对所有 token 的损失求和后,除以 (批次大小 * 最大完成长度)
# sum_batch sum_t (L_t * mask_t) / (B_eff * C_max)
= (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
loss else:
raise ValueError(f"Unknown loss type: {self.loss_type}")