5.5. Convolution
Open the notebook in Colab

In this section, we will extend Section 4.8 to optimize convolution on GPUs.

import d2ltvm
import numpy as np
import timeit
import tvm
from tvm import te

target = 'cuda'

5.5.1. Setup

As usual, we will use MXNet as our baseline, which calls cuDNN to execute the convolution. Like what we have done on CPUs, we benchmark the performance with various numbers of channels, when the input and kernel width/height are fixed to be 64 and 3, respectively. The benchmark method bench_conv_mxnet has already been defined in Section 4.8. The only change that we need to make is to specify to the method that the target device is GPU.

channels = 2**np.arange(4, 9)
# a list of (c, n, k)
sizes = [(int(c), 64, 3) for c in channels]
mx_gflops = d2ltvm.bench_conv_mxnet(sizes, 'gpu')
d2ltvm.plot_gflops(channels, [mx_gflops], ['mxnet'])
mx_gflops
[136.79632757961235,
 544.4900466560006,
 1744.3097712377894,
 4596.087785743593,
 6286.238545279585]
../_images/output_conv_0e1d8e_3_1.svg

The results above show that on GPUs the performance of convolution increases while the number of channels increases. For showing the gradual performance improvement brought by TVM scheduling, we now fix the channel size to be 64, where MXNet could get to the performance of about 1700 GFLOPS or 1.7 TGLOPS. Please keep this in mind while we are working on the TVM scheduling.

sizes = [(64, 64, 3)]

5.5.2. Default schedule of CONV

We then describe the computation of convolution in TVM using conv method, which is defined in Section 4.8. For a default schedule, we can simply bind two axes of the convolution loop nest into block and thread indexing axes.

def default_sch(oc, ic, n, k, p, s):
    X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
    sch = te.create_schedule(Y.op)
    sch[PaddedX].compute_inline()
    _, y, x = sch[Y].op.axis
    sch[Y].bind(y, te.thread_axis("blockIdx.x"))
    sch[Y].bind(x, te.thread_axis("threadIdx.x"))
    print(tvm.lower(sch, [X, K, Y], simple_mode=True))
    return sch, (X, K, Y)

tvm_gflops = d2ltvm.bench_conv_tvm(default_sch, sizes, target)
tvm_gflops
produce Y {
  for (c, 0, 64) {
    // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 64
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 64
    Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] = 0f
    for (ric, 0, 64) {
      for (rkh, 0, 3) {
        for (rkw, 0, 3) {
          Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] = (Y[(((c*4096) + (blockIdx.x*64)) + threadIdx.x)] + (tvm_if_then_else((((((blockIdx.x + rkh) < 1) || (65 <= (blockIdx.x + rkh))) || ((threadIdx.x + rkw) < 1)) || (65 <= (threadIdx.x + rkw))), 0f, X[((((((ric*4096) + (blockIdx.x*64)) + (rkh*64)) + threadIdx.x) + rkw) - 65)])*K[((((c*576) + (ric*9)) + (rkh*3)) + rkw)]))
        }
      }
    }
  }
}
array([128.32613589])

The default scheduling gives us the performance around 100 GFLOPS.

5.5.3. Tiling

As described in the last section, we can do tiling and bring the data to the shared and local memory of the GPU explicitly to improve the performance. Specifically, we tile three dimensions of the output (channel, height and width) as well as the three reduce dimensions (input channel, kernel height and kernel width). For the output dimensions, we split each of them into three parts for block binding, thread binding and CUDA kernel processing, respectively. The splitting factors are chosen to make sure the data can be fit into the shared and local memory of the GPU.

In our case, we specify the output tile (YL) and its corresponding input tiles (XL and KL) to the local memory, and the tiled input data (XX) and kernel (KK) to the shared memory. Fig. 5.5.1 shows how the tiling works for convolution.

../_images/conv_thread_block.svg

Fig. 5.5.1 Blocked tiling for convolution to put small tiles into shared and local memory of GPU.

Under our tiling factors below, each thread needs to access \(64\)(YL)\(+48\)(XL)\(+24\)(KL)\(=136\) 32-bit floats. And in our setting each block contains \(4 \times 2 \times 16 = 128\) threads, making the total occupied local memory registers to be \(128 \times 136 = 17,408\), which is easy to be fit into one SM. Similarly, we can reason that XX and KK are fittable to the shared memory.

In addition, as in the last section, we use cooperative fetching to local the data from GPU host memory to the shared memory in parallel.

tile_c = [4, 8]
tile_h = [2, 2]
tile_w = [16, 4]
tile_rc = [1, 1]
tile_rh = [1, 1]
tile_rw = [1, 3]

# Save to the d2ltvm package.
def split_axis(factors, sch, op, axis):
        """Splitting an axis into factors

        Parameters
        ----------
        factors: array of integers
            The factors that the split applies
        sch: tvm.te.schedule.Schedule
            The tvm schedule
        op: tvm.te.tensor.Operation
            The stage to be applied
        axis: tvm.te.schedule.IterVar
            axis to split

        Returns
        -------
        axes : list of Axis
            The transformed axes.
        """
        ret = []
        for i in range(0, len(factors)):
            ax0, ax1 = sch[op].split(axis, factor=int(np.prod(factors[i:])))
            ret.append(ax0)
            axis = ax1
        return ret + [axis]

def tiling(oc, ic, n, k, p, s):
    X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
    sch = te.create_schedule(Y.op)
    sch[PaddedX].compute_inline()

    YL = sch.cache_write(Y, 'local')

    # create cache stage
    XX = sch.cache_read(PaddedX, 'shared', [YL])
    KK = sch.cache_read(K, 'shared', [YL])
    XL = sch.cache_read(XX, 'local', [YL])
    KL = sch.cache_read(KK, 'local', [YL])

    c, h, w = sch[Y].op.axis

    bc, tc, ic = split_axis(tile_c, sch, Y, c)
    bh, th, ih = split_axis(tile_h, sch, Y, h)
    bw, tw, iw = split_axis(tile_w, sch, Y, w)

    sch[Y].bind(bc, te.thread_axis("blockIdx.z"))
    sch[Y].bind(bh, te.thread_axis("blockIdx.y"))
    sch[Y].bind(bw, te.thread_axis("blockIdx.x"))
    sch[Y].bind(tc, te.thread_axis("threadIdx.z"))
    sch[Y].bind(th, te.thread_axis("threadIdx.y"))
    sch[Y].bind(tw, te.thread_axis("threadIdx.x"))
    sch[Y].reorder(bc, bh, bw, tc, th, tw, ic, ih, iw)

    sch[YL].compute_at(sch[Y], tw)

    # tile reduction axes
    c, h, w = sch[YL].op.axis
    rc, rh, rw = sch[YL].op.reduce_axis
    rco, rcm, rci = split_axis(tile_rc, sch, YL, rc)
    rho, rhm, rhi = split_axis(tile_rh, sch, YL, rh)
    rwo, rwm, rwi = split_axis(tile_rw, sch, YL, rw)
    sch[YL].reorder(rco, rho, rwo, rcm, rhm, rwm, rci, rhi, rwi, c, h, w)

    sch[XX].compute_at(sch[YL], rwo)
    sch[KK].compute_at(sch[YL], rwo)
    sch[XL].compute_at(sch[YL], rwm)
    sch[KL].compute_at(sch[YL], rwm)

    # cooperative fetching
    for load in [XX, KK]:
        args = sch[load].op.axis
        fused = sch[load].fuse(*args)
        # align thread layout
        tz, fused = sch[load].split(fused, nparts=tile_c[0])
        ty, fused = sch[load].split(fused, nparts=tile_h[0])
        tx, _ = sch[load].split(fused, nparts=tile_w[0])
        sch[load].bind(tz, te.thread_axis("threadIdx.z"))
        sch[load].bind(ty, te.thread_axis("threadIdx.y"))
        sch[load].bind(tx, te.thread_axis("threadIdx.x"))

    return sch, (X, K, Y)

tvm_gflops = d2ltvm.bench_conv_tvm(tiling, sizes, target)
tvm_gflops
array([1861.83916811])

The performance increases over one order of magnitude which is already on par with our baseline. We can still improve it by solving some bank conflict issue when accessing the data.

5.5.4. Optimizing the data access on GPUs

5.5.4.1. Bank Conflict

In the scheduling above, all threads will read XX and KK simultaneously, which may harm the performance. To understand it, we need to dive a little bit into the shared memory architecture and how threads are executed.

Remember that we created 128 threads for a thread block. Due to resource limits, we cannot execute all of them at the same time. Instead, each time we select a group of threads and run time simultaneously, and then switch to another group quickly. Such a group is called a warp, which contains 32 threads, and each having a consecutive thread index.

Each thread in a warp can access data in the shared memory simultaneously. So the shared memory is designed to support parallel load and store. The basic unit of the shared memory is word. Each word has 4 bytes, which can hold a single 32-bit floating number. Words are grouped into 32 banks. The \(j\)-th word is in the \(i\)-th bank if j % 32 == i.

Each bank can only handle a single request at a time, while these 32 banks are processed in parallel. So the shared memory performs fastest when each bank gets one request from a single thread in the warp, therefore we can access 32 words at the same time. However, if there are two threads requesting data from the same bank, we need to serialize these two requests. This is called a bank conflict. A special case is that if all (or some) threads of a warp request the same word in a bank, then the word will only read once and broadcast (multicast) to each thread, so there is no bank conflict. Fig. 5.5.2 illustrated these three cases.

../_images/bank_conflict.svg

Fig. 5.5.2 Accessing shared memory

5.5.4.2. Data Access Pattern

Now let’s analyze the read pattern of XX and KK. Note from Fig. 5.5.1 that a thread reads a row segment of XX with a length of 4 consecutive numbers, which spread in 4 adjacent banks which causes severe bank conflict. The same thing applies to the reading of KK.

One way to mitigate this is to let the thread read in columns instead of rows, so each reading would have a stride, making the numbers to be read from one thread more spread among banks. Fig. 5.5.3 shows the difference of the reading patterns from shared memory to local memory.

../_images/conv_row_column.svg

Fig. 5.5.3 Different reading patterns from shared memory to local memory

5.5.4.3. Virtual Threads

In addition, TVM provides another mechanism, called virtual thread, to further increase the data access stride to mitigate bank conflict. Let’s revisit the thread structure we defined above. Each block has 16 threads in the x dimension, each thread processes 4 elements. It conceptually gets data in the pattern depicted in the left of Fig. 5.5.4.

In order to let threads to process data in a spread manner, we can use virtual threads to obtain strided data chunks. For example, we can first split the data into 2 parts to bind to 2 virtual threads. Then we further split the each part into 16 pieces to bind to 16 CUDA threads. In practice, the \(i\)-th CUDA thread in all virtual threads will be merged into a single one, so we will only get 16 CUDA threads instead of \(16 \times 2\) threads. In this case, each thread processes two spread data pieces as depicted in the right of Fig. 5.5.4.

../_images/conv_virtual_thread.svg

Fig. 5.5.4 Virtual thread binding

tile_c = [1, 4, 8]
tile_h = [1, 2, 2]
tile_w = [2, 16, 2] # making 2 virtual thread along the ow dimension
tile_rc = [1, 1]
tile_rh = [1, 3] # making the data access in columns
tile_rw = [1, 1]

def vthread(oc, ic, n, k, p, s):
    X, K, Y, PaddedX = d2ltvm.conv(oc, ic, n, n, k, k, p, p, s, s)
    sch = te.create_schedule(Y.op)
    sch[PaddedX].compute_inline()

    YL = sch.cache_write(Y, 'local')

    # create cache stage
    XX = sch.cache_read(PaddedX, 'shared', [YL])
    KK = sch.cache_read(K, 'shared', [YL])
    XL = sch.cache_read(XX, 'local', [YL])
    KL = sch.cache_read(KK, 'local', [YL])

    c, h, w = sch[Y].op.axis

    bc, vc, tc, ic = split_axis(tile_c, sch, Y, c)
    bh, vh, th, ih = split_axis(tile_h, sch, Y, h)
    bw, vw, tw, iw = split_axis(tile_w, sch, Y, w)

    sch[Y].bind(bc, te.thread_axis("blockIdx.z"))
    sch[Y].bind(bh, te.thread_axis("blockIdx.y"))
    sch[Y].bind(bw, te.thread_axis("blockIdx.x"))
    sch[Y].bind(vc, te.thread_axis("vthread"))
    sch[Y].bind(vh, te.thread_axis("vthread"))
    sch[Y].bind(vw, te.thread_axis("vthread"))
    sch[Y].bind(tc, te.thread_axis("threadIdx.z"))
    sch[Y].bind(th, te.thread_axis("threadIdx.y"))
    sch[Y].bind(tw, te.thread_axis("threadIdx.x"))
    sch[Y].reorder(bc, bh, bw, vc, vh, vw, tc, th, tw, ic, ih, iw)

    sch[YL].compute_at(sch[Y], tw)

    # tile reduction axes
    c, h, w = sch[YL].op.axis
    rc, rh, rw = sch[YL].op.reduce_axis
    rco, rcm, rci = split_axis(tile_rc, sch, YL, rc)
    rho, rhm, rhi = split_axis(tile_rh, sch, YL, rh)
    rwo, rwm, rwi = split_axis(tile_rw, sch, YL, rw)
    sch[YL].reorder(rco, rho, rwo, rcm, rhm, rwm, rci, rhi, rwi, c, h, w)

    sch[XX].compute_at(sch[YL], rwo)
    sch[KK].compute_at(sch[YL], rwo)
    sch[XL].compute_at(sch[YL], rwm)
    sch[KL].compute_at(sch[YL], rwm)

    # cooperative fetching
    for load in [XX, KK]:
        args = sch[load].op.axis
        fused = sch[load].fuse(*args)
        # align thread layout
        tz, fused = sch[load].split(fused, nparts=tile_c[1])
        ty, fused = sch[load].split(fused, nparts=tile_h[1])
        tx, _ = sch[load].split(fused, nparts=tile_w[1])
        sch[load].bind(tz, te.thread_axis("threadIdx.z"))
        sch[load].bind(ty, te.thread_axis("threadIdx.y"))
        sch[load].bind(tx, te.thread_axis("threadIdx.x"))

    return sch, (X, K, Y)

tvm_gflops = d2ltvm.bench_conv_tvm(vthread, sizes, target)
tvm_gflops
array([2697.69346504])

After carefully optimizing the data access, the performance we get outperforms our baseline at channel=64.

Now let’s vary the number of channels to test out the convolution performance obtained by TVM more comprehensively. And then we can plot the chart to compare MXNet and TVM.

channels = 2**np.arange(4, 9)
# a list of (c, n, k)
sizes = [(int(c), 64, 3) for c in channels]
target = 'cuda'
tvm_gflops = d2ltvm.bench_conv_tvm(vthread, sizes, target)
d2ltvm.plot_gflops(channels, [mx_gflops, tvm_gflops], legend=['MXNet', 'TVM'])
tvm_gflops
array([ 708.63586743, 1831.78216238, 2691.95795593, 3390.56281695,
       3705.44656231])
../_images/output_conv_0e1d8e_13_1.svg

From the figure we see that TVM outperforms MXNet in smaller channel sizes but MXNet becomes better as the number of channels increases. This is mostly because cuDNN used by MXNet has manually optimized implementation for different data shapes, but here we use only one scheduling strategy for convolution kernels in different sizes.

You may wonder how we can choose different schedules for convolutions of different data sizes, and even better, if we can automate the choice of schedules given a specific set of data shapes for convolution. We will talk about these techniques later.

In addition, there are ways that one can do to further increase the performance. For example, we can try to avoid all bank conflict by making the data reading stride always a multiple of 32. However, these tricks may be ad hoc and require intensive programming efforts. Our goal is to come up with some more high-level and generic scheduling scheme to achieve the reasonable performance.

5.5.5. Summary

  • We leverage the memory hierarchy of GPU to tile the data for better convolution performance.

  • We carefully manipulate the data access pattern to mitigate bank conflict which harms the performance.

5.5.6. Exercise

  • Try our different factors to split the axes and observe the performance difference.

  • Vary the size of input data and observe the performance difference.