RLHF(三):基于TRL的GrpoTrainer详解
写在前面:目前主流的LLM post-training框架主要有trl, OpenRLHF, verl。后两者集成度较高,适合对LLM零代码训练,而trl灵活性较强,这里主要对GRPO Trainer的训练流程进行梳理
GRPOTrainer类
它继承了transformers.Trainer,并重写或拓展了若干方法,包括:
init
作用:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。
- model: 加载策略模型, 可以是字符串(模型ID或路径)或预训练模型对象。仅支持因果语言模型
- reward_funcs: 加载奖励函数,可以是预训练模型(仅支持
SequenceClassification
模型);用户自定义Python函数;或者是一个列表,意味着多种奖励函数一起用 - args: GRPOConfig对象,包含训练的所有参数
- train_dataset: 训练数据集,必须包含名为’prompt’的列,可以是Dataset或IterableDataset
- eval_dataset: 评估数据集
- processing_class: 数据处理器类,用于对训练和评估数据进行预处理 typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
- reward_processing_class: 奖励函数对应的分词器,支持单个分词器或多个分词器的列表
- callback: 自定义训练回调列表,可扩展或覆盖默认的训练过程 typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None
- optimizer
- peft_config
补充说明:
- processing_class 填充侧必须设置为 “left”。如果为 None,则使用 from_pretrained 从模型名称加载处理类。
- reward_processing_class: 可选,默认为None。在自定义时,必须与 reward_funcs 中奖励函数的顺序和长度匹配。
_prepare_inputs
作用:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
简要流程:
- 对batch中每个prompt调用模型一次性生成
num_generations
条回答。若生成中提前出现EOS,对EOS之后的token使用completion_mask进行掩码 - 使用ref_model对完整序列(prompt+completion) 计算token级对数概率,用于后面进行KL
- 调用reward model对每条回答打分,形成[B*G]的reward
- 计算相对优势:[B*G]-reshape->[B,G], 对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage
全部代码:具体细节:https://blog.csdn.net/shizheng_Li/article/details/1457949491
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_ids = torch.repeat_interleave(prompt_ids, self.num_generations, dim=0)
prompt_mask = torch.repeat_interleave(prompt_mask, self.num_generations, dim=0)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Log the metrics
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
compute_loss
作用:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
简要流程:
- 在当前策略下计算完整序列的token级对数概率
- 根据actor model和ref model的token_log_prob,计算KL散度
- 利用KL散度和
_prepare_inputs
中得到的相对优势计算GRPO Loss,然后反向传播进行梯度更新
全部代码:具体细节:https://blog.csdn.net/shizheng_Li/article/details/1457930701
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss
prediction_step
作用:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。
log 与 create_model_card
作用:日志与模型卡,可上传到 Hugging Face Hub 做模型管理。
create_model_card输入参数:
- model_name:模型的名称
- dataset_name:用于训练的数据集的名称
- tags:要与模型卡关联的标签
自定义奖励函数
GRPOTrainer 支持使用自定义奖励函数来代替密集奖励模型。为了确保兼容性,您的奖励函数必须满足以下要求:
- 输入参数
函数必须接受以下内容作为关键字参数:- prompts(包含提示),
- completions(包含生成的补全),
- 数据集可能包含的所有列名(但prompt除外)。例如,如果数据集包含名为ground_truth的列,则将使用ground_truth作为关键字参数调用该函数。满足此要求的最简单方法是在函数签名中使用**kwargs。
根据数据集格式,输入会有所不同:
- 对于标准格式,prompts和completions将是字符串列表。
- 对于对话格式,prompts和completions将是消息字典列表。
- 返回值
函数必须返回一个浮点数列表。每个浮点数代表对应于单个补全的奖励。
示例一: 奖励较长的补全
1 |
|
示例二:奖励具有特定格式的补全
1 |
|
写在最后: 详见https://huggingface.co/docs/trl/main/en/grpo_trainer