.. _ch_batch_norm: Batch Normalization =================== This section talks about how to use TVM to do batch normalization (``batch_norm``). Like pooling, ``batch_norm`` is also a common operator in CNN. D2L introduces this operator in `details `__. From the calculation perspective, for a given value, ``batch_norm`` subtracts the :math:`mean` out of it, and then divide it with the square root of the :math:`variance`, no difference than a regular normalization. It is call ``batch_norm`` because the mean and variance are attained from the batches of when performed the training. After that, ``batch_norm`` also applies an affine transformation to the value, i.e. multiplies it with a scale value :math:`gamma` followed by adding a shift value :math:`beta`. :math:`Gamma` and :math:`beta` are attained from the gradient computation of training. Lastly, a small positive value :math:`epsilon` is added to prevent the divisor to be 0. In the case of inference, both the mean and variance are determined, so the process of ``batch_norm`` is just a combination of several simple element-wise operations. .. code:: python import tvm from tvm import te import d2ltvm import numpy as np Compute definition ------------------ In practice, we are not going to perform ``batch_norm`` of one value. Instead, the ``batch_norm`` will be executed on the output of a convolution, namely, 3-D data in (channel, height, weight). Data in different channels have different values of :math:`mean`, :math:`variance`, :math:`gamma`, and :math:`beta`. The calculation can be expressed as the following formula. .. math:: out[i,:,:] = \frac{data[i,:,:] - mean[i]}{\sqrt{var[i]+\epsilon}} \ * gamma[i] + beta[i] During model training, :math:`mean` and :math:`var` are computed from the input :math:`data`. However, in model inference which we focus on here, :math:`mean` and :math:`var` are given; therefore we don't need to compute them from :math:`data`. We will define the compute of this formula. Essentially, ``batch_norm`` is a combination of a number of simple broadcasting and element-wise calculations. Note that in :numref:`ch_bcast_add` we defined a limited ``broadcast_add`` to perform only broadcast addition for 2-D tensors. If we generalize it to more dimensions and more calculators, we can reuse them to compose the ``batch_norm`` operator. This is actually what TVM does. Here, for simplicity, we use TVM basic operators for broadcast calculation. TVM operators are defined in ``TOPI``, which stands for Tensor OPerator Inventory. It follows the NumPy convention to override the arithmetic operators (i.e. ``+``, ``-``, ``*``, ``/``) for broadcast calculation. The element-wise square root can be found in ``TOPI``, too. The code snippet to define ``batch_norm`` is as follows. .. code:: python # Save to the d2ltvm package. import topi def batch_norm(c, n, eps=1e-5): """batch normalization c : channels N : input width and height eps : small positive value to prevent divide 0 """ X = te.placeholder((c, n, n), name='X') Mean = te.placeholder((c, 1, 1), name='Mean') Var = te.placeholder((c, 1, 1), name='Var') Gamma = te.placeholder((c, 1, 1), name='Gamma') Beta = te.placeholder((c, 1, 1), name='Beta') C1 = X - Mean C2 = topi.sqrt(Var + eps) Y = C1 / C2 * Gamma + Beta return X, Mean, Var, Gamma, Beta, Y We can then compile print the IR and compile it. The IR contains several stages but should be easy to follow. .. code:: python c = 32 n = 28 X, Mean, Var, Gamma, Beta, Y = batch_norm(c, n) sch = te.create_schedule(Y.op) mod = tvm.build(sch, [X, Mean, Var, Gamma, Beta, Y]) print(tvm.lower(sch, [X, Mean, Var, Gamma, Beta], simple_mode=True)) .. parsed-literal:: :class: output // attr [T_subtract] storage_scope = "global" allocate T_subtract[float32 * 25088] // attr [T_add] storage_scope = "global" allocate T_add[float32 * 32] produce T_subtract { for (ax0, 0, 32) { for (ax1, 0, 28) { for (ax2, 0, 28) { T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (X[(((ax0*784) + (ax1*28)) + ax2)] - Mean[ax0]) } } } } produce T_add { for (ax0, 0, 32) { T_add[ax0] = (Var[ax0] + 1e-05f) } } produce compute { for (i0, 0, 32) { T_add[i0] = sqrt(T_add[i0]) } } produce T_divide { for (ax0, 0, 32) { for (ax1, 0, 28) { for (ax2, 0, 28) { T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)]/T_add[ax0]) } } } } produce T_multiply { for (ax0, 0, 32) { for (ax1, 0, 28) { for (ax2, 0, 28) { T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)]*Gamma[ax0]) } } } } produce T_add { for (ax0, 0, 32) { for (ax1, 0, 28) { for (ax2, 0, 28) { T_subtract[(((ax0*784) + (ax1*28)) + ax2)] = (T_subtract[(((ax0*784) + (ax1*28)) + ax2)] + Beta[ax0]) } } } } To execute it, we will need to create data for ``batch_norm``. Similar to the previous sections for getting conv and pooling data, we define a ``get_bn_data`` method to generate the data of ``batch_norm``. One tricky thing is that the variance must be non-negative numbers. Therefore, we move the mean value of the random number generator's normal distribution to 1 (by default mean 0 and standard deviation 1), and get the absolute numbers of generated results. After getting the data, we can simply call the compiled module to execute. .. code:: python # Save to the d2ltvm package. def get_bn_data(c, n, constructor=None): """Return the batch norm data, mean, variance, gamma and beta tensors. Also return the empty tensor for output. c : channels n : input width and height constructor : user-defined tensor constructor """ np.random.seed(0) data = np.random.normal(size=(c, n, n)).astype('float32') mean = np.random.normal(size=(c, 1, 1)).astype('float32') # move the mean of the normal distribution to be 1 var = np.random.normal(loc=1.0, size=(c, 1, 1)).astype('float32') # make sure all variance numbers are not negative var = np.absolute(var) gamma = np.random.normal(size=(c, 1, 1)).astype('float32') beta = np.random.normal(size=(c, 1, 1)).astype('float32') out = np.empty((c, n, n), dtype='float32') if constructor: data, mean, var, gamma, beta, out = \ (constructor(x) for x in [data, mean, var, gamma, beta, out]) return data, mean, var, gamma, beta, out data, mean, var, gamma, beta, out = get_bn_data(c, n, tvm.nd.array) mod(data, mean, var, gamma, beta, out) MXNet Baseline -------------- We use the ``batch_norm`` function of MXNet as the baseline to check the correctness of our compiled functions. This function in MXNet was defined to be generic for both training and inference. In the inference case that we talk about here, we will need to set the corresponding input arguments properly. One is :math:`use_global_stats`, which needs to be set ``True`` as we will use the input mean and variance for ``batch_norm`` to compute instead of computing them from the input data (training will do so). The other is :math:`fix\_gamma`, which needs to be set ``False`` so that the input :math:`gamma` will be used instead of setting :math:`gamma` to be all 1. Lastly, like we have discussed in other cases, MXNet ``batch_norm`` has input data in 4D, including batch as the outmost dimension. So we will expand this dimension in the data accordingly. .. code:: python import mxnet as mx # Save to the d2ltvm package. def get_bn_data_mxnet(c, n, ctx='cpu'): ctx = getattr(mx, ctx)() data, mean, var, gamma, beta, out = get_bn_data(c, n, lambda x: mx.nd.array(x, ctx=ctx)) data, out = data.expand_dims(axis=0), out.expand_dims(axis=0) return data, mean, var, gamma, beta, out # Save to the d2ltvm package. def batch_norm_mxnet(data, mean, var, gamma, beta, out, eps=1e-5): # use_global_stats=True to use the input mean and var instead of computing # the mean and var of the input data. # fix_gamma=False so that gamma won't be set to 1. mx.nd.BatchNorm(data, gamma, beta, mean, var, eps, use_global_stats=True, fix_gamma=False, out=out) data, mean, var, gamma, beta, out_mx = get_bn_data_mxnet(c, n) batch_norm_mxnet(data, mean, var, gamma, beta, out_mx) Finally, we check if our results are close enough to the results produced by MXNet. .. code:: python np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5) Summary ------- - From the computation perspective, ``batch_norm`` is a combination of a number of broadcast and element-wise simple operators, which can be easily attained from TVM's Tensor OPerator Inventory (TOPI). - In inference, :math:`mean` and :math:`var` of ``batch_norm`` are pre-defined.