Stem: Rethinking Causal Information Flow in Sparse Attention#

Stem 是 AngelSlim 的稀疏注意力算法,用于加速长上下文 LLM 的 Prefill 阶段。它通过在 block 粒度上估计注意力重要性,动态选择 top-k 关键块执行 block-sparse attention,在保持生成质量的同时大幅降低 Prefill 延迟。

1. 算法动机#

长上下文推理(如 32K–128K tokens)中,Prefill 阶段的全量 attention 计算是主要瓶颈:

  • 计算量随序列长度 二次增长,显存和延迟双重压力

  • 实际上大部分 attention block 对最终输出贡献极小,存在大量冗余

Stem 的核心思路是:先用低成本的 block-level scoring 估计每个 attention block 的重要性,再只对重要的 block 执行精确 attention

2. 技术原理#

Stem 的 Prefill 过程分为三步:

2.1 Block-Level Scoring#

使用 Triton 加速的 strided group GEMM 计算下采样的 Q·K^T 分数矩阵,并结合 value-norm bonus 项,得到每个 query-block 对每个 key-block 的重要性估计:

$$\text{score}(Q_i, K_j) = \frac{Q_i \cdot K_j^T}{\sqrt{d} \cdot s \cdot n} + \lambda \cdot \text{ReLU}(\bar{v}_j)$$

其中 $s$ 为 stride 因子,$n$ 为归一化系数,$\bar{v}_j$ 为 value-norm 的标准化对数值。

2.2 Top-k Schedule#

每层根据预设的 keep-ratio 和 alpha 衰减因子,生成 per-block 的 top-k budget:

  • 前 N 层(warmup):alpha=1.0,保留更多 block 以保证底层特征提取的完整性

  • 后续层:alpha=0.7,更激进的稀疏化以加速计算

  • 额外保证 initial blocks(sink tokens)和 sliding window blocks 始终被保留

2.3 Block-Sparse Attention#

根据 top-k mask 执行稀疏 attention:

  • 如果安装了 block-sparse-attn 库,使用真正的 block-sparse kernel

  • 否则自动 fallback 到 pseudo-sparse 实现(展开 mask 后做 dense attention)

  • HPC 后端支持 bf16 dense prefill 和 fp8 block-sparse prefill(varlen / paged 两种路径)

Decode 阶段不受影响,仍使用模型原始的 attention 实现(FlashAttention-2 / eager / SDPA)。

3. 支持范围#

维度

支持情况

后端

torch(纯 PyTorch + Triton)、hpc(HPC C++ 扩展)

HPC 精度

bf16(dense prefill)、fp8(block-sparse prefill,varlen / paged)

序列长度

无上限,建议 4K+ tokens 以体现加速效果

4. 性能评测#

我们在长上下文与 Agent 类任务上评测了 Stem 的精度保持能力。在 FP8-W8A8 + Stem 配置下,模型在 LongBench v2、CL-bench、CL-bench Life、SWE-bench Verified、Terminal-Bench 2.0、ClawEval 等多个 benchmark 上的得分与 BF16 基线基本持平,部分任务(如 ClawEval)甚至略有提升,验证了 Stem 稀疏注意力在大幅加速 Prefill 的同时几乎无损模型质量。

Stem 在多个 benchmark 上的精度对比(BF16 vs FP8-W8A8+Stem)。

5. 快速开始#

确保已安装 AngelSlim(pip install -e .uv sync),然后在项目根目录运行:

Dense 对照(无 Stem patch)#

python tools/run_stem.py \
    --mode dense \
    --model-path /path/to/Qwen3-8B \
    --prompt-file prompt.txt \
    --max-new-tokens 160

Stem + HPC bf16#

python tools/run_stem.py \
    --mode stem \
    --stem-backend hpc \
    --hpc-dtype bf16 \
    --model-path /path/to/Qwen3-8B \
    --prompt-file prompt.txt \
    --max-new-tokens 160

Stem + HPC fp8#

python tools/run_stem.py \
    --mode stem \
    --stem-backend hpc \
    --hpc-dtype fp8 \
    --model-path /path/to/Qwen3-8B \
    --prompt-file prompt.txt \
    --max-new-tokens 160

使用自定义 prompt#

python tools/run_stem.py \
    --mode stem \
    --stem-backend hpc \
    --hpc-dtype bf16 \
    --model-path /path/to/Qwen3-8B \
    --prompt-file my_long_document.txt \
    --max-new-tokens 256

也可以通过封装脚本启动:

bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt stem

6. 参数说明#

参数

默认值

说明

backend

"torch"

后端选择:"torch""hpc"

hpc_dtype

"bf16"

HPC 后端精度:"bf16""fp8"

hpc_fp8_path

"paged"

FP8 执行路径:"varlen""paged"

stem_alpha

1.0

per-layer alpha 衰减因子,可传 list 实现分层控制

block_size

128

attention block 大小

stride

8

scoring 阶段的下采样步长

chunk_size

2048

scoring 阶段的分块宽度

norm

1.0

scoring 阶段的额外归一化系数

initial_blocks

4

始终保留的头部 block 数量(sink tokens)

window_size

4

sliding window 保留的尾部 block 数量

7. 代码结构#

angelslim/compressor/sparsity/
├── __init__.py                          # re-export StemInference
└── stem/
    ├── __init__.py                      # 包入口
    ├── stem.py                          # StemInference 类(主入口)
    ├── patch.py                         # 模型 patch 逻辑
    ├── stem_configuration.py            # StemConfig 配置
    ├── backends/
    │   ├── dispatcher.py                # torch / hpc 路由
    │   ├── torch_impl.py               # PyTorch + Triton 实现
    │   └── hpc_impl.py                 # HPC C++ 扩展实现
    ├── modules/
    │   └── forward.py                   # patched attention forward
    └── ops/
        └── stem_kernel.py               # Triton kernel

tools/run_stem.py                            # 推理入口
scripts/sparsity/run_stem.sh             # 启动脚本

8. Python API#

from angelslim.compressor.sparsity import StemInference

stem = StemInference(attn_kwargs={
    "backend": "hpc",
    "hpc_dtype": "fp8",
    "stem_alpha": [1.0] * 5 + [0.7] * 31,  # 36 层 Qwen3-8B
})
model = stem(model)  # 返回 patched 后的同一个 model 对象