机器学习中的层规范化(及PyTorch示例)

2023-04-24

机器学习中的层规范化(及PyTorch示例)

机器学习中层规范化的快速介绍,以及PyTorch示例代码。

导语

训练机器学习算法是一项颇具挑战性的任务,尤其是在使用来自现实世界的数据集时。在人们可能遇到的众多问题中,中间激活层统计量的稳定性问题是非常常见的一个。这篇文章中,我们将简要讨论保证统计量稳定性的常用方法之一:层规范化。

到底什么是层规范化?

遇到的问题

正如你所知道的,训练一个机器学习模型是一个随机化的过程。其根源在于初始化即使是最常见的优化器(如SGD、Adam 等)在本质上都是一个随机的过程。

因此,机器学习的优化往往有收敛到解决方案空间上尖锐的(不可泛化的)最小值的风险,从而导致巨大的梯度变化。也就是简单地说:激活层的结果(即非线性层的输出)有骤升到非常大的值的趋势。这至少可以说是一种不理想的情况,而解决这个问题的最常见方法就是使用批规范化。

但是,这里有一个问题。一旦批大小减少,批规范化很快就会失效。然而随着如今机器学习算法在数据分辨率方面的提高,这成为一个严重问题:因为为了便于将数据放入内存,批大小就需要很小。此外,进行批规范化需要计算每一层激活结果的均值/方差。此方法不适用于迭代式的模型(如 RNN),由于这些层的统计量估计值取决于序列的长度(即同一隐藏层被调用的次数)。

解决的方法

LayerNorm提供了一种同时解决上述两种问题的方法,即通过计算每一批激活结果中各项的统计量(即均值和方差),来规范化这些项。

举例来说,给定一个形如 [N,C,H,W]的样本,LayerNorm会计算每一批数据中每一个形如[C,H,W]的元素的均值和方差(如下图)。这种方法不仅解决了前面提到的问题,并且还不需要存储均值和方差来进行推理(而批规范化层在训练时需要这么做)。

代码实现

在PyTorch中实现层规范化是一件相对简单的事情。你需要做的就是使用torch.nn.LayerNorm()

不过对于卷积神经网络,还需要在给定执行卷积时使用的参数的情况下计算输出激活结果的形状,如下函数calc_activation_shape()给出了一个简要的实现。

class Network(torch.nn.Module):
    @staticmethod
    def calc_activation_shape(
        dim, ksize, dilation=(1, 1), stride=(1, 1), padding=(0, 0)
    ):
        def shape_each_dim(i):
            odim_i = dim[i] + 2 * padding[i] - dilation[i] * (ksize[i] - 1) - 1
            return (odim_i / stride[i]) + 1

        return shape_each_dim(0), shape_each_dim(1)

    def __init__(self, idim, num_classes=10):
        self.layer1 = torch.nn.Conv2D(3, 5, 3)
        ln_shape = Network.calc_activation_shape(idim, 3) # <--- Calculate the shape of output of Convolution
        self.norm1 = torch.nn.LayerNorm([5, *ln_shape]) # <--- Normalize activations over C, H, and W (see fig.above)
        self.layer2 = torch.nn.Conv2D(5, 10, 3)
        ln_shape = Network.calc_activation_shape(ln_shape, 3)
        self.norm2 = torch.nn.LayerNorm([10, *ln_shape])
        self.layer3 = torch.nn.Dense(num_classes)

    def __call__(self, inputs):
        x = F.relu(self.norm1(self.layer1(input)))
        x = F.relu(self.norm2(self.layer2(x)))
        x = F.sigmoid(self.layer3(x))
        return x

我们在colab notebook中对使用和不使用层规范化化的模型进行了比对,如下表所示。这里可以看出Layer Norm 的表现非常出色。 (注意:我们取了 4 次运行的平均值,实线表示这些运行的平均结果,浅色区域表示标准差。)