By Rishi Puri, with support from the NVIDIA team and Matthias Fey (PyG)
Introduction
The diagram below, from the original paper, explains the Heterogeneous Graph Transformer (HGT) architecture.

Simply put, the HGT uses attention over features of each node and edge type in a heterogeneous graph instead of over the tokens in a sentence or pixels in an image/video as is done with traditional transformers. Recently, transformer models have been successfully applied to large scale tasks of structured inputs of many modalities, including images, video, and language. Similarly, due to the expressivity of its heterogeneous graph attention mechanism, HGT is a highly general and effective learning method for heterogeneous graph learning tasks of all kinds. The original HGTConv implementation in PyTorch Geometric (PyG) involves iterating over node and edge types which leaves ample room for optimization. In the Fall of 2022, we integrated NVIDIA’s CUTLASS Grouped GEMM kernel into the new pyg-lib which exists as a low level backend for PyG. This pyg-lib integration allows for typed matrix multiplication which provides major accelerations for heterogeneous GNNs. This talk at the Graph Learning Conference goes over the integration of typed matrix multiplication into HeteroLinear and RGCNConv and the speed boosts yielded. The same process of parallelizing over types can be done for both node and edge types for HGT.
Below is a breakdown of the forward pass and how it has been accelerated with no code changes needed by users.

Implementation
Step 1)

# Iterate over node-types:
for node_type, x in x_dict.items():
k_dict[node_type] = self.k_lin[node_type](x).view(-1, H, D)
q_dict[node_type] = self.q_lin[node_type](x).view(-1, H, D)
v_dict[node_type] = self.v_lin[node_type](x).view(-1, H, D)
out_dict[node_type] = []
Here, Linear layers are applied iteratively over node types. This for-loop can be replaced with a pyg-lib node-typed matrix multiply.
Step 2)

# Iterate over edge-types:
for edge_type, edge_index in edge_index_dict.items():
src_type, _, dst_type = edge_type
edge_type = '__'.join(edge_type)
a_rel = self.a_rel[edge_type]
k = (k_dict[src_type].transpose(0, 1) @ a_rel).transpose(1, 0)
m_rel = self.m_rel[edge_type]
v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0)
# propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor)
out = self.propagate(edge_index, k=k, q=q_dict[dst_type], v=v,
rel=self.p_rel[edge_type], size=None)
out_dict[dst_type].append(out)
Similar to the previous step, there is an iterative matrix multiplication process occurring, except over edge types instead. For each edge type, we perform matrix multiplication and then propagation using PyG’s gather-scatter schema. The message propagation gather-scatter schema is described in detail in the PyG docs. In this case, the matrix multiplications can be parallelized with a pyg-lib edge-typed matrix multiply. Then, the edge_index_dict can be concatenated into a single edge index tensor to be used for a single step of message propagation.
Step 3)

# Iterate over node-types:
for node_type, outs in out_dict.items():
out = group(outs, self.group)
if out is None:
out_dict[node_type] = None
continue
out = self.a_lin[node_type](F.gelu(out))
if out.size(-1) == x_dict[node_type].size(-1):
alpha = self.skip[node_type].sigmoid()
out = alpha * out + (1 - alpha) * x_dict[node_type]
out_dict[node_type] = out
Again this step can be accelerated by parallelizing a_lin.
Results

As can be seen in the above figure, pyg-lib provides up to 3.6x speed boost on gpu with no code changes needed. These results were gathered with an A100 NVIDIA GPU using HGTConv on a FakeHeteroDataset while varying num_node_types and num_edge_types (avg_num_nodes: int = 1000, avg_degree: int = 10, avg_num_channels: int = 64) . These GPU accelerations are turned on starting with PyG 2.3. Below is the same figure viewed at a different angle for additional ease of visualization. These accelerations will be made available through FastHGTConv which will become HGTConv in future releases.

Conclusion
HGT is one of the most powerful GNNs for solving complex heterogeneous graph tasks. Oftentimes, these tasks require enormous graphs with countless edge and node types. Using FastHGTConv, PyG users can leverage NVIDIA GPUs for most real-world applications with efficiency and ease of use.