metal-kernel by pytorch/pytorch
npx skills add https://github.com/pytorch/pytorch --skill metal-kernel本技能将指导您为 Apple Silicon 上的 PyTorch 算子实现 Metal 内核。
重要提示: 本技能的目标是通过 c10/metal/ 基础设施使用原生 Metal 能力,而不是 MPSGraph。原生 Metal 内核能提供更好的控制力、性能和可维护性。
本技能涵盖两种工作流程:
两种工作流程都涉及:
aten/src/ATen/native/native_functions.yaml 中更新分发机制aten/src/ATen/native/mps/kernels/ 中编写 Metal 内核aten/src/ATen/native/mps/operations/ 中实现主机端存根位置: aten/src/ATen/native/native_functions.yaml
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
找到算子条目并添加 MPS 分发:
# 简单的 MPS 特定实现
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# 跨设备共享实现(结构化内核首选)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# 结构化内核(新算子首选)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
将现有算子从 MPSGraph 迁移到原生 Metal 时,需要整合分发条目:
# 迁移前(基于 MPSGraph,独立分发)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # 独立的 MPS 实现
# 迁移后(原生 Metal,通过存根共享分发)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS 现在使用相同的存根机制
关键更改: 将 MPS: my_op_out_mps 替换为将 MPS 添加到共享分发行(例如 CPU, CUDA, MPS: my_op_out)。
分发命名约定:
MPS: function_name_mps - MPS 特定实现(旧的 MPSGraph 模式)CPU, CUDA, MPS: function_name - 共享存根实现(原生 Metal 模式)位置: aten/src/ATen/native/mps/kernels/
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// 定义操作函数对象
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* 你的操作 */;
}
};
// 为支持的类型注册
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* 你的操作 */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
对于二元操作,使用 BinaryKernel.metal 中定义的便捷宏:
// 仅浮点类型(float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// 输出为浮点型的整数类型(用于 atan2、copysign 等数学运算)
// 注册:long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// 输出类型相同的整数类型(用于按位/逻辑运算)
// 注册:long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// 具有运算数学精度的浮点类型(需要更高精度的运算)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
常见模式:
REGISTER_FLOAT_BINARY_OP 和 REGISTER_INT2FLOAT_BINARY_OPREGISTER_FLOAT_BINARY_OP 和 REGISTER_INTEGER_BINARY_OPREGISTER_FLOAT_BINARY_OP 和 REGISTER_INTEGER_BINARY_OPatan2 示例(同时支持浮点和整数输入):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
struct special_functor {
// 浮点类型
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // 使用精确数学运算
}
// 整数类型
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// 复数类型(cfloat 对应 float2,chalf 对应 half2)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = 实部, x.y = 虚部
return T(/* 实部 */, /* 虚部 */);
}
};
关于复数类型的说明: Metal 中的复数表示为向量类型:
c10::complex<float> 映射到 float2(x = 实部,y = 虚部)c10::complex<half> 映射到 half2在函数对象中使用 is_complex_v<T> 来为复数类型进行特化。
utils.h:
opmath_t<T> - 运算数学类型(half->float)accum_t<T> - 用于归约的累加类型max()、min()special_math.h:
precise::exp()、precise::log()、precise::sqrt()precise::sin()、precise::cos()、precise::tan()erf()、erfc()、erfinv()indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)位置: aten/src/ATen/native/mps/operations/
根据操作类型选择或创建适当的文件:
UnaryKernel.mm - 通过存根分发的单输入操作BinaryKernel.mm - 通过存根分发的双输入操作UnaryOps.mm / BinaryOps.mm - 遗留的 MPSGraph 实现(供参考)ReduceOps.mm - 归约操作(sum、mean、max 等)对于使用 TensorIterator 模式的结构化内核:
// 在 BinaryKernel.mm(或适当的文件)中
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" 与 .metal 文件中的函数对象名称匹配
}
// 注册 MPS 存根 - 这将连接到分发系统
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
对于一元操作:
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
从 MPSGraph 迁移时,还需移除旧的实现:
从 BinaryOps.mm(或 UnaryOps.mm)中移除:
TORCH_IMPL_FUNC(my_op_out_mps) 实现#include <ATen/ops/my_op_native.h> 头文件添加到 BinaryKernel.mm(或 UnaryKernel.mm):
REGISTER_DISPATCH 调用完成更改后,进行编译以验证所有内容是否正确构建:
cd build && ninja torch_cpu
基本算子支持已通过 test/test_mps.py 中的 test_output_match 进行测试。实现算子后,通过移除预期失败来启用测试:
位置: torch/testing/_internal/common_mps.py
找到并从跳过/预期失败列表中移除算子:
# 移除类似条目:
MPS_XFAILLIST = {
"my_op": ..., # 移除此行
}
MPS_SKIPLIST = {
"my_op": ..., # 移除此行
}
位置: torch/testing/_internal/common_methods_invocations.py(或相关文件)
从 OpInfo 中移除 MPS 特定的装饰器:
OpInfo(
"my_op",
# 移除装饰器,例如:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
# 运行特定算子测试
python test/test_mps.py -k test_output_match_my_op
# 或运行完整的 MPS 测试套件
python test/test_mps.py
torch.mps.compile_shader 调试 Metal 内核使用 torch.mps.compile_shader 可以即时编译并独立测试单个 Metal 内核。这对于调试多内核流水线非常宝贵,因为您需要独立验证每个阶段。
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
compile_shader 使用 dispatchThreads 语义(与 PyTorch 中的 mtl_dispatch1DJob 相同):
threads=[N, 1, 1] — 线程总数(不是线程组数)group_size=[G, 1, 1] — 每个线程组的线程数这与某些主机端代码使用的 dispatchThreadgroups API 不同。要匹配 dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):
# 等效的 compile_shader 调用:
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
将标量常量作为单元素张量传递:
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
当一系列内核(例如,直方图 → 前缀和 → 分散)产生错误结果时,单独测试每个内核,并将其输出与 Python/NumPy 参考实现进行验证:
# 1. 运行 GPU 内核
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. 在 Python 中计算参考值
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. 比较
assert np.array_equal(hist.cpu().numpy(), ref_hist), "直方图不匹配!"
这样可以隔离流水线中哪个内核出现问题,而不是一次性调试整个流水线。
threads 计数 — threads 是线程总数,不是线程组数。对于 5 个线程组,每组 256 个线程,应使用 threads=[1280, 1, 1]。compile_shader 不直接支持 [[threadgroup(N)]] 参数。如果您的内核需要线程组内存,请重构为使用在内核体内声明的 threadgroup 数组。native_functions.yamlkernels/ 中实现了 Metal 内核operations/ 中实现了主机端算子torch/testing/_internal/common_mps.py 中移除预期失败每周安装次数
181
代码仓库
GitHub 星标数
98.5K
首次出现
2026年1月27日
安全审计
已安装在
opencode173
codex172
gemini-cli171
claude-code170
cursor168
github-copilot166
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
There are two workflows covered by this skill:
Both workflows involve:
aten/src/ATen/native/native_functions.yamlaten/src/ATen/native/mps/kernels/aten/src/ATen/native/mps/operations/Location: aten/src/ATen/native/native_functions.yaml
Find the operator entry and add MPS dispatch:
# Simple MPS-specific implementation
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# Shared implementation across devices (preferred for structured kernels)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# Structured kernel (preferred for new ops)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry :
# BEFORE (MPSGraph-based, separate dispatch)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # Separate MPS implementation
# AFTER (native Metal, shared dispatch via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out).
Dispatch naming conventions:
MPS: function_name_mps - MPS-specific implementation (old MPSGraph pattern)CPU, CUDA, MPS: function_name - Shared stub implementation (native Metal pattern)Location: aten/src/ATen/native/mps/kernels/
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
For binary operations, use the convenience macros defined in BinaryKernel.metal:
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPExample for atan2 (supports both float and int inputs):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};
Note on complex types: Complex numbers in Metal are represented as vector types:
c10::complex<float> maps to float2 (x = real, y = imaginary)c10::complex<half> maps to half2Use is_complex_v<T> to specialize for complex types in functors.
utils.h:
opmath_t<T> - Operation math type (half->float)accum_t<T> - Accumulation type for reductionsmax(), min() with NaN propagationspecial_math.h:
precise::exp(), precise::log(), precise::sqrt()precise::sin(), precise::cos(), precise::tan()erf(), erfc(), erfinv()indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)Location: aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
UnaryKernel.mm - Single input operations via stub dispatchBinaryKernel.mm - Two input operations via stub dispatchUnaryOps.mm / BinaryOps.mm - Legacy MPSGraph implementations (for reference)ReduceOps.mm - Reductions (sum, mean, max, etc.)For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
When migrating from MPSGraph, also remove the old implementation:
Remove from BinaryOps.mm (or UnaryOps.mm):
TORCH_IMPL_FUNC(my_op_out_mps) implementation#include <ATen/ops/my_op_native.h> headerAdd to BinaryKernel.mm (or UnaryKernel.mm):
REGISTER_DISPATCH callAfter making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:
Location: torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
# Remove entries like:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
Location: torch/testing/_internal/common_methods_invocations.py (or related files)
Remove MPS-specific decorators from the OpInfo:
OpInfo(
"my_op",
# Remove decorators like:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
# Run the specific operator test
python test/test_mps.py -k test_output_match_my_op
# Or run full MPS test suite
python test/test_mps.py
torch.mps.compile_shaderUse torch.mps.compile_shader to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently.
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
compile_shader uses dispatchThreads semantics (same as mtl_dispatch1DJob in PyTorch):
threads=[N, 1, 1] — total number of threads (NOT threadgroups)group_size=[G, 1, 1] — threads per threadgroupThis differs from the dispatchThreadgroups API used by some host-side code. To match dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):
# Equivalent compile_shader call:
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
Pass scalar constants as single-element tensors:
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference:
# 1. Run GPU kernel
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. Compute reference in Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. Compare
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once.
threads count — threads is total threads, not threadgroups. For 5 threadgroups of 256, use threads=[1280, 1, 1].compile_shader doesn't support [[threadgroup(N)]] parameters directly. If your kernel needs threadgroup memory, restructure to use threadgroup arrays declared inside the kernel body instead.native_functions.yamlkernels/operations/torch/testing/_internal/common_mps.pyWeekly Installs
181
Repository
GitHub Stars
98.5K
First Seen
Jan 27, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykPass
Installed on
opencode173
codex172
gemini-cli171
claude-code170
cursor168
github-copilot166
超能力技能使用指南:AI助手技能调用优先级与工作流程详解
46,500 周安装
Docker安全指南:全面容器安全最佳实践、漏洞扫描与合规性要求
177 周安装
iOS开发专家技能:精通Swift 6、SwiftUI与原生应用开发,涵盖架构、性能与App Store合规
177 周安装
describe技能:AI驱动结构化测试用例生成,提升代码质量与评审效率
2 周安装
专业 README 生成器 | 支持 Rust/TypeScript/Python 项目,自动应用最佳实践
2 周安装
Django 6 升级指南:从 Django 5 迁移的完整步骤与重大变更解析
1 周安装
GitLab DAG与并行处理指南:needs与parallel优化CI/CD流水线速度
2 周安装