重要前提
安装AI Skills的关键前提是:必须科学上网,且开启TUN模式,这一点至关重要,直接决定安装能否顺利完成,在此郑重提醒三遍:科学上网,科学上网,科学上网。查看完整安装教程 →
pytorch-fsdp2 by orchestra-research/ai-research-skills
npx skills add https://github.com/orchestra-research/ai-research-skills --skill pytorch-fsdp2fully_shard)本技能教授编码代理如何将 PyTorch FSDP2 添加到训练循环中,包括正确的初始化、分片、混合精度/卸载配置以及检查点保存。
PyTorch 中的 FSDP2 主要通过
torch.distributed.fsdp.fully_shard及其向模块原地添加的FSDPModule方法来使用。参见:references/pytorch_fully_shard_api.md,references/pytorch_fsdp2_tutorial.md。
在以下情况使用 FSDP2:
在以下情况避免(或谨慎)使用:
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
参考:references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md。
torchrun 启动,并设置每个进程的 CUDA 设备(通常通过 LOCAL_RANK)。fully_shard(),即在根模块之前分片子模块(例如,Transformer 块)。model(input),而不是 model.forward(input),以便 FSDP2 钩子运行(除非你显式调用 unshard() 或注册前向方法)。fully_shard 之后)。torch.save(model.state_dict()),除非你故意将张量收集为完整张量。(这些规则在官方 API 文档/教程中均有直接描述;参见参考资料。)
torchrun --nproc_per_node <gpus_per_node> ... 并确保 RANK、WORLD_SIZE、LOCAL_RANK 可见。参考:references/pytorch_fsdp2_tutorial.md(启动命令和设置),references/pytorch_fully_shard_api.md(用户约定)。
最小化、正确的模式:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))DeviceMesh 来描述数据并行组参考:references/pytorch_device_mesh_tutorial.md(为什么存在 DeviceMesh 以及它如何管理进程组)。
对于大型模型,在 meta 上初始化,应用分片,然后在 GPU 上具体化权重:
with torch.device("meta"): model = ...fully_shard(...),然后 fully_shard(model)model.to_empty(device="cuda")model.reset_parameters()(或你的初始化例程)参考:references/pytorch_fsdp2_tutorial.md(迁移指南明确展示了此流程)。
fully_shard()(包装策略 = “在需要的地方应用”)不要只在最顶层的模块上调用 fully_shard。
针对类 Transformer 模型的推荐分片模式:
if isinstance(m, TransformerBlock): fully_shard(m, ...)fully_shard(model, ...)原因:
fully_shard 为集合通信效率形成“参数组”,并排除已被先前调用分组的参数。自底向上能提供更好的重叠和更低的内存峰值。参考:references/pytorch_fully_shard_api.md(自底向上的要求及原因)。
reshard_after_forward 以权衡内存/性能默认行为:
None 表示非根模块为 True,根模块为 False(良好的默认值)。启发式方法:
True。False)。int 将参数重新分片到更小的网格(例如,节点内)。参考:references/pytorch_fully_shard_api.md(完整语义)。
FSDP2 使用:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)offload_policy=CPUOffloadPolicy() 如果你需要 CPU 卸载经验法则:
reduce_dtype 与你的梯度规约期望一致。参考:references/pytorch_fully_shard_api.md(MixedPrecisionPolicy / OffloadPolicy 类)。
set_requires_gradient_sync)而不是 FSDP1 的 no_sync()。梯度裁剪:
参考:references/pytorch_fsdp2_tutorial.md。
两种推荐方法:
A) 分布式检查点 (DCP) — 最佳默认选择
B) 分布式状态字典辅助函数
get_model_state_dict / set_model_state_dict 并配合 StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)get_optimizer_state_dict / set_optimizer_state_dict避免:
torch.save 保存 DTensor 状态字典,除非你故意使用 DTensor.full_tensor() 进行转换并仔细管理内存。参考资料:
references/pytorch_dcp_overview.md(DCP 行为和注意事项)references/pytorch_dcp_recipe.md 和 references/pytorch_dcp_async_recipe.md(端到端用法)references/pytorch_fsdp2_tutorial.md(DTensor 与 DCP 状态字典流程)references/pytorch_examples_fsdp2.md(可用的检查点脚本)torchrun 启动并初始化进程组。LOCAL_RANK 设置 CUDA 设备;如果需要多维并行,创建 DeviceMesh。meta),自底向上应用 fully_shard,然后 fully_shard(model)。model(inputs) 以便钩子运行;使用 set_requires_gradient_sync 进行累积。torch.distributed.checkpoint 辅助函数添加 DCP 保存/加载。参考:references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md。
Stateful 包装状态或通过 get_state_dict 组装状态。dcp.save(...) 到共享路径。dcp.load(...) 并使用 set_state_dict 恢复。参考:references/pytorch_dcp_recipe.md。
torch.cuda.set_device(LOCAL_RANK) 和你的 torchrun 标志。forward()?model(input) 或显式调用 unshard() / 注册前向方法。fully_shard() 是自底向上应用的吗?torch.save 混用,除非你理解转换过程。model(inputs)(或显式调用 unshard())而不是 model.forward(...)。fully_shard 调用之后创建优化器。fully_shard。reshard_after_forward=True。set_requires_gradient_sync 而不是 FSDP1 的 no_sync()。参考:references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md。
编码代理应实现一个包含以下标记块的脚本:
init_distributed():初始化进程组,设置设备build_model_meta():在 meta 上构建模型,应用 fully_shard,具体化权重build_optimizer():在分片后创建优化器train_step():使用 model(inputs) 和 DTensor 感知模式进行前向/反向/步进checkpoint_save/load():DCP 或分布式状态字典辅助函数具体示例位于 references/pytorch_examples_fsdp2.md 和官方教程参考资料中。
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.mdreferences/torchtitan_fsdp_notes.md(可选,生产笔记)references/ray_train_fsdp2_example.md(可选,集成示例)每周安装次数
64
代码仓库
GitHub 星标数
5.5K
首次出现
2026年2月7日
安全审计
已安装于
opencode55
codex54
cursor54
gemini-cli53
claude-code53
github-copilot52
fully_shard) correctly in a training scriptThis skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via
torch.distributed.fsdp.fully_shardand theFSDPModulemethods it adds in-place to modules. See:references/pytorch_fully_shard_api.md,references/pytorch_fsdp2_tutorial.md.
Use FSDP2 when:
Avoid (or be careful) if:
Reference: references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.
torchrun and set the CUDA device per process (usually via LOCAL_RANK).fully_shard() bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.model(input), not model.forward(input), so the FSDP2 hooks run (unless you explicitly unshard() or register the forward method).fully_shard).torch.save(model.state_dict()) unless you deliberately gather to full tensors.(Each of these rules is directly described in the official API docs/tutorial; see references.)
torchrun --nproc_per_node <gpus_per_node> ... and ensure RANK, WORLD_SIZE, LOCAL_RANK are visible.Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).
Minimal, correct pattern:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))DeviceMesh to describe the data-parallel group(s)Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).
For big models, initialize on meta, apply sharding, then materialize weights on GPU:
with torch.device("meta"): model = ...fully_shard(...) on submodules, then fully_shard(model)model.to_empty(device="cuda")model.reset_parameters() (or your init routine)Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).
fully_shard() bottom-up (wrapping policy = “apply where needed”)Do not only call fully_shard on the topmost module.
Recommended sharding pattern for transformer-like models:
if isinstance(m, TransformerBlock): fully_shard(m, ...)fully_shard(model, ...)Why:
fully_shard forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).
reshard_after_forward for memory/perf trade-offsDefault behavior:
None means True for non-root modules and False for root modules (good default).Heuristics:
True on many blocks.False).int to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.Reference: references/pytorch_fully_shard_api.md (full semantics).
FSDP2 uses:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)offload_policy=CPUOffloadPolicy() if you want CPU offloadRules of thumb:
reduce_dtype aligned with your gradient reduction expectations.Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).
set_requires_gradient_sync) instead of FSDP1’s no_sync().Gradient clipping:
Reference: references/pytorch_fsdp2_tutorial.md.
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
B) Distributed state dict helpers
get_model_state_dict / set_model_state_dict with StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)get_optimizer_state_dict / set_optimizer_state_dictAvoid:
torch.save unless you intentionally convert with DTensor.full_tensor() and manage memory carefully.References:
references/pytorch_dcp_overview.md (DCP behavior and caveats)references/pytorch_dcp_recipe.md and references/pytorch_dcp_async_recipe.md (end-to-end usage)references/pytorch_fsdp2_tutorial.md (DTensor vs DCP state-dict flows)references/pytorch_examples_fsdp2.md (working checkpoint scripts)torchrun and initialize the process group.LOCAL_RANK; create a DeviceMesh if you need multi-dim parallelism.meta if needed), apply fully_shard bottom-up, then fully_shard(model).model(inputs) so hooks run; use set_requires_gradient_sync for accumulation.torch.distributed.checkpoint helpers.Reference: references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.
Stateful or assemble state via get_state_dict.dcp.save(...) from all ranks to a shared path.dcp.load(...) and restore with set_state_dict.Reference: references/pytorch_dcp_recipe.md.
torch.cuda.set_device(LOCAL_RANK) and your torchrun flags.forward() directly?model(input) or explicitly unshard() / register forward.fully_shard() applied bottom-up?torch.save unless you understand conversions.model(inputs) (or unshard() explicitly) instead of model.forward(...).fully_shard calls.fully_shard bottom-up on submodules before the root.reshard_after_forward=True for more modules.set_requires_gradient_sync instead of FSDP1’s no_sync().Reference: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.
The coding agent should implement a script with these labeled blocks:
init_distributed(): init process group, set devicebuild_model_meta(): model on meta, apply fully_shard, materialize weightsbuild_optimizer(): optimizer created after shardingtrain_step(): forward/backward/step with model(inputs) and DTensor-aware patternscheckpoint_save/load(): DCP or distributed state dict helpersConcrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdWeekly Installs
64
Repository
GitHub Stars
5.5K
First Seen
Feb 7, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykPass
Installed on
opencode55
codex54
cursor54
gemini-cli53
claude-code53
github-copilot52
超能力技能使用指南:AI助手技能调用优先级与工作流程详解
53,700 周安装
references/pytorch_examples_fsdp2.mdreferences/torchtitan_fsdp_notes.md (optional, production notes)references/ray_train_fsdp2_example.md (optional, integration example)