Lora Adapter调试跟踪

打算花点时间看看在peft库中lora是怎么注入base model的,这里简单总结下:

首先写个测试程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from peft import LoraModel, LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer


model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf', torch_dtype=torch.float16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules=["gate_proj","up_proj","q_proj","down_proj","o_proj","k_proj","v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = LoraModel(model, lora_config, "default") # 在这里打个断点
model.print_trainable_parameters()
for name, param in model.named_parameters():
print(name, param)
PYTHON

跟进去调试,LoraModel是基于BaseTuner类实现的子类,BaseTuner的init函数中有一个inject_adapter方法,该方法实现了如何将lora中的target module与base model中的module进行替换

1
2
3
4
elf.active_adapter: str | list[str] = adapter_name
self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA:
self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
PYTHON

跟进inject_adapter,它首先会收集base model的所有named modules组成一个key list,然后进入循环,如果key在peft_config中定义的target_modules中,则调用_create_and_replace方法(@abstractmethod),该方法在子类中进行实现。下面重点介绍下该方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _create_and_replace(
self,
lora_config,
adapter_name,
target,
target_name,
parent,
current_key,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")

# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)
r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)

kwargs = {
"r": r,
"lora_alpha": alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}

quant_methods = ["gptq", "aqlm", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config

# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
from peft.tuners.adalora import AdaLoraLayer

if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
target.update_layer(
adapter_name,
r,
lora_alpha=alpha,
lora_dropout=lora_config.lora_dropout,
init_lora_weights=lora_config.init_lora_weights,
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
PYTHON

kwargs为构造Lora层所必需的参数,此外这里还检查了base model是否已经被8 bit or 4 bit量化,以防止不兼容的操作。然后判断target是都已经是一个LoraLayer:如果 target 已经是一个 LoRA 层LoraLayer),则 更新 其参数(如 rlora_alpha 等);如果 targetAdaLoraLayer,则跳过,因为 AdaLoRA 有自己的适配逻辑;否则,创建新的 LoRA 层并替换目标层。self._replace_module用于 将 parent 中的 target_name 层替换为 new_module

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable

# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
if hasattr(new_module, "W_q"): # HQQ
new_module.W_q = child.W_q
else:
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)

meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
if hasattr(child, "weight")
else next(child.parameters())
)
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)
PYTHON

Lora Adapter调试跟踪
http://example.com/2025/03/11/Lora Adapter调试跟踪/
作者
Peter
发布于
2025年3月11日
许可协议