at-dispatch-v2 by pytorch/pytorch
npx skills add https://github.com/pytorch/pytorch --skill at-dispatch-v2此技能帮助将 PyTorch 的旧版 AT_DISPATCH 宏转换为新的 AT_DISPATCH_V2 格式,该格式定义于 aten/src/ATen/Dispatch_v2.h 中。
在以下情况下使用此技能:
aten/src/ATen/native/ 目录下使用分发宏的文件旧格式:
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda 函数体
});
新格式:
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda 函数体
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
scalar_type 和 在前,然后是 lambda 函数,最后是类型列表广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
nameAT_WRAP(lambda) 处理内部逗号AT_EXPAND(AT_ALL_TYPES) 替代隐式展开#include <ATen/Dispatch_v2.h>在现有的 #include <ATen/Dispatch.h> 附近添加 v2 头文件:
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
暂时保留旧的 Dispatch.h 包含(其他代码可能仍需要它)。
需要转换的常见模式:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)识别哪个类型组宏对应于基础类型:
| 旧宏基础 | AT_DISPATCH_V2 类型组 |
|---|---|
ALL_TYPES | AT_EXPAND(AT_ALL_TYPES) |
FLOATING_TYPES | AT_EXPAND(AT_FLOATING_TYPES) |
INTEGRAL_TYPES | AT_EXPAND(AT_INTEGRAL_TYPES) |
COMPLEX_TYPES | AT_EXPAND(AT_COMPLEX_TYPES) |
ALL_TYPES_AND_COMPLEX | AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX) |
对于组合模式,使用多个 AT_EXPAND() 条目:
// 旧:AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// 新:AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
从 AT_DISPATCH_*_AND2(type1, type2, ...) 或 AT_DISPATCH_*_AND3(type1, type2, type3, ...) 中提取单独类型(type1、type2 等)。
这些成为类型组之后的尾部参数:
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
来自 AND3 的单独类型
应用转换:
模式:
AT_DISPATCH_V2(
scalar_type, // 第 1 个:dtype 表达式
"name", // 第 2 个:调试字符串
AT_WRAP(lambda), // 第 3 个:用 AT_WRAP 包装的 lambda 函数
type_groups, // 第 4 个及之后:使用 AT_EXPAND() 的类型组
individual_types // 最后:单独类型
)
转换示例:
// 转换前
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// 转换后
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
对于包含内部逗号或复杂表达式的 lambda 函数,AT_WRAP 是必需的:
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // 内部有逗号!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);
检查以下内容:
AT_WRAP() 包装了整个 lambda 函数AT_EXPAND()AT_EXPAND()(只是 kBFloat16,而不是 AT_EXPAND(kBFloat16))#include <ATen/Dispatch_v2.h>可用的类型组宏(与 AT_EXPAND() 一起使用):
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + 无符号类型
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8 变体
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// 转换后
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
// 转换前
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// 转换后
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);
// 转换前
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// 转换后
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
AT_EXPAND() 混合和匹配类型组#include <ATen/Dispatch.h> - 其他代码可能需要它AT_WRAP() 是强制性的 - 防止 lambda 函数中的逗号解析问题AT_EXPAND(),单独类型不需要aten/src/ATen/Dispatch_v2.h - 请参考它以获取完整文档当被要求转换 AT_DISPATCH 宏时:
#include <ATen/Dispatch_v2.h>请勿编译或测试代码 - 仅专注于准确转换。
每周安装次数
203
代码仓库
GitHub 星标数
98.5K
首次出现
2026 年 1 月 20 日
安全审计
已安装于
claude-code190
opencode189
cursor187
gemini-cli184
codex184
github-copilot175
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in aten/src/ATen/Dispatch_v2.h.
Use this skill when:
aten/src/ATen/native/ that use dispatch macrosOld format:
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda body
});
New format:
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
scalar_type and name come first, then lambda, then typesAT_WRAP(lambda) to handle internal commasAT_EXPAND(AT_ALL_TYPES) instead of implicit expansion#include <ATen/Dispatch_v2.h> near other Dispatch includesAdd the v2 header near the existing #include <ATen/Dispatch.h>:
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
Keep the old Dispatch.h include for now (other code may still need it).
Common patterns to convert:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)Identify which type group macro corresponds to the base types:
| Old macro base | AT_DISPATCH_V2 type group |
|---|---|
ALL_TYPES | AT_EXPAND(AT_ALL_TYPES) |
FLOATING_TYPES | AT_EXPAND(AT_FLOATING_TYPES) |
INTEGRAL_TYPES | AT_EXPAND(AT_INTEGRAL_TYPES) |
COMPLEX_TYPES | AT_EXPAND(AT_COMPLEX_TYPES) |
For combined patterns, use multiple AT_EXPAND() entries:
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
From AT_DISPATCH_*_AND2(type1, type2, ...) or AT_DISPATCH_*_AND3(type1, type2, type3, ...), extract the individual types (type1, type2, etc.).
These become the trailing arguments after the type group:
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
Individual types from AND3
Apply the transformation:
Pattern:
AT_DISPATCH_V2(
scalar_type, // 1st: The dtype expression
"name", // 2nd: The debug string
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
type_groups, // 4th+: Type groups with AT_EXPAND()
individual_types // Last: Individual types
)
Example transformation:
// BEFORE
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// AFTER
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);
Check that:
AT_WRAP() wraps the entire lambdaAT_EXPAND()AT_EXPAND() (just kBFloat16, not AT_EXPAND(kBFloat16))#include <ATen/Dispatch_v2.h>Available type group macros (use with AT_EXPAND()):
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8 variants
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// Before
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// After
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
// Before
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// After
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);
// Before
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// Before
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
AT_EXPAND()#include <ATen/Dispatch.h> - other code may need itAT_WRAP() is mandatory - prevents comma parsing issues in the lambdaAT_EXPAND(), individual types don'taten/src/ATen/Dispatch_v2.h - refer to it for full docsWhen asked to convert AT_DISPATCH macros:
#include <ATen/Dispatch_v2.h> if not presentDo NOT compile or test the code - focus on accurate conversion only.
Weekly Installs
203
Repository
GitHub Stars
98.5K
First Seen
Jan 20, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykPass
Installed on
claude-code190
opencode189
cursor187
gemini-cli184
codex184
github-copilot175
React 组合模式指南:Vercel 组件架构最佳实践,提升代码可维护性
113,700 周安装
Plisio自动化集成:通过Rube MCP实现加密货币支付自动化操作
1 周安装
Plain自动化工具包:通过Rube MCP实现客户支持平台Plain的自动化操作
1 周安装
Pipeline CRM自动化工具:通过Rube MCP和Composio实现CRM操作自动化
1 周安装
Supabase Auth 身份验证技能:邮箱密码注册登录、会话管理、用户元数据操作
145 周安装
piggy-automation自动化技能 - Claude AI集成与工作流自动化工具
1 周安装
People Data Labs自动化技能 - 集成Claude AI与数据API的自动化工具
1 周安装
ALL_TYPES_AND_COMPLEX | AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX) |