Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric

In my last article, I introduced the concept of Graph Neural Network (GNN) and some recent advancements of it. Since this topic is getting seriously hyped up, I decided to make this tutorial on how to easily implement your Graph Neural Network in your project. You will learn how to construct your own GNN with PyTorch Geometric, and how to use GNN to solve a real-world problem (Recsys Challenge 2015).

In this blog post, we will be using PyTorch and PyTorch Geometric (PyG), a Graph Neural Network framework built on top of PyTorch that runs blazingly fast. Well … how fast is it? Compared to another popular Graph Neural Network Library, DGL, in terms of training time, it is at most 80% faster!!

Benchmark Speed Test (https://github.com/rusty1s/pytorch_geometric)

Aside from its remarkable speed, PyG comes with a collection of well-implemented GNN models illustrated in various papers. Therefore, it would be very handy to reproduce the experiments with PyG.

Given its advantage in speed and convenience, without a doubt, PyG is one of the most popular and widely used GNN libraries. Let’s dive into the topic and get our hands dirty!


  • PyTorch — 1.1.0
  • PyTorch Geometric — 1.2.0

PyTorch Geometric Basics

This section will walk you through the basics of PyG. Essentially, it will cover torch_geometric.data and torch_geometric.nn. You will learn how to pass geometric data into your GNN, and how to design a custom MessagePassing layer, the core of GNN.


The torch_geometric.data module contains a Data class that allows you to create graphs from your data very easily. You only need to specify:

  1. the attributes/ features associated with each node
  2. the connectivity/adjacency of each node (edge index)

Let’s use the following graph to demonstrate how to create a Data object

Example Graph

So there are 4 nodes in the graph, v1 … v4, each of which is associated with a 2-dimensional feature vector, and a label y indicating its class. These two can be represented as FloatTensors:

The graph connectivity (edge index) should be confined with the COO format, i.e. the first list contains the index of the source nodes, while the index of target nodes is specified in the second list.

Note that the order of the edge index is irrelevant to the Data object you create since such information is only for computing the adjacency matrix. Therefore, the above edge_index express the same information as the following one.

Putting them together, we can create a Data object as shown below:


The dataset creation procedure is not very straightforward, but it may seem familiar to those who’ve used torchvision, as PyG is following its convention. PyG provides two different types of dataset classes, InMemoryDataset and Dataset. As they indicate literally, the former one is for data that fit in your RAM, while the second one is for much larger data. Since their implementations are quite similar, I will only cover InMemoryDataset.

To create an InMemoryDataset object, there are 4 functions you need to implement:

It returns a list that shows a list of raw, unprocessed file names. If you only have a file then the returned list should only contain 1 element. In fact, you can simply return an empty list and specify your file later in process().

Similar to the last function, it also returns a list containing the file names of all the processed data. After process() is called, Usually, the returned list should only have one element, storing the only processed data file name.

This function should download the data you are working on to the directory as specified in self.raw_dir. If you don’t need to download data, simply drop in


in the function.

This is the most important method of Dataset. You need to gather your data into a list of Data objects. Then, call self.collate() to compute the slices that will be used by the DataLoader object. The following shows an example of the custom dataset from PyG official website.

I will show you how I create a custom dataset from the data provided in RecSys Challenge 2015 later in this article.


The DataLoader class allows you to feed data by batch into the model effortlessly. To create a DataLoader object, you simply specify the Dataset and the batch size you want.

loader = DataLoader(dataset, batch_size=512, shuffle=True)

Every iteration of a DataLoader object yields a Batch object, which is very much like a Data object but with an attribute, “batch”. It indicates which graph each node is associated with. Since a DataLoader aggregates x, y, and edge_index from different samples/ graphs into Batches, the GNN model needs this “batch” information to know which nodes belong to the same graph within a batch to perform computation.

for batch in loader:
>>> Batch(x=[1024, 21], edge_index=[2, 1568], y=[512], batch=[1024])


Message passing is the essence of GNN which describes how node embeddings are learned. I have talked about in my last post, so I will just briefly run through this with terms that conform to the PyG documentation.

Message Passing

x denotes the node embeddings, e denotes the edge features, 𝜙 denotes the message function, □ denotes the aggregation function, 𝛾 denotes the update function. If the edges in the graph have no feature other than connectivity, e is essentially the edge index of the graph. The superscript represents the index of the layer. When k=1, x represents the input feature of each node. Below I will illustrate how each function works:

  • propagate(edge_index, size=None, **kwargs):

It takes in edge index and other optional information, such as node features (embedding). Calling this function will consequently call message and update.

You specify how you construct “message” for each of the node pair (x_i, x_j). Since it follows the calls of propagate, it can take any argument passing to propagate. One thing to note is that you can define the mapping from arguments to the specific nodes with “_i” and “_j”. Therefore, you must be very careful when naming the argument of this function.

  • update(aggr_out, **kwargs)

It takes in the aggregated message and other arguments passed into propagate, assigning a new embedding value for each node.


Let’s see how we can implement a SageConv layer from the paper “Inductive Representation Learning on Large Graphs”. The message passing formula of SageConv is defined as:


Here, we use max pooling as the aggregation method. Therefore, the right-hand side of the first line can be written as:


which illustrates how the “message” is constructed. Each neighboring node embedding is multiplied by a weight matrix, added a bias and passed through an activation function. This can be easily done with torch.nn.Linear.

class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max')
self.lin = torch.nn.Linear(in_channels, out_channels)
self.act = torch.nn.ReLU()

def message(self, x_j):
# x_j has shape [E, in_channels]

x_j = self.lin(x_j)
x_j = self.act(x_j)

return x_j

As for the update part, the aggregated message and the current node embedding is aggregated. Then, it is multiplied by another weight matrix and applied another activation function.

class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max')
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
self.update_act = torch.nn.ReLU()

def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]

new_embedding = torch.cat([aggr_out, x], dim=1)
new_embedding = self.update_lin(new_embedding)
new_embedding = torch.update_act(new_embedding)

return new_embedding

Putting it together, we have the following SageConv layer.

A Real-World Example — RecSys Challenge 2015

The RecSys Challenge 2015 is challenging data scientists to build a session-based recommender system. Participants in this challenge are asked to solve two tasks:

  1. Predict whether there will be a buy event followed by a sequence of clicks
  2. Predict which item will be bought

First, we download the data from the official website of RecSys Challenge 2015 and construct a Dataset. We’ll start with the first task as that one is easier.

The challenge provides two main sets of data, yoochoose-clicks.dat, and yoochoose-buys.dat, containing click events and buy events, respectively. Let’s quickly glance through the data:



After downloading the data, we preprocess it so that it can be fed to our model. item_ids are categorically encoded to ensure the encoded item_ids, which will later be mapped to an embedding matrix, starts at 0.

Since the data is quite large, we subsample it for easier demonstration.

Number of unique elements in the subsampled data

To determine the ground truth, i.e. whether there is any buy event for a given session, we simply check if a session_id in yoochoose-clicks.dat presents in yoochoose-buys.dat as well.

Dataset Construction

The data is ready to be transformed into a Dataset object after the preprocessing step. Here, we treat each item in a session as a node, and therefore all items in the same session form a graph. To build the dataset, we group the preprocessed data by session_id and iterate over these groups. In each iteration, the item_id in each group are categorically encoded again since for each graph, the node index should count from 0. Thus, we have the following:

After building the dataset, we call shuffle() to make sure it has been randomly shuffled and then split it into three sets for training, validation, and testing.

Build a Graph Neural Network

The following custom GNN takes reference from one of the examples in PyG’s official Github repository. I changed the GraphConv layer with our self-implemented SAGEConv layer illustrated above. In addition, the output layer was also modified to match with a binary classification setup.


Training our custom GNN is very easy, we simply iterate the DataLoader constructed from the training set and back-propagate the loss function. Here, we use Adam as the optimizer with the learning rate set to 0.005 and Binary Cross Entropy as the loss function.


This label is highly unbalanced with an overwhelming amount of negative labels since most of the sessions are not followed by any buy event. In other words, a dumb model guessing all negatives would give you above 90% accuracy. Therefore, instead of accuracy, Area Under Curve (AUC) is a better metric for this task as it only cares if the positive examples are scored higher than the negative examples. We use the off-the-shelf AUC calculation function from Sklearn.


I trained the model for 1 epoch, and measure the training, validation, and testing AUC scores:

With only 1 Million rows of training data (around 10% of all data) and 1 epoch of training, we can obtain an AUC score of around 0.73 for validation and test set. The score is very likely to improve if more data is used to train the model with larger training steps.


You have learned the basic usage of PyTorch Geometric, including dataset construction, custom graph layer, and training GNNs with real-world data. All the code in this post can also be found in my Github repo, where you can find another Jupyter notebook file in which I solve the second task of the RecSys Challenge 2015. I hope you have enjoyed this article. Should you have any questions or comments, please leave it below! Make sure to follow me on twitter where I share my blog post or interesting Machine Learning/ Deep Learning news! Have fun playing GNN with PyG!