Chapter 14: Graph Neural Networks (GNNs) with PyTorch

Abstract:

Graph Neural Networks (GNNs) are a type of deep learning architecture designed to analyze and make predictions on data structured as graphs, which consist of nodes and the relationships (edges) between them. They are used across many fields, including social network analysis, molecular modeling, recommender systems, and computer vision, because they can handle the complex, relational nature of graph-structured data which is difficult for traditional neural networks to process.  
How GNNs work
  • Graph structure
    GNNs process data where entities are represented as nodes and their connections as edges. Information can be stored on both nodes and edges. 
  • Learning from neighbors
    GNNs work by having each node aggregate information from its neighbors. Through message-passing layers, nodes iteratively update their representations by combining features from their local neighborhood. 
  • Deepening understanding
    With each message-passing layer, a node's receptive field grows, allowing it to incorporate information from nodes further away. This helps GNNs understand complex dependencies within the graph. 
  • Prediction
    GNNs can make predictions at the node, edge, or graph level, depending on the task. 
Key applications
  • Science and medicine:
    • Molecular modeling: Predicting the properties of molecules, such as their potential as a drug, by representing atoms as nodes and bonds as edges. 
    • Protein interaction: Analyzing protein interactions to understand biological processes. 
  • Natural Language Processing (NLP):
    • Text classification: Analyzing relationships between words to classify documents.
    • Neural machine translation: Incorporating semantic and syntactic information to improve translation quality.
  • Computer vision:
    • Object and interaction detection: Reasoning about the connections between objects in an image.
    • Image classification: Using knowledge graphs to perform zero-shot learning, where the model can classify objects it hasn't seen during training.
  • Other domains:
    • Social networks: Analyzing user connections to make recommendations or detect communities. 
    • Recommender systems: Modeling user-item relationships to predict preferences. 
    • Combinatorial optimization: Solving complex sub-problems in areas like scheduling and logistics. 
    • Simulation: Modeling physical systems like fluid dynamics or traffic flow by representing particles and their interactions as a graph

Here’s the complete Chapter 14: Graph Neural Networks (GNNs) written in a textbook format with clear learning objectives, detailed explanations, examples, and exercises


Chapter 14: Graph Neural Networks (GNNs)

Learning Objectives

After completing this chapter, readers will be able to:

  • Understand the concept and structure of graph data.

  • Represent graphs mathematically for neural network processing.

  • Explain the concept of Message Passing Neural Networks (MPNNs).

  • Implement a simple GNN model using PyTorch Geometric (PyG).

  • Identify real-world applications of GNNs in social networks and biological systems.


14.1 Introduction

Graphs are a powerful data structure that can represent complex relationships between entities—such as social networks, molecules, traffic systems, or citation networks. Traditional deep learning models like CNNs and RNNs assume grid-like (image, sequence) data, making them unsuitable for irregular, non-Euclidean structures.

Graph Neural Networks (GNNs) are designed to handle such data by learning over nodes (vertices) and edges (connections), enabling effective representation learning for graph-structured problems.


14.2 Graph Data and Representations

14.2.1 Graph Basics

A graph ( G = (V, E) ) consists of:

  • V: A set of nodes (vertices).

  • E: A set of edges connecting the nodes.

Each node ( v_i \in V ) can have a feature vector ( x_i ), and each edge ( (v_i, v_j) \in E ) can also have an edge feature ( e_{ij} ).

Graphs can be:

  • Undirected (edges have no direction)

  • Directed (edges have direction)

  • Weighted (edges have associated weights)

  • Heterogeneous (different types of nodes or edges)


14.2.2 Adjacency Matrix Representation

A graph is often represented by an Adjacency Matrix (A):
[
A_{ij} =
\begin{cases}
1, & \text{if there is an edge between } i \text{ and } j \
0, & \text{otherwise}
\end{cases}
]

If a graph has ( N ) nodes, then ( A ) is an ( N \times N ) matrix.

For weighted graphs, ( A_{ij} ) holds the edge weight instead of 1.


14.2.3 Node and Edge Features

  • Node features (X): An ( N \times F ) matrix where ( F ) is the number of features per node.

  • Edge features: Represented by an ( E \times F_e ) matrix for edges.

Combining both, a graph can be described as ( G = (A, X) ).


14.3 Message Passing Neural Networks (MPNNs)

The core idea behind GNNs is message passing, where nodes exchange information with their neighbors to update their own embeddings.

14.3.1 The Message Passing Framework

Each node updates its representation by:

  1. Aggregating messages from its neighbors.

  2. Updating its state based on the aggregated information.

Mathematically:
[
h_i^{(k)} = \text{UPDATE}^{(k)} \left( h_i^{(k-1)}, \text{AGGREGATE}^{(k)} \left( { h_j^{(k-1)} : j \in \mathcal{N}(i) } \right) \right)
]
where:

  • ( h_i^{(k)} ) is the node embedding at layer ( k ),

  • ( \mathcal{N}(i) ) is the set of neighbors of node ( i ),

  • AGGREGATE can be sum, mean, or max pooling,

  • UPDATE is typically a neural network layer.


14.3.2 Types of GNN Layers

  1. Graph Convolutional Network (GCN):
    [
    H^{(k+1)} = \sigma \left( \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(k)} W^{(k)} \right)
    ]
    where ( \tilde{A} = A + I ) (adding self-loops) and ( \tilde{D} ) is the degree matrix.

  2. Graph Attention Network (GAT):
    Uses attention coefficients to weigh neighbor contributions dynamically.

  3. GraphSAGE:
    Uses sampling and aggregation functions for large-scale graphs.


14.3.3 Readout Layer

After multiple message-passing steps, a readout function combines node embeddings to get:

  • Graph-level embedding (for tasks like molecular property prediction)

  • Node-level embedding (for node classification)

Common readout functions: mean pooling, sum pooling, or attention-based pooling.


14.4 GNN Implementation using PyTorch Geometric (PyG)

14.4.1 Overview of PyTorch Geometric

PyTorch Geometric (PyG) is a popular library for graph-based deep learning, extending PyTorch for easy handling of graph data.

It provides data structures like Data objects and modules for GCN, GAT, GraphSAGE, etc.


14.4.2 Installation

pip install torch-geometric

Note: You may also need to install dependencies like torch-scatter, torch-sparse, and torch-cluster depending on your platform.


14.4.3 Creating a Graph Dataset

Example of a simple graph with 3 nodes and edges:

from torch_geometric.data import Data
import torch

# Node features: each node has 2 features
x = torch.tensor([[1, 2], [2, 3], [3, 4]], dtype=torch.float)

# Edge indices (source, target)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

data = Data(x=x, edge_index=edge_index)
print(data)

14.4.4 Implementing a Simple GCN

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

14.4.5 Training the Model

# Example data
data.y = torch.tensor([0, 1, 0])  # Labels for 3 nodes
model = GCN(input_dim=2, hidden_dim=4, output_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(50):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out, data.y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

14.5 Applications of GNNs

14.5.1 Social Networks

GNNs are used to model social connections and interactions:

  • Friend recommendation systems

  • Community detection

  • Fake news propagation analysis

Example: Predicting potential friends on Facebook by learning node embeddings based on shared mutual connections.


14.5.2 Biological and Chemical Networks

In bioinformatics and chemistry:

  • Protein-protein interaction prediction

  • Molecular property prediction

  • Drug discovery by representing molecules as graphs of atoms and bonds.

Example: Predicting toxicity or solubility using molecular graph embeddings.


14.5.3 Knowledge Graphs and Recommendation Systems

GNNs help link entities and infer relationships, improving:

  • Recommendation systems

  • Semantic reasoning

  • Graph-based search engines


14.6 Advantages and Challenges

Advantages

  • Captures complex relationships and dependencies.

  • Generalizes CNN principles to non-Euclidean data.

  • Works effectively on structured and semi-structured data.

Challenges

  • Computationally intensive for large graphs.

  • Over-smoothing issue (node embeddings become similar after many layers).

  • Data sparsity and scalability issues.


14.7 Summary

Graph Neural Networks extend deep learning to structured, relational data. Using message passing, nodes exchange and aggregate information to form meaningful representations. Libraries like PyTorch Geometric simplify GNN implementation. From social networks to biological systems, GNNs have become an essential tool for learning from connected data.


14.8 Key Terms

  • Graph: Structure of nodes and edges.

  • Adjacency Matrix: Matrix representation of graph connections.

  • Message Passing: Mechanism for exchanging information between nodes.

  • GCN: Graph Convolutional Network.

  • Readout Layer: Aggregates node features into graph-level output.


14.9 Exercises

A. Conceptual Questions

  1. Define a graph and explain how it differs from grid-structured data.

  2. What is message passing in GNNs?

  3. Differentiate between GCN and GAT architectures.

  4. Explain how an adjacency matrix is used in GNN computation.

  5. List two applications of GNNs in social networks.

B. Coding Exercises

  1. Modify the provided GCN model to use GraphSAGEConv instead of GCNConv from PyTorch Geometric.

  2. Train the model on a small graph dataset like Cora or PubMed (available in torch_geometric.datasets).

  3. Visualize node embeddings before and after training using t-SNE.


14.10 Further Reading

  • Kipf, T. N., & Welling, M. (2016). Semi-Supervised Classification with Graph Convolutional Networks. arXiv:1609.02907.

  • Veličković, P., et al. (2018). Graph Attention Networks. ICLR.

  • PyTorch Geometric Documentation: https://pytorch-geometric.readthedocs.io/

Comments