From 100 Lines of Glue Code to One Explicit Call
A typical first GNN experiment in PyTorch looks something like this:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.loader import NeighborLoader
# Construct graph
data = Data(x=x, edge_index=edge_index, y=y)
# Train/val/test masks
n = data.num_nodes
perm = torch.randperm(n)
train_mask = torch.zeros(n, dtype=torch.bool); train_mask[perm[:int(.7*n)]] = True
val_mask = torch.zeros(n, dtype=torch.bool); val_mask[perm[int(.7*n):int(.85*n)]] = True
test_mask = torch.zeros(n, dtype=torch.bool); test_mask[perm[int(.85*n):]] = True
# Model
class GCN(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, out_dim)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index).relu()
return self.conv2(h, edge_index)
model = GCN(in_dim=x.shape[1], hidden_dim=64, out_dim=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device); data = data.to(device)
# Train
for epoch in range(50):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
# Validation
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
val_acc = (out[val_mask].argmax(-1) == data.y[val_mask]).float().mean().item()
print(f"Final val accuracy: {val_acc:.4f}")
That's ~50 lines and there is still no logging, no proper random seeding, no validation tracking across epochs, no test evaluation, no reproducibility report. A research-grade version is 100+ lines.
Most of it is boilerplate. The same 100 lines appear in every GNN paper's repository, slightly different each time, occasionally with bugs.
The one-call alternative
In TGraphX:
import tgraphx as tgx
with tgx.reproducible(seed=42, deterministic=True):
result = tgx.classify_nodes(
x=x, edge_index=edge_index, labels=y,
model="gcn", seed=42, device="auto",
)
print(f"Final val accuracy: {result.metrics['val_accuracy']:.4f}")
That is the equivalent experiment: 5 lines instead of 50. Train/val/test masks are generated automatically (70/15/15 random split by default), the model is constructed with sensible defaults, training runs with default hyperparameters, validation and test metrics are computed, a run directory is created with audit artifacts, and the result object includes everything you need to extend.
What the one-call API does
Under the hood, tgx.classify_nodes() does what you would have written manually:
- Constructs a
tgx.Graphand validates it. - Generates or accepts train/val/test masks.
- Builds an appropriate model using
build_model()(auto-selects layer based on feature rank). - Sets up an optimizer (Adam by default).
- Trains for a default number of epochs (20).
- Evaluates on validation and test.
- Writes run artifacts to
runs/exp_NNN/. - Returns a
WorkflowResultwith.model,.graph,.loader,.optimizer,.metrics,.config,.run_dir.
You can override every default by passing arguments. You can also drop down to the explicit API when you need control the one-call API does not give you.
The explicit API for when you need control
When the one-call API is too coarse, the framework's lower-level building blocks are available:
from tgraphx import Graph, build_model, fit, NeighborLoader, set_seed
set_seed(42, deterministic=True)
g = tgx.Graph(x=x, edge_index=edge_index, labels=y)
tgx.validate_graph(g, strict=True)
model = build_model(
task="node_classification",
layer="gcn",
in_shape=(x.shape[1],),
out_dim=num_classes,
hidden_dim=64,
num_layers=2,
)
loader = NeighborLoader(g, num_neighbors=[15, 10], batch_size=128, seed=42)
history = fit(
model=model,
graph=g,
loader=loader,
train_mask=my_custom_train_mask,
val_mask=my_custom_val_mask,
epochs=50,
lr=2e-3,
weight_decay=5e-4,
)
This is more verbose than the one-call but still less boilerplate than writing the full loop manually. Each piece is independently usable; you can use just build_model() and write your own training loop, or use just NeighborLoader with your own model.
When to use each level
One-call API (tgx.classify_nodes, tgx.kg_completion, etc.):
- Quick experimentation.
- Reproducible baselines for a research project.
- Teaching examples.
- AI-assisted code generation.
Explicit API (build_model, fit, samplers):
- Custom loss functions.
- Custom training-loop logic (e.g., curriculum learning).
- Custom samplers.
- When you need to drop into raw PyTorch for any reason.
Raw PyTorch with TGraphX as a data layer (Graph, NeighborLoader):
- When you have an existing PyTorch training loop you do not want to rewrite.
- When you need full control over every step.
- When you are integrating TGraphX into a larger system that has its own training infrastructure.
The framework supports all three levels. Pick the one that fits the project's stage.
The cost of brevity
A one-call API hides decisions. The default optimizer is Adam; the default learning rate is 2e-3; the default split is random 70/15/15. If you don't know the defaults, you don't know what you measured.
This is a real trade-off. The mitigation is:
- The defaults are documented in the docstring.
- The config used is saved in
result.configand in the run directory. - The defaults are sensible — they will not perform terribly on standard tasks, even if they are not optimal.
For a research baseline, the defaults are usually fine. For a competitive benchmark, hyperparameter search is needed regardless of which API you use.
A word on premature explicitness
The opposite failure mode is also real: writing a 100-line custom training loop for a problem where the one-call API would have produced the same result in 5 lines. Custom loops are easy to get wrong. The one-call API is tested. For exploratory work, use the one-call. Drop down only when you have a concrete reason.
TGraphX-specific notes
The one-call APIs are documented in docs/easy_mode.md and docs/api_cheatsheet.json. The explicit APIs are in tgraphx/__init__.py. Both are stable; the framework's API stability document lists them as Beta and unlikely to change in v1.x.
FAQ
Q: Can I pass my own model to the one-call API?
A: No, classify_nodes builds the model internally. For custom models, use the explicit API: build_model() plus fit() with the model you constructed.
Q: What if the auto-generated split is wrong for my data?
A: Pass train_mask, val_mask, test_mask arguments. The one-call API accepts them.
Q: How do I get the trained model out of the one-call result?
A: result.model. It is a standard PyTorch module.
Q: What about graph classification (not node classification)?
A: There is no one-call API for graph classification because the task is more configurable (different pooling, different aggregation). Use build_model(task="graph_classification", ...) plus fit().
Q: Where do the metrics come from?
A: result.metrics is a dict with train_loss, train_accuracy, val_accuracy, test_accuracy, and per-epoch history. The full history is in result.history.