[GENERAL] Vision Transformers in PyTorch

When convolutions stopped being a trend?

[GENERAL] Vision Transformers in PyTorch

When convolutions stopped being a trend?

This article was originally published by Ta-Ying Cheng on Towards Data Science.


Convolutional neural networks (CNNs) have been the pre-dominant backbone for almost all networks used in computer vision and image-related tasks due to the advantages they have in 2D neighbourhood awareness and translation equivariance compared to traditional multi-layer perceptrons (MLPs). However, with the recent shift in the language processing domain of replacing recurrent neural networks with transformers, one may wonder upon the capability of transformers the image domain.

Luckily, a recent paper in ICLR 2021* have explored such capabilities and actually provides a new state-of-the-art architecture — vision transformer — that is in large contrasts to convolution-based models.

This article dives into the concept of a transformer, particularly a vision transformer and its comparison to CNNs, and discusses how to incorporate/train transformers on PyTorch despite the difficulty in training these architectures.

*Side Note: International Conference on Learning Representations (ICLR) is a top-tier prestigious conference focusing on deep learning and representations.

Jump ahead to a specific section.

What are the benefits of CNNs?

Why are CNNs so popular in the computer vision domain? The answer lies in the inherent nature of convolutions. The kernels, or the convolutional windows aggregate features from nearby pixels together, allowing features nearby to be considered together during learning. In addition, as we shift the kernels through out the images, features appearing in anywhere on the image could be detected and utilised for classification — we refer to this as translation equivariance. These characteristics allow CNNs to extract features regardless of the location the feature lies in the images, and hence encouraged significant improvements in image classification tasks in the past years.

But if CNNs do all of these, what do transformers do?

Transformers in NLP

Figure 1. Scaled Dot-Product Attention mechanism and the Multi-Head Attention mechanism in the original transformer. Source: https://arxiv.org/abs/1706.03762.

Transformers were first proposed in the area of natural language process in the paper “Attention Is All You Need”. The traditional approaches in this area (e.g., RNNs and LSTMs) take into account information of nearby words within a phrase when computing any predictions. However, as the current state (input) requires all the previous inputs to be computed, the process is sequential and thus rather slow.

Transformers utilise an attention scheme, which in some sense is essentially the correlation of vectorised words with one another, to compute the final prediction. As the correlation of one word with others is independent from the correlation of other words, simultaneous computation is possible and thus makes deep networks much more plausible in this case in terms of computation. By considering all the words and correlations, the results are actually significantly better than traditional recurrent approaches.

Moreover, transformer incorporates multi-headed attention, which runs attention mechanisms multiple times in parallel and concatenates the separated vectors into the final output.

Shifting to the Vision World

Figure 2. Vision Transformer Pipeline. Images are divided into patches and flattened to mimic a sequence. Source: https://arxiv.org/abs/2010.11929.

With the success it brings to language processing, the question arises: How can we shift the technique from languages to images? The paper vision transformer provides the most straightforward method. It divides images into patches, and further uses these patches and converts them to embeddings, then feeds them as sequences equivalent to the embeddings in language processing to find the attentions between each other.

Experimental Codes

In this section we will be exploring well-pretrained vision transformers and testing its capabilities on various datasets. It is worth noting that throughout extensive studies in the original paper, vision transformers only outperforms CNNs when the pre-trained dataset reaches a very large scale. Hence, it is less preferred to self-train it if your computational resources are fairly limited.

Datasets

To explore the capability and generalisation of vision transformers, we may want to test it on multiple datasets. Graviti open dataset platform provides many famous datasets in the CV field for free. These datasets are fast to download, and can be directly integrated into your own code using the SDK provided by Graviti.

Vision Transformer in PyTorch

As mentioned previously, vision transformers are extremely hard to train due to the extremely large scale of data needed to learn good feature extraction. It is fortunate that many Github repositories now offers pre-built and pre-trained vision transformers. Our tutorial will be based on the vision transformer from lucidrains.

To import their models, one needs to install via pip through the following:

pip install vit-pytorch

Make sure that the Pytorch and Torchvision libraries are also updated so that the versions align with each other.

You may then initialise a vision transformer with the following:

For inference, simply perform the following:

If you really want to further train your vision transformer, you may refer to a data-efficient training via distillation, published recently in this paper. This method of training is much more efficient than directly training a vision transformer. The code is also available under the above-mentioned vit-pytorch repository.

Results

Figure 2. Results of Vision Transformers on multiple dataset. Source: https://arxiv.org/abs/2010.11929.

If we refer back to the paper, we can see that large vision transformer models provide state-of-the-art results when pre-trained with very-large-scale datasets. Nevertheless, the pre-training requires significant training power for such models to achieve high accuracies.

Conclusion

Computer vision community in recent years have been dedicated to improving transformers to suit the needs of image-based tasks, or even 3D point cloud tasks. Recent ICCV 2021 papers such as cloud transformers and the best paper awardee Swin transformers both show the power of attention mechanism being the new trend in image tasks. So this is it! A brief overview of the trending transformer and its application in computer vision.