2.3. Index and Shape Expressions¶
You already know that a shape can be a tuple of symbols such as
(n, m)
and the elements can be accessed via indexing, e.g.
a[i, j]
. In practice, both shapes and indices may be computed
through complex expressions. We will go through several examples in this
section.
import d2ltvm
import numpy as np
import tvm
from tvm import te
2.3.1. Matrix Transpose¶
Our first example is matrix transpose a.T
, in which we access
a
’s elements by columns.
n = te.var('n')
m = te.var('m')
A = te.placeholder((n, m), name='a')
B = te.compute((m, n), lambda i, j: A[j, i], 'b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
for (i, 0, m) {
for (j, 0, n) {
b[((i*stride) + (j*stride))] = a[((j*stride) + (i*stride))]
}
}
}
Note that the 2-D index, e.g. b[i,j]
are collapsed to the 1-D index
b[((i*n) + j)]
to follow the C convention.
Now verify the results.
a = np.arange(12, dtype='float32').reshape((3, 4))
b = np.empty((4, 3), dtype='float32')
a, b = tvm.nd.array(a), tvm.nd.array(b)
mod = tvm.build(s, [A, B])
mod(a, b)
print(a)
print(b)
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]]
[[ 0. 4. 8.]
[ 1. 5. 9.]
[ 2. 6. 10.]
[ 3. 7. 11.]]
2.3.2. Reshaping¶
Next let’s use expressions for indexing. The following code block
reshapes a 2-D array a
(n
by m
as defined above) to 1-D
(just like a.reshape(-1)
in NumPy). Note how we convert the 1-D
index i
to the 2-D index [i//m, i%m]
.
B = te.compute((m*n, ), lambda i: A[i//m, i%m], name='b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
for (i, 0, (m*n)) {
b[i] = a[((floordiv(i, m)*stride) + (floormod(i, m)*stride))]
}
}
Since an \(n\)-D array is essentially listed in the memory as a 1-D
array, the generated code does not rearrange the data sequence, but it
simplifies the index expression from 2-D ((i//m)*m + i%m
) to 1-D
(i
) to improve the efficiency.
We can implement a general 2-D reshape function as well.
p, q = te.var('p'), te.var('q')
B = te.compute((p, q), lambda i, j: A[(i*q+j)//m, (i*q+j)%m], name='b')
s = te.create_schedule(B.op)
tvm.lower(s, [A, B], simple_mode=True)
produce b {
for (i, 0, p) {
for (j, 0, q) {
b[((i*stride) + (j*stride))] = a[((floordiv(((i*q) + j), m)*stride) + (floormod(((i*q) + j), m)*stride))]
}
}
}
When testing the results, we should be aware that we put no constraint
on the output shape, which can have an arbitrary shape (p, q)
, and
therefore TVM will not be able to check if \(qp = nm\) for us. For
example, in the following example we created a b
with size (20)
larger than a
(12), then only the first 12 elements in b
are
from a
, others are uninitialized values.
mod = tvm.build(s, [A, B])
a = np.arange(12, dtype='float32').reshape((3,4))
b = np.zeros((5, 4), dtype='float32')
a, b = tvm.nd.array(a), tvm.nd.array(b)
mod(a, b)
print(b)
[[0.000000e+00 1.000000e+00 2.000000e+00 3.000000e+00]
[4.000000e+00 5.000000e+00 6.000000e+00 7.000000e+00]
[8.000000e+00 9.000000e+00 1.000000e+01 1.100000e+01]
[8.722636e-23 3.066461e-41 9.108440e-44 0.000000e+00]
[8.743578e-23 3.066461e-41 8.741932e-23 3.066461e-41]]
2.3.3. Slicing¶
Now let’s consider a special slicing operator a[bi::si, bj::sj]
where bi
, bj
, si
and sj
can be specified later. Now the
output shape needs to be computed based on the arguments. In addition,
we need to pass the variables bi
, bj
, si
and sj
as
arguments when compiling the module.
bi, bj, si, sj = [te.var(name) for name in ['bi', 'bj', 'si', 'sj']]
B = te.compute(((n-bi)//si, (m-bj)//sj), lambda i, j: A[i*si+bi, j*sj+bj], name='b')
s = te.create_schedule(B.op)
mod = tvm.build(s, [A, B, bi, si, bj, sj])
Now test two cases to verify the correctness.
b = tvm.nd.array(np.empty((1, 3), dtype='float32'))
mod(a, b, 1, 2, 1, 1)
np.testing.assert_equal(b.asnumpy(), a.asnumpy()[1::2, 1::1])
b = tvm.nd.array(np.empty((1, 2), dtype='float32'))
mod(a, b, 2, 1, 0, 2)
np.testing.assert_equal(b.asnumpy(), a.asnumpy()[2::1, 0::2])
2.3.4. Summary¶
Both shape dimensions and indices can be expressions with variables.
If a variable doesn’t only appear in the shape tuple, we need to pass it as an argument when compiling.