前言
按照时间上的迭代顺序,近些年神经网络先后出现了 Gradient Descent (GD)、Momentum、Adaptive Gradient (AdaGrad)、Root Mean Square prop (RMSprop)、Adaptive Moment estimation (Adam) 等优秀的优化器。到如今,大部分 NLP 预训练模型已不再使用这些方法,而是使用 Adam Weight Decay Regularization (AdamW) 和去年首度亮相的 Layer-wise Adaptive Moments optimizer for Batching training (LAMB)。为何最为传统的 GD,包括衍生的 stochastic GD、mini-batch GD 优化器已不再使用,下文会有详细的介绍。
Gradient Descent (GD)
梯度下降法是最为经典的凸优化优化器,思想也非常明确:通过 loss 反向传导计算参数的梯度,参数往哪个方向跑可以让 loss 下降,就让参数往哪个方向更新:
ΔWk?=?Wk??loss?=?Zn??loss??Zn?1??Zn??...?Wk??Zk+1??
Wk?←Wk??αΔWk?
需要注意的是,
Wk? 中的每一个浮点元素的梯度计算和梯度更新,相互之间是完全独立的,这对于理解梯度更新的机理非常重要。上式中,
α 为学习率,通常是一个固定的超参数,学习率越高,收敛越快。但需要注意控制范围。学习率过大,容易造成梯度跨过参数的局部最优点造成参数震荡;学习率过小,会导致训练过程过于漫长。为避免参数震荡,使用 GD 时,学习率通常设置在一个较低值,且训练的 batch_size 越大,学习率越低。梯度裁剪虽能一定程度上解决梯度震荡的问题,但由于输出的概率分布发生偏移,模型收敛也受到一定负面影响,因此需尽可能避免对梯度裁剪的依赖。
Adaptive Moment estimation (Adam)
为解决 GD 中固定学习率带来的不同参数间收敛速度不一致的弊端,AdaGrad 和 RMSprop 诞生出来,为每个参数赋予独立的学习率。计算梯度后,梯度较大的参数获得的学习率较低,反之亦然。此外,为避免每次梯度更新时都独立计算梯度,导致梯度方向持续变化,Momentum 将上一轮梯度值加入到当前梯度的计算中,通过某种权重对两者加权求和,获得当前批次参数更新的更新值。 Adam 结合了这两项考虑,既为每一个浮点参数自适应性地设置学习率,又将过去的梯度历史纳入考量:
mt?=β1?mt?1?+(1?β1?)ΔW
vt?=β2?vt?1?+(1?β2?)ΔW2
mt?^?=1?β1t?mt??
vt?^?=1?β2t?vt??
Wt?←Wt?1??vt?^?
实际使用中,通常
β1?=0.9,
β2?>0.9。BERT 源代码中,预训练的
β2? 为 0.98,微调的
β2? 为 0.999,其目的是为了减少对预训练中得到的原始参数结构的破坏,使收敛更为平缓。此外,
m0? 和
v0? 皆为初始化得来,因此训练时参数种子的设置往往对模型结果的影响较大。从上述公式可以看出,训练前期的学习率和梯度更新是比较激进的,到后期逐渐平稳。
虽然 Adam 优化器的使用会导致内存中多出两倍于原参数体量的占用,但与之换来的训练收益使得学术界并没有放弃这一高效的方法。
Adam Weight Decay Regularization (AdamW)
Adam 虽然收敛速度快,但没能解决参数过拟合的问题。学术界讨论了诸多方案,其中包括在损失函数中引入参数的 L2 正则项。这样的方法在其他的优化器中或许有效,但会因为 Adam 中自适应学习率的存在而对使用 Adam 优化器的模型失效。AdamW 的出现便是为了解决这一问题,达到同样使参数接近于 0 的目的。具体的举措,是在最终的参数更新时引入参数自身:
mt?=β1?mt?1?+(1?β1?)ΔW
vt?=β2?vt?1?+(1?β2?)ΔW2
mt?^?=1?β1t?mt??
vt?^?=1?β2t?vt??
Wt?←Wt?1??α(vt?^?
λ 即为权重衰减因子,常见的设置为 0.005/0.01。这一优化策略目前正广泛应用于各大预训练语言模型。
Layer-wise Adaptive Moments optimizer for Batching training (LAMB)
LAMB 优化器是 2019 年出现的一匹新秀,原论文标题后半部分叫做 “Training BERT in 76 Minutes”,足以看出其野心之大。 LAMB 出现的目的是加速预训练进程,这个优化器也成为 NLP 社区为泛机器学习领域做出的一大贡献。在使用 Adam 和 AdamW 等优化器时,一大问题在于 batch size 存在一定的隐式上限,一旦突破这个上限,梯度更新极端的取值会导致自适应学习率调整后极为困难的收敛,从而无法享受增加的 batch size 带来的提速增益。LAMB 优化器的作用便在于使模型在进行大批量数据训练时,能够维持梯度更新的精度:
mt?=β1?mt?1?+(1?β1?)ΔW
vt?=β2?vt?1?+(1?β2?)ΔW2
rt?=vt?
Wt?←Wt?1??α??(∣∣rt?+λWt?1?∣∣∣∣Wt?1?∣∣?)(rt?+λWt?1?)
其中,
? 是一个可选择的映射函数,一种是
?(z)=z,另一种则为起到归一化作用的
?(z)=min(max(z,γl?),γu?)。
γl? 和
γu? 为预先设定的超参数,分别代表参数调整的下界和上界。这一简单的调整所带来的实际效果非常显著。使用 AdamW 时,batch size 超过 512 便会导致模型效果大幅下降,但在 LAMB 下,batch size 可以直接提到 32,000 而不会导致精度损失。
由于在下游微调预训练模型时,通常无需过大的数据集,因而 LAMB 仅在预训练环节使用。遗憾的是,LAMB 在 batch size 512 以下时无法起到显著作用,目前只能作为大体量财团的工具。
附录
以下是 LAMB 优化器的 tensorflow1.x 代码,可作为参考以理解算法,具体的代码出处已无法找寻。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | class LAMBOptimizer(tf.train.Optimizer): ''' LAMBOptimizer optimizer. # Important Note - This is NOT an official implementation. - LAMB optimizer is changed from arXiv v1 ~ v3. - We implement v3 version (which is the latest version on June, 2019.). - Our implementation is based on `AdamWeightDecayOptimizer` in BERT (provided by Google). # References - LAMB optimier: https://github.com/ymcui/LAMB_Optimizer_TF - Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962v3 - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805 # Parameters - There is nothing special, just the same as `AdamWeightDecayOptimizer`. ''' def __init__(self, learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=None, name="LAMBOptimizer"): """Constructs a LAMBOptimizer.""" super(LAMBOptimizer, self).__init__(False, name) self.learning_rate = learning_rate self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.exclude_from_weight_decay = exclude_from_weight_decay def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) m = tf.get_variable( name=param_name + "/lamb_m", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) v = tf.get_variable( name=param_name + "/lamb_v", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) # Standard Adam update. next_m = ( tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) next_v = ( tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(grad))) update = next_m / (tf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += self.weight_decay_rate * param ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ############## # Note: Here are two choices for scaling function \phi(z) # minmax: \phi(z) = min(max(z, \gamma_l), \gamma_u) # identity: \phi(z) = z # The authors does not mention what is \gamma_l and \gamma_u # UPDATE: after asking authors, they provide me the code below. # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( # math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) r1 = tf.sqrt(tf.reduce_sum(tf.square(param))) r2 = tf.sqrt(tf.reduce_sum(tf.square(update))) r = tf.where(tf.greater(r1, 0.0), tf.where(tf.greater(r2, 0.0), r1 / r2, 1.0), 1.0) eta = self.learning_rate * r update_with_lr = eta * update next_param = param - update_with_lr assignments.extend( [param.assign(next_param), m.assign(next_m), v.assign(next_v)]) return tf.group(*assignments, name=name) def _do_use_weight_decay(self, param_name): """Whether to use L2 weight decay for `param_name`.""" if not self.weight_decay_rate: return False if self.exclude_from_weight_decay: for r in self.exclude_from_weight_decay: if re.search(r, param_name) is not None: return False return True def _get_variable_name(self, param_name): """Get the variable name from the tensor name.""" m = re.match("^(.*):\\d+$", param_name) if m is not None: param_name = m.group(1) return param_name |