Graph Machine Learning in Genomic Prediction

This work explores how genetic relationships can be exploited alongside genomic information to predict genetic traits with the aid of graph machine learning algorithms.

By Thanh Nguyen Mueller, CSIRO Data61


Photo by David Becker on Unsplash


In genomic settings, the large volume of individuals and the high complexity of genomic architecture make valuable analytics and insight difficult. Deep learning is widely known for its flexibility and the capability to uncover complex patterns in large datasets; with these advantages, instances of deep learning in the genomics domain are emerging.

One such application is genomic prediction, where the traits of individuals — like susceptibility to disease or yield-related traits — are predicted using their genomic information. Understanding the correlation of the genetic traits and variations in genomes could have many benefits such as advancing crop breeding processes, and hence improve food security.

In this article, we explore how genetic relationships can be exploited alongside genomic information to predict genetic traits, with the aid of graph machine learning algorithms.


Deep learning in a genomic prediction context

In genomic prediction, traditional deep learning would use an individual’s genomic information — like a single nucleotide polymorphism (SNP) — as input features to the neural network. A SNP is essentially a difference that occurs at a specific position in an individual’s genome.

By observing individuals’ genomic information, e.g. SNPs, and the observed traits, the neural network will learn to predict the traits for unseen individuals from their genomic information.

Taking the below MultiLayer Perceptron (MLP) network as an example, the network contains an input layer holding the SNPs, one or more hidden layers, and an output layer that predicts the trait (quantitative or categorical). We train the network by adjusting its parameters in such a way as to minimise the average error between the predicted and observed traits of each individual in the training set, using one of the flavours of gradient descent optimisation algorithm e.g., stochastic gradient descent.


Figure 1: An MLP neural network illustrates the use of SNPs features as input, two hidden (fully connected) layers, and an output layer that predicts the trait value.


Alongside this genomic information, individuals also have genetic relationships that can be beneficial for the improvement of trait prediction accuracy. Our question is then; how can these relationships be leveraged for trait prediction?


Graph representation for trait prediction

Graph machine learning is a tool that allows us not only to utilise intrinsic information about entities (e.g., SNP features) but also relationships between the entities, to perform a prediction task. It is an extension of deep learning on data that can be modeled as a graph.

A graph of individuals would represent the individuals as nodes, and the relations between them as edges. A pedigree-based kinship matrix is something that can be portrayed as relations between individuals. This N x N matrix, where N is the number of individuals, contains pedigree-based relationship coefficients that indicate the biological relationships between the individuals, e.g. first-order (parent-child, siblings), second-order (aunts, uncle), third-order (cousins, grandparents) and so on.

With pedigree-based relationships, we can construct a graph consisting of nodes with genetic features, e.g. SNPs, and edges representing a certain degree of relatedness between them. This is a natural representation of the data that can be used for trait prediction.


Figure 2 illustrates a graph for some of purple straw wheat’s relationships. The graph on the left includes only first-order relationships. In the second graph, both first-order and second-order relationships are considered and the third graph shows the density of the connections when first-order, second-order and third-order relationships are included.


In the context of genomics breeding (e.g. wheat), besides the genetic features, growing conditions also have an important effect on individual’s traits. That is, the same species of individuals growing in different environments might share the same SNPs while having additional distinct environmental-related features and traits due to the different growing conditions. Thus, adding this to the graph is useful as we might want to:

  1. observe a plant under one environmental condition while predicting the trait of the same plant in another environment
  2. observe a plant grown in all kinds of environmental conditions and predict traits for completely different plants treated in the same set of environments.

One possible way to incorporate this information into the graph is by creating replicas of the individuals for each environmental condition and drawing an edge between the replicas, which encodes the fact that these are replicas of the same genome.


Figure 3: The graph illustrates purple straw’s first-order relationships in long-day and short-day environmental treatment. An edge either represents a pedigree relationship or a connection to its replica.


However, the edges that join the individuals’ replicas have a different semantic meaning as the edges of the pedigree relationships. To take this into account, we construct a heterogeneous graph having individuals as a single node type with pedigree and environmental condition as the two distinct edge types.


Figure 4: Purple straw in long-day and short-day environmental treatment with two edge types — “pedigree” represents the first-ordered relationship and “condition” shows a connection to its replica grown in a different environment.


So far, we have represented individuals with their environmental condition and pedigree-based relationships as a graph. Our last question is then; how can a neural network be applied to such graph-structured data for trait prediction?


Trait prediction from graph

GraphSAGE [1], belonging to a class of Graph Convolutional Neural Networks, is a neural network that when applied to a graph will learn to produce such latent vector representations — also called “embeddings” — for each node that are most suitable for the downstream prediction task (e.g. node classification or regression). It does this by fusing the node features with the aggregated features of the nodes within its neighbourhood.

Applying that to our graph of individuals above, a GraphSAGE layer forms a new embedding vector for each individual, fusing the individual’s features with those from their direct relatives and their replicas in other environments.

When stacking k GraphSAGE layers we expand each node’s neighbourhood to also fuse the embeddings of neighbours from k-hops away. As an example, with two GraphSAGE layers, we would also include the information from relatives of relatives for each individual.

For scalability, rather than fusing features of all neighbours, each GraphSAGE layer only fuses features of a set of randomly selected neighbours. With GraphSAGE, the number of layers and the number of neighbours per layer are user-defined.

Finally, feeding these node embeddings into a stack of hidden layers and the output layer, the neural network learns to adjust the produced node embeddings and the model’s parameters in such a way as to find the embeddings optimal for the trait prediction. As with MLP, the output layer of the neural network contains the predicted trait of the individuals.


Figure 5 illustrates an end-to-end graph neural network with an input layer containing the nodes and the edges (adjacency matrix), two GraphSAGE layers, two fully connected layers and an output layer.



HinSAGE for trait prediction on heterogeneous graph

The GraphSAGE algorithm only works with homogeneous graphs, thus does not make any distinction between node types and edge types when fusing information from nodes and their “neighbours” in the graph. However, such distinction is desired for the pedigree and environmental conditional relationships as shown in Figure 4, as those are semantically different relationships, and the corresponding node neighbourhoods are also different.

HinSAGE (Heterogeneous GraphSAGE) [2] is an extension of the GraphSAGE algorithm that allows us to leverage the heterogeneity of nodes and edges in the graph. HinSAGE follows a neighbourhood aggregation strategy where neighbours are selected and fused together by edge type. As a result, instead of fusing the relatives with the environmental dependent replicas, HinSAGE first fuses features from the relatives, then from the replicas (or vice versa), and only at the end fuses the results with features of the individuals themselves.

Similar to the GraphSAGE neural network, our graph neural network’s architecture consists of an input layer, one or more HinSAGE layers, one or more fully connected layers and an output layer. The input layer holds the graph with individuals as nodes, each node having SNPs and environmental features. The pedigree and environmental conditions are represented with edges of different types.


Figure 6: an end-to-end graph neural network with an input layer, two HinSAGE layers, two fully connected layers and an output layer. The input layer demonstrates the adjacency matrix with two edge types, pedigree (blue cells) and environmental condition (yellow cells).



A new potential

Graph machine learning portrays a new potential in the landscape of genomic prediction. Along with the advantages of flexibility and scalability that deep learning offers, graph machine learning lets us exploit the valuable information available in the data for our prediction task.

Despite its advantages, graph machine learning faces similar challenges to deep learning — like tuning architecture and hyperparameters for best performance— and requires a large enough dataset for training. Moreover, it needs further exploration in terms of graph representation of genomic data.

Our work to apply graph machine learning to genomic prediction is a work in progress. Nevertheless, graph machine learning is a promising tool which deserves its place in the genomics prediction toolkit.

StellarGraph is an open source python library that delivers state of the art graph machine learning algorithms on Tensorflow and Keras. To get started, run pip install stellargraph, and follow the one of the GraphSAGE or HinSAGE demos.

Thanks to Anna Leontjeva for her big contribution to this project, and Yuriy Tyshetskiy and Leda Kalleske for reviewing the blog post.

This work is supported by CSIRO’s Data61, Australia’s leading digital research network and this research is supported by the Science and Industry Endowment Fund.




  1. Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec. Neural Information Processing Systems (NIPS), 2017
  2. Heterogeneous GraphSAGE (HinSAGE): Data61’s generalisation of GraphSAGE. StellarGraph Release v0.10.0, 2020

Bio: Thanh Nguyen Mueller is a senior software engineer at CSIRO’s Data61, Australia’s leading digital research network.

Original. Reposted with permission.