3.6. Batch Normalization
Open the notebook in Colab

This section talks about how to use TVM to do batch normalization (batch_norm). Like pooling, batch_norm is also a common operator in CNN. D2L introduces this operator in details.

From the calculation perspective, for a given value, batch_norm subtracts the \(mean\) out of it, and then divide it with the square root of the \(variance\), no difference than a regular normalization. It is call batch_norm because the mean and variance are attained from the batches of when performed the training. After that, batch_norm also applies an affine transformation to the value, i.e. multiplies it with a scale value \(gamma\) followed by adding a shift value \(beta\). \(Gamma\) and \(beta\) are attained from the gradient computation of training. Lastly, a small positive value \(epsilon\) is added to prevent the divisor to be 0.

In the case of inference, both the mean and variance are determined, so the process of batch_norm is just a combination of several simple element-wise operations.

import tvm
from tvm import te
import d2ltvm
import numpy as np

3.6.1. Compute definition

In practice, we are not going to perform batch_norm of one value. Instead, the batch_norm will be executed on the output of a convolution, namely, 3-D data in (channel, height, weight). Data in different channels have different values of \(mean\), \(variance\), \(gamma\), and \(beta\). The calculation can be expressed as the following formula.

(3.6.1)\[out[i,:,:] = \frac{data[i,:,:] - mean[i]}{\sqrt{var[i]+\epsilon}} \ * gamma[i] + beta[i]\]

During model training, \(mean\) and \(var\) are computed from the input \(data\). However, in model inference which we focus on here, \(mean\) and \(var\) are given; therefore we don’t need to compute them from \(data\).

We will define the compute of this formula. Essentially, batch_norm is a combination of a number of simple broadcasting and element-wise calculations. Note that in Section 3.1 we defined a limited broadcast_add to perform only broadcast addition for 2-D tensors. If we generalize it to more dimensions and more calculators, we can reuse them to compose the batch_norm operator. This is actually what TVM does.

Here, for simplicity, we use TVM basic operators for broadcast calculation. TVM operators are defined in TOPI, which stands for Tensor OPerator Inventory. It follows the NumPy convention to override the arithmetic operators (i.e. +, -, *, /) for broadcast calculation. The element-wise square root can be found in TOPI, too.

The code snippet to define batch_norm is as follows.

# Save to the d2ltvm package.
import topi

def batch_norm(c, n, eps=1e-5):
    """batch normalization

    c : channels
    N : input width and height
    eps : small positive value to prevent divide 0
    """

    X = te.placeholder((c, n, n), name='X')
    Mean = te.placeholder((c, 1, 1), name='Mean')
    Var = te.placeholder((c, 1, 1), name='Var')
    Gamma = te.placeholder((c, 1, 1), name='Gamma')
    Beta = te.placeholder((c, 1, 1), name='Beta')
    C1 = X - Mean
    C2 = topi.sqrt(Var + eps)
    Y = C1 / C2 * Gamma + Beta
    return X, Mean, Var, Gamma, Beta, Y

We can then compile print the IR and compile it. The IR contains several stages but should be easy to follow.

c = 32
n = 28
X, Mean, Var, Gamma, Beta, Y = batch_norm(c, n)

sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Mean, Var, Gamma, Beta, Y])

print(tvm.lower(sch, [X, Mean, Var, Gamma, Beta], simple_mode=True))
// attr [T_subtract] storage_scope = "global"
allocate T_subtract[float32 * 25088]
// attr [T_add] storage_scope = "global"
allocate T_add[float32 * 32]
produce T_subtract {
  for (ax0, 0, 32) {
    for (ax1, 0, 28) {
      for (ax2, 0, 28) {
        T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (X[(((ax0*784) + (ax1*28)) + ax2)] - Mean[ax0])
      }
    }
  }
}
produce T_add {
  for (ax0, 0, 32) {
    T_add[ax0] = (Var[ax0] + 1e-05f)
  }
}
produce compute {
  for (i0, 0, 32) {
    T_add[i0] = sqrt(T_add[i0])
  }
}
produce T_divide {
  for (ax0, 0, 32) {
    for (ax1, 0, 28) {
      for (ax2, 0, 28) {
        T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)]/T_add[ax0])
      }
    }
  }
}
produce T_multiply {
  for (ax0, 0, 32) {
    for (ax1, 0, 28) {
      for (ax2, 0, 28) {
        T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)]*Gamma[ax0])
      }
    }
  }
}
produce T_add {
  for (ax0, 0, 32) {
    for (ax1, 0, 28) {
      for (ax2, 0, 28) {
        T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)] + Beta[ax0])
      }
    }
  }
}

To execute it, we will need to create data for batch_norm. Similar to the previous sections for getting conv and pooling data, we define a get_bn_data method to generate the data of batch_norm. One tricky thing is that the variance must be non-negative numbers. Therefore, we move the mean value of the random number generator’s normal distribution to 1 (by default mean 0 and standard deviation 1), and get the absolute numbers of generated results.

After getting the data, we can simply call the compiled module to execute.

# Save to the d2ltvm package.
def get_bn_data(c, n, constructor=None):
    """Return the batch norm data, mean, variance, gamma and beta tensors.
       Also return the empty tensor for output.

    c : channels
    n : input width and height
    constructor : user-defined tensor constructor
    """
    np.random.seed(0)
    data = np.random.normal(size=(c, n, n)).astype('float32')
    mean = np.random.normal(size=(c, 1, 1)).astype('float32')
    # move the mean of the normal distribution to be 1
    var = np.random.normal(loc=1.0, size=(c, 1, 1)).astype('float32')
    # make sure all variance numbers are not negative
    var = np.absolute(var)
    gamma = np.random.normal(size=(c, 1, 1)).astype('float32')
    beta = np.random.normal(size=(c, 1, 1)).astype('float32')
    out = np.empty((c, n, n), dtype='float32')
    if constructor:
        data, mean, var, gamma, beta, out = \
        (constructor(x) for x in [data, mean, var, gamma, beta, out])
    return data, mean, var, gamma, beta, out

data, mean, var, gamma, beta, out = get_bn_data(c, n, tvm.nd.array)
mod(data, mean, var, gamma, beta, out)

3.6.2. MXNet Baseline

We use the batch_norm function of MXNet as the baseline to check the correctness of our compiled functions. This function in MXNet was defined to be generic for both training and inference. In the inference case that we talk about here, we will need to set the corresponding input arguments properly. One is \(use_global_stats\), which needs to be set True as we will use the input mean and variance for batch_norm to compute instead of computing them from the input data (training will do so). The other is \(fix\_gamma\), which needs to be set False so that the input \(gamma\) will be used instead of setting \(gamma\) to be all 1.

Lastly, like we have discussed in other cases, MXNet batch_norm has input data in 4D, including batch as the outmost dimension. So we will expand this dimension in the data accordingly.

import mxnet as mx

# Save to the d2ltvm package.
def get_bn_data_mxnet(c, n, ctx='cpu'):
    ctx = getattr(mx, ctx)()
    data, mean, var, gamma, beta, out = get_bn_data(c, n,
                                      lambda x: mx.nd.array(x, ctx=ctx))
    data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)
    return data, mean, var, gamma, beta, out

# Save to the d2ltvm package.
def batch_norm_mxnet(data, mean, var, gamma, beta, out, eps=1e-5):
    # use_global_stats=True to use the input mean and var instead of computing
    # the mean and var of the input data.
    # fix_gamma=False so that gamma won't be set to 1.
    mx.nd.BatchNorm(data, gamma, beta, mean, var, eps,
                    use_global_stats=True, fix_gamma=False, out=out)

data, mean, var, gamma, beta, out_mx = get_bn_data_mxnet(c, n)
batch_norm_mxnet(data, mean, var, gamma, beta, out_mx)

Finally, we check if our results are close enough to the results produced by MXNet.

np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)

3.6.3. Summary

  • From the computation perspective, batch_norm is a combination of a number of broadcast and element-wise simple operators, which can be easily attained from TVM’s Tensor OPerator Inventory (TOPI).

  • In inference, \(mean\) and \(var\) of batch_norm are pre-defined.