CIFAR-10 Patch Graphs: Shape-Aware Image Graph Learning
CIFAR-10 is the standard small-image classification dataset: 60,000 32x32 RGB images across 10 classes. Most uses are for CNN benchmarks. This tutorial uses CIFAR-10 differently — as a substrate for image-as-graph learning where each image becomes a graph of patches, and each patch is a tensor-valued node.
The pattern is not new (graph-based image analysis has been around for decades) but the implementation is unusually clean in TGraphX because the framework supports tensor-valued nodes natively.
What we are building
For each CIFAR-10 image (32x32 RGB):
- Divide into 16 patches of size 8x8 (4 patches across, 4 down).
- Each patch becomes a node with feature shape
[3, 8, 8]. - Connect adjacent patches (grid graph).
- Use the image's class label as the label for each node (simple supervision) — or as the label for the whole graph (graph classification).
The result is a per-image graph with 16 tensor-valued nodes. Training a tensor-aware GNN on this should outperform a flat-vector version that loses the spatial structure of each patch.
Setup
pip install tgraphx
pip install "tgraphx[pillow]" # for image dataset support
Step 1: Load CIFAR-10
The framework includes a registered dataset adapter for CIFAR-10 patch graphs:
import tgraphx as tgx
# Load CIFAR-10 with patch_size=8
dataset = tgx.load_dataset(
"cifar10_patch",
download=True,
patch_size=8,
split="train",
)
print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print(f"Sample type: {type(sample).__name__}")
The adapter returns per-image graphs with patches as nodes. Each sample is a tgx.Graph object with x.shape = [16, 3, 8, 8] and edges connecting adjacent patches in the grid.
Step 2: Inspect a single example
g = dataset[0]
print(f"Nodes: {g.num_nodes}") # 16
print(f"Edges: {g.num_edges}") # ~48 (each patch connected to neighbors)
print(f"Feature shape: {g.node_features.shape}") # [16, 3, 8, 8]
print(f"Label: {g.label}")
tgx.validate_graph(g, strict=True)
tgx.assert_tensor_native(g, min_rank=3)
This is the unit of input to the GNN: a graph of 16 image patches, each with full 3-channel spatial structure.
Step 3: Build a training graph from many images
For mini-batch graph classification, we want to process many of these per-image graphs as a batch:
from tgraphx import GraphBatch, GraphDataLoader
# DataLoader produces batched graphs
loader = GraphDataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
# batch.x: [32*16, 3, 8, 8] — patches across all 32 images
# batch.edge_index: [2, ~32*48]
# batch.batch: [32*16] — which image each patch belongs to
# batch.y: [32] — image labels
break
print(f"Batch x shape: {batch.x.shape}")
batch is a GraphBatch. The batch.batch tensor maps each node to its source graph index, which is what graph-level pooling needs.
Step 4: Build a model
from tgraphx import build_model
model = build_model(
task="graph_classification",
layer="conv", # tensor-aware ConvMessagePassing
in_shape=(3, 8, 8),
out_dim=10,
hidden_dim=64,
num_layers=3,
pooling="mean", # mean pool over nodes per graph
)
The task="graph_classification" flag tells the factory to add a graph-level pooling layer at the end. Without it, the model would output per-node logits; with it, per-graph logits.
Step 5: Train
import torch
import torch.nn.functional as F
from tgraphx import fit, set_seed
set_seed(42, deterministic=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
for epoch in range(20):
model.train()
total_loss = 0.0
correct = 0
total = 0
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.y.size(0)
correct += (out.argmax(dim=-1) == batch.y).sum().item()
total += batch.y.size(0)
print(f"Epoch {epoch+1:3d} loss={total_loss/total:.4f} acc={correct/total:.4f}")
This is a standard PyTorch training loop. Nothing unusual; the framework provides convenience but does not replace the training-loop mental model.
Step 6: Compare against a flat-vector baseline
# Build a flat-vector dataset for comparison
flat_dataset = [
type('G', (), {
'x': g.node_features.view(g.num_nodes, -1), # flatten [3,8,8] → [192]
'edge_index': g.edge_index,
'batch': g.batch,
'y': g.y,
})() for g in dataset
]
# Build a vector model
flat_model = build_model(
task="graph_classification",
layer="gcn", # vector-only
in_shape=(192,),
out_dim=10,
hidden_dim=64,
num_layers=3,
pooling="mean",
)
# Train both and compare
# (full training loop omitted for brevity)
print(f"Tensor model accuracy: {tensor_acc:.4f}")
print(f"Flat model accuracy: {flat_acc:.4f}")
For CIFAR-10 patch graphs, the tensor version may show different accuracy than the flat version when patches carry meaningful spatial structure — verify with your own measurements.
For comparison: a standard ResNet-18 on CIFAR-10 reaches ~93% accuracy. A patch-graph GNN on the same task reaches lower (typically 70-80% range depending on architecture and training). The patch-graph approach is not competitive with dedicated CNN architectures for image classification on its own. Its value is when the graph structure carries additional information — multi-image graphs, image regions with non-grid adjacency, or hybrid vision-graph tasks.
When this pattern is useful
- Image regions with non-trivial adjacency (e.g., segmentation regions connected by perimeter).
- Multi-image graphs where edges represent relationships between images.
- Vision tasks where graph structure encodes domain knowledge (e.g., scene graphs).
- Research on graph-CNN hybrids.
When it is not
- Standard image classification — use a CNN.
- ImageNet-scale workloads — the patch-graph approach has high per-image overhead.
Honest limitations
- The patch graph adapter is straightforward but produces a fixed grid structure. Custom adjacencies require manual graph construction.
- For very high-resolution images, patch count grows quickly. A 512x512 image at 16x16 patches has 1024 nodes per image; batches of these are large.
- The
tgraphx[pillow]extra is required for image loading. Skipping it disables image dataset adapters.
Reproducibility
The complete experiment is reproducible with:
with tgx.reproducible(seed=42, deterministic=True):
# full training loop
...
Set deterministic=False if you need the cuDNN speedup and can accept run-to-run variance within a few tenths of a percent.
FAQ
Q: Why patch_size=8 specifically?
A: 32/8 = 4, so we get a clean 4x4 grid of patches with no leftover pixels. patch_size=4 gives 8x8 grid with 64 patches per image — much larger graphs.
Q: Can I use overlapping patches?
A: Not with the built-in adapter. You can construct overlapping patches manually with tgx.image_to_patch_graph(image, patch_size=8, stride=4).
Q: What about edge features (e.g., visual similarity between patches)?
A: The tgx.Graph container accepts edge_features. Compute them per edge and pass them in; they will be available to layers that support edge features.
Q: How does this compare to Vision Transformers?
A: ViTs also operate on patches but use attention across all patches (not grid-restricted). For most image classification, ViTs are stronger. The patch-graph approach is useful when you specifically want grid or custom adjacency.
Q: Where is the dataset adapter implemented?
A: tgraphx.datasets.adapters.cifar10_patch in the source repository.