speculative-decoding by davila7/claude-code-templates
npx skills add https://github.com/davila7/claude-code-templates --skill speculative-decoding在以下场景中使用推测解码:
关键技术:草稿模型推测解码、Medusa(多头解码)、前瞻解码(Jacobi 迭代)
相关论文:Medusa (arXiv 2401.10774)、前瞻解码 (ICML 2024)、推测解码综述 (ACL 2024)
# 标准推测解码 (transformers)
pip install transformers accelerate
# Medusa (多头解码)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .
# 前瞻解码
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .
# 可选:支持推测解码的 vLLM
pip install vllm
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载目标模型(大模型,速度慢)
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
device_map="auto",
torch_dtype=torch.float16
)
# 加载草稿模型(小模型,速度快)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# 使用推测解码生成
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Transformers 4.36+ 支持辅助生成
outputs = target_model.generate(
**inputs,
assistant_model=draft_model, # 启用推测解码
max_new_tokens=256,
do_sample=True,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
from medusa.model.medusa_model import MedusaModel
# 加载增强后的 Medusa 模型
model = MedusaModel.from_pretrained(
"FasterDecoding/medusa-vicuna-7b-v1.3", # 预训练了 Medusa 头
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")
# 使用 Medusa 生成(2-3 倍加速)
prompt = "Write a Python function to calculate fibonacci numbers:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.medusa_generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
posterior_threshold=0.09, # 接受阈值
posterior_alpha=0.3, # 树构建参数
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
from lookahead.lookahead_decoding import LookaheadDecoding
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 初始化前瞻解码
lookahead = LookaheadDecoding(
model=model,
tokenizer=tokenizer,
window_size=15, # 前瞻窗口大小 (W)
ngram_size=5, # N-gram 大小 (N)
guess_size=5 # 并行猜测数量
)
# 生成(1.5-2.3 倍加速)
prompt = "Implement quicksort in Python:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)
核心思想:使用小型草稿模型生成候选词,大型目标模型并行验证。
算法流程:
def speculative_decode(target_model, draft_model, prompt, K=4):
"""推测解码算法。"""
# 1. 生成 K 个草稿词元
draft_tokens = draft_model.generate(prompt, max_new_tokens=K)
# 2. 目标模型单次前向传播并行评估所有 K 个词元
target_logits = target_model(draft_tokens) # 并行!
# 3. 基于概率匹配接受/拒绝
accepted = []
for i in range(K):
p_draft = softmax(draft_model.logits[i])
p_target = softmax(target_logits[i])
# 接受概率
if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
accepted.append(draft_tokens[i])
else:
break # 拒绝,从目标模型重新采样
return accepted
性能表现:
来源:arXiv 2401.10774 (2024)
创新点:在现有模型上添加多个预测头,无需单独的草稿模型即可预测未来词元。
架构:
Input → Base LLM (frozen) → Hidden State
├→ Head 1 (predicts token t+1)
├→ Head 2 (predicts token t+2)
├→ Head 3 (predicts token t+3)
└→ Head 4 (predicts token t+4)
训练方式:
基于树的注意力机制:
# Medusa 构建候选树
# 示例:提前预测 2 步,每步取 top-2
# Root
# / \
# T1a T1b (Step 1: 2 candidates)
# / \ / \
# T2a T2b T2c T2d (Step 2: 4 candidates total)
# 单次前向传播评估整棵树!
优势:
来源:ICML 2024
核心思想:将自回归解码重新表述为求解方程组,使用 Jacobi 迭代并行求解。
数学公式:
Traditional: y_t = f(x, y_1, ..., y_{t-1}) (sequential)
Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (parallel)
两个分支:
class LookaheadDecoding:
def __init__(self, model, window_size=15, ngram_size=5):
self.model = model
self.W = window_size # 前瞻窗口
self.N = ngram_size # N-gram 大小
def generate_step(self, tokens):
# 前瞻分支:生成 W × N 个候选
candidates = {}
for w in range(1, self.W + 1):
for n in range(1, self.N + 1):
# 从位置 w 开始生成长度为 n 的 n-gram
ngram = self.generate_ngram(tokens, start=w, length=n)
candidates[(w, n)] = ngram
# 验证分支:查找匹配的 n-gram
verified = []
for ngram in candidates.values():
if ngram[0] == tokens[-1]: # 首个词元匹配最后一个输入
if self.verify(tokens, ngram):
verified.append(ngram)
# 接受最长的已验证 n-gram
return max(verified, key=len) if verified else [self.model.generate_next(tokens)]
性能表现:
| 方法 | 加速比 | 是否需要训练 | 草稿模型 | 质量损失 |
|---|---|---|---|---|
| 草稿模型推测解码 | 1.5-2 倍 | 否 | 是(外部) | 无 |
| Medusa | 2-3.6 倍 | 少量(仅预测头) | 否(内置预测头) | 无 |
| 前瞻解码 | 1.5-2.3 倍 | 无 | 否 | 无 |
| 朴素批处理 | 1.2-1.5 倍 | 否 | 否 | 无 |
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import initialize_past_key_values
import torch.nn as nn
# 1. 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
"lmsys/vicuna-7b-v1.3",
torch_dtype=torch.float16
)
# 2. 添加 Medusa 预测头
num_heads = 4
medusa_heads = nn.ModuleList([
nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
for _ in range(num_heads)
])
# 3. 训练循环(Medusa-1 冻结基础模型)
for param in base_model.parameters():
param.requires_grad = False # 冻结基础模型
optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)
for batch in dataloader:
# 前向传播
hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]
# 使用每个预测头预测未来词元
loss = 0
for i, head in enumerate(medusa_heads):
logits = head(hidden_states)
# 目标:偏移 (i+1) 个位置的词元
target = batch['input_ids'][:, i+1:]
loss += F.cross_entropy(logits[:, :-i-1], target)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 使用 Medusa 作为推测解码的草稿模型
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b")
target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")
# 草稿模型使用 Medusa 生成多个候选
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)
# 目标模型单次前向传播验证
outputs = target_model.generate(
prompt,
assistant_model=draft_medusa, # 使用 Medusa 作为草稿模型
max_new_tokens=256
)
# 结合优势:Medusa 速度 + 大模型质量
def select_draft_model(target_model_size, target):
"""为推测解码选择最优草稿模型。"""
# 规则:草稿模型应比目标模型小 5-10 倍
if target_model_size == "70B":
return "7B" # 10 倍小
elif target_model_size == "33B":
return "7B" # 5 倍小
elif target_model_size == "13B":
return "1B" # 13 倍小
else:
return None # 目标模型太小,改用 Medusa/前瞻解码
# 示例
draft = select_draft_model("70B", target_model)
# 返回 "7B" → 使用 Llama-2-7b 作为 Llama-2-70b 的草稿模型
# 新部署 → Medusa(整体加速比最佳,无需草稿模型)
if deploying_new_model:
use_method = "Medusa"
# 已有部署且有小模型可用 → 草稿模型推测解码
elif have_small_version_of_model:
use_method = "Draft Model Speculative"
# 希望零训练/零配置 → 前瞻解码
elif want_plug_and_play:
use_method = "Lookahead Decoding"
草稿模型推测解码:
# K = 推测词元数量
K = 4 # 良好默认值
K = 2 # 保守(接受率更高)
K = 8 # 激进(接受率更低,但接受时加速更多)
# 规则:更大的 K → 如果草稿模型好,则加速比更高
Medusa:
# 后验阈值(接受置信度)
posterior_threshold = 0.09 # 标准值(来自论文)
posterior_threshold = 0.05 # 更保守(速度较慢,质量更高)
posterior_threshold = 0.15 # 更激进(速度更快,可能降低质量)
# 树深度(前瞻步数)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]] # 深度 3(标准)
前瞻解码:
# 窗口大小 W(前瞻距离)
# N-gram 大小 N(生成上下文)
# 7B 模型(资源较多)
W, N = 15, 5
# 13B 模型(中等)
W, N = 10, 5
# 33B+ 模型(资源有限)
W, N = 7, 5
# 使用推测解码的 vLLM
from vllm import LLM, SamplingParams
# 使用草稿模型初始化
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
speculative_model="meta-llama/Llama-2-7b-hf", # 草稿模型
num_speculative_tokens=5,
use_v2_block_manager=True,
)
# 生成
prompts = ["Tell me about AI:", "Explain quantum physics:"]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
references/draft_model.md - 草稿模型选择与训练references/medusa.md - Medusa 架构与训练references/lookahead.md - 前瞻解码实现细节每周安装量
156
代码仓库
GitHub 星标数
22.6K
首次出现
2026 年 1 月 21 日
安全审计
已安装于
opencode124
claude-code124
gemini-cli114
cursor107
codex103
antigravity95
Use Speculative Decoding when you need to:
Key Techniques : Draft model speculative decoding, Medusa (multiple heads), Lookahead Decoding (Jacobi iteration)
Papers : Medusa (arXiv 2401.10774), Lookahead Decoding (ICML 2024), Speculative Decoding Survey (ACL 2024)
# Standard speculative decoding (transformers)
pip install transformers accelerate
# Medusa (multiple decoding heads)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .
# Lookahead Decoding
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .
# Optional: vLLM with speculative decoding
pip install vllm
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load target model (large, slow)
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
device_map="auto",
torch_dtype=torch.float16
)
# Load draft model (small, fast)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# Generate with speculative decoding
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Transformers 4.36+ supports assisted generation
outputs = target_model.generate(
**inputs,
assistant_model=draft_model, # Enable speculative decoding
max_new_tokens=256,
do_sample=True,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
from medusa.model.medusa_model import MedusaModel
# Load Medusa-enhanced model
model = MedusaModel.from_pretrained(
"FasterDecoding/medusa-vicuna-7b-v1.3", # Pre-trained with Medusa heads
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")
# Generate with Medusa (2-3× speedup)
prompt = "Write a Python function to calculate fibonacci numbers:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.medusa_generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
posterior_threshold=0.09, # Acceptance threshold
posterior_alpha=0.3, # Tree construction parameter
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
from lookahead.lookahead_decoding import LookaheadDecoding
# Load model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Initialize lookahead decoding
lookahead = LookaheadDecoding(
model=model,
tokenizer=tokenizer,
window_size=15, # Lookahead window (W)
ngram_size=5, # N-gram size (N)
guess_size=5 # Number of parallel guesses
)
# Generate (1.5-2.3× speedup)
prompt = "Implement quicksort in Python:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)
Idea : Use small draft model to generate candidates, large target model to verify in parallel.
Algorithm :
def speculative_decode(target_model, draft_model, prompt, K=4):
"""Speculative decoding algorithm."""
# 1. Generate K draft tokens
draft_tokens = draft_model.generate(prompt, max_new_tokens=K)
# 2. Target model evaluates all K tokens in one forward pass
target_logits = target_model(draft_tokens) # Parallel!
# 3. Accept/reject based on probability match
accepted = []
for i in range(K):
p_draft = softmax(draft_model.logits[i])
p_target = softmax(target_logits[i])
# Acceptance probability
if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
accepted.append(draft_tokens[i])
else:
break # Reject, resample from target
return accepted
Performance :
Source : arXiv 2401.10774 (2024)
Innovation : Add multiple prediction heads to existing model, predict future tokens without separate draft model.
Architecture :
Input → Base LLM (frozen) → Hidden State
├→ Head 1 (predicts token t+1)
├→ Head 2 (predicts token t+2)
├→ Head 3 (predicts token t+3)
└→ Head 4 (predicts token t+4)
Training :
Tree-based Attention :
# Medusa constructs tree of candidates
# Example: Predict 2 steps ahead with top-2 per step
# Root
# / \
# T1a T1b (Step 1: 2 candidates)
# / \ / \
# T2a T2b T2c T2d (Step 2: 4 candidates total)
# Single forward pass evaluates entire tree!
Advantages :
Source : ICML 2024
Core idea : Reformulate autoregressive decoding as solving system of equations, solve in parallel using Jacobi iteration.
Mathematical formulation :
Traditional: y_t = f(x, y_1, ..., y_{t-1}) (sequential)
Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (parallel)
Two branches :
Lookahead Branch : Generate n-grams in parallel
Verification Branch : Verify promising n-grams
class LookaheadDecoding:
def __init__(self, model, window_size=15, ngram_size=5):
self.model = model
self.W = window_size # Lookahead window
self.N = ngram_size # N-gram size
def generate_step(self, tokens):
# Lookahead branch: Generate W × N candidates
candidates = {}
for w in range(1, self.W + 1):
for n in range(1, self.N + 1):
# Generate n-gram starting at position w
ngram = self.generate_ngram(tokens, start=w, length=n)
candidates[(w, n)] = ngram
# Verification branch: Find matching n-grams
verified = []
for ngram in candidates.values():
if ngram[0] == tokens[-1]: # First token matches last input
if self.verify(tokens, ngram):
verified.append(ngram)
# Accept longest verified n-gram
return max(verified, key=len) if verified else [self.model.generate_next(tokens)]
Performance :
| Method | Speedup | Training Needed | Draft Model | Quality Loss |
|---|---|---|---|---|
| Draft Model Speculative | 1.5-2× | No | Yes (external) | None |
| Medusa | 2-3.6× | Minimal (heads only) | No (built-in heads) | None |
| Lookahead | 1.5-2.3× | None | No | None |
| Naive Batching | 1.2-1.5× | No | No | None |
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import initialize_past_key_values
import torch.nn as nn
# 1. Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"lmsys/vicuna-7b-v1.3",
torch_dtype=torch.float16
)
# 2. Add Medusa heads
num_heads = 4
medusa_heads = nn.ModuleList([
nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
for _ in range(num_heads)
])
# 3. Training loop (freeze base model for Medusa-1)
for param in base_model.parameters():
param.requires_grad = False # Freeze base
optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)
for batch in dataloader:
# Forward pass
hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]
# Predict future tokens with each head
loss = 0
for i, head in enumerate(medusa_heads):
logits = head(hidden_states)
# Target: tokens shifted by (i+1) positions
target = batch['input_ids'][:, i+1:]
loss += F.cross_entropy(logits[:, :-i-1], target)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Use Medusa as draft model for speculative decoding
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b")
target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")
# Draft generates multiple candidates with Medusa
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)
# Target verifies in single forward pass
outputs = target_model.generate(
prompt,
assistant_model=draft_medusa, # Use Medusa as draft
max_new_tokens=256
)
# Combines benefits: Medusa speed + large model quality
def select_draft_model(target_model_size, target):
"""Select optimal draft model for speculative decoding."""
# Rule: Draft should be 5-10× smaller
if target_model_size == "70B":
return "7B" # 10× smaller
elif target_model_size == "33B":
return "7B" # 5× smaller
elif target_model_size == "13B":
return "1B" # 13× smaller
else:
return None # Target too small, use Medusa/Lookahead instead
# Example
draft = select_draft_model("70B", target_model)
# Returns "7B" → Use Llama-2-7b as draft for Llama-2-70b
# New deployment → Medusa (best overall speedup, no draft model)
if deploying_new_model:
use_method = "Medusa"
# Existing deployment with small model available → Draft speculative
elif have_small_version_of_model:
use_method = "Draft Model Speculative"
# Want zero training/setup → Lookahead
elif want_plug_and_play:
use_method = "Lookahead Decoding"
Draft Model Speculative :
# K = number of speculative tokens
K = 4 # Good default
K = 2 # Conservative (higher acceptance)
K = 8 # Aggressive (lower acceptance, but more when accepted)
# Rule: Larger K → more speedup IF draft model is good
Medusa :
# Posterior threshold (acceptance confidence)
posterior_threshold = 0.09 # Standard (from paper)
posterior_threshold = 0.05 # More conservative (slower, higher quality)
posterior_threshold = 0.15 # More aggressive (faster, may degrade quality)
# Tree depth (how many steps ahead)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]] # Depth 3 (standard)
Lookahead :
# Window size W (lookahead distance)
# N-gram size N (context for generation)
# 7B model (more resources)
W, N = 15, 5
# 13B model (moderate)
W, N = 10, 5
# 33B+ model (limited resources)
W, N = 7, 5
# vLLM with speculative decoding
from vllm import LLM, SamplingParams
# Initialize with draft model
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
speculative_model="meta-llama/Llama-2-7b-hf", # Draft model
num_speculative_tokens=5,
use_v2_block_manager=True,
)
# Generate
prompts = ["Tell me about AI:", "Explain quantum physics:"]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
references/draft_model.md - Draft model selection and trainingreferences/medusa.md - Medusa architecture and trainingreferences/lookahead.md - Lookahead decoding implementation detailsWeekly Installs
156
Repository
GitHub Stars
22.6K
First Seen
Jan 21, 2026
Security Audits
Gen Agent Trust HubWarnSocketPassSnykWarn
Installed on
opencode124
claude-code124
gemini-cli114
cursor107
codex103
antigravity95
AI 代码实施计划编写技能 | 自动化开发任务分解与 TDD 流程规划工具
47,700 周安装