2.5. Conditional Expression: if-then-else
¶
The if-then-else
statement is supported through te.if_then_else
.
In this section, we will introduce this expression using computing the
lower triangle of an matrix as the example.
import tvm
from tvm import te
import numpy as np
import d2ltvm
In NumPy, we can easily use np.tril
to obtain the lower triangle.
a = np.arange(12, dtype='float32').reshape((3, 4))
np.tril(a)
array([[ 0., 0., 0., 0.],
[ 4., 5., 0., 0.],
[ 8., 9., 10., 0.]], dtype=float32)
Now let’s implement it in TVM with if_then_else
. It accepts three
arguments, the first one is the condition, if true returning the second
argument, otherwise returning the third one.
n, m = te.var('n'), te.var('m')
A = te.placeholder((n, m))
B = te.compute(A.shape, lambda i, j: te.if_then_else(i >= j, A[i,j], 0.0))
Verify the results.
b = tvm.nd.array(np.empty_like(a))
s = te.create_schedule(B.op)
print(tvm.lower(s, [A, B], simple_mode=True))
mod = tvm.build(s, [A, B])
mod(tvm.nd.array(a), b)
b
produce compute {
for (i, 0, n) {
for (j, 0, m) {
compute[((i*stride) + (j*stride))] = tvm_if_then_else((j <= i), placeholder[((i*stride) + (j*stride))], 0f)
}
}
}
<tvm.nd.NDArray shape=(3, 4), cpu(0)>
array([[ 0., 0., 0., 0.],
[ 4., 5., 0., 0.],
[ 8., 9., 10., 0.]], dtype=float32)
2.5.1. Summary¶
We can use
tvm.if_then_else
for the if-then-else statement.