重要前提
安装AI Skills的关键前提是:必须科学上网,且开启TUN模式,这一点至关重要,直接决定安装能否顺利完成,在此郑重提醒三遍:科学上网,科学上网,科学上网。查看完整安装教程 →
model-pruning by orchestra-research/ai-research-skills
npx skills add https://github.com/orchestra-research/ai-research-skills --skill model-pruning当您需要以下情况时,请使用模型剪枝:
关键技术:Wanda(权重 × 激活值)、SparseGPT(二阶)、结构化剪枝、N:M 稀疏性
论文:Wanda ICLR 2024 (arXiv 2306.11695)、SparseGPT (arXiv 2301.00774)
# Wanda 实现
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# 可选:SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# 依赖项
pip install torch transformers accelerate
来源:ICLR 2024 (arXiv 2306.11695)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 校准数据(用于激活统计的小数据集)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda 剪枝函数
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda:基于权重幅度 × 输入激活值进行剪枝。
参数:
sparsity:要剪枝的权重比例(0.5 = 50%)
"""
# 1. 收集激活统计信息
activations = {}
def hook_fn(name):
def hook(module, input, output):
# 存储输入激活值的范数
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# 为所有线性层注册钩子
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# 运行校准数据
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# 移除钩子
for hook in hooks:
hook.remove()
# 2. 基于 |权重| × 激活值进行剪枝
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activations:
W = module.weight.data
act = activations[name]
# 计算重要性:|权重| × 激活值
importance = W.abs() * act.unsqueeze(0)
# 展平并找到阈值
threshold = torch.quantile(importance.flatten(), sparsity)
# 创建掩码
mask = importance >= threshold
# 应用掩码(剪枝)
W *= mask.float()
return model
# 应用 Wanda 剪枝(50% 稀疏度,一次性,无需重新训练)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
# 保存
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
来源:arXiv 2301.00774
from sparsegpt import SparseGPT
# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# 初始化 SparseGPT
pruner = SparseGPT(model)
# 校准数据
calib_data = load_calibration_data() # ~128 个样本
# 剪枝(一次性,逐层重建)
pruned_model = pruner.prune(
calib_data=calib_data,
sparsity=0.5, # 50% 稀疏度
prunen=0, # 非结构化 (0) 或 N:M 结构化
prunem=0,
percdamp=0.01, # Hessian 逆矩阵的阻尼
)
# 结果:在 50% 稀疏度下实现近乎无损的剪枝
def nm_prune(weight, n=2, m=4):
"""
N:M 剪枝:每 M 个连续权重中保留 N 个权重。
示例:2:4 = 每 4 个权重中保留 2 个。
兼容 NVIDIA 稀疏张量核心 (2:4, 4:8)。
"""
# 将权重重塑为 M 个一组的组
shape = weight.shape
weight_flat = weight.flatten()
# 填充到 M 的倍数
pad_size = (m - weight_flat.numel() % m) % m
weight_padded = F.pad(weight_flat, (0, pad_size))
# 重塑为 (num_groups, m)
weight_grouped = weight_padded.reshape(-1, m)
# 在每个组中找到 top-N
_, indices = torch.topk(weight_grouped.abs(), n, dim=-1)
# 创建掩码
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, indices, 1.0)
# 应用掩码
weight_pruned = weight_grouped * mask
# 重塑回原状
weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
return weight_pruned.reshape(shape)
# 应用 2:4 稀疏度(NVIDIA 硬件)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.weight.data = nm_prune(module.weight.data, n=2, m=4)
# 50% 稀疏度,在配备稀疏张量核心的 A100 上实现 2 倍加速
幅度剪枝(基线):
# 剪除绝对值最小的权重
importance = weight.abs()
threshold = torch.quantile(importance, sparsity)
mask = importance >= threshold
Wanda(权重 × 激活值):
# 重要性 = |权重| × 输入_激活值
importance = weight.abs() * activation
# 优于仅使用幅度(考虑了使用情况)
SparseGPT(二阶):
# 使用 Hessian 矩阵(二阶导数)计算重要性
# 更准确但计算成本更高
importance = weight^2 / diag(Hessian)
非结构化(细粒度):
结构化(粗粒度):
半结构化 (N:M):
# 非结构化(随机)
# [1, 0, 1, 0, 1, 1, 0, 0]
# 优点:灵活,质量高
# 缺点:无加速
# 结构化(块状)
# [1, 1, 0, 0, 1, 1, 0, 0]
# 优点:硬件友好
# 缺点:更多精度损失
# N:M(半结构化)
# [1, 0, 1, 0] [1, 1, 0, 0] (2:4 模式)
# 优点:硬件加速 + 良好质量
# 缺点:需要特定硬件 (NVIDIA)
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
"""在训练过程中逐渐增加稀疏度。"""
for step in range(num_steps):
# 当前稀疏度
current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)
# 按当前稀疏度剪枝
for module in model.modules():
if isinstance(module, torch.nn.Linear):
weight = module.weight.data
threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
mask = weight.abs() >= threshold
weight *= mask.float()
# 训练一步
train_step(model)
return model
def layer_wise_prune(model, sparsity_per_layer):
"""对不同层应用不同的稀疏度。"""
# 早期层:较少剪枝(更重要)
# 后期层:较多剪枝(不太关键)
sparsity_schedule = {
"layer.0": 0.3, # 30% 稀疏度
"layer.1": 0.4,
"layer.2": 0.5,
"layer.3": 0.6, # 60% 稀疏度
}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# 查找层索引
for layer_name, sparsity in sparsity_schedule.items():
if layer_name in name:
# 按层特定稀疏度剪枝
prune_layer(module, sparsity)
break
return model
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
"""逐步剪枝,并在迭代之间进行微调。"""
current_sparsity = 0.0
sparsity_increment = target_sparsity / iterations
for i in range(iterations):
# 增加稀疏度
current_sparsity += sparsity_increment
# 剪枝
prune_model(model, sparsity=current_sparsity)
# 微调(恢复精度)
fine_tune(model, epochs=2, lr=1e-5)
return model
# 结果:在高稀疏度下比一次性剪枝精度更好
from transformers import Trainer, TrainingArguments
def production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda", # 或 "sparsegpt"
):
# 1. 加载模型
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. 加载校准数据
calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
# 3. 应用剪枝
if method == "wanda":
pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
elif method == "sparsegpt":
pruner = SparseGPT(model)
pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)
# 4. (可选)微调以恢复精度
training_args = TrainingArguments(
output_dir="./pruned-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=finetune_dataset,
)
trainer.train()
# 5. 保存
pruned_model.save_pretrained("./pruned-llama-7b-50")
tokenizer.save_pretrained("./pruned-llama-7b-50")
return pruned_model
# 用法
pruned_model = production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda"
)
from lm_eval import evaluator
# 评估剪枝后模型与原始模型
original_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-7b-hf",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
pruned_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=./pruned-llama-7b-50",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
# 比较
print(f"原始模型: {original_results['results']['arc_easy']['acc']:.3f}")
print(f"剪枝模型: {pruned_results['results']['arc_easy']['acc']:.3f}")
print(f"精度下降: {(original_results - pruned_results):.3f}")
# 50% 稀疏度下的典型结果:
# - Wanda: <1% 精度损失
# - SparseGPT: <0.5% 精度损失
# - 幅度剪枝: 2-3% 精度损失
# 保守(安全)
sparsity = 0.3 # 30%, <0.5% 损失
# 平衡(推荐)
sparsity = 0.5 # 50%, ~1% 损失
# 激进(有风险)
sparsity = 0.7 # 70%, 2-5% 损失
# 极端(取决于模型)
sparsity = 0.9 # 90%, 显著精度下降
# 一次性,无需重新训练 → Wanda 或 SparseGPT
if no_retraining_budget:
use_method = "wanda" # 更快
# 最佳质量 → SparseGPT
if need_best_quality:
use_method = "sparsegpt" # 更准确
# 硬件加速 → N:M 结构化
if need_speedup:
use_method = "nm_prune" # 2:4 或 4:8
# ❌ 错误:在没有校准数据的情况下剪枝
prune_random(model) # 无激活统计信息
# ✅ 正确:使用校准数据
prune_wanda(model, calib_data)
# ❌ 错误:一次性设置过高的稀疏度
prune(model, sparsity=0.9) # 巨大的精度损失
# ✅ 正确:渐进式或迭代式
iterative_prune(model, target=0.9, steps=10)
50% 稀疏度下的剪枝方法 (LLaMA-7B):
| 方法 | 精度损失 | 速度 | 内存 | 是否需要重新训练 |
|---|---|---|---|---|
| 幅度剪枝 | -2.5% | 1.0× | -50% | 否 |
| Wanda | -0.8% | 1.0× | -50% | 否 |
| SparseGPT | -0.4% | 1.0× | -50% | 否 |
| N:M (2:4) | -1.0% | 2.0× | -50% | 否 |
| 结构化剪枝 | -3.0% | 2.0× | -50% | 否 |
来源:Wanda 论文 (ICLR 2024)、SparseGPT 论文
每周安装次数
66
代码仓库
GitHub 星标数
5.6K
首次出现
2026年2月7日
安全审计
安装于
opencode57
codex56
cursor56
gemini-cli55
claude-code55
github-copilot54
Use Model Pruning when you need to:
Key Techniques : Wanda (weights × activations), SparseGPT (second-order), structured pruning, N:M sparsity
Papers : Wanda ICLR 2024 (arXiv 2306.11695), SparseGPT (arXiv 2301.00774)
# Wanda implementation
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# Optional: SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# Dependencies
pip install torch transformers accelerate
Source : ICLR 2024 (arXiv 2306.11695)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Calibration data (small dataset for activation statistics)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda pruning function
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda: Prune by weight magnitude × input activation.
Args:
sparsity: Fraction of weights to prune (0.5 = 50%)
"""
# 1. Collect activation statistics
activations = {}
def hook_fn(name):
def hook(module, input, output):
# Store input activation norms
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# Register hooks for all linear layers
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# Run calibration data
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# Remove hooks
for hook in hooks:
hook.remove()
# 2. Prune weights based on |weight| × activation
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activations:
W = module.weight.data
act = activations[name]
# Compute importance: |weight| × activation
importance = W.abs() * act.unsqueeze(0)
# Flatten and find threshold
threshold = torch.quantile(importance.flatten(), sparsity)
# Create mask
mask = importance >= threshold
# Apply mask (prune)
W *= mask.float()
return model
# Apply Wanda pruning (50% sparsity, one-shot, no retraining)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
# Save
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
Source : arXiv 2301.00774
from sparsegpt import SparseGPT
# Load model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Initialize SparseGPT
pruner = SparseGPT(model)
# Calibration data
calib_data = load_calibration_data() # ~128 samples
# Prune (one-shot, layer-wise reconstruction)
pruned_model = pruner.prune(
calib_data=calib_data,
sparsity=0.5, # 50% sparsity
prunen=0, # Unstructured (0) or N:M structured
prunem=0,
percdamp=0.01, # Damping for Hessian inverse
)
# Results: Near-lossless pruning at 50% sparsity
def nm_prune(weight, n=2, m=4):
"""
N:M pruning: Keep N weights per M consecutive weights.
Example: 2:4 = keep 2 out of every 4 weights.
Compatible with NVIDIA sparse tensor cores (2:4, 4:8).
"""
# Reshape weight into groups of M
shape = weight.shape
weight_flat = weight.flatten()
# Pad to multiple of M
pad_size = (m - weight_flat.numel() % m) % m
weight_padded = F.pad(weight_flat, (0, pad_size))
# Reshape into (num_groups, m)
weight_grouped = weight_padded.reshape(-1, m)
# Find top-N in each group
_, indices = torch.topk(weight_grouped.abs(), n, dim=-1)
# Create mask
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, indices, 1.0)
# Apply mask
weight_pruned = weight_grouped * mask
# Reshape back
weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
return weight_pruned.reshape(shape)
# Apply 2:4 sparsity (NVIDIA hardware)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.weight.data = nm_prune(module.weight.data, n=2, m=4)
# 50% sparsity, 2× speedup on A100 with sparse tensor cores
Magnitude Pruning (baseline):
# Prune weights with smallest absolute values
importance = weight.abs()
threshold = torch.quantile(importance, sparsity)
mask = importance >= threshold
Wanda (weights × activations):
# Importance = |weight| × input_activation
importance = weight.abs() * activation
# Better than magnitude alone (considers usage)
SparseGPT (second-order):
# Uses Hessian (second derivative) for importance
# More accurate but computationally expensive
importance = weight^2 / diag(Hessian)
Unstructured (fine-grained):
Structured (coarse-grained):
Semi-structured (N:M) :
# Unstructured (random)
# [1, 0, 1, 0, 1, 1, 0, 0]
# Pros: Flexible, high quality
# Cons: No speedup
# Structured (block)
# [1, 1, 0, 0, 1, 1, 0, 0]
# Pros: Hardware friendly
# Cons: More accuracy loss
# N:M (semi-structured)
# [1, 0, 1, 0] [1, 1, 0, 0] (2:4 pattern)
# Pros: Hardware speedup + good quality
# Cons: Requires specific hardware (NVIDIA)
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
"""Gradually increase sparsity during training."""
for step in range(num_steps):
# Current sparsity
current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)
# Prune at current sparsity
for module in model.modules():
if isinstance(module, torch.nn.Linear):
weight = module.weight.data
threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
mask = weight.abs() >= threshold
weight *= mask.float()
# Train one step
train_step(model)
return model
def layer_wise_prune(model, sparsity_per_layer):
"""Different sparsity for different layers."""
# Early layers: Less pruning (more important)
# Late layers: More pruning (less critical)
sparsity_schedule = {
"layer.0": 0.3, # 30% sparsity
"layer.1": 0.4,
"layer.2": 0.5,
"layer.3": 0.6, # 60% sparsity
}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Find layer index
for layer_name, sparsity in sparsity_schedule.items():
if layer_name in name:
# Prune at layer-specific sparsity
prune_layer(module, sparsity)
break
return model
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
"""Prune gradually with fine-tuning between iterations."""
current_sparsity = 0.0
sparsity_increment = target_sparsity / iterations
for i in range(iterations):
# Increase sparsity
current_sparsity += sparsity_increment
# Prune
prune_model(model, sparsity=current_sparsity)
# Fine-tune (recover accuracy)
fine_tune(model, epochs=2, lr=1e-5)
return model
# Results: Better accuracy than one-shot at high sparsity
from transformers import Trainer, TrainingArguments
def production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda", # or "sparsegpt"
):
# 1. Load model
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. Load calibration data
calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
# 3. Apply pruning
if method == "wanda":
pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
elif method == "sparsegpt":
pruner = SparseGPT(model)
pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)
# 4. (Optional) Fine-tune to recover accuracy
training_args = TrainingArguments(
output_dir="./pruned-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=finetune_dataset,
)
trainer.train()
# 5. Save
pruned_model.save_pretrained("./pruned-llama-7b-50")
tokenizer.save_pretrained("./pruned-llama-7b-50")
return pruned_model
# Usage
pruned_model = production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda"
)
from lm_eval import evaluator
# Evaluate pruned vs original model
original_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-7b-hf",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
pruned_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=./pruned-llama-7b-50",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
# Compare
print(f"Original: {original_results['results']['arc_easy']['acc']:.3f}")
print(f"Pruned: {pruned_results['results']['arc_easy']['acc']:.3f}")
print(f"Degradation: {(original_results - pruned_results):.3f}")
# Typical results at 50% sparsity:
# - Wanda: <1% accuracy loss
# - SparseGPT: <0.5% accuracy loss
# - Magnitude: 2-3% accuracy loss
# Conservative (safe)
sparsity = 0.3 # 30%, <0.5% loss
# Balanced (recommended)
sparsity = 0.5 # 50%, ~1% loss
# Aggressive (risky)
sparsity = 0.7 # 70%, 2-5% loss
# Extreme (model-dependent)
sparsity = 0.9 # 90%, significant degradation
# One-shot, no retraining → Wanda or SparseGPT
if no_retraining_budget:
use_method = "wanda" # Faster
# Best quality → SparseGPT
if need_best_quality:
use_method = "sparsegpt" # More accurate
# Hardware speedup → N:M structured
if need_speedup:
use_method = "nm_prune" # 2:4 or 4:8
# ❌ Bad: Pruning without calibration data
prune_random(model) # No activation statistics
# ✅ Good: Use calibration data
prune_wanda(model, calib_data)
# ❌ Bad: Too high sparsity in one shot
prune(model, sparsity=0.9) # Massive accuracy loss
# ✅ Good: Gradual or iterative
iterative_prune(model, target=0.9, steps=10)
Pruning methods at 50% sparsity (LLaMA-7B):
| Method | Accuracy Loss | Speed | Memory | Retraining Needed |
|---|---|---|---|---|
| Magnitude | -2.5% | 1.0× | -50% | No |
| Wanda | -0.8% | 1.0× | -50% | No |
| SparseGPT | -0.4% | 1.0× | -50% | No |
| N:M (2:4) | -1.0% | 2.0× | -50% | No |
| Structured | -3.0% | 2.0× | -50% | No |
Source : Wanda paper (ICLR 2024), SparseGPT paper
Weekly Installs
66
Repository
GitHub Stars
5.6K
First Seen
Feb 7, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykWarn
Installed on
opencode57
codex56
cursor56
gemini-cli55
claude-code55
github-copilot54
超能力技能使用指南:AI助手技能调用优先级与工作流程详解
53,700 周安装
JSON 转 React Email 渲染器:用 JSON 规范生成 HTML/纯文本邮件 | @json-render/react-email
8 周安装
runtime-context技能:AI智能体运行时环境检测与工具适配,实现跨平台兼容性
9 周安装
JSON 转 SVG/PNG 图像渲染器 - 使用 Satori 将 JSON 规范快速生成图片
9 周安装
学术论文写作智能体工具:12个AI智能体流水线,支持LaTeX/PDF输出,涵盖所有学科
63 周安装
check-version插件:70毫秒快速检查Claude插件版本,自动提示更新
9 周安装
对抗性代码审查工作流:Hunter/Skeptic/Referee 消除偏见,提升缺陷检测准确率
10 周安装