numpy.einsum with example
Problem and solution
Recently walk into a problem, in which I found the numpy.einsum
is extremely useful and elegant to use when you exactly know what you are doing when manipulating matrix.
The problem is: given array A of shape [M, C, H, W] and array B of shape [N, M], output array C of shape [N, C, H, W], such that:
\[C_{nchw} = \sum_{m=1}^M A_{mchw} * B_{nm}.\]This is simply an extended matrix multiplication. In numpy
it can be implemented as:
import numpy as np
M = 2; N = 3; C = 4; H = 5; W = 6;
A = np.empty((M, C, H, W))
B = np.empty((N, M))
C = (B[np.newaxis, np.newaxis, np.newaxis, ...] \
@ A.transpose(1, 2, 3, 0)[..., np.newaxis]) \
.squeeze(axis=-1).transpose(3, 0, 1, 2)
print(C.shape) # (3, 4, 5, 6), i.e. (N, C, H, W)
Let’s explain the code:
A.transpose(1, 2, 3, 0)[..., np.newaxis]
: transpose A into shape [C, H, W, M] and then append another dimension in the end.B[np.newaxis, np.newaxis, np.newaxis, ...]
: insert three extra dimensions at the beginning.- by far, the operands around
@
are of shape [1, 1, 1, N, M] and [C, H, W, M, 1]. - the
@
operator is a syntax suggar fornumpy.matmul
, which does the matrix multiplication. For n-dimensional arrays as we encounter here, it does the matrix multiplication at the last two dimensions of operands, and broadcasts the results along the former axes. So the multiplication yields an array of shape [C, H, W, N, 1]. squeeze(axis=-1)
to remove the extra dimension induced by@
andtranspose(3, 0, 1, 2)
to rearrange the dimensions.
Seems our work is done! But look back into the code, and you’ll find that unless experienced, one can hardly tell its purpose at first sight, even though is basically just an extended matrix multiplication.
One can easily get lost in the transpose
, squeeze
and np.newaxis
(can be replaced by None
).
I’ve heard of the einops
pacakge during my study of deep learning, and knows that numpy
has its embeed counterpart numpy.einsum
.
Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion.
So I decided to simplify the code above with the help with it.
In fact, with the help of numpy.einsum
, the core procedure can be simplified into one line of code:
C = np.einsum('mchw,nm->nchw', A, B)
Quite amazing, right? Plus, the code itself almost explains what is does.
Explaination
Let’s dive into the detailed useage of numpy.einsum
now.
Before everything, I’ll declare that numpy.einsum
has an explicit and an implicit mode. We’ll work on the explicit mode, as it is more flexible. (What you can do in implicit mode, you can do in explicit mode.)
The basic syntax of is einsum(subscripts, *operands)
, where subscripts
defines the computational rule and operands
list all input arrays.
Firstly and most importantly, we must determine mathematically what you want as the output array. More specifically, we must figure out the computation rule for each element of the output, i.e. which elements of inputs contribute to and how do they contribute (perhaps by multiplying other elements or by summing over some axis).
In the problem above, the computation rule is that the [n, c, h, w] element of output is the dot product of A[:, c, h, w] and B[n, :].
In the Einstein summation convention, repeated indexes (in input arrays) are called dummy indexes, which means the result is summed along this axis.
So we can now determine the subscripts
is 'mchw,nm->nchw'
.
What if we suddenly don’t want the sum any more, and simpy want the multiplication and want output of shape [M, N, C, H, W]
?
Simply use 'mchw,nm->mnchw'
.
Other examples
TBD