Spatial Transformer Network
The first model I was supposed to implement as part of my Google Summer of Code project was the Spatial Transformer Network. A Spatial Transformer Network (STN) is a learnable module that can be placed in a Convolutional Neural Network (CNN), to increase the spatial invariance in an efficient manner. Spatial invariance refers to the invariance of the model towards spatial transformations of images such as rotation, translation and scaling. Invariance is the ability of the model to recognize and identify features even when the input is transformed or slightly modified. Spatial Transformers can be placed into CNNs to benefit various tasks. One example is image classification. Suppose the task is to perform classification of handwritten digits, where the position, size and orientation of the digit in each sample varies significantly. A spatial transformer crops out, transforms and scales the region of interest in the sample. Now a CNN can perform the task of classification.
A Spatial Transformer Network consists of 3 main components:
(i) Localization Network : This network takes a 4D tensor representation of a batch of images (Width x Height x Channels x Batch_Size according to Flux conventions) as input. It is a simple neural network with a few convolution layers and a few dense layers. It predicts the parameters of transformation as output. These parameters determine the angle by which the input has to be rotated, the amount of translation to be done, and the scaling factor required to focus on the region of interest in the input feature map.
(ii) Sampling Grid Generator : The transformation parameters predicted by the localization net are used in the form of an affine transformation matrix of size 2 x 3 for each image in the batch. An affine transformation is one which preserves points, straight lines and planes. Parallel lines remain parallel after affine transformation. Rotation, scaling and translation are all affine transformations.
Here, T is the transformation and A is the matrix representing the affine transformation. θ11, θ12, θ21, θ22 determine the angle by which the image has to be rotated. θ13, θ23 determine the translations along width and height of the image respectively. Thus we obtain a sampling grid of transformed indices.
(iii) Bilinear Interpolation on transformed indices : Now the indices and axes of the image have undergone an affine transformation. So its pixels have moved around. For example a point (1, 1) after rotation of axes by 45 degrees counter clockwise becomes (√2, 0). So to find the pixel value at the transformed point we need to perform bilinear interpolation using the four nearest pixel values.
To find pixel values at a point (x, y), we take the four nearest points as shown in the above figure. Here, floor(x) refers to the greatest integer function, while ceil(x) refers to ceiling function. Linear Interpolations (can be simply thought of as application of section formula) have to be done along both x and y directions. So this function returns the completely transformed image with proper pixel values at transformed indices.
The code for pure Julia implementation of the Spatial Transformer Network can be found here. I tested the functioning of my spatial transformer module on a few images. Here are a few sample images of output from the transformation function .The left image is input and right is output of the transformer module.
1) Zooming in to the region of interest
2) Zoom in on the face and rotate by 45 degrees.
3) Translate the image along width to center it.
It is clear from the above examples that the spatial transformer module is capable of performing any kind of affine transformations. During the implementation, I spent a lot of time understanding how reshape, permutedims, concatenation of arrays works as it became difficult to debug how the pixels and indices were being moved around when I used these functions. Debugging the interpolation and image indexing was the most time consuming and frustrating part during the implementation of STN.
Now I plan to train this spatial transformer module with a CNN for handwritten digit classification on a cluttered and distorted MNIST dataset. The spatial transformer will be able to increase the spatial invariance of the CNN, and is hence expected to give good classification results even when the digits are translated, rotated or scaled.
References: