5. Convolution¶
In this section, we will extend Section 8 to optimize convolution on GPUs.
import d2ltvm
import numpy as np
import timeit
import tvm
from tvm import te
target = 'cuda'
5.1. Setup¶
As usual, we will use MXNet as our baseline, which calls cuDNN to
execute the convolution. Like what we have done on CPUs, we benchmark
the performance with various numbers of channels, when the input and
kernel width/height are fixed to be 64 and 3, respectively. The
benchmark method bench_conv_mxnet
has already been defined in
Section 8. The only change that we need to make is
to specify to the method that the target device is GPU
.
channels = 2**np.arange(4, 9)
# a list of (c, n, k)
sizes = [(int(c), 64, 3) for c in channels]
mx_gflops = d2ltvm.bench_conv_mxnet(sizes, 'gpu')
d2ltvm.plot_gflops(channels, [mx_gflops], ['mxnet'])
mx_gflops
[134.41320431231432,
547.5253539941237,
1639.1830964521912,
4471.396488551705,
6281.556088500076]
The results above show that on GPUs the performance of convolution increases while the number of channels increases. For showing the gradual performance improvement brought by TVM scheduling, we now fix the channel size to be 64, where MXNet could get to the performance of about 1700 GFLOPS or 1.7 TGLOPS. Please keep this in mind while we are working on the TVM scheduling.
sizes = [(64, 64, 3)]
5.2. Default schedule of CONV¶
We then describe the computation of convolution in TVM using conv
method, which is defined in Section 8. For a
default schedule, we can simply bind two axes of the convolution loop
nest into block and thread indexing axes.
def default_sch(oc, ic, n, k, p, s):
X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
sch[PaddedX].compute_inline()
_, y, x = sch[Y].op.axis
sch[Y].bind(y, te.thread_axis("blockIdx.x"))
sch[Y].bind(x, te.thread_axis("threadIdx.x"))
print(tvm.lower(sch, [X, K, Y], simple_mode=True))
return sch, (X, K, Y)
tvm_gflops = d2ltvm.bench_conv_tvm(default_sch, sizes, target)
tvm_gflops
produce Y {
for (c, 0, 64) {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 64
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] = 0f
for (ric, 0, 64) {
for (rkh, 0, 3) {
for (rkw, 0, 3) {
Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] = (Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] + (tvm_if_then_else((((((blockIdx.x + rkh) < 1) || (65 <= (blockIdx.x + rkh))) || ((threadIdx.x + rkw) < 1)) || (65 <= (threadIdx.x + rkw))), 0f, X[((((((ric*4096) + (blockIdx.x*64)) + (rkh*64)) + threadIdx.x) + rkw) - 65)])*K[((((c*576) + (ric*9)) + (rkh*3)) + rkw)]))
}
}
}
}
}
array([129.81401547])
The default scheduling gives us the performance around 100 GFLOPS.
5.3. Tiling¶
As described in the last section, we can do tiling and bring the data to the shared and local memory of the GPU explicitly to improve the performance. Specifically, we tile three dimensions of the output (channel, height and width) as well as the three reduce dimensions (input channel, kernel height and kernel width). For the output dimensions, we split each of them into three parts for block binding, thread binding and CUDA kernel processing, respectively. The splitting factors are chosen to make sure the data can be fit into the shared and local memory of the GPU.
In our case, we specify the output tile (YL
) and its corresponding
input tiles (XL
and KL
) to the local memory, and the tiled input
data (XX
) and kernel (KK
) to the shared memory.
Fig. 5.3.2 shows how the tiling works for
convolution.
Under our tiling factors below, each thread needs to access
\(64\)(YL
)\(+48\)(XL
)\(+24\)(KL
)\(=136\)
32-bit floats. And in our setting each block contains
\(4 \times 2 \times 16 = 128\) threads, making the total occupied
local memory registers to be \(128 \times 136 = 17,408\), which is
easy to be fit into one SM. Similarly, we can reason that XX
and
KK
are fittable to the shared memory.
In addition, as in the last section, we use cooperative fetching to local the data from GPU host memory to the shared memory in parallel.
tile_c = [4, 8]
tile_h = [2, 2]
tile_w = [16, 4]
tile_rc = [1, 1]
tile_rh = [1, 1]
tile_rw = [1, 3]
# Save to the d2ltvm package.
def split_axis(factors, sch, op, axis):
"""Splitting an axis into factors
Parameters
----------
factors: array of integers
The factors that the split applies
sch: tvm.te.schedule.Schedule
The tvm schedule
op: tvm.te.tensor.Operation
The stage to be applied
axis: tvm.te.schedule.IterVar
axis to split
Returns
-------
axes : list of Axis
The transformed axes.
"""
ret = []
for i in range(0, len(factors)):
ax0, ax1 = sch[op].split(axis, factor=int(np.prod(factors[i:])))
ret.append(ax0)
axis = ax1
return ret + [axis]
def tiling(oc, ic, n, k, p, s):
X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
sch[PaddedX].compute_inline()
YL = sch.cache_write(Y, 'local')
# create cache stage
XX = sch.cache_read(PaddedX, 'shared', [YL])
KK = sch.cache_read(K, 'shared', [YL])
XL = sch.cache_read(XX, 'local', [YL])
KL = sch.cache_read(KK, 'local', [YL])
c, h, w = sch[Y].op.axis
bc, tc, ic = split_axis(tile_c, sch, Y, c)
bh, th, ih = split_axis(tile_h, sch, Y, h)
bw, tw, iw = split_axis(tile_w, sch, Y, w)
sch[Y].bind(bc, te.thread_axis("blockIdx.z"))
sch[Y].bind(bh, te.thread_axis("blockIdx.y"))
sch[Y].bind(bw, te.thread_axis("blockIdx.x"))
sch[Y].bind(tc, te.thread_axis("threadIdx.z"))
sch[Y].bind(th, te.thread_axis("threadIdx.y"))
sch[Y].bind(tw, te.thread_axis("threadIdx.x"))
sch[Y].reorder(bc, bh, bw, tc, th, tw, ic, ih, iw)
sch[YL].compute_at(sch[Y], tw)
# tile reduction axes
c, h, w = sch[YL].op.axis
rc, rh, rw = sch[YL].op.reduce_axis
rco, rcm, rci = split_axis(tile_rc, sch, YL, rc)
rho, rhm, rhi = split_axis(tile_rh, sch, YL, rh)
rwo, rwm, rwi = split_axis(tile_rw, sch, YL, rw)
sch[YL].reorder(rco, rho, rwo, rcm, rhm, rwm, rci, rhi, rwi, c, h, w)
sch[XX].compute_at(sch[YL], rwo)
sch[KK].compute_at(sch[YL], rwo)
sch[XL].compute_at(sch[YL], rwm)
sch[KL].compute_at(sch[YL], rwm)
# cooperative fetching
for load in [XX, KK]:
args = sch[load].op.axis
fused = sch[load].fuse(*args)
# align thread layout
tz, fused = sch[load].split(fused, nparts=tile_c[0])
ty, fused = sch[load].split(fused, nparts=tile_h[0])
tx, _ = sch[load].split(fused, nparts=tile_w[0])
sch[load].bind(tz, te.thread_axis("threadIdx.z"))
sch[load].bind(ty, te.thread_axis("threadIdx.y"))
sch[load].bind(tx, te.thread_axis("threadIdx.x"))
return sch, (X, K, Y)
tvm_gflops = d2ltvm.bench_conv_tvm(tiling, sizes, target)
tvm_gflops
array([1871.96109223])
The performance increases over one order of magnitude which is already on par with our baseline. We can still improve it by solving some bank conflict issue when accessing the data.
5.4. Optimizing the data access on GPUs¶
5.4.1. Bank Conflict¶
In the scheduling above, all threads will read XX
and KK
simultaneously, which may harm the performance. To understand it, we
need to dive a little bit into the shared memory architecture and how
threads are executed.
Remember that we created 128 threads for a thread block. Due to resource limits, we cannot execute all of them at the same time. Instead, each time we select a group of threads and run time simultaneously, and then switch to another group quickly. Such a group is called a warp, which contains 32 threads, and each having a consecutive thread index.
Each thread in a warp can access data in the shared memory
simultaneously. So the shared memory is designed to support parallel
load and store. The basic unit of the shared memory is word. Each word
has 4 bytes, which can hold a single 32-bit floating number. Words are
grouped into 32 banks. The \(j\)-th word is in the \(i\)-th bank
if j % 32 == i
.
Each bank can only handle a single request at a time, while these 32 banks are processed in parallel. So the shared memory performs fastest when each bank gets one request from a single thread in the warp, therefore we can access 32 words at the same time. However, if there are two threads requesting data from the same bank, we need to serialize these two requests. This is called a bank conflict. A special case is that if all (or some) threads of a warp request the same word in a bank, then the word will only read once and broadcast (multicast) to each thread, so there is no bank conflict. Fig. 5.4.1 illustrated these three cases.
5.4.2. Data Access Pattern¶
Now let’s analyze the read pattern of XX
and KK
. Note from
Fig. 5.3.2 that a thread reads a row segment of
XX
with a length of 4 consecutive numbers, which spread in 4
adjacent banks which causes severe bank conflict. The same thing applies
to the reading of KK
.
One way to mitigate this is to let the thread read in columns instead of rows, so each reading would have a stride, making the numbers to be read from one thread more spread among banks. Fig. 5.4.2 shows the difference of the reading patterns from shared memory to local memory.
5.4.3. Virtual Threads¶
In addition, TVM provides another mechanism, called virtual thread, to
further increase the data access stride to mitigate bank conflict. Let’s
revisit the thread structure we defined above. Each block has 16 threads
in the x
dimension, each thread processes 4 elements. It
conceptually gets data in the pattern depicted in the left of
Fig. 5.4.3.
In order to let threads to process data in a spread manner, we can use virtual threads to obtain strided data chunks. For example, we can first split the data into 2 parts to bind to 2 virtual threads. Then we further split the each part into 16 pieces to bind to 16 CUDA threads. In practice, the \(i\)-th CUDA thread in all virtual threads will be merged into a single one, so we will only get 16 CUDA threads instead of \(16 \times 2\) threads. In this case, each thread processes two spread data pieces as depicted in the right of Fig. 5.4.3.
tile_c = [1, 4, 8]
tile_h = [1, 2, 2]
tile_w = [2, 16, 2] # making 2 virtual thread along the ow dimension
tile_rc = [1, 1]
tile_rh = [1, 3] # making the data access in columns
tile_rw = [1, 1]
def vthread(oc, ic, n, k, p, s):
X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
sch[PaddedX].compute_inline()
YL = sch.cache_write(Y, 'local')
# create cache stage
XX = sch.cache_read(PaddedX, 'shared', [YL])
KK = sch.cache_read(K, 'shared', [YL])
XL = sch.cache_read(XX, 'local', [YL])
KL = sch.cache_read(KK, 'local', [YL])
c, h, w = sch[Y].op.axis
bc, vc, tc, ic = split_axis(tile_c, sch, Y, c)
bh, vh, th, ih = split_axis(tile_h, sch, Y, h)
bw, vw, tw, iw = split_axis(tile_w, sch, Y, w)
sch[Y].bind(bc, te.thread_axis("blockIdx.z"))
sch[Y].bind(bh, te.thread_axis("blockIdx.y"))
sch[Y].bind(bw, te.thread_axis("blockIdx.x"))
sch[Y].bind(vc, te.thread_axis("vthread"))
sch[Y].bind(vh, te.thread_axis("vthread"))
sch[Y].bind(vw, te.thread_axis("vthread"))
sch[Y].bind(tc, te.thread_axis("threadIdx.z"))
sch[Y].bind(th, te.thread_axis("threadIdx.y"))
sch[Y].bind(tw, te.thread_axis("threadIdx.x"))
sch[Y].reorder(bc, bh, bw, vc, vh, vw, tc, th, tw, ic, ih, iw)
sch[YL].compute_at(sch[Y], tw)
# tile reduction axes
c, h, w = sch[YL].op.axis
rc, rh, rw = sch[YL].op.reduce_axis
rco, rcm, rci = split_axis(tile_rc, sch, YL, rc)
rho, rhm, rhi = split_axis(tile_rh, sch, YL, rh)
rwo, rwm, rwi = split_axis(tile_rw, sch, YL, rw)
sch[YL].reorder(rco, rho, rwo, rcm, rhm, rwm, rci, rhi, rwi, c, h, w)
sch[XX].compute_at(sch[YL], rwo)
sch[KK].compute_at(sch[YL], rwo)
sch[XL].compute_at(sch[YL], rwm)
sch[KL].compute_at(sch[YL], rwm)
# cooperative fetching
for load in [XX, KK]:
args = sch[load].op.axis
fused = sch[load].fuse(*args)
# align thread layout
tz, fused = sch[load].split(fused, nparts=tile_c[1])
ty, fused = sch[load].split(fused, nparts=tile_h[1])
tx, _ = sch[load].split(fused, nparts=tile_w[1])
sch[load].bind(tz, te.thread_axis("threadIdx.z"))
sch[load].bind(ty, te.thread_axis("threadIdx.y"))
sch[load].bind(tx, te.thread_axis("threadIdx.x"))
return sch, (X, K, Y)
tvm_gflops = d2ltvm.bench_conv_tvm(vthread, sizes, target)
tvm_gflops
array([2705.41367495])
After carefully optimizing the data access, the performance we get
outperforms our baseline at channel=64
.
Now let’s vary the number of channels to test out the convolution performance obtained by TVM more comprehensively. And then we can plot the chart to compare MXNet and TVM.
channels = 2**np.arange(4, 9)
# a list of (c, n, k)
sizes = [(int(c), 64, 3) for c in channels]
target = 'cuda'
tvm_gflops = d2ltvm.bench_conv_tvm(vthread, sizes, target)
d2ltvm.plot_gflops(channels, [mx_gflops, tvm_gflops], legend=['MXNet', 'TVM'])
tvm_gflops
array([ 710.40119273, 1822.09562259, 2695.62647693, 3405.25584565,
3746.19560406])
From the figure we see that TVM outperforms MXNet in smaller channel sizes but MXNet becomes better as the number of channels increases. This is mostly because cuDNN used by MXNet has manually optimized implementation for different data shapes, but here we use only one scheduling strategy for convolution kernels in different sizes.
You may wonder how we can choose different schedules for convolutions of different data sizes, and even better, if we can automate the choice of schedules given a specific set of data shapes for convolution. We will talk about these techniques later.
In addition, there are ways that one can do to further increase the performance. For example, we can try to avoid all bank conflict by making the data reading stride always a multiple of 32. However, these tricks may be ad hoc and require intensive programming efforts. Our goal is to come up with some more high-level and generic scheduling scheme to achieve the reasonable performance.
5.5. Summary¶
We leverage the memory hierarchy of GPU to tile the data for better convolution performance.
We carefully manipulate the data access pattern to mitigate bank conflict which harms the performance.
5.6. Exercise¶
Try our different factors to split the axes and observe the performance difference.
Vary the size of input data and observe the performance difference.