model-pruning by davila7/claude-code-templates
npx skills add https://github.com/davila7/claude-code-templates --skill model-pruning当您需要以下情况时,请使用模型剪枝:
关键技术:Wanda(权重 × 激活值)、SparseGPT(二阶)、结构化剪枝、N:M 稀疏性
论文:Wanda ICLR 2024 (arXiv 2306.11695)、SparseGPT (arXiv 2301.00774)
# Wanda 实现
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# 可选:SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# 依赖项
pip install torch transformers accelerate
来源:ICLR 2024 (arXiv 2306.11695)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 校准数据(用于激活统计的小型数据集)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda 剪枝函数
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda:通过权重幅度 × 输入激活进行剪枝。
参数:
sparsity:要剪枝的权重比例(0.5 = 50%)
"""
# 1. 收集激活统计信息
activations = {}
def hook_fn(name):
def hook(module, input, output):
# 存储输入激活范数
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# 为所有线性层注册钩子
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# 运行校准数据
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# 移除钩子
for hook in hooks:
hook.remove()
# 2. 基于 |权重| × 激活进行剪枝
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activations:
W = module.weight.data
act = activations[name]
# 计算重要性:|权重| × 激活
importance = W.abs() * act.unsqueeze(0)
# 展平并找到阈值
threshold = torch.quantile(importance.flatten(), sparsity)
# 创建掩码
mask = importance >= threshold
# 应用掩码(剪枝)
W *= mask.float()
return model
# 应用 Wanda 剪枝(50% 稀疏性,一次性,无需重新训练)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
# 保存
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
来源:arXiv 2301.00774
from sparsegpt import SparseGPT
# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# 初始化 SparseGPT
pruner = SparseGPT(model)
# 校准数据
calib_data = load_calibration_data() # ~128 个样本
# 剪枝(一次性,逐层重建)
pruned_model = pruner.prune(
calib_data=calib_data,
sparsity=0.5, # 50% 稀疏性
prunen=0, # 非结构化(0)或 N:M 结构化
prunem=0,
percdamp=0.01, # Hessian 逆的阻尼
)
# 结果:50% 稀疏性下的近无损剪枝
def nm_prune(weight, n=2, m=4):
"""
N:M 剪枝:每 M 个连续权重中保留 N 个权重。
示例:2:4 = 每 4 个权重中保留 2 个。
兼容 NVIDIA 稀疏张量核心(2:4, 4:8)。
"""
# 将权重重塑为 M 个一组
shape = weight.shape
weight_flat = weight.flatten()
# 填充到 M 的倍数
pad_size = (m - weight_flat.numel() % m) % m
weight_padded = F.pad(weight_flat, (0, pad_size))
# 重塑为 (num_groups, m)
weight_grouped = weight_padded.reshape(-1, m)
# 找到每组中的前 N 个
_, indices = torch.topk(weight_grouped.abs(), n, dim=-1)
# 创建掩码
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, indices, 1.0)
# 应用掩码
weight_pruned = weight_grouped * mask
# 重塑回原状
weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
return weight_pruned.reshape(shape)
# 应用 2:4 稀疏性(NVIDIA 硬件)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.weight.data = nm_prune(module.weight.data, n=2, m=4)
# 50% 稀疏性,在配备稀疏张量核心的 A100 上实现 2 倍加速
幅度剪枝(基线):
# 剪枝绝对值最小的权重
importance = weight.abs()
threshold = torch.quantile(importance, sparsity)
mask = importance >= threshold
Wanda(权重 × 激活值):
# 重要性 = |权重| × 输入激活
importance = weight.abs() * activation
# 比单独使用幅度更好(考虑了使用情况)
SparseGPT(二阶):
# 使用 Hessian(二阶导数)计算重要性
# 更准确但计算成本更高
importance = weight^2 / diag(Hessian)
非结构化(细粒度):
结构化(粗粒度):
半结构化(N:M):
# 非结构化(随机)
# [1, 0, 1, 0, 1, 1, 0, 0]
# 优点:灵活,质量高
# 缺点:无加速
# 结构化(块)
# [1, 1, 0, 0, 1, 1, 0, 0]
# 优点:硬件友好
# 缺点:精度损失更大
# N:M(半结构化)
# [1, 0, 1, 0] [1, 1, 0, 0] (2:4 模式)
# 优点:硬件加速 + 良好质量
# 缺点:需要特定硬件(NVIDIA)
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
"""在训练过程中逐渐增加稀疏性。"""
for step in range(num_steps):
# 当前稀疏性
current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)
# 按当前稀疏性剪枝
for module in model.modules():
if isinstance(module, torch.nn.Linear):
weight = module.weight.data
threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
mask = weight.abs() >= threshold
weight *= mask.float()
# 训练一步
train_step(model)
return model
def layer_wise_prune(model, sparsity_per_layer):
"""不同层使用不同的稀疏性。"""
# 早期层:较少剪枝(更重要)
# 后期层:较多剪枝(不太关键)
sparsity_schedule = {
"layer.0": 0.3, # 30% 稀疏性
"layer.1": 0.4,
"layer.2": 0.5,
"layer.3": 0.6, # 60% 稀疏性
}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# 查找层索引
for layer_name, sparsity in sparsity_schedule.items():
if layer_name in name:
# 按层特定稀疏性剪枝
prune_layer(module, sparsity)
break
return model
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
"""逐步剪枝,并在迭代之间进行微调。"""
current_sparsity = 0.0
sparsity_increment = target_sparsity / iterations
for i in range(iterations):
# 增加稀疏性
current_sparsity += sparsity_increment
# 剪枝
prune_model(model, sparsity=current_sparsity)
# 微调(恢复精度)
fine_tune(model, epochs=2, lr=1e-5)
return model
# 结果:在高稀疏性下比一次性剪枝精度更好
from transformers import Trainer, TrainingArguments
def production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda", # 或 "sparsegpt"
):
# 1. 加载模型
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. 加载校准数据
calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
# 3. 应用剪枝
if method == "wanda":
pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
elif method == "sparsegpt":
pruner = SparseGPT(model)
pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)
# 4. (可选)微调以恢复精度
training_args = TrainingArguments(
output_dir="./pruned-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=finetune_dataset,
)
trainer.train()
# 5. 保存
pruned_model.save_pretrained("./pruned-llama-7b-50")
tokenizer.save_pretrained("./pruned-llama-7b-50")
return pruned_model
# 用法
pruned_model = production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda"
)
from lm_eval import evaluator
# 评估剪枝模型与原始模型
original_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-7b-hf",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
pruned_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=./pruned-llama-7b-50",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
# 比较
print(f"原始模型: {original_results['results']['arc_easy']['acc']:.3f}")
print(f"剪枝模型: {pruned_results['results']['arc_easy']['acc']:.3f}")
print(f"精度下降: {(original_results - pruned_results):.3f}")
# 50% 稀疏性下的典型结果:
# - Wanda: <1% 精度损失
# - SparseGPT: <0.5% 精度损失
# - 幅度剪枝: 2-3% 精度损失
# 保守(安全)
sparsity = 0.3 # 30%, <0.5% 损失
# 平衡(推荐)
sparsity = 0.5 # 50%, ~1% 损失
# 激进(有风险)
sparsity = 0.7 # 70%, 2-5% 损失
# 极端(取决于模型)
sparsity = 0.9 # 90%, 显著下降
# 一次性,无需重新训练 → Wanda 或 SparseGPT
if no_retraining_budget:
use_method = "wanda" # 更快
# 最佳质量 → SparseGPT
if need_best_quality:
use_method = "sparsegpt" # 更准确
# 硬件加速 → N:M 结构化
if need_speedup:
use_method = "nm_prune" # 2:4 或 4:8
# ❌ 错误:不使用校准数据进行剪枝
prune_random(model) # 无激活统计信息
# ✅ 正确:使用校准数据
prune_wanda(model, calib_data)
# ❌ 错误:一次性剪枝稀疏性过高
prune(model, sparsity=0.9) # 巨大的精度损失
# ✅ 正确:渐进式或迭代式
iterative_prune(model, target=0.9, steps=10)
50% 稀疏性下的剪枝方法(LLaMA-7B):
| 方法 | 精度损失 | 速度 | 内存 | 是否需要重新训练 |
|---|---|---|---|---|
| 幅度剪枝 | -2.5% | 1.0× | -50% | 否 |
| Wanda | -0.8% | 1.0× | -50% | 否 |
| SparseGPT | -0.4% | 1.0× | -50% | 否 |
| N:M (2:4) | -1.0% | 2.0× | -50% | 否 |
| 结构化剪枝 | -3.0% | 2.0× | -50% | 否 |
来源:Wanda 论文(ICLR 2024)、SparseGPT 论文
每周安装次数
191
代码仓库
GitHub 星标数
23.4K
首次出现
2026 年 1 月 21 日
安全审计
安装于
opencode156
claude-code153
gemini-cli144
cursor138
codex137
github-copilot123
Use Model Pruning when you need to:
Key Techniques : Wanda (weights × activations), SparseGPT (second-order), structured pruning, N:M sparsity
Papers : Wanda ICLR 2024 (arXiv 2306.11695), SparseGPT (arXiv 2301.00774)
# Wanda implementation
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# Optional: SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# Dependencies
pip install torch transformers accelerate
Source : ICLR 2024 (arXiv 2306.11695)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Calibration data (small dataset for activation statistics)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda pruning function
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda: Prune by weight magnitude × input activation.
Args:
sparsity: Fraction of weights to prune (0.5 = 50%)
"""
# 1. Collect activation statistics
activations = {}
def hook_fn(name):
def hook(module, input, output):
# Store input activation norms
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# Register hooks for all linear layers
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# Run calibration data
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# Remove hooks
for hook in hooks:
hook.remove()
# 2. Prune weights based on |weight| × activation
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activations:
W = module.weight.data
act = activations[name]
# Compute importance: |weight| × activation
importance = W.abs() * act.unsqueeze(0)
# Flatten and find threshold
threshold = torch.quantile(importance.flatten(), sparsity)
# Create mask
mask = importance >= threshold
# Apply mask (prune)
W *= mask.float()
return model
# Apply Wanda pruning (50% sparsity, one-shot, no retraining)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
# Save
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
Source : arXiv 2301.00774
from sparsegpt import SparseGPT
# Load model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Initialize SparseGPT
pruner = SparseGPT(model)
# Calibration data
calib_data = load_calibration_data() # ~128 samples
# Prune (one-shot, layer-wise reconstruction)
pruned_model = pruner.prune(
calib_data=calib_data,
sparsity=0.5, # 50% sparsity
prunen=0, # Unstructured (0) or N:M structured
prunem=0,
percdamp=0.01, # Damping for Hessian inverse
)
# Results: Near-lossless pruning at 50% sparsity
def nm_prune(weight, n=2, m=4):
"""
N:M pruning: Keep N weights per M consecutive weights.
Example: 2:4 = keep 2 out of every 4 weights.
Compatible with NVIDIA sparse tensor cores (2:4, 4:8).
"""
# Reshape weight into groups of M
shape = weight.shape
weight_flat = weight.flatten()
# Pad to multiple of M
pad_size = (m - weight_flat.numel() % m) % m
weight_padded = F.pad(weight_flat, (0, pad_size))
# Reshape into (num_groups, m)
weight_grouped = weight_padded.reshape(-1, m)
# Find top-N in each group
_, indices = torch.topk(weight_grouped.abs(), n, dim=-1)
# Create mask
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, indices, 1.0)
# Apply mask
weight_pruned = weight_grouped * mask
# Reshape back
weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
return weight_pruned.reshape(shape)
# Apply 2:4 sparsity (NVIDIA hardware)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.weight.data = nm_prune(module.weight.data, n=2, m=4)
# 50% sparsity, 2× speedup on A100 with sparse tensor cores
Magnitude Pruning (baseline):
# Prune weights with smallest absolute values
importance = weight.abs()
threshold = torch.quantile(importance, sparsity)
mask = importance >= threshold
Wanda (weights × activations):
# Importance = |weight| × input_activation
importance = weight.abs() * activation
# Better than magnitude alone (considers usage)
SparseGPT (second-order):
# Uses Hessian (second derivative) for importance
# More accurate but computationally expensive
importance = weight^2 / diag(Hessian)
Unstructured (fine-grained):
Structured (coarse-grained):
Semi-structured (N:M) :
# Unstructured (random)
# [1, 0, 1, 0, 1, 1, 0, 0]
# Pros: Flexible, high quality
# Cons: No speedup
# Structured (block)
# [1, 1, 0, 0, 1, 1, 0, 0]
# Pros: Hardware friendly
# Cons: More accuracy loss
# N:M (semi-structured)
# [1, 0, 1, 0] [1, 1, 0, 0] (2:4 pattern)
# Pros: Hardware speedup + good quality
# Cons: Requires specific hardware (NVIDIA)
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
"""Gradually increase sparsity during training."""
for step in range(num_steps):
# Current sparsity
current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)
# Prune at current sparsity
for module in model.modules():
if isinstance(module, torch.nn.Linear):
weight = module.weight.data
threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
mask = weight.abs() >= threshold
weight *= mask.float()
# Train one step
train_step(model)
return model
def layer_wise_prune(model, sparsity_per_layer):
"""Different sparsity for different layers."""
# Early layers: Less pruning (more important)
# Late layers: More pruning (less critical)
sparsity_schedule = {
"layer.0": 0.3, # 30% sparsity
"layer.1": 0.4,
"layer.2": 0.5,
"layer.3": 0.6, # 60% sparsity
}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Find layer index
for layer_name, sparsity in sparsity_schedule.items():
if layer_name in name:
# Prune at layer-specific sparsity
prune_layer(module, sparsity)
break
return model
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
"""Prune gradually with fine-tuning between iterations."""
current_sparsity = 0.0
sparsity_increment = target_sparsity / iterations
for i in range(iterations):
# Increase sparsity
current_sparsity += sparsity_increment
# Prune
prune_model(model, sparsity=current_sparsity)
# Fine-tune (recover accuracy)
fine_tune(model, epochs=2, lr=1e-5)
return model
# Results: Better accuracy than one-shot at high sparsity
from transformers import Trainer, TrainingArguments
def production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda", # or "sparsegpt"
):
# 1. Load model
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. Load calibration data
calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
# 3. Apply pruning
if method == "wanda":
pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
elif method == "sparsegpt":
pruner = SparseGPT(model)
pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)
# 4. (Optional) Fine-tune to recover accuracy
training_args = TrainingArguments(
output_dir="./pruned-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=finetune_dataset,
)
trainer.train()
# 5. Save
pruned_model.save_pretrained("./pruned-llama-7b-50")
tokenizer.save_pretrained("./pruned-llama-7b-50")
return pruned_model
# Usage
pruned_model = production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda"
)
from lm_eval import evaluator
# Evaluate pruned vs original model
original_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-7b-hf",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
pruned_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=./pruned-llama-7b-50",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
# Compare
print(f"Original: {original_results['results']['arc_easy']['acc']:.3f}")
print(f"Pruned: {pruned_results['results']['arc_easy']['acc']:.3f}")
print(f"Degradation: {(original_results - pruned_results):.3f}")
# Typical results at 50% sparsity:
# - Wanda: <1% accuracy loss
# - SparseGPT: <0.5% accuracy loss
# - Magnitude: 2-3% accuracy loss
# Conservative (safe)
sparsity = 0.3 # 30%, <0.5% loss
# Balanced (recommended)
sparsity = 0.5 # 50%, ~1% loss
# Aggressive (risky)
sparsity = 0.7 # 70%, 2-5% loss
# Extreme (model-dependent)
sparsity = 0.9 # 90%, significant degradation
# One-shot, no retraining → Wanda or SparseGPT
if no_retraining_budget:
use_method = "wanda" # Faster
# Best quality → SparseGPT
if need_best_quality:
use_method = "sparsegpt" # More accurate
# Hardware speedup → N:M structured
if need_speedup:
use_method = "nm_prune" # 2:4 or 4:8
# ❌ Bad: Pruning without calibration data
prune_random(model) # No activation statistics
# ✅ Good: Use calibration data
prune_wanda(model, calib_data)
# ❌ Bad: Too high sparsity in one shot
prune(model, sparsity=0.9) # Massive accuracy loss
# ✅ Good: Gradual or iterative
iterative_prune(model, target=0.9, steps=10)
Pruning methods at 50% sparsity (LLaMA-7B):
| Method | Accuracy Loss | Speed | Memory | Retraining Needed |
|---|---|---|---|---|
| Magnitude | -2.5% | 1.0× | -50% | No |
| Wanda | -0.8% | 1.0× | -50% | No |
| SparseGPT | -0.4% | 1.0× | -50% | No |
| N:M (2:4) | -1.0% | 2.0× | -50% | No |
| Structured | -3.0% | 2.0× | -50% | No |
Source : Wanda paper (ICLR 2024), SparseGPT paper
Weekly Installs
191
Repository
GitHub Stars
23.4K
First Seen
Jan 21, 2026
Security Audits
Gen Agent Trust HubFailSocketPassSnykWarn
Installed on
opencode156
claude-code153
gemini-cli144
cursor138
codex137
github-copilot123
AI 代码实施计划编写技能 | 自动化开发任务分解与 TDD 流程规划工具
50,900 周安装
PostgreSQL优化助手 - JSONB操作、性能调优、窗口函数、全文搜索实战指南
9,600 周安装
GitHub Copilot create-readme:AI自动生成专业README文档工具
9,600 周安装
React Native 最佳实践与性能优化指南 | 提升应用FPS、启动速度与包体积
9,600 周安装
Web无障碍性(a11y)指南:WCAG 2.1原则、Lighthouse审计与代码实践
10,500 周安装
Vue Router 最佳实践指南:导航守卫、路由生命周期与常见陷阱解决方案
9,900 周安装
SEO优化指南:技术性SEO、页面优化与结构化数据实践
10,700 周安装