TGraphX Insights MNIST as a Graph: Tensor-Native Node Classification with TGraphX
← Back to Insights

MNIST as a Graph: Tensor-Native Node Classification with TGraphX

Target keyword: MNIST graph neural network node classification

MNIST as a Graph: Tensor-Native Node Classification with TGraphX

MNIST has been the field's "Hello, World" dataset since the 1990s. Most modern uses skip it — CIFAR-10 and ImageNet have replaced it for image classification benchmarks. But MNIST remains useful as a teaching dataset: small, simple, well-understood, and perfect for demonstrating workflows.

This tutorial uses MNIST as the substrate for showing how TGraphX handles image-as-graph workflows. We convert each MNIST image into a patch graph, train a tensor-aware node classifier, and audit the run.

Setup

bash
pip install tgraphx
        pip install "tgraphx[pillow]"  # for image dataset loading
        

You need PyTorch ≥ 1.13 and torchvision (installed automatically as a TGraphX dependency).

Step 1: Load MNIST

python
import torch
        import torchvision
        from torchvision import transforms
        
        transform = transforms.Compose([transforms.ToTensor()])
        dataset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform,
        )
        
        print(f"Dataset size: {len(dataset)}")
        image_0, label_0 = dataset[0]
        print(f"Image shape: {image_0.shape}, label: {label_0}")
        # Image shape: torch.Size([1, 28, 28]), label: 5
        

Each MNIST image is [1, 28, 28] — single channel, 28x28 pixels.

Step 2: Convert images to patch graphs

The idea: divide each image into 7x7 patches (producing 4x4 = 16 patches), then build a graph where adjacent patches are connected.

python
import tgraphx as tgx
        
        # Take one image as an example
        image = image_0   # shape [1, 28, 28]
        
        # Convert to patch graph
        patches, edge_index = tgx.image_to_patch_graph(image, patch_size=7)
        # patches.shape:    [16, 1, 7, 7]   — 16 nodes, each a 7x7 patch
        # edge_index.shape: [2, E]          — grid-adjacent connections
        

tgx.image_to_patch_graph returns the patch tensor and an edge index where adjacent patches in the grid are connected. The result is a graph with 16 nodes; each node carries a 7x7 patch as its feature.

Step 3: Build the full dataset

For node-classification across all images, we need a single big graph or per-image graphs. Let's go with per-image:

python
def image_to_graph(image, label):
            patches, edge_index = tgx.image_to_patch_graph(image, patch_size=7)
            # Use the image's overall label for each patch node — simple supervision
            node_labels = torch.full((patches.shape[0],), label.item() if hasattr(label, 'item') else label, dtype=torch.long)
            return tgx.Graph(x=patches, edge_index=edge_index, labels=node_labels)
        
        # Build a batch of graphs
        graphs = [image_to_graph(img, lbl) for img, lbl in [(image_0, label_0)]]
        

In a real experiment you would batch many graphs. For this tutorial we will use a single combined graph constructed by concatenating patches across many images, with appropriate edge index offsets.

Step 4: Concatenate into a single training graph

python
def build_combined_graph(dataset, num_images=1000, patch_size=7):
            all_patches = []
            all_labels  = []
            all_edges   = []
            offset = 0
            for i in range(num_images):
                img, lbl = dataset[i]
                patches, ei = tgx.image_to_patch_graph(img, patch_size=patch_size)
                all_patches.append(patches)
                all_labels.append(torch.full((patches.shape[0],), lbl, dtype=torch.long))
                all_edges.append(ei + offset)  # offset edge indices
                offset += patches.shape[0]
            x = torch.cat(all_patches, dim=0)
            y = torch.cat(all_labels, dim=0)
            edge_index = torch.cat(all_edges, dim=1)
            return tgx.Graph(x=x, edge_index=edge_index, labels=y)
        
        g = build_combined_graph(dataset, num_images=1000)
        print(f"Combined: {g.num_nodes} nodes, {g.num_edges} edges, feature shape {g.node_features.shape}")
        # Combined: 16000 nodes, 76000 edges, feature shape torch.Size([16000, 1, 7, 7])
        

Each image contributes 16 nodes; with 1000 images that is 16,000 nodes total. Each node carries a [1, 7, 7] patch feature.

Step 5: Validate and split

python
tgx.validate_graph(g, strict=True)
        tgx.assert_tensor_native(g, min_rank=3)
        
        torch.manual_seed(42)
        perm = torch.randperm(g.num_nodes)
        n = g.num_nodes
        train_mask = torch.zeros(n, dtype=torch.bool); train_mask[perm[:int(.7*n)]] = True
        val_mask   = torch.zeros(n, dtype=torch.bool); val_mask[perm[int(.7*n):int(.85*n)]] = True
        test_mask  = torch.zeros(n, dtype=torch.bool); test_mask[perm[int(.85*n):]] = True
        tgx.check_leakage(train_mask, val_mask, test_mask, strict=True)
        

The validation and leakage checks are sanity safeguards. They run quickly and prevent obscure errors later.

Step 6: Train

python
with tgx.reproducible(seed=42, deterministic=False):
            result = tgx.classify_nodes(
                x=g.node_features,
                edge_index=g.edge_index,
                labels=g.node_labels,
                model="tensor_gcn",   # uses ConvMessagePassing internally
                seed=42,
                device="auto",
            )
        
        print(result.metrics)
        result.summary()
        

This trains a tensor-aware GCN on the combined graph. The tensor_gcn model preserves the [1, 7, 7] patch shape across message-passing aggregation, applying convolutional operations rather than flattening.

Step 7: Inspect the result

python
# Where did the run land?
        print(result.run_dir)
        print(tgx.audit_run_dir(result.run_dir))
        
        # Compare against a flat-vector baseline
        result_flat = tgx.classify_nodes(
            x=g.node_features.view(g.num_nodes, -1),  # flattened to [N, 49]
            edge_index=g.edge_index,
            labels=g.node_labels,
            model="gcn",
            seed=42,
        )
        print(f"Tensor: {result.metrics['val_accuracy']:.3f}")
        print(f"Flat:   {result_flat.metrics['val_accuracy']:.3f}")
        

For this specific setup — patches as nodes, per-image label propagation — the tensor and flat versions may produce similar accuracy because the per-patch shape is small (7x7) and the message-passing structure dominates. For larger patches (16x16+) and more interesting structure (CIFAR-10 with 3-channel patches), the tensor version begins to pull ahead.

What this teaches

The workflow. Image → patch graph → tensor-aware GNN is a clean, repeatable pattern for any image-as-graph task. The same code structure works for CIFAR-10, ImageNet patches, medical imaging regions, etc.

The validation discipline. Even on a small dataset, tgx.validate_graph and tgx.check_leakage catch real bugs early.

The reproducibility. The whole experiment is wrapped in tgx.reproducible(seed=42). Running it twice produces identical numbers (modulo GPU determinism caveats).

The audit artifacts. The run directory contains everything needed to verify what happened. A teammate can audit without rerunning.

Honest assessment

MNIST is a toy dataset. The accuracy you get from this workflow is not interesting in itself. What is interesting is that the workflow scales: the same code structure with different patch_size, num_images, and model parameters applies to real research data.

For real research, you would:

  • Use CIFAR-10 or a domain-specific dataset
  • Use larger patches that capture meaningful spatial content
  • Tune hyperparameters
  • Run multiple seeds and report variance
  • Inspect the audit artifacts before publishing

The tutorial pattern is the same.


FAQ

Q: How does this compare to standard image classification CNNs?
A: For straight image classification, a CNN beats this approach by a large margin — CNNs are designed for image data. The patch graph is interesting when the graph structure carries information CNNs cannot easily use (e.g., non-grid adjacency, multi-image relationships).

Q: Why per-image labels for nodes?
A: Simplest supervision. For per-patch labels (e.g., semantic segmentation), provide a different label tensor.

Q: Can I batch graphs with NeighborLoader?
A: Yes. NeighborLoader works on the combined graph and samples mini-batches across all the per-image subgraphs.

Q: What about CIFAR-10?
A: Same pattern. Use patches, ei = tgx.image_to_patch_graph(image, patch_size=8). The 3-channel [3, 8, 8] patches benefit more from tensor-aware aggregation.

Q: Where can I see this as a complete script?
A: The TGraphX repository's tutorials/tensor_node_classification_neighbor_loader.py and tutorials/image_patch_tensor_graph_demo.py show end-to-end versions.