2.5. Conditional Expression: if-then-else¶
The if-then-else statement is supported through te.if_then_else.
In this section, we will introduce this expression using computing the
lower triangle of an matrix as the example.
import tvm
from tvm import te
import numpy as np
import d2ltvm
In NumPy, we can easily use np.tril to obtain the lower triangle.
a = np.arange(12, dtype='float32').reshape((3, 4))
np.tril(a)
array([[ 0., 0., 0., 0.],
[ 4., 5., 0., 0.],
[ 8., 9., 10., 0.]], dtype=float32)
Now let’s implement it in TVM with if_then_else. It accepts three
arguments, the first one is the condition, if true returning the second
argument, otherwise returning the third one.
n, m = te.var('n'), te.var('m')
A = te.placeholder((n, m))
B = te.compute(A.shape, lambda i, j: te.if_then_else(i >= j, A[i,j], 0.0))
Verify the results.
b = tvm.nd.array(np.empty_like(a))
s = te.create_schedule(B.op)
print(tvm.lower(s, [A, B], simple_mode=True))
mod = tvm.build(s, [A, B])
mod(tvm.nd.array(a), b)
b
produce compute {
for (i, 0, n) {
for (j, 0, m) {
compute[((i*stride) + (j*stride))] = tvm_if_then_else((j <= i), placeholder[((i*stride) + (j*stride))], 0f)
}
}
}
<tvm.nd.NDArray shape=(3, 4), cpu(0)>
array([[ 0., 0., 0., 0.],
[ 4., 5., 0., 0.],
[ 8., 9., 10., 0.]], dtype=float32)
2.5.1. Summary¶
We can use
tvm.if_then_elsefor the if-then-else statement.