2.3. Index and Shape Expressions
Open the notebook in Colab

You already know that a shape can be a tuple of symbols such as (n, m) and the elements can be accessed via indexing, e.g. a[i, j]. In practice, both shapes and indices may be computed through complex expressions. We will go through several examples in this section.

import d2ltvm
import numpy as np
import tvm
from tvm import te

2.3.1. Matrix Transpose

Our first example is matrix transpose a.T, in which we access a’s elements by columns.

n = te.var('n')
m = te.var('m')
A = te.placeholder((n, m), name='a')
B = te.compute((m, n), lambda i, j: A[j, i], 'b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
  for (i, 0, m) {
    for (j, 0, n) {
      b[((i*stride) + (j*stride))] = a[((j*stride) + (i*stride))]

Note that the 2-D index, e.g. b[i,j] are collapsed to the 1-D index b[((i*n) + j)] to follow the C convention.

Now verify the results.

a = np.arange(12, dtype='float32').reshape((3, 4))
b = np.empty((4, 3), dtype='float32')
a, b = tvm.nd.array(a), tvm.nd.array(b)

mod = tvm.build(s, [A, B])
mod(a, b)
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
[[ 0.  4.  8.]
 [ 1.  5.  9.]
 [ 2.  6. 10.]
 [ 3.  7. 11.]]

2.3.2. Reshaping

Next let’s use expressions for indexing. The following code block reshapes a 2-D array a (n by m as defined above) to 1-D (just like a.reshape(-1) in NumPy). Note how we convert the 1-D index i to the 2-D index [i//m, i%m].

B = te.compute((m*n, ), lambda i: A[i//m, i%m], name='b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
  for (i, 0, (m*n)) {
    b[i] = a[((floordiv(i, m)*stride) + (floormod(i, m)*stride))]

Since an \(n\)-D array is essentially listed in the memory as a 1-D array, the generated code does not rearrange the data sequence, but it simplifies the index expression from 2-D ((i//m)*m + i%m) to 1-D (i) to improve the efficiency.

We can implement a general 2-D reshape function as well.

p, q = te.var('p'), te.var('q')
B = te.compute((p, q), lambda i, j: A[(i*q+j)//m, (i*q+j)%m], name='b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
  for (i, 0, p) {
    for (j, 0, q) {
      b[((i*stride) + (j*stride))] = a[((floordiv(((i*q) + j), m)*stride) + (floormod(((i*q) + j), m)*stride))]

When testing the results, we should be aware that we put no constraint on the output shape, which can have an arbitrary shape (p, q), and therefore TVM will not be able to check if \(qp = nm\) for us. For example, in the following example we created a b with size (20) larger than a (12), then only the first 12 elements in b are from a, others are uninitialized values.

mod = tvm.build(s, [A, B])
a = np.arange(12, dtype='float32').reshape((3,4))
b = np.zeros((5, 4), dtype='float32')
a, b = tvm.nd.array(a), tvm.nd.array(b)

mod(a, b)
[[0.000000e+00 1.000000e+00 2.000000e+00 3.000000e+00]
 [4.000000e+00 5.000000e+00 6.000000e+00 7.000000e+00]
 [8.000000e+00 9.000000e+00 1.000000e+01 1.100000e+01]
 [8.722636e-23 3.066461e-41 9.108440e-44 0.000000e+00]
 [8.743578e-23 3.066461e-41 8.741932e-23 3.066461e-41]]

2.3.3. Slicing

Now let’s consider a special slicing operator a[bi::si, bj::sj] where bi, bj, si and sj can be specified later. Now the output shape needs to be computed based on the arguments. In addition, we need to pass the variables bi, bj, si and sj as arguments when compiling the module.

bi, bj, si, sj = [te.var(name) for name in ['bi', 'bj', 'si', 'sj']]
B = te.compute(((n-bi)//si, (m-bj)//sj), lambda i, j: A[i*si+bi, j*sj+bj], name='b')
s = te.create_schedule(B.op)
mod = tvm.build(s, [A, B, bi, si, bj, sj])

Now test two cases to verify the correctness.

b = tvm.nd.array(np.empty((1, 3), dtype='float32'))
mod(a, b, 1, 2, 1, 1)
np.testing.assert_equal(b.asnumpy(), a.asnumpy()[1::2, 1::1])

b = tvm.nd.array(np.empty((1, 2), dtype='float32'))
mod(a, b, 2, 1, 0, 2)
np.testing.assert_equal(b.asnumpy(), a.asnumpy()[2::1, 0::2])

2.3.4. Summary

  • Both shape dimensions and indices can be expressions with variables.

  • If a variable doesn’t only appear in the shape tuple, we need to pass it as an argument when compiling.