MNIST as a Graph: Tensor-Native Node Classification with TGraphX
MNIST has been the field's "Hello, World" dataset since the 1990s. Most modern uses skip it — CIFAR-10 and ImageNet have replaced it for image classification benchmarks. But MNIST remains useful as a teaching dataset: small, simple, well-understood, and perfect for demonstrating workflows.
This tutorial uses MNIST as the substrate for showing how TGraphX handles image-as-graph workflows. We convert each MNIST image into a patch graph, train a tensor-aware node classifier, and audit the run.
Setup
pip install tgraphx
pip install "tgraphx[pillow]" # for image dataset loading
You need PyTorch ≥ 1.13 and torchvision (installed automatically as a TGraphX dependency).
Step 1: Load MNIST
import torch
import torchvision
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform,
)
print(f"Dataset size: {len(dataset)}")
image_0, label_0 = dataset[0]
print(f"Image shape: {image_0.shape}, label: {label_0}")
# Image shape: torch.Size([1, 28, 28]), label: 5
Each MNIST image is [1, 28, 28] — single channel, 28x28 pixels.
Step 2: Convert images to patch graphs
The idea: divide each image into 7x7 patches (producing 4x4 = 16 patches), then build a graph where adjacent patches are connected.
import tgraphx as tgx
# Take one image as an example
image = image_0 # shape [1, 28, 28]
# Convert to patch graph
patches, edge_index = tgx.image_to_patch_graph(image, patch_size=7)
# patches.shape: [16, 1, 7, 7] — 16 nodes, each a 7x7 patch
# edge_index.shape: [2, E] — grid-adjacent connections
tgx.image_to_patch_graph returns the patch tensor and an edge index where adjacent patches in the grid are connected. The result is a graph with 16 nodes; each node carries a 7x7 patch as its feature.
Step 3: Build the full dataset
For node-classification across all images, we need a single big graph or per-image graphs. Let's go with per-image:
def image_to_graph(image, label):
patches, edge_index = tgx.image_to_patch_graph(image, patch_size=7)
# Use the image's overall label for each patch node — simple supervision
node_labels = torch.full((patches.shape[0],), label.item() if hasattr(label, 'item') else label, dtype=torch.long)
return tgx.Graph(x=patches, edge_index=edge_index, labels=node_labels)
# Build a batch of graphs
graphs = [image_to_graph(img, lbl) for img, lbl in [(image_0, label_0)]]
In a real experiment you would batch many graphs. For this tutorial we will use a single combined graph constructed by concatenating patches across many images, with appropriate edge index offsets.
Step 4: Concatenate into a single training graph
def build_combined_graph(dataset, num_images=1000, patch_size=7):
all_patches = []
all_labels = []
all_edges = []
offset = 0
for i in range(num_images):
img, lbl = dataset[i]
patches, ei = tgx.image_to_patch_graph(img, patch_size=patch_size)
all_patches.append(patches)
all_labels.append(torch.full((patches.shape[0],), lbl, dtype=torch.long))
all_edges.append(ei + offset) # offset edge indices
offset += patches.shape[0]
x = torch.cat(all_patches, dim=0)
y = torch.cat(all_labels, dim=0)
edge_index = torch.cat(all_edges, dim=1)
return tgx.Graph(x=x, edge_index=edge_index, labels=y)
g = build_combined_graph(dataset, num_images=1000)
print(f"Combined: {g.num_nodes} nodes, {g.num_edges} edges, feature shape {g.node_features.shape}")
# Combined: 16000 nodes, 76000 edges, feature shape torch.Size([16000, 1, 7, 7])
Each image contributes 16 nodes; with 1000 images that is 16,000 nodes total. Each node carries a [1, 7, 7] patch feature.
Step 5: Validate and split
tgx.validate_graph(g, strict=True)
tgx.assert_tensor_native(g, min_rank=3)
torch.manual_seed(42)
perm = torch.randperm(g.num_nodes)
n = g.num_nodes
train_mask = torch.zeros(n, dtype=torch.bool); train_mask[perm[:int(.7*n)]] = True
val_mask = torch.zeros(n, dtype=torch.bool); val_mask[perm[int(.7*n):int(.85*n)]] = True
test_mask = torch.zeros(n, dtype=torch.bool); test_mask[perm[int(.85*n):]] = True
tgx.check_leakage(train_mask, val_mask, test_mask, strict=True)
The validation and leakage checks are sanity safeguards. They run quickly and prevent obscure errors later.
Step 6: Train
with tgx.reproducible(seed=42, deterministic=False):
result = tgx.classify_nodes(
x=g.node_features,
edge_index=g.edge_index,
labels=g.node_labels,
model="tensor_gcn", # uses ConvMessagePassing internally
seed=42,
device="auto",
)
print(result.metrics)
result.summary()
This trains a tensor-aware GCN on the combined graph. The tensor_gcn model preserves the [1, 7, 7] patch shape across message-passing aggregation, applying convolutional operations rather than flattening.
Step 7: Inspect the result
# Where did the run land?
print(result.run_dir)
print(tgx.audit_run_dir(result.run_dir))
# Compare against a flat-vector baseline
result_flat = tgx.classify_nodes(
x=g.node_features.view(g.num_nodes, -1), # flattened to [N, 49]
edge_index=g.edge_index,
labels=g.node_labels,
model="gcn",
seed=42,
)
print(f"Tensor: {result.metrics['val_accuracy']:.3f}")
print(f"Flat: {result_flat.metrics['val_accuracy']:.3f}")
For this specific setup — patches as nodes, per-image label propagation — the tensor and flat versions may produce similar accuracy because the per-patch shape is small (7x7) and the message-passing structure dominates. For larger patches (16x16+) and more interesting structure (CIFAR-10 with 3-channel patches), the tensor version begins to pull ahead.
What this teaches
The workflow. Image → patch graph → tensor-aware GNN is a clean, repeatable pattern for any image-as-graph task. The same code structure works for CIFAR-10, ImageNet patches, medical imaging regions, etc.
The validation discipline. Even on a small dataset, tgx.validate_graph and tgx.check_leakage catch real bugs early.
The reproducibility. The whole experiment is wrapped in tgx.reproducible(seed=42). Running it twice produces identical numbers (modulo GPU determinism caveats).
The audit artifacts. The run directory contains everything needed to verify what happened. A teammate can audit without rerunning.
Honest assessment
MNIST is a toy dataset. The accuracy you get from this workflow is not interesting in itself. What is interesting is that the workflow scales: the same code structure with different patch_size, num_images, and model parameters applies to real research data.
For real research, you would:
- Use CIFAR-10 or a domain-specific dataset
- Use larger patches that capture meaningful spatial content
- Tune hyperparameters
- Run multiple seeds and report variance
- Inspect the audit artifacts before publishing
The tutorial pattern is the same.
FAQ
Q: How does this compare to standard image classification CNNs?
A: For straight image classification, a CNN beats this approach by a large margin — CNNs are designed for image data. The patch graph is interesting when the graph structure carries information CNNs cannot easily use (e.g., non-grid adjacency, multi-image relationships).
Q: Why per-image labels for nodes?
A: Simplest supervision. For per-patch labels (e.g., semantic segmentation), provide a different label tensor.
Q: Can I batch graphs with NeighborLoader?
A: Yes. NeighborLoader works on the combined graph and samples mini-batches across all the per-image subgraphs.
Q: What about CIFAR-10?
A: Same pattern. Use patches, ei = tgx.image_to_patch_graph(image, patch_size=8). The 3-channel [3, 8, 8] patches benefit more from tensor-aware aggregation.
Q: Where can I see this as a complete script?
A: The TGraphX repository's tutorials/tensor_node_classification_neighbor_loader.py and tutorials/image_patch_tensor_graph_demo.py show end-to-end versions.