Stable Diffusion推理概览
本文最后更新于 2025-07-04,文章内容可能已经过时。
Stable Diffusion属于一类称为扩散模型(diffusion model)的深度学习模型,属于生成类模型。 目前有 SD1.x、SD2.x 以及 SDXL系列,支持 文生图、图生图以及图像重绘。
训练过程包括前向扩散和反向扩散, 推理时只使用反向扩散。
前向扩散(Forward Diffusion): 通过添加噪声,把训练图像逐渐转换为没有特点的噪声图像。
逆向扩散(Reverse Diffusion):通过逐步去噪,把嘈杂无意义的图像(噪声)逐步恢复为有意义的图像。
1 模型结构
SD模型主要包括三个主要部分,Clip Text Encoder 、Unet + Scheduler 、VAE Decoder。
- Clip Text Encoder:用于对用户输入的prompt进行 文本编码,提取 Text Embedding特征, 形状是 [77,768]; 其中77是文本转换成 token_id之后的序列长度。如果原始prompt经过分词器之后长度不足77,则进行补padding, 如果超过77,则截断。
- Unet + Scheduler:输入是 Text Encoder输出的特征编码 + 随机高斯噪声矩阵, 形状是[64,64,4];作用是经过 N 次迭代之后,逐步去除噪声,还原 Latent 特征。
- Unet:预测噪声残差
- 主要模块:U型结构,主要包括 CrossAttnDownBlock2D 、 CrossAttnMidBlock 、CrossAttnUpBlock2D
- CrossAttn[Down/up]Block = ResNetBlock + SpatialTransformer + [Down/Up]Block
- 输入:潜在特征(Latent)、prompt的文字特征编码以及当前 时间步(timestep)
- 输出:当前步的 噪声预测(noise_pred)。
- 主要模块:U型结构,主要包括 CrossAttnDownBlock2D 、 CrossAttnMidBlock 、CrossAttnUpBlock2D
- Scheduler: 使用Unet预测出来的噪声,结合特定的算法(DDPM, DDIM. Euler等),从当前的潜在特征(Latent)中祛除预测出来的噪声。
- Unet:预测噪声残差
- VAE: 从 Unet+Scheduler 模块经过若干step之后生成的 latet feature,重建成像素级图像。
2 推理流程
2.1 文生图
2.1.1 模块
Clip Text Encoder , Unet + Scheduler, VAE Decoder2.1.2 流程
-
用户的prompt,经过 Text Encoder 提取出形状为 [77, 768]的文本特征编码。
- 如果包含负面输入(negative prompt),则会对负面输入也进行特征编码。
- 同时,生成一个形状为[64, 64, 4]的高斯噪声矩阵。
-
提取出的文本特征和高斯噪声矩阵一同送入 图像优化器(Unet + Schduler),
-
Unet负责生成当前step的噪声预测(noise_pred);
-
如果有负面输入(negative prompt),则噪声预测也会包含 正/负面两部分,需要把这两者进行相减。
-
Scheduler 根据特定的算法(DDPM、DDIM等)控制降噪的幅度,把 预测的噪声 从当前step的 潜在特征(Latent)中祛除,得到新的潜在特征。
-
以上两步迭代多次(num_steps)
-
图像优化器迭代多次生成的最终 潜在特征,经过VAE Decoder 解码还原成有意义的像素级图像。
2.2 图生图
图像图的推理过程与文生图差不多,主要区别是模型结构上多了一个 VAE Encoder 用于对输入图像进行编码,再加上一定强度的随机噪声作为Unet的输入;文生图是完全随机的噪声矩阵。
2.2.1 模块
Clip Text Encoder + VAE Encoder , Unet + Scheduler, VAE Decoder2.2.2 流程
-
用户的prompt,经过 Text Encoder 提取出形状为 [77, 768]的文本特征。
- 同时,输入图像经过VAE Encoder进行图像编码,然后 加上一定强度的随机噪声 生成形状为[64, 64, 4]的矩阵。
-
提取出的文本特征和加了噪声的图像编码一同送入 图像优化器(Unet + Schduler),
-
Unet负责进行当前step的噪声预测;
-
Scheduler 根据特定的算法(DDPM、DDIM等)把预测的噪声从当前step的 潜在特征(Latent)中祛除,得到新的潜在特征。
-
迭代以上两步多次(num_steps)
-
图像优化器迭代多次生成的最终 潜在特征,经过VAE Decoder 解码还原成有意义的像素级图像。
3 解决显存问题
在 Stable Diffusion 推理阶段,UNet 和 VAE 模块的显存使用大体可分为两部分:
- 常驻显存(常量部分):模型权重 + 固定缓冲区(比如 LayerNorm 常量、常驻缓存等);
- 峰值临时显存(动态部分):各层激活(activations)、中间临时张量(比如 attention 的 Q/K/V、卷积临时缓冲、softmax 临时结果)及其它算子中间态。
模型本身的权重大小、批次大小(batch size)、数据类型以及设置的图片分辨率大小 都会影响SD模型推理时峰值显存的占用。
优化思路
-
半精度浮点型/INT8量化
- 推理速度加快,显存降低,但可能影响出图质量
-
CPU Offload
-
需要动态H2D 权重拷贝,影响推理速度
-
Tiled diffusion & Tiled VAE
-
影响推理速度
-
图级别优化
-
小算子融合
3.1 VAE Tiling
VAE 解码器在生成最终图像时,需要将 latent 图像(如 [4, 64, 64])解码成 RGB 图像 [3, 512, 512]VAE 解码中后期的卷积操作(尤其是Upsample)会产生巨大的临时张量,这是推理显存占用的瓶颈之一。VAE Tiling的做法是:
将 latent 图像分成小块(tiles),如 64×64 latent patch(即原图像 512×512 区块),分别送入 VAE 解码器推理,再拼接成完整图像。
这使得显存使用与 tile 大小成正比,而非整图分辨率。
3.2 分片Attention(Attention Slicing)
Transformer 中的 self-attention 层通常需要将 Query、Key、Value 全部加载到显存中计算,对于输入 token 数量 N=H×W,计算 QK^T 需要 O(N^2) 的显存。高分辨率图像(如 latent 的 128×128)会导致巨大中间张量。
分片 attention 的做法是:
将 attention 的输入 batch 或头部分成多个小块,分块计算 QK 和 softmax,再拼接。
这大大减少了某一时刻同时驻留显存中的 token 数。
4 XPU上支持SD的推理
4.1 支持方法
主要思路是自定义 XPUStableDiffusionPipeline;对于SD中3个主要模块如VAE、Unet等,使用C++在XPU后端开发模型级别的高效推理实现,并通过torch的机制来爆出出Python的接口,来复用 diffusers Pipeline的一些能力,同时在推理时调用到XPU的后端。
- 首先通过继承StableDiffusionPipeline 自定义了XPUStableDiffusionPipeline的类;
- 对于SD模型的三个主要模块 Encoder、 Unet 与 VAE ,实现XPU上优化后的C++模块级实现, 并通过pytorch的 torch::CustomClassHolder机制绑定Python 类级别的 python接口。
- 这三个主要模块的接口里,实现了 模型的初始化、权重加载、模型前向计算以及动态Lora等功能
- 另外,由于XPU的推理后端只支持 .bin 或者 .npz格式的权重,需要提前把 safetensor格式的权重转换成 .npz 格式的权重,以供 XPUUnet等模块加载。
- 在推理时,使用方式与使用 diffusers的StableDiffusionPipeline 接口基本保持一致。
- Unet等模块会调用XPU自定义的后端进行推理; scheduler调用的是官方的版本
4.2 推理加速手段
-
算子级别优化
- 半精度类型 / INT8 量化
- 关键算子的高效实现如(FlashAttention、FusedConv等)
- 若干稀碎小算子使用大的融合算子替代(Kernel Fusion)
-
调度器优化
-
使用更高效的shceduler(如DPM-Slover-Multistep),减少步数值20-30步
-
模型结构优化
-
蒸馏
-
使用参数量较小的学生模型代替原始的教师模型
-
Efficient Unet (如SnapFusion):
-
通过一定的指标来分析模块的重要性,从而移除不重要的模块,降低推理时间
-
LCM(潜在一致性模型)
-
LCM-Lora: Lora作用于Unet的Attention 和 Conv模块
-
动态Lora:根据 lora_key 读取 Lora的权重,先计算原始权重的结果 res1,在计算Lora的结果res2,最后 y = res1 + res2
-
prompt缓存机制
-
类似于 LLM 里的前缀缓存,维护一个hashtable, 键:输入的hash值, 值:图片的缓存(内存或者存储路径),若检测到的当前输入的hash值已在 hashtable里存在,在直接读取缓存的图片返回
- 感谢你赐予我前进的力量