A Deeper Tutorial on Tensor-Valued Nodes in TGraphX
The introduction tutorial showed tgx.classify_nodes() — the one-call entry point. This tutorial goes deeper. We build a tensor-valued node classification workflow using the explicit API: graph construction, validation, mini-batch sampling, training loop, and artifact persistence.
Setup
pip install tgraphx
You will need Python ≥ 3.10 and PyTorch ≥ 1.13. Everything in this tutorial runs on CPU; CUDA optional.
import torch
import tgraphx as tgx
print(tgx.__version__) # 1.4.1
Step 1: Constructing the graph
Suppose we have 1000 image-patch nodes, each with shape [3, 8, 8], organized in a graph by patch adjacency:
N = 1000
node_shape = (3, 8, 8)
x = torch.randn(N, *node_shape)
edge_index = tgx.knn_graph(x.view(N, -1), k=8, metric="cosine", make_symmetric=True)
labels = torch.randint(0, 10, (N,))
g = tgx.Graph(x=x, edge_index=edge_index, labels=labels)
A few things to notice:
tgx.knn_graphoperates on a flattened view ofxfor the purpose of computing similarity. The graph it produces is just an edge index; we still pass the originalxto the Graph constructor.tgx.Graphaccepts node features of any rank. The[1000, 3, 8, 8]tensor is stored as-is.
Step 2: Validate
tgx.validate_graph(g, strict=True)
tgx.assert_tensor_native(g, min_rank=3)
Validation catches shape mismatches and out-of-bounds edge indices before training. assert_tensor_native enforces that node features have at least rank 3 — useful when you want to fail fast if someone accidentally flattens the data.
Step 3: Train/val/test masks
torch.manual_seed(0)
perm = torch.randperm(N)
train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:700]] = True
val_mask = torch.zeros(N, dtype=torch.bool); val_mask[perm[700:850]] = True
test_mask = torch.zeros(N, dtype=torch.bool); test_mask[perm[850:]] = True
tgx.check_leakage(train_mask, val_mask, test_mask, strict=True)
check_leakage is a safety check against the very common bug where one node appears in two splits.
Step 4: Build a model
from tgraphx import build_model
model = build_model(
task="node_classification",
layer="conv", # uses ConvMessagePassing for tensor-aware aggregation
in_shape=node_shape,
out_dim=10,
hidden_dim=64,
num_layers=2,
)
The factory selects appropriate layer types based on in_shape. For a rank-3 input ((3, 8, 8)), it returns a stack of tensor-aware convolutional message-passing layers.
Step 5: Mini-batch training with NeighborLoader
For graphs that do not fit comfortably in memory, use sampled mini-batches:
from tgraphx import NeighborLoader, fit
loader = NeighborLoader(
g,
num_neighbors=[10, 5], # 2-hop sample: 10 first-hop neighbors, 5 second-hop
batch_size=64,
seed=42,
)
history = fit(
model=model,
graph=g,
loader=loader,
train_mask=train_mask,
val_mask=val_mask,
epochs=20,
lr=2e-3,
)
print(history.metrics_per_epoch[-1])
fit runs a standard training loop and returns a history object with per-epoch metrics. There is nothing magical — under the hood it is a PyTorch loop you could write yourself.
Step 6: Save the experiment
The .tgx format preserves rank-4 tensors and metadata that GraphML cannot:
g.save("runs/exp_001/graph.tgx")
torch.save(model.state_dict(), "runs/exp_001/model.pt")
You can later load with tgx.Graph.load("runs/exp_001/graph.tgx").
Step 7: Reproducibility
Wrap the entire experiment in a reproducibility context so the run can be reproduced bit-for-bit on the same hardware:
with tgx.reproducible(seed=42, deterministic=True):
# construct, train, evaluate
...
This seeds Python's random, NumPy, and PyTorch RNGs. With deterministic=True, CUDA operations also run deterministically (slower, but bitwise repeatable). A reproducibility_report.json is written alongside other run artifacts.
Common mistakes
Mixing vector-only layers with tensor inputs. GCNConv, GATv2Conv, and APPNP expect rank-2 input. If you pass a [N, 3, 8, 8] tensor to them, you get a shape error. Use build_model(..., layer="conv") or pick a tensor-aware layer explicitly.
kNN graph on huge x. tgx.knn_graph is O(N²). For very large N, either sample first or pre-compute edges from a different similarity proxy.
Not validating before training. A mismatched edge index causes opaque errors at runtime. tgx.validate_graph(g, strict=True) catches them in milliseconds.
Assuming tensor features improve accuracy by default. They preserve spatial structure. Whether that structure is useful depends on your task. Benchmark against a flat-vector baseline before committing.
FAQ
Q: How does NeighborLoader interact with tensor-valued node features?
A: The sampler operates on edge_index. The node features are gathered for the sampled subset and preserved in their original shape. No flattening is done implicitly.
Q: What is RandomBalancedPartitioner for?
A: It is used by ClusterLoader for graph partitioning when training on graphs too large for NeighborLoader. See docs/cluster_gcn.md in the TGraphX repo for details.
Q: Can I export to a standard format like GraphML?
A: For flat-vector graphs, yes — TGraphX has GraphML I/O for that case. For tensor-valued nodes, GraphML cannot represent the data, so use .tgx.
Q: What happens if validation fails mid-experiment?
A: tgx.validate_graph(g, strict=True) raises a ValueError with a description of the violation. In non-strict mode it returns a result object with ok=False and a list of issues.