Getting Started with Tensor-Valued Nodes in TGraphX
This tutorial walks through the basics of representing graph data with tensor-valued node features in TGraphX. By the end, you should be able to construct a graph where each node carries a multi-dimensional tensor, validate it, and run a minimal training workflow.
The API examples here are drawn from TGraphX v1.4.2. All API calls shown are verified against the library source.
Prerequisites
pip install tgraphx
Optionally, for the image-patch examples:
pip install "tgraphx[pillow]"
TGraphX requires Python ≥ 3.10 and PyTorch ≥ 1.13. No GPU is required for this tutorial — everything runs on CPU.
Part 1: Understanding Why Tensor-Valued Nodes Exist
Most GNN frameworks represent each node as a flat vector. If your graph has 100 nodes and each node has 32 features, the node feature matrix is shape [100, 32]. This is efficient and well-supported.
The challenge arises when node features are not naturally flat:
| Node type | Natural representation | Standard GNN approach | TGraphX approach |
|---|---|---|---|
| Image patch | [3, 8, 8] (C, H, W) |
Flatten to [192] |
Keep as [3, 8, 8] |
| Volumetric block | [1, 4, 4, 4] (C, D, H, W) |
Flatten to [64] |
Keep as [1, 4, 4, 4] |
| Time series | [32, 16] (T, D) |
Flatten to [512] |
Keep as [32, 16] |
| Text embedding | [D] |
Already flat | Pass-through |
Flattening does not lose information in theory, but in practice it can make model architecture choices less transparent, and it makes it harder to apply convolution-style operations inside message passing.
TGraphX's tensor-aware layers handle [N, C, H, W] and [N, C, D, H, W] features natively through convolutional aggregation.
Part 2: Constructing a TGX Graph
Minimal graph with flat vector features
import torch
import tgraphx as tgx
# 50 nodes, each with a 16-dimensional feature vector
x_flat = torch.randn(50, 16)
# Edge index: pairs [source, target] in shape [2, E]
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 0],
], dtype=torch.long)
# Labels for each node
labels = torch.randint(0, 3, (50,))
g = tgx.Graph(x=x_flat, edge_index=edge_index, labels=labels)
print(g)
Graph with tensor-valued nodes
# 50 nodes, each an 8×8 RGB image patch
x_spatial = torch.randn(50, 3, 8, 8)
g_spatial = tgx.Graph(
x=x_spatial,
edge_index=edge_index,
labels=labels,
)
print(g_spatial) # shows node_features shape: [50, 3, 8, 8]
The tgx.Graph constructor accepts the same keyword arguments regardless of feature rank. PyG-style aliases work as well:
# PyG-style
g = tgx.Graph(
x=x_spatial, # alias for node_features
edge_index=edge_index,
y=labels, # alias for node_labels
)
Part 3: Validating Your Graph
Before any training, validate the graph. Validation checks that shapes are consistent, edge indices are within bounds, and no silent issues exist:
# Basic validation
tgx.validate_graph(g_spatial)
# Strict validation — raises on any anomaly
tgx.validate_graph(g_spatial, strict=True)
# Assert that node features are genuinely tensor-native (rank ≥ 3)
tgx.assert_tensor_native(g_spatial, min_rank=3)
If you accidentally construct a graph where x.shape[1] doesn't match what downstream layers expect, validation catches it early rather than at training time with an opaque error.
Checking graph invariants
tgx.check_graph_invariants(g_spatial)
This runs a broader set of consistency checks and reports any structural anomalies.
Part 4: Building a Graph from Image Patches
A common use case for tensor-valued nodes is converting images to patch graphs:
import torch
import tgraphx as tgx
# Simulate a single RGB image of 32×32 pixels
image = torch.randn(3, 32, 32)
# Convert to a patch graph: 8×8 patches, connected in a grid
patches, edge_index = tgx.image_to_patch_graph(image, patch_size=8)
# patches shape: [16, 3, 8, 8] (16 patches of 8×8 pixels each)
# edge_index connects adjacent patches in a grid
g = tgx.Graph(x=patches, edge_index=edge_index)
tgx.validate_graph(g, strict=True)
print(f"Nodes: {g.num_nodes}, Edges: {g.num_edges}")
This pattern — image to patch graph — is the starting point for several of TGraphX's example notebooks and tutorials.
Part 5: A First Training Run
Using the easy mode (no boilerplate)
The simplest way to run a training experiment is the zero-boilerplate easy namespace:
data = tgx.easy.synthetic_tensor_node_classification(
num_nodes=500,
node_shape=(8, 8, 8), # each node is a [8,8,8] block
num_classes=5,
seed=42,
)
result = tgx.easy.train_node_classifier(
data,
model="tensor_gcn",
sampler="neighbor",
epochs=10,
seed=42,
)
result.summary()
print(f"Val accuracy: {result.metrics['val_accuracy']:.3f}")
Using the one-call API
If you have your own data:
import torch
import tgraphx as tgx
# Your own tensor-valued graph data
num_nodes = 200
x = torch.randn(num_nodes, 3, 8, 8) # [N, C, H, W]
edge_index = tgx.knn_graph(
x.view(num_nodes, -1), # kNN on flattened features for connectivity
k=5,
metric="cosine",
make_symmetric=True,
)
labels = torch.randint(0, 4, (num_nodes,))
result = tgx.classify_nodes(
x=x,
edge_index=edge_index,
labels=labels,
model="tensor_gcn",
seed=42,
device="auto",
)
print(result.metrics)
Using the standard training API
For researchers who want explicit control:
from tgraphx import Graph, NeighborLoader, build_model, set_seed, fit
set_seed(42)
# Construct graph
g = tgx.Graph(x=x, edge_index=edge_index, labels=labels)
# Build model — auto-selects tensor-aware layers for rank-4 input
model = build_model(
task="node_classification",
layer="conv", # uses ConvMessagePassing
in_shape=x.shape[1:], # (3, 8, 8)
out_dim=4,
hidden_dim=64,
num_layers=2,
)
# NeighborLoader for mini-batch training
loader = NeighborLoader(
g,
num_neighbors=[10, 5],
batch_size=32,
seed=42,
)
# Train using the standard fit() function
train_mask = torch.rand(g.num_nodes) < 0.7
history = fit(
model=model,
graph=g,
loader=loader,
train_mask=train_mask,
epochs=20,
)
Part 6: Saving and Loading Graphs
The .tgx format preserves multi-dimensional tensors that standard graph formats cannot represent:
# Save
g.save("my_patch_graph.tgx")
# Load
g_loaded = tgx.Graph.load("my_patch_graph.tgx")
print(g_loaded.node_features.shape) # [N, 3, 8, 8]
This is useful for pre-computed patch graphs or any workflow where re-creating the graph is expensive.
Part 7: Reproducibility
Experimental reproducibility should be set up before any randomized operation:
with tgx.reproducible(seed=42, deterministic=True):
result = tgx.classify_nodes(
x=x, edge_index=edge_index, labels=labels,
)
Or as a context:
tgx.set_seed(42, deterministic=False)
# ... rest of experiment
deterministic=True enables CUDA determinism, which can be slower but ensures identical results across runs on the same hardware.
Common Mistakes
1. Forgetting to validate before training
A mismatched edge index (containing a node ID that's out of bounds) will cause an obscure error at runtime. Call tgx.validate_graph(g) early.
2. Using tensor features with vector-only layers
Not all layers support tensor-valued inputs. The layers GCNConv, GATv2Conv, and APPNP expect [N, D] vector features. For [N, C, H, W] inputs, use ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer, or TensorGINLayer. The build_model() factory selects the right layer automatically based on feature rank.
3. kNN on large tensors
tgx.knn_graph(x, k=10) with large x is O(N²). For N > a few thousand, either sample nodes first or use a pre-built edge index.
4. Assuming tensor nodes improve accuracy automatically
Tensor-valued nodes preserve spatial structure. Whether that structure is useful depends on whether spatial correlation in your features is informative for the task. The benefit is not automatic; it depends on your data and model.
Summary
TGraphX's tgx.Graph container handles multi-dimensional node features natively. The key steps for a tensor-node workflow are:
- Construct:
tgx.Graph(x=..., edge_index=..., labels=...) - Validate:
tgx.validate_graph(g, strict=True) - Build:
tgx.build_model(task="node_classification", layer="conv", in_shape=...) - Train:
tgx.classify_nodes(...)orfit(model, graph, loader, ...) - Save:
g.save("file.tgx")
For most starting-out experiments, tgx.easy.train_node_classifier() or tgx.classify_nodes() will handle steps 2–4 automatically.
FAQ
Q: When should I use tensor-valued nodes vs. flat vectors?
A: Use tensor-valued nodes when your features have spatial or temporal structure that you want to preserve across message-passing steps — image patches, volumetric blocks, or time sequences. If your features are already a bag-of-words, embeddings, or a list of scalars, flat vectors are simpler and sufficient.
Q: Does classify_nodes work with flat [N, D] features too?
A: Yes. tgx.classify_nodes() auto-detects feature rank and selects appropriate layers. For rank-2 inputs ([N, D]), it uses vector layers like GCNConv.
Q: How do I handle edge features?
A: tgx.Graph accepts edge_features of any shape. Edge features are preserved but their integration in message-passing depends on the specific layer. Check the layer documentation for ConvMessagePassing and TensorGATLayer.
Q: Is there a way to visualize what the patch graph looks like?
A: The tgx.plotting submodule has utilities for graph visualization. For patch graphs specifically, tgx.dashboard_audit() and the local dashboard can display run artifacts. For custom visualization, convert to NetworkX with g.to_networkx() and use NetworkX's drawing utilities.
Q: Can I use TGraphX graphs with standard PyTorch modules?
A: Yes. g.node_features is a regular torch.Tensor. You can pass it to any PyTorch module. The graph container and loaders are helpers; the underlying data is standard PyTorch.