为什么需要 FP8?
训练大模型的成本主要是显存和算力。FP8(8位浮点数)可以:
- 显存占用降低 40-50%
- 训练速度提升 1.5-2 倍
- 成本降低 30-40%
FP8 格式详解
| 格式 | 指数位 | 尾数位 | 动态范围 | 精度 |
|---|---|---|---|---|
| E5M2 | 5 | 2 | 1.4e-4 ~ 5.7e4 | 低 |
| E4M3 | 4 | 3 | 9.5e-5 ~ 4.4e3 | 高 |
推荐:前向用 E4M3(精度高),反向用 E5M2(范围大)
实战配置
1. 环境准备
pip install transformer-engine[pytorch] nvidia-smi # 需要 H100/H800/A100
2. 代码示例
import torch import transformer_engine.pytorch as te # 初始化 FP8 层 linear = te.Linear(4096, 4096) # 自动处理 FP8 计算 output = linear(input)
3. 关键参数
fp8_recipe = te.recipe.DelayedScaling(
fp8_format=te.recipe.Format.E4M3,
amax_history_len=1024,
amax_compute_algo=max
)
踩坑记录
问题 1:Loss 发散
原因:梯度下溢,缩放因子太小
解决:
fp8_recipe = te.recipe.DelayedScaling(
fp8_format=te.recipe.Format.E4M3,
override_linear_precision={'wgrad': False}
)
问题 2:精度下降
解决:关键层保持 BF16
# Embedding 和输出层用 BF16
x = embedding(input_ids) # BF16
with te.fp8_autocast(enabled=True):
x = fp8_layer(x) # FP8
x = output_head(x) # BF16
性能对比
| 配置 | 显存占用 | 训练速度 | 精度损失 |
|---|---|---|---|
| BF16 | 100% | 1.0x | 0% |
| FP8 (E4M3) | 58% | 1.8x | < 0.5% |
| FP8 (E5M2) | 55% | 1.9x | < 1% |
最佳实践
- 渐进式迁移:先用 BF16 训练稳定,再切换到 FP8 微调
- 关键层保护:Embedding 和 Output 层保持 BF16
- 监控指标:缩放因子、梯度范数、验证集精度
总结
FP8 训练是降低成本的有效手段,但需要谨慎配置。
不适合:第一次训模型、对精度极度敏感的任务。