torch-geometric by davila7/claude-code-templates
npx skills add https://github.com/davila7/claude-code-templates --skill torch-geometricPyTorch Geometric 是一个基于 PyTorch 构建的库,用于开发和训练图神经网络(GNN)。应用此技能进行图和不规则结构的深度学习,包括小批量处理、多 GPU 训练和几何深度学习应用。
此技能应在处理以下任务时使用:
uv pip install torch_geometric
对于额外的依赖项(稀疏操作、聚类):
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
import torch
from torch_geometric.data import Data
# 创建一个包含 3 个节点的简单图
edge_index = torch.tensor([[0, 1, 1, 2], # 源节点
[1, 0, 2, 1]], dtype=torch.long) # 目标节点
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # 节点特征
data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
from torch_geometric.datasets import Planetoid
# 加载 Cora 引文网络
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # 获取第一个(也是唯一一个)图
print(f"Dataset: {dataset}")
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
PyG 使用 torch_geometric.data.Data 类表示图,具有以下关键属性:
data.x : 节点特征矩阵 [num_nodes, num_node_features]data.edge_index : COO 格式的图连接性 [2, num_edges]data.edge_attr : 边特征矩阵 [num_edges, num_edge_features](可选)data.y : 节点或图的目标标签data.pos : 节点空间位置 [num_nodes, num_dimensions](可选)data.train_mask、data.batch)重要提示 : 这些属性不是强制性的——可以根据需要扩展 Data 对象以包含自定义属性。
边以 COO(坐标)格式存储为 [2, num_edges] 张量:
第一行:源节点索引
第二行:目标节点索引
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
PyG 通过创建块对角邻接矩阵来处理批处理,将多个图连接成一个大的不连通图:
邻接矩阵沿对角线堆叠
节点特征沿节点维度连接
batch 向量将每个节点映射到其源图
无需填充——计算效率高
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: print(f"Batch size: {batch.num_graphs}") print(f"Total nodes: {batch.num_nodes}") # batch.batch 将节点映射到图
PyG 中的 GNN 遵循邻域聚合方案:
PyG 提供了 40 多种卷积层。常见的包括:
GCNConv(图卷积网络):
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GATConv(图注意力网络):
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GraphSAGE :
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = SAGEConv(num_features, 64)
self.conv2 = SAGEConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
对于自定义层,继承自 MessagePassing:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class CustomConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", 或 "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# 向邻接矩阵添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 转换节点特征
x = self.lin(x)
# 计算归一化
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# 传播消息
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j: 源节点的特征
return norm.view(-1, 1) * x_j
关键方法:
forward() : 主要入口点message() : 从源节点到目标节点构建消息aggregate() : 聚合消息(通常不重写——设置 aggr 参数)update() : 聚合后更新节点嵌入变量命名约定 : 在张量名称后附加 _i 或 _j 会自动将它们映射到目标或源节点。
PyG 提供了广泛的基准数据集:
# 引文网络(节点分类)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora') # 或 'CiteSeer', 'PubMed'
# 图分类
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# 分子数据集
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')
# 大规模数据集
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')
查看 references/datasets_reference.md 获取完整列表。
对于可以放入内存的数据集,继承自 InMemoryDataset:
from torch_geometric.data import InMemoryDataset, Data
import torch
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['my_data.csv'] # raw_dir 中需要的文件
@property
def processed_file_names(self):
return ['data.pt'] # processed_dir 中的文件
def download(self):
# 下载原始数据到 self.raw_dir
pass
def process(self):
# 读取数据,创建 Data 对象
data_list = []
# 示例:创建一个简单的图
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
x = torch.randn(2, 16)
y = torch.tensor([0], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)
# 应用 pre_filter 和 pre_transform
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
# 保存处理后的数据
self.save(data_list, self.processed_paths[0])
对于无法放入内存的大型数据集,继承自 Dataset 并实现 len() 和 get(idx)。
import pandas as pd
import torch
from torch_geometric.data import HeteroData
# 加载节点
nodes_df = pd.read_csv('nodes.csv')
x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)
# 加载边
edges_df = pd.read_csv('edges.csv')
edge_index = torch.tensor([edges_df['source'].values,
edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
# 加载数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# 创建模型
model = GCN(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# 评估
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
# 全局池化(将节点特征聚合到图级别)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# 加载数据集
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练
model.train()
for epoch in range(100):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
对于大型图,使用 NeighborLoader 来采样子图:
from torch_geometric.loader import NeighborLoader
# 创建邻居采样器
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10], # 第一跳采样 25 个邻居,第二跳采样 10 个邻居
batch_size=128,
input_nodes=data.train_mask,
)
# 训练
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
# 仅在种子节点(前 batch_size 个节点)上计算损失
loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
重要提示 :
对于具有多种节点和边类型的图,使用 HeteroData:
from torch_geometric.data import HeteroData
data = HeteroData()
# 为不同类型添加节点特征
data['paper'].x = torch.randn(100, 128) # 100 篇论文,128 个特征
data['author'].x = torch.randn(200, 64) # 200 位作者,64 个特征
# 为不同类型添加边(源类型, 边类型, 目标类型)
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)
将同构模型转换为异构:
from torch_geometric.nn import to_hetero
# 定义同构模型
model = GNN(...)
# 转换为异构
model = to_hetero(model, data.metadata(), aggr='sum')
# 正常使用
out = model(data.x_dict, data.edge_index_dict)
或者使用 HeteroConv 进行自定义的边类型特定操作:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
}, aggr='sum')
self.conv2 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(64, 32),
('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
}, aggr='sum')
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
应用变换来修改图结构或特征:
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose
# 单一变换
transform = NormalizeFeatures()
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
# 组合多个变换
transform = Compose([
AddSelfLoops(),
NormalizeFeatures(),
])
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
常见变换:
ToUndirected, AddSelfLoops, RemoveSelfLoops, KNNGraph, RadiusGraphNormalizeFeatures, NormalizeScale, CenterRandomNodeSplit, RandomLinkSplitAddLaplacianEigenvectorPE, AddRandomWalkPE查看 references/transforms_reference.md 获取完整列表。
PyG 提供可解释性工具来理解模型预测:
from torch_geometric.explain import Explainer, GNNExplainer
# 创建解释器
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model', # 或 'phenomenon'
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
# 为特定节点生成解释
node_idx = 10
explanation = explainer(data.x, data.edge_index, index=node_idx)
# 可视化
print(f'Node {node_idx} explanation:')
print(f'Important edges: {explanation.edge_mask.topk(5).indices}')
print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
用于分层图表示:
from torch_geometric.nn import TopKPooling, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.pool1 = TopKPooling(64, ratio=0.8)
self.conv2 = GCNConv(64, 64)
self.pool2 = TopKPooling(64, ratio=0.8)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# 无向检查
from torch_geometric.utils import is_undirected
print(f"Is undirected: {is_undirected(data.edge_index)}")
# 连通分量
from torch_geometric.utils import connected_components
print(f"Connected components: {connected_components(data.edge_index)}")
# 包含自环
from torch_geometric.utils import contains_self_loops
print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
# 对于 DataLoader
for batch in loader:
batch = batch.to(device)
# 训练...
# 保存
torch.save(model.state_dict(), 'model.pth')
# 加载
model = GCN(num_features, num_classes)
model.load_state_dict(torch.load('model.pth'))
model.eval()
选择层时,请考虑以下能力:
查看 references/layer_capabilities.md 中的 GNN 速查表。
此技能包含详细的参考文档:
references/layers_reference.md : 所有 40 多种 GNN 层的完整列表,包含描述和能力references/datasets_reference.md : 按类别组织的综合数据集目录references/transforms_reference.md : 所有可用的变换及其用例references/api_patterns.md : 常见的 API 模式和编码示例scripts/ 目录中提供了实用脚本:
scripts/visualize_graph.py : 使用 networkx 和 matplotlib 可视化图结构scripts/create_gnn_template.py : 为常见的 GNN 架构生成样板代码scripts/benchmark_model.py : 在标准数据集上对模型性能进行基准测试直接执行脚本或阅读它们以获取实现模式。
每周安装次数
138
代码库
GitHub 星标数
22.6K
首次出现
Jan 21, 2026
安全审计
安装于
claude-code117
opencode111
gemini-cli103
cursor103
codex95
antigravity94
PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.
This skill should be used when working with:
uv pip install torch_geometric
For additional dependencies (sparse operations, clustering):
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
import torch
from torch_geometric.data import Data
# Create a simple graph with 3 nodes
edge_index = torch.tensor([[0, 1, 1, 2], # source nodes
[1, 0, 2, 1]], dtype=torch.long) # target nodes
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features
data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
from torch_geometric.datasets import Planetoid
# Load Cora citation network
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Get the first (and only) graph
print(f"Dataset: {dataset}")
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
PyG represents graphs using the torch_geometric.data.Data class with these key attributes:
data.x : Node feature matrix [num_nodes, num_node_features]data.edge_index : Graph connectivity in COO format [2, num_edges]data.edge_attr : Edge feature matrix [num_edges, num_edge_features] (optional)data.y : Target labels for nodes or graphsdata.pos : Node spatial positions [num_nodes, num_dimensions] (optional)Important : These attributes are not mandatory—extend Data objects with custom attributes as needed.
Edges are stored in COO (coordinate) format as a [2, num_edges] tensor:
First row: source node indices
Second row: target node indices
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:
Adjacency matrices are stacked diagonally
Node features are concatenated along the node dimension
A batch vector maps each node to its source graph
No padding needed—computationally efficient
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: print(f"Batch size: {batch.num_graphs}") print(f"Total nodes: {batch.num_nodes}") # batch.batch maps nodes to graphs
GNNs in PyG follow a neighborhood aggregation scheme:
PyG provides 40+ convolutional layers. Common ones include:
GCNConv (Graph Convolutional Network):
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GATConv (Graph Attention Network):
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GraphSAGE :
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = SAGEConv(num_features, 64)
self.conv2 = SAGEConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
For custom layers, inherit from MessagePassing:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class CustomConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", or "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Add self-loops to adjacency matrix
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Transform node features
x = self.lin(x)
# Compute normalization
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Propagate messages
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j: features of source nodes
return norm.view(-1, 1) * x_j
Key methods:
forward() : Main entry pointmessage() : Constructs messages from source to target nodesaggregate() : Aggregates messages (usually don't override—set aggr parameter)update() : Updates node embeddings after aggregationVariable naming convention : Appending _i or _j to tensor names automatically maps them to target or source nodes.
PyG provides extensive benchmark datasets:
# Citation networks (node classification)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed'
# Graph classification
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# Molecular datasets
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')
# Large-scale datasets
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')
Check references/datasets_reference.md for a comprehensive list.
For datasets that fit in memory, inherit from InMemoryDataset:
from torch_geometric.data import InMemoryDataset, Data
import torch
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['my_data.csv'] # Files needed in raw_dir
@property
def processed_file_names(self):
return ['data.pt'] # Files in processed_dir
def download(self):
# Download raw data to self.raw_dir
pass
def process(self):
# Read data, create Data objects
data_list = []
# Example: Create a simple graph
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
x = torch.randn(2, 16)
y = torch.tensor([0], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)
# Apply pre_filter and pre_transform
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
# Save processed data
self.save(data_list, self.processed_paths[0])
For large datasets that don't fit in memory, inherit from Dataset and implement len() and get(idx).
import pandas as pd
import torch
from torch_geometric.data import HeteroData
# Load nodes
nodes_df = pd.read_csv('nodes.csv')
x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)
# Load edges
edges_df = pd.read_csv('edges.csv')
edge_index = torch.tensor([edges_df['source'].values,
edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Create model
model = GCN(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Evaluation
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
# Global pooling (aggregate node features to graph-level)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# Load dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training
model.train()
for epoch in range(100):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
For large graphs, use NeighborLoader to sample subgraphs:
from torch_geometric.loader import NeighborLoader
# Create a neighbor sampler
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop
batch_size=128,
input_nodes=data.train_mask,
)
# Training
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
# Only compute loss on seed nodes (first batch_size nodes)
loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
Important :
For graphs with multiple node and edge types, use HeteroData:
from torch_geometric.data import HeteroData
data = HeteroData()
# Add node features for different types
data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features
data['author'].x = torch.randn(200, 64) # 200 authors with 64 features
# Add edges for different types (source_type, edge_type, target_type)
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)
Convert homogeneous models to heterogeneous:
from torch_geometric.nn import to_hetero
# Define homogeneous model
model = GNN(...)
# Convert to heterogeneous
model = to_hetero(model, data.metadata(), aggr='sum')
# Use as normal
out = model(data.x_dict, data.edge_index_dict)
Or use HeteroConv for custom edge-type-specific operations:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
}, aggr='sum')
self.conv2 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(64, 32),
('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
}, aggr='sum')
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
Apply transforms to modify graph structure or features:
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose
# Single transform
transform = NormalizeFeatures()
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
# Compose multiple transforms
transform = Compose([
AddSelfLoops(),
NormalizeFeatures(),
])
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
Common transforms:
ToUndirected, AddSelfLoops, RemoveSelfLoops, KNNGraph, RadiusGraphNormalizeFeatures, NormalizeScale, CenterRandomNodeSplit, RandomLinkSplitSee references/transforms_reference.md for the full list.
PyG provides explainability tools to understand model predictions:
from torch_geometric.explain import Explainer, GNNExplainer
# Create explainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model', # or 'phenomenon'
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
# Generate explanation for a specific node
node_idx = 10
explanation = explainer(data.x, data.edge_index, index=node_idx)
# Visualize
print(f'Node {node_idx} explanation:')
print(f'Important edges: {explanation.edge_mask.topk(5).indices}')
print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
For hierarchical graph representations:
from torch_geometric.nn import TopKPooling, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.pool1 = TopKPooling(64, ratio=0.8)
self.conv2 = GCNConv(64, 64)
self.pool2 = TopKPooling(64, ratio=0.8)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# Undirected check
from torch_geometric.utils import is_undirected
print(f"Is undirected: {is_undirected(data.edge_index)}")
# Connected components
from torch_geometric.utils import connected_components
print(f"Connected components: {connected_components(data.edge_index)}")
# Contains self-loops
from torch_geometric.utils import contains_self_loops
print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
# For DataLoader
for batch in loader:
batch = batch.to(device)
# Train...
# Save
torch.save(model.state_dict(), 'model.pth')
# Load
model = GCN(num_features, num_classes)
model.load_state_dict(torch.load('model.pth'))
model.eval()
When choosing layers, consider these capabilities:
See the GNN cheatsheet at references/layer_capabilities.md.
This skill includes detailed reference documentation:
references/layers_reference.md : Complete listing of all 40+ GNN layers with descriptions and capabilitiesreferences/datasets_reference.md : Comprehensive dataset catalog organized by categoryreferences/transforms_reference.md : All available transforms and their use casesreferences/api_patterns.md : Common API patterns and coding examplesUtility scripts are provided in scripts/:
scripts/visualize_graph.py : Visualize graph structure using networkx and matplotlibscripts/create_gnn_template.py : Generate boilerplate code for common GNN architecturesscripts/benchmark_model.py : Benchmark model performance on standard datasetsExecute scripts directly or read them for implementation patterns.
Weekly Installs
138
Repository
GitHub Stars
22.6K
First Seen
Jan 21, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykWarn
Installed on
claude-code117
opencode111
gemini-cli103
cursor103
codex95
antigravity94
超能力技能使用指南:AI助手技能调用优先级与工作流程详解
46,500 周安装
data.train_mask, data.batch)AddLaplacianEigenvectorPEAddRandomWalkPE