PyG is the ultimate library for Graph Neural Networks

Build graph learning pipelines with ease

Become a Core Contributor

What is PyG?

PyG is a library built upon PyTorch to easily write and train Graph Neural Networks for a wide range of applications related to structured data.

PyG is both friendly to machine learning researchers and first-time users of machine learning toolkits.

Key Features

Comprehensive and Flexible Interface to Build GNNs

  • Simple abstractions to build GNNs using state-of-the-art architectures and models
  • Fully extensible to any use case for both homogeneous and heterogeneous graphs

PyTorch-on-the-rocks Design

  • Tensor-centric API following the design principles of vanilla PyTorch
  • Unified node-level, link-level and graph-level implementations

Scalable and Flexible Backend Support

  • Support for customizable feature and graph stores to scale to any graph size

Extensive Tutorials and Examples

  • Learn practically about GNNs via Videos, Colabs & Blogs
  • Application-driven Graph ML Tutorials

Scroll left

# Load a dataset:
dataset = Planetoid(root, name="Cora")
# Create a mini-batch loader: loader = NeighborLoader(dataset[0], num_neighbors=[25, 10])
# Create your GNN model: class GNN(torch.nn.Module): def __init__(self): # Choose between different GNN building blocks: self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index)
# Train you GNN model: for data in loader: y_hat = model(data.x, data.edge_index) loss = criterion(y_hat, data.y)
# Load a dataset:
dataset = TUDataset(root, name="PROTEINS")
# Create a mini-batch loader: loader = DataLoader(dataset, batch_size=256)
# Create your GNN model: class GNN(torch.nn.Module): def __init__(self): self.conv1 = GATConv(dataset.num_features, 16) self.conv2 = GATConv(16, 16) self.lin = Linear(16, dataset.num_classes) def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() # Choose between different GNN building blocks: x = global_mean_pool(x, batch) return self.lin(x)
# Train you GNN model: for data in loader: y_hat = model(data.x, data.edge_index, data.batch) loss = criterion(y_hat, data.y)
# Load a dataset:
dataset = Reddit(root)
# Create a mini-batch loader: loader = LinkNeighborLoader(dataset[0], num_neighbors=[25, 10])
# Create your GNN model: class GNN(torch.nn.Module): def __init__(self): # Choose between different GNN building blocks: self.encoder = GraphSAGE(dataset.num_features, 16, num_layers=2) Self.decoder = InnerProductDecoder() def forward(self, x, edge_index, edge_label_index): x = self.encoder(x, edge_index) return self.decoder(x, edge_label_index)
# Train you GNN model: for data in loader: y_hat = model(data.x, data.edge_index, data.edge_label_index) loss = criterion(y_hat, data.edge_label)

Ecosystem

There is a well-established ecosystem of libraries and OS tools available built on top of PyG.

Here are some of the popular ones

Contact us at team@pyg.org