from torch import nn
class HaplotypeTokenizer(nn.Module):
def __init__(
self,
int = 1,
input_channels: int = 36,
input_width: int = 128,
hidden_size:
):super().__init__()
self.input_dim = input_channels * input_width
self.hidden_size = hidden_size
# simple linear projection
self.proj = torch.nn.Linear(self.input_dim, hidden_size)
def forward(self, x):
# shape should be (B, C, H, W) where H is the number
# of haplotypes and W is the number of SNPs
= x.shape
B, C, H, W # permute to (B, H, C, W)
= x.permute(0, 2, 1, 3)
x # then, flatten each "patch" of C * W such that
# each patch is 1D and size (C * W).
= x.reshape(B, H, -1)
x # embed "patches" of size (C * W, effectively a 1d
# array equivalent to the number of SNPs)
= self.proj(x)
tokens return tokens
Tiny transformers for population genetic inference
The self-attention mechanism works quite well for recombination rate classification
Introduction
Representing genetic variation with PNGs
Take a look at the two images below. These images are of shape \((H = 64, W = 36)\) — each row represents a human haplotype, and each column represents a single-nucleotide polymorphism (SNP). Each “pixel” location \(H_i, W_j\) can take a value of \(0\) (meaning haplotype \(i\) possesses the ancestral allele at site \(j\)) or \(1\) (meaning it possesses the derived allele).
Both images were simulated using a backwards-in-time simulation engine called msprime
, and each image represents the genetic variation present in collection of \(H = 64\) haplotypes sampled at \(W = 36\) consecutive SNPs. The genomes in the first image were simulated from a demographic model in the stdpopsim
catalog1, assuming a recombination rate \(\rho = 10^{-9}\). The genomes in the second image were simulated from the same demographic model, but assuming \(\rho = 10^{-8}\).
Classifying images of genetic variation with machine learning
By eye, we might be able to tell a difference between the two images, but what if we wanted to classify tens or hundreds of thousands of them? This example is totally contrived, but it represents an important kind of inference challenge in population genetics: distinguishing between the patterns of genetic variation observed in different regions of the genome. For example, we might want to discriminate between regions of the genome undergoing positive or negative selection and those evolving neutrally or near-neutrally. Although numerous statistical methods exist for doing this kind of inference, machine learning methods may be particularly well-suited to the task.
Machine learning approaches — in particular, convolutional neural networks (CNNs) — have proven “unreasonably effective”2 for population genetic inference. CNNs can be trained to detect population genetic phenomena (e.g., selection3, adaptive admixture4, incomplete lineage sorting5) and predict various summary statistics (e.g., recombination rate) by ingesting lots of labeled training data. These training data often comprise 1- or 2-channel “images” of haplotypes in genomic windows of a defined SNP length (see Figure 1), much like the ones shown above.

Images of human haplotypes aren’t like images of cats
Exchangability and permutation-invariance
Unlike images of cats, dogs, or handwritten digits, these images of human genetic variation are “row-invariant.” In other words, the information in the image is exactly the same regardless of how you permute the haplotypes. If you were to shuffle the rows of an image of a cat, on the other hand, that image would look dramatically different, and the spatial correlations between pixels would be totally broken.
To enable permutation-invariant population genetic inference on images of human haplotypes, Chan et al. developed the defiNETti
architecture (details in Note 1).
Chan et al. demonstrate that their approach is highly effective for a number of population genetics inference tasks, and outperforms established statistical methods for e.g., the detection of recombination hotspots. In addition to being permutation-invariant, the aggregation operation applied before its fully-connected layers means that defiNETti
is also (mostly) agnostic to sample size. Even if defiNETti
is trained on images of height \(H = 512\), it is insensitive to the sizes of images at test time, working quite well even when images contain as few as 64 haplotypes. In the population genetics setting, permutation-invariance and sample size indifference are extremely attractive properties of a machine learning method. For example, we might want to train a machine learning model using data from one cohort (say, a subpopulation of the Thousand Genomes consortium), and test that model on haplotypes sampled from another cohort with fewer or more samples.
Transformers as permutation-invariant architectures
Although the defiNETti
approach demonstrated that permutation-invariant CNNs are powerful, lightweight tools for population genetic inference, other architectures may be even better. For example, the transformer architecture, and the self-attention mechanism in particular, is a natural choice for embedding images of human genetic variation. To my knowledge, though, transformers are relatively unexplored in the population genetics space.
In the parlance of the vision transformer (ViT)6, we can treat input images of genetic variation as collections of “haplotype patches.” Each patch \(P\) is shape \((1, W)\), where \(W\) is the number of SNPs and \(P \in \{0, 1\}\). We first embed each patch in a new, \(d\)-dimensional feature space using a single fully-connected layer (Figure 2).

I won’t spend much time on the details here, but the basic conceit of self-attention is that the embedding of each patch is compared to the embedding of every other patch, and the weights associated with these pairwise comparisons can be tuned to capture the important relationships between patches. The transformer outputs an updated set of haplotype patch embeddings that should reflect these inter-haplotype relationships (Figure 3).

[CLS]
classification token when training the transformer and use its output embedding for classification tasks.
The self-attention mechanism is inherently invariant to the order of the patch embeddings, though we could introduce learned positional/rotary embeddings into our model if their order does matter. It can also be applied to images with any number of haplotypes (assuming our CPU/GPU hardware can handle the sequence length).
So, how well do transformers work for population genetic inference?
Materials and Methods
Let’s compare a simple transformer model to a simple implementation of the defiNETti
architecture. I’ve included pytorch
code for replicating the two models below.
Architectural details
CNN model architectures were adapted from Chan et al. (defiNETti
) and Wang et al. (2021). Convolutional operations used a kernel of shape (1, 5)
, stride of (1, 1)
, and no zero-padding. All convolutional operations were followed by ReLU
activations. In the Wang et al. architecture, convolutional operations were further followed by MaxPool
with a stride and kernel of (1, 2)
. The first convolutional layer always outputs a feature map with 32 channels, and feature maps increased in dimension by a factor of 2 with each subsequent convolutional layer. I used max
as the permutation-invariant function (to collapse along the height dimension) in all cases; empirically, I found that max
outperformed mean
as an aggregation function. After applying the permutation-invariant function, feature maps were flattened and passed to two fully-connected layers of 128 dimensions each, with ReLU
activations after each.
Our transformer model architecture follows the same general structure as in the Vision Transformer (ViT) paper — see Equations 1-4 in the ViT preprint (or the code samples below) for details. Briefly, we embedded haplotype patches using a single fully-connected layer with 128 dimensions and passed the resulting tensor of patch embeddings to a single multi-headed self-attention block with 8 heads, LayerNorm
, a multi-layer perceptron (MLP) and residual connections.
model type | kernel size | stride | max-pool | convolutional layers | fully-connected dimension | trainable params (hotspot task) | trainable params (rate task) |
---|---|---|---|---|---|---|---|
CNN (defiNETti ) |
5 | 1 | False | 2 | 128 | 256,770 | 256,899 |
CNN (Wang et al. ) |
5 | 1 | True | 2 | 128 | 76,546 | 76,675 |
model type | depth | hidden size | MLP size | heads | trainable params (hotspot task) | trainable params (rate task) |
---|---|---|---|---|---|---|
Transformer | 1 | 128 | 256 | 8 | 137,603 | 137,732 |
All models were trained with the Adam optimizer. As in Chan et al., we used the following learning rate schedule: \(10^{-3} \times 0.9^{\frac{m}{I}}\), where \(m\) is the current minibatch and \(I\) is the total number of training iterations.
Training datasets and classification tasks
I compared the performance of CNNs and transformers on two simple classification tasks.
Recombination hotspot detection
The first task was adapted from Chan et al. Training data belonged to one of two classes: regions that either contained or did not contain a recombination hotspot. All regions were simulated using a simple msprime
demographic model and an msprime
recombination RateMap
object as follows:
In all rate maps, \(\rho_{bg} \sim \mathcal{U}\{1 \times 10^{-8}, 1.5 \times 10^{-8}\}\).
left | right | mid | span | rate | |
---|---|---|---|---|---|
0 | 11,500 | 5,750 | 11,500 | \(\rho_{bg}\) | |
11,500 | 13,500 | 12,500 | 2,000 | \(\rho_{bg} \times \mathcal{U}\{10, 100\}\) | |
13,500 | 25,000 | 19,250 | 11,500 | \(\rho_{bg}\) |
left | right | mid | span | rate | |
---|---|---|---|---|---|
0 | 25,000 | 12,500 | 25,000 | \(\rho_{bg}\) |
Recombination rate classification
In the second task, training data belonged to one of three classes: regions with one of three different recombination rates (\(10^{-7}\), \(10^{-8}\), or \(10^{-9}\)). All regions were simulated using a CEU population from the OutOfAfrica_3G09
model7 in the stdpopsim
catalog.
Simulation “on-the-fly”
Rather than train these models with a fixed dataset of \(N\) images, I simulated training examples “on the fly” as described in Chan et al. In each training iteration, I simulated a fresh minibatch of 256 haplotype images of shape \((1, 200, 36)\) — with approximately equal contribution from each class — passed the minibatch through the model, and updated the model weights using cross-entropy loss. As the models never saw the same image twice during training, there was no need to run them on a “held out” validation or test set. Models were trained for 1,000 iterations (one minibatch per iteration); therefore, each model saw a total of 25,600 unique images during training. I used random seeds to ensure that each model saw the same 25,600 images during training.
Results
Transformer models outperform CNNs for recombination rate classification
Overall, the transformer exhibits higher accuracy than the CNNs on the recombination rate classification task (Figure 4).
Transformer models slightly outperform CNNs for recombination hotspot detection
Transformer and CNN performance is more similar on the recombination hotspot detection task (Figure 5), though the former still exhibits better accuracy.
Future work
This was a very simple experiment, and involved very little hyper-parameter or model tuning. At the very least, it suggests that simple transformers with self-attention are useful architectures for population genetics inference. There are many ways to investigate their utility further.
- test models on more complex and realistic classification tasks
- tweak both CNNs and transformers to find optimal hyper-parameters
- perhaps we could engineer CNNs to have comparable performance?
- compare performance on phased vs. unphased data
- fine-tune pre-trained transformer models (e.g., from
huggingface
) instead of training from scratch- we’ll need to modify these pre-trained models to ignore positional embeddings to ensure permutation-invariance
- examine the “representations” learned by each model to see if the transformer representation space is generally more or less useful for diverse classification tasks
- ensure that both models are robust to the number of haplotypes in images at test time
- randomly downsample each batch of haplotype images at test time by taking a batch of \((B, C, H, W)\) and randomly subsampling \(H' \sim \mathcal{U}\{32, H\}\) so that the new batch is \((B, C, H', W)\)
Footnotes
precisely, a CEU population from the
OutOfAfrica_3G09
demography↩︎for an overview of the utility of CNNs in popgen, see Flagel et al. 2019↩︎