RLHF(三):基于TRL的GrpoTrainer详解

写在前面:目前主流的LLM post-training框架主要有trl, OpenRLHF, verl。后两者集成度较高,适合对LLM零代码训练,而trl灵活性较强,这里主要对GRPO Trainer的训练流程进行梳理

GRPOTrainer类

它继承了transformers.Trainer,并重写或拓展了若干方法,包括:
init
作用:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。

  • model: 加载策略模型, 可以是字符串(模型ID或路径)或预训练模型对象。仅支持因果语言模型
  • reward_funcs: 加载奖励函数,可以是预训练模型(仅支持SequenceClassification模型);用户自定义Python函数;或者是一个列表,意味着多种奖励函数一起用
  • args: GRPOConfig对象,包含训练的所有参数
  • train_dataset: 训练数据集,必须包含名为’prompt’的列,可以是Dataset或IterableDataset
  • eval_dataset: 评估数据集
  • processing_class: 数据处理器类,用于对训练和评估数据进行预处理 typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
  • reward_processing_class: 奖励函数对应的分词器,支持单个分词器或多个分词器的列表
  • callback: 自定义训练回调列表,可扩展或覆盖默认的训练过程 typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None
  • optimizer
  • peft_config

补充说明:

  1. processing_class 填充侧必须设置为 “left”。如果为 None,则使用 from_pretrained 从模型名称加载处理类。
  2. reward_processing_class: 可选,默认为None。在自定义时,必须与 reward_funcs 中奖励函数的顺序和长度匹配。

_prepare_inputs
作用:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
简要流程:

  1. 对batch中每个prompt调用模型一次性生成 num_generations条回答。若生成中提前出现EOS,对EOS之后的token使用completion_mask进行掩码
  2. 使用ref_model对完整序列(prompt+completion) 计算token级对数概率,用于后面进行KL
  3. 调用reward model对每条回答打分,形成[B*G]的reward
  4. 计算相对优势:[B*G]-reshape->[B,G], 对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage
    全部代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
    device = self.accelerator.device
    prompts = [x["prompt"] for x in inputs]
    prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
    prompt_inputs = self.processing_class(
    prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
    )
    prompt_inputs = super()._prepare_inputs(prompt_inputs)
    prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

    if self.max_prompt_length is not None:
    prompt_ids = prompt_ids[:, -self.max_prompt_length :]
    prompt_mask = prompt_mask[:, -self.max_prompt_length :]

    # Generate completions using either vLLM or regular generation
    if self.args.use_vllm:
    # First, have main process load weights if needed
    if self.state.global_step != self._last_loaded_step:
    with unwrap_model_for_generation(
    self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
    ) as unwrapped_model:
    if is_compiled_module(unwrapped_model):
    state_dict = unwrapped_model._orig_mod.state_dict()
    else:
    state_dict = unwrapped_model.state_dict()
    if self.accelerator.is_main_process:
    llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
    llm_model.load_weights(state_dict.items())
    self._last_loaded_step = self.state.global_step

    # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
    all_prompts_text = gather_object(prompts_text)
    if self.accelerator.is_main_process:
    outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
    completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
    else:
    completion_ids = [None] * len(all_prompts_text) * self.num_generations

    # Broadcast the completions from the main process to all processes, ensuring each process receives its
    # corresponding slice.
    completion_ids = broadcast_object_list(completion_ids, from_process=0)
    process_slice = slice(
    self.accelerator.process_index * len(prompts) * self.num_generations,
    (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
    )
    completion_ids = completion_ids[process_slice]

    # Pad the completions, and concatenate them with the prompts
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
    completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
    prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
    prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
    prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    else:
    # Regular generation path
    with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
    prompt_completion_ids = unwrapped_model.generate(
    prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
    )

    # Compute prompt length and extract completion ids
    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_completion_ids[:, :prompt_length]
    completion_ids = prompt_completion_ids[:, prompt_length:]
    prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)

    # Mask everything after the first EOS token
    is_eos = completion_ids == self.processing_class.eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
    sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
    completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

    # Concatenate prompt_mask with completion_mask for logit computation
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)

    logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

    with torch.inference_mode():
    if self.ref_model is not None:
    ref_per_token_logps = self._get_per_token_logps(
    self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
    )
    else:
    with self.accelerator.unwrap_model(self.model).disable_adapter():
    ref_per_token_logps = self._get_per_token_logps(
    self.model, prompt_completion_ids, attention_mask, logits_to_keep
    )

    # Decode the generated completions
    completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
    if is_conversational(inputs[0]):
    completions = [[{"role": "assistant", "content": completion}] for completion in completions]

    # Compute the rewards
    prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts

    rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
    for i, (reward_func, reward_processing_class) in enumerate(
    zip(self.reward_funcs, self.reward_processing_classes)
    ):
    if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
    if is_conversational(inputs[0]):
    messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
    texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
    else:
    texts = [p + c for p, c in zip(prompts, completions)]
    reward_inputs = reward_processing_class(
    texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
    )
    reward_inputs = super()._prepare_inputs(reward_inputs)
    with torch.inference_mode():
    rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
    else:
    # Repeat all input columns (but "prompt" and "completion") to match the number of generations
    reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
    for key in reward_kwargs:
    for example in inputs:
    # Repeat each value in the column for `num_generations` times
    reward_kwargs[key].extend([example[key]] * self.num_generations)
    output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
    rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

    # Sum the rewards from all reward functions
    rewards = rewards_per_func.sum(dim=1)

    # Compute grouped-wise rewards
    mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
    std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

    # Normalize the rewards to compute the advantages
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

    # Log the metrics
    reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
    for i, reward_func in enumerate(self.reward_funcs):
    if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
    reward_func_name = reward_func.config._name_or_path.split("/")[-1]
    else:
    reward_func_name = reward_func.__name__
    self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

    self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
    self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

    return {
    "prompt_ids": prompt_ids,
    "prompt_mask": prompt_mask,
    "completion_ids": completion_ids,
    "completion_mask": completion_mask,
    "ref_per_token_logps": ref_per_token_logps,
    "advantages": advantages,
    }
    具体细节:https://blog.csdn.net/shizheng_Li/article/details/145794949

compute_loss
作用:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
简要流程:

  1. 在当前策略下计算完整序列的token级对数概率
  2. 根据actor model和ref model的token_log_prob,计算KL散度
  3. 利用KL散度和_prepare_inputs中得到的相对优势计算GRPO Loss,然后反向传播进行梯度更新
    全部代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    if return_outputs:
    raise ValueError("The GRPOTrainer does not support returning outputs")
    # Compute the per-token log probabilities for the model

    prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
    completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
    logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

    per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

    # Compute the KL divergence between the model and the reference model
    ref_per_token_logps = inputs["ref_per_token_logps"]
    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

    # x - x.detach() allows for preserving gradients from x
    advantages = inputs["advantages"]
    per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
    per_token_loss = -(per_token_loss - self.beta * per_token_kl)
    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

    # Log the metrics
    completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
    self._metrics["completion_length"].append(completion_length)

    mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

    return loss
    具体细节:https://blog.csdn.net/shizheng_Li/article/details/145793070

prediction_step
作用:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。

log 与 create_model_card
作用:日志与模型卡,可上传到 Hugging Face Hub 做模型管理。
create_model_card输入参数:

  • model_name:模型的名称
  • dataset_name:用于训练的数据集的名称
  • tags:要与模型卡关联的标签

自定义奖励函数

GRPOTrainer 支持使用自定义奖励函数来代替密集奖励模型。为了确保兼容性,您的奖励函数必须满足以下要求:

  1. 输入参数
    函数必须接受以下内容作为关键字参数:
    • prompts(包含提示),
    • completions(包含生成的补全),
    • 数据集可能包含的所有列名(但prompt除外)。例如,如果数据集包含名为ground_truth的列,则将使用ground_truth作为关键字参数调用该函数。满足此要求的最简单方法是在函数签名中使用**kwargs。

根据数据集格式,输入会有所不同:
- 对于标准格式,prompts和completions将是字符串列表。
- 对于对话格式,prompts和completions将是消息字典列表。

  1. 返回值
    函数必须返回一个浮点数列表。每个浮点数代表对应于单个补全的奖励。

示例一: 奖励较长的补全

1
2
3
def reward_func(completions, **kwargs):
"""奖励函数,对较长的补全给予更高的分数。"""
return [float(len(completion)) for completion in completions]

示例二:奖励具有特定格式的补全

1
2
3
4
5
6
7
8
import re

def format_reward_func(completions, **kwargs):
"""奖励函数,检查补全是否具有特定格式。"""
pattern = r"^&lt;think&gt;.*?&lt;/think&gt;&lt;answer&gt;.*?&lt;/answer&gt;$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]

写在最后: 详见https://huggingface.co/docs/trl/main/en/grpo_trainer


RLHF(三):基于TRL的GrpoTrainer详解
http://example.com/2025/04/19/RLHF(三)/
作者
Peter
发布于
2025年4月19日
许可协议