optimizing-attention-flash by davila7/claude-code-templates
npx skills add https://github.com/davila7/claude-code-templates --skill optimizing-attention-flashFlash Attention 通过 IO 感知的平铺和重计算,为 Transformer 注意力机制提供 2-4 倍的速度提升和 10-20 倍的内存减少。
PyTorch 原生支持(最简单,PyTorch 2.2+):
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# 如果可用,自动使用 Flash Attention
out = F.scaled_dot_product_attention(q, k, v)
flash-attn 库(更多功能):
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
复制此清单:
Flash Attention 集成:
- [ ] 步骤 1:检查 PyTorch 版本(≥2.2)
- [ ] 步骤 2:启用 Flash Attention 后端
- [ ] 步骤 3:通过性能分析验证速度提升
- [ ] 步骤 4:测试准确性是否与基线匹配
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
步骤 1:检查 PyTorch 版本
python -c "import torch; print(torch.__version__)"
# 应 ≥2.2.0
如果 <2.2,升级:
pip install --upgrade torch
步骤 2:启用 Flash Attention 后端
替换标准注意力:
# 之前(标准注意力)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# 之后(Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
强制使用 Flash Attention 后端:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
步骤 3:通过性能分析验证速度提升
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# 基准测试
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
预期:对于 >512 个标记的序列,速度提升 2-4 倍。
步骤 4:测试准确性是否与基线匹配
# 比较输出
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# 标准注意力
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# 检查差异
diff = (out_flash - out_standard).abs().max()
print(f"最大差异: {diff:.6f}")
# 对于 float16,应 <1e-3
用于多查询注意力、滑动窗口或 H100 FP8。
复制此清单:
flash-attn 库设置:
- [ ] 步骤 1:安装 flash-attn 库
- [ ] 步骤 2:修改注意力代码
- [ ] 步骤 3:启用高级功能
- [ ] 步骤 4:基准测试性能
步骤 1:安装 flash-attn 库
# NVIDIA GPU(CUDA 12.0+)
pip install flash-attn --no-build-isolation
# 验证安装
python -c "from flash_attn import flash_attn_func; print('Success')"
步骤 2:修改注意力代码
from flash_attn import flash_attn_func
# 输入: [batch_size, seq_len, num_heads, head_dim]
# 如果需要,从 [batch, heads, seq, dim] 转置
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # 用于自回归模型
window_size=(-1, -1), # 无滑动窗口
softmax_scale=None # 自动缩放
)
out = out.transpose(1, 2) # 转回 [batch, heads, seq, dim]
步骤 3:启用高级功能
多查询注意力(跨头共享 K/V):
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # 较少的 KV 头
out = flash_attn_func(q, k, v) # 自动处理 MQA
滑动窗口注意力(局部注意力):
# 仅关注前后 256 个标记的窗口
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (左, 右) 窗口
causal=True
)
步骤 4:基准测试性能
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# 预热
for _ in range(10):
_ = flash_attn_func(q, k, v)
# 基准测试
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"每次迭代时间: {(end-start)/100*1000:.2f}ms")
print(f"分配的内存: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
用于在 H100 GPU 上获得最大性能。
FP8 设置:
- [ ] 步骤 1:验证 H100 GPU 可用
- [ ] 步骤 2:安装支持 FP8 的 flash-attn
- [ ] 步骤 3:将输入转换为 FP8
- [ ] 步骤 4:使用 FP8 注意力运行
步骤 1:验证 H100 GPU
nvidia-smi --query-gpu=name --format=csv
# 应显示 "H100" 或 "H800"
步骤 2:安装支持 FP8 的 flash-attn
pip install flash-attn --no-build-isolation
# H100 包含 FP8 支持
步骤 3:将输入转换为 FP8
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# 转换为 float8_e4m3(FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
步骤 4:使用 FP8 注意力运行
from flash_attn import flash_attn_func
# FlashAttention-3 在 H100 上自动使用 FP8 内核
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# 结果:约 1.2 PFLOPS,比 FP16 快 1.5-2 倍
在以下情况使用 Flash Attention:
在以下情况使用替代方案:
问题:ImportError: cannot import flash_attn
使用 no-build-isolation 标志安装:
pip install flash-attn --no-build-isolation
或先安装 CUDA 工具包:
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
问题:速度比预期慢(无速度提升)
Flash Attention 的收益随序列长度增加:
2K 个标记:3-4 倍速度提升
检查序列长度是否足够。
问题:RuntimeError: CUDA error
验证 GPU 是否支持 Flash Attention:
import torch
print(torch.cuda.get_device_capability())
# 对于 Turing+,应 ≥(7, 5)
Flash Attention 要求:
问题:准确性下降
检查 dtype 是否为 float16 或 bfloat16(不是 float32):
q = q.to(torch.float16) # 或 torch.bfloat16
Flash Attention 使用 float16/bfloat16 以提高速度。不支持 Float32。
与 HuggingFace Transformers 集成:有关在 BERT、GPT、Llama 模型中启用 Flash Attention 的信息,请参阅 references/transformers-integration.md。
性能基准测试:有关跨 GPU 和序列长度的详细速度和内存比较,请参阅 references/benchmarks.md。
算法细节:有关平铺策略、重计算和 IO 复杂度分析,请参阅 references/algorithm.md。
高级功能:有关旋转嵌入、ALiBi、分页 KV 缓存和自定义注意力掩码的信息,请参阅 references/advanced-features.md。
不支持:V100(Volta)、CPU 推理
每周安装数
169
代码仓库
GitHub 星标数
23.4K
首次出现
2026年1月21日
安全审计
安装于
claude-code140
opencode137
gemini-cli129
cursor127
codex118
antigravity113
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+) :
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
flash-attn library (more features) :
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
Copy this checklist:
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
Step 1: Check PyTorch version
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
If <2.2, upgrade:
pip install --upgrade torch
Step 2: Enable Flash Attention backend
Replace standard attention:
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Force Flash Attention backend:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
Step 3: Verify speedup with profiling
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
Expected: 2-4x speedup for sequences >512 tokens.
Step 4: Test accuracy matches baseline
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
Step 1: Install flash-attn library
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
Step 2: Modify attention code
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
Step 3: Enable advanced features
Multi-query attention (shared K/V across heads):
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
Sliding window attention (local attention):
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
Step 4: Benchmark performance
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
For maximum performance on H100 GPUs.
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
Step 1: Verify H100 GPU
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
Step 2: Install flash-attn with FP8 support
pip install flash-attn --no-build-isolation
# FP8 support included for H100
Step 3: Convert inputs to FP8
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
Step 4: Run with FP8 attention
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
Use Flash Attention when:
Use alternatives instead:
Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
pip install flash-attn --no-build-isolation
Or install CUDA toolkit first:
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
Issue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
Flash Attention requires:
Issue: Accuracy degradation
Check dtype is float16 or bfloat16 (not float32):
q = q.to(torch.float16) # Or torch.bfloat16
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
Integration with HuggingFace Transformers : See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks : See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details : See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features : See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
Not supported : V100 (Volta), CPU inference
Weekly Installs
169
Repository
GitHub Stars
23.4K
First Seen
Jan 21, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykPass
Installed on
claude-code140
opencode137
gemini-cli129
cursor127
codex118
antigravity113
超能力技能使用指南:AI助手技能调用优先级与工作流程详解
49,600 周安装