大模型显存占用计算

1. 显存单位换算

在讨论显存占用时,首先要明白“B”和“G”的含义。通常,“B”指的是十亿(1B = 10^9),而“G”则表示千兆字节(1G = 10^9字节)。例如,1B参数意味着有10亿个参数。显存的单位通常以字节计算,而1个字节等于8位。
🎈如果使用全精度训练(fp32),每个参数需要占用32位(即4个字节),因此1B的参数需要占用4GB的显存。
🎈如果使用半精度(fp16或bf16),则每个参数占用2字节,1B的参数只需占用2GB的显存。

2. 显存开销的其他组成部分

除了模型参数本身外,训练过程中还会消耗一定的显存,主要包括以下几部分:
🎈梯度:每个参数对应一个梯度,因此梯度的显存占用与参数量相同。
🎈优化器状态:优化器,如Adam,通常会为每个参数保存一阶动量和二阶动量,因此优化器的显存开销为参数量的2倍(对于Adam)。对于其他优化器(如SGD),则取决于优化器的具体实现,若是带动量的SGD,则为参数量的1倍。

3. 显存总占用计算

假设我们训练一个参数量为1B的模型,采用全精度(fp32)并使用Adam优化器,显存的占用计算如下:
🎈参数:1B × 4GB = 4GB
🎈梯度:1B × 4GB = 4GB
🎈优化器状态:1B × 8GB = 8GB
因此,总显存占用为16GB。如果使用半精度(bf16),则显存占用减半,为8GB。混合精度训练则会根据各部分精度调整计算结果。


大模型显存占用计算
http://example.com/2025/02/17/大模型显存占用计算/
作者
Peter
发布于
2025年2月17日
许可协议