Graph Neural Networks

From BloomWiki
Revision as of 01:52, 25 April 2026 by Wordpad (talk | contribs) (BloomWiki: Graph Neural Networks)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

How to read this page: This article maps the topic from beginner to expert across six levels � Remembering, Understanding, Applying, Analyzing, Evaluating, and Creating. Scan the headings to see the full scope, then read from wherever your knowledge starts to feel uncertain. Learn more about how BloomWiki works ?

Graph Neural Networks (GNNs) are a class of deep learning models designed to operate on graph-structured data — data where entities (nodes) are connected by relationships (edges). Unlike images or text, which have regular grid or sequence structure, graphs can represent complex relational systems: social networks, molecular structures, knowledge graphs, road networks, biological protein interaction networks, and recommendation systems. GNNs learn by propagating and aggregating information across graph neighborhoods, enabling predictions about nodes, edges, or entire graphs.

Remembering[edit]

  • Graph — A mathematical structure consisting of nodes (vertices) connected by edges. Edges can be directed or undirected, weighted or unweighted.
  • Node — An entity in a graph (a person in a social network, an atom in a molecule, a paper in a citation network).
  • Edge — A relationship between two nodes (friendship, chemical bond, citation, road connection).
  • Node features — Attribute vectors associated with each node (e.g., user profile features, atom type, paper topic).
  • Edge features — Attribute vectors associated with each edge (e.g., bond type, relationship strength, road distance).
  • Adjacency matrix — A square matrix A where A[i][j] = 1 if there is an edge from node i to node j.
  • Neighborhood — The set of nodes directly connected to a given node by edges.
  • Message passing — The core GNN operation: each node sends "messages" (feature vectors) to its neighbors, which aggregate them to update node representations.
  • Node-level task — Predicting a property of each node (e.g., classifying users as spammers or not).
  • Edge-level task — Predicting properties of or between pairs of nodes (e.g., link prediction: will user A befriend user B?).
  • Graph-level task — Predicting a property of the entire graph (e.g., classifying a molecule's toxicity).
  • GCN (Graph Convolutional Network) — A foundational GNN that aggregates node features from neighbors with normalized averaging.
  • GraphSAGE — A GNN that samples a fixed number of neighbors for scalable inductive learning on large graphs.
  • GAT (Graph Attention Network) — A GNN that learns attention weights for neighbor aggregation, giving more weight to important neighbors.
  • Readout/Pooling — Aggregating all node representations into a single graph-level representation for graph classification.

Understanding[edit]

The key insight of GNNs is: a node's representation should depend on both its own features and the features of its neighborhood. This mirrors how we understand entities in relation to their context — a person's role in society depends partly on who they're connected to.

Message passing is the universal GNN framework:

For each layer l and each node v: 1. Message: compute a message from each neighbor u: m{u→v} = M(hu^l, hv^l, e{uv}) 2. Aggregate: combine all incoming messages: av^l = AGG({m{u→v} : u ∈ N(v)}) 3. Update: compute new node representation: hv^{l+1} = U(hv^l, a_v^l)

Where M is a message function, AGG is an aggregation function (sum, mean, max, attention-weighted), and U is an update function (typically a neural network). After L layers, each node's representation encodes information from its L-hop neighborhood.

Think of it like rumor spreading: after 1 layer, each node knows about its direct neighbors; after 2 layers, about neighbors' neighbors; and so on.

Why not just use a standard neural network? Graphs are irregular — nodes have different numbers of neighbors, and there's no canonical ordering of neighbors. GNNs handle this through permutation-invariant aggregation (sum, mean, max produce the same result regardless of neighbor ordering).

Expressiveness limits: The Weisfeiler-Leman (WL) graph isomorphism test provides a theoretical upper bound on GNN expressiveness. Standard message-passing GNNs cannot distinguish certain graph structures — for example, regular graphs of the same degree look identical from any node's perspective, even if the graphs have different global structure.

Applying[edit]

Node classification with PyTorch Geometric (PyG):

<syntaxhighlight lang="python"> import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv, GATConv from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures

  1. Load Cora citation network dataset

dataset = Planetoid(root='data/Cora', name='Cora', transform=NormalizeFeatures()) data = dataset[0]

  1. 2708 nodes (papers), 10556 edges (citations), 7 classes (topics)

class GCN(torch.nn.Module):

   def __init__(self, in_channels, hidden_channels, out_channels):
       super().__init__()
       self.conv1 = GCNConv(in_channels, hidden_channels)
       self.conv2 = GCNConv(hidden_channels, out_channels)
   def forward(self, x, edge_index):
       x = self.conv1(x, edge_index)
       x = F.relu(x)
       x = F.dropout(x, p=0.5, training=self.training)
       x = self.conv2(x, edge_index)
       return F.log_softmax(x, dim=1)

model = GCN(dataset.num_features, 64, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

  1. Training loop

model.train() for epoch in range(200):

   optimizer.zero_grad()
   out = model(data.x, data.edge_index)
   loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
   loss.backward()
   optimizer.step()
  1. Evaluation

model.eval() pred = model(data.x, data.edge_index).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() accuracy = float(correct) / int(data.test_mask.sum()) print(f'Accuracy: {accuracy:.4f}') # ~0.81 for GCN on Cora </syntaxhighlight>

GNN variant selection by task
Node classification (small graph) → GCN or GAT (simple, well-understood)
Node classification (large graph, millions of nodes) → GraphSAGE (mini-batch, neighbor sampling)
Link prediction → GAT or SEAL (subgraph-based link prediction)
Molecular property prediction → MPNN, DimeNet, SchNet (3D-aware)
Knowledge graph completion → RotatE, ComplEx, RGCN

Analyzing[edit]

GNN Architecture Comparison
Architecture Aggregation Key Innovation Best For
GCN Normalized mean Spectral convolution approximation Semi-supervised node classification
GraphSAGE Sampling + mean/max/LSTM Inductive learning on unseen nodes Large-scale production graphs
GAT Attention-weighted sum Learned neighbor importance Heterogeneous neighbor importance
GIN (Graph Isomorphism Net) Sum Maximally expressive (WL-test) Graph classification
MPNN Custom Generalized message passing Molecular property prediction

Failure modes and challenges:

  • Over-smoothing — As GNN depth increases, all node representations converge to the same value. Stacking too many layers homogenizes representations, destroying discriminative power. Mitigation: skip connections, JK-Net (jumping knowledge), layer normalization.
  • Over-squashing — Information from exponentially many nodes must be compressed into fixed-size representations as layers deepen. Bottlenecks lose important long-range information.
  • Scalability — Full-graph message passing requires materializing the adjacency matrix, which is infeasible for graphs with millions of nodes. Mini-batch sampling (GraphSAGE, ClusterGCN) is essential.
  • Dynamic graphs — Most GNN architectures assume static graphs. Real-world graphs (social networks, transaction graphs) evolve over time. Temporal GNNs (TGNN, EvolveGCN) address this.
  • Heterogeneous graphs — Many real-world graphs have multiple node types and edge types. Standard GNNs treat all the same; HGT (Heterogeneous Graph Transformer) handles this.

Evaluating[edit]

Expert evaluation of GNNs requires care about what the performance numbers actually mean:

Transductive vs. inductive evaluation: Transductive GNNs (GCN) train and test on the same graph with different node masks. Inductive GNNs (GraphSAGE) generalize to entirely unseen graphs or nodes. These are different capabilities — inductive evaluation is more practically useful and harder to fake.

OGB (Open Graph Benchmark): The standard benchmark suite for GNNs. Provides large, realistic graphs with standardized train/val/test splits and a public leaderboard. Much more rigorous than small academic datasets like Cora/Citeseer.

Link prediction pitfalls: Naively splitting edges can cause data leakage (two nodes that appear together in training shouldn't appear as positive test pairs without careful subgraph isolation). The SEAL paper identified this issue; OGB provides proper splits.

Expert practitioners verify that their GNN is learning graph structure rather than node features alone — by ablating edge features, shuffling node labels, and measuring the performance delta. If removing edges doesn't hurt performance, the model isn't actually using graph structure.

Creating[edit]

Designing a GNN-based recommendation system:

1. Graph construction <syntaxhighlight lang="text"> Nodes: ├── Users (features: demographics, preferences) └── Items (features: content embeddings, metadata)

Edges: ├── User-Item: interaction (click, purchase, rating) ├── Item-Item: co-purchase, co-view, semantic similarity └── User-User: social connections (if available)

Edge features: timestamp, interaction type, rating value </syntaxhighlight>

2. Model architecture (LightGCN for collaborative filtering) <syntaxhighlight lang="text"> User/Item IDs → Embedding lookup

[GNN layers: propagate embeddings across bipartite graph]

[Sum pooling across all layers (jumping knowledge)]

[User embedding · Item embedding → interaction score]

[BPR loss: maximize score of interacted items vs. non-interacted] </syntaxhighlight>

3. Scalability considerations

  • Use FAISS for approximate nearest-neighbor retrieval of top-K items at inference
  • Pre-compute item embeddings offline; only user embeddings need real-time update
  • Cluster graph into subgraphs for mini-batch training (ClusterGCN, GraphSAINT)
  • Cache neighborhood aggregations for stable, frequently-seen nodes

4. Cold start problem

  • New users (no interaction history): fall back to content-based features or popularity
  • New items: use item content embeddings as initial node features; propagate with GNN
  • Inductive GNNs (GraphSAGE) generalize to new nodes naturally