简洁实现
1 | import torch |
使用数学方法实现的softmax函数的缺陷
我们已经知道softmax函数的具体定义为:
$$
\hat{y}{j}=\frac{\exp \left(o{j}\right)}{\sum_{k} \exp \left(o_{k}\right)}
$$
现在思考这样一个问题,如果oj的值为100,会发生什么?
考虑到float32的表示范围,e^100根本无法表示,此时会产生数值上溢!
一种解决办法是将softmax改进,将每一项都减去最大的项,最后的结果会保持不变,即:
$$
\hat{y}_j = \frac{\exp(o_j - \max(o_k)) \exp(\max(o_k))}{\sum_k \exp(o_k - \max(o_k)) \exp(\max(o_k))}
= \frac{\exp(o_j - \max(o_k))}{\sum_k \exp(o_k - \max(o_k))}.
$$
这样做又会带来一个新的问题,假设现在oj - max(o_k)的值为-100, 那么又会产生数值下溢问题,再几次反向传播后,很容易出现梯度变为NAN的问题
由于softmax函数输出的是每个类别的预测概率,这些中间变量最终会通过交叉熵损失函数(crossEntropy)转换为损失值,而交叉熵损失函数又需要使用到log函数,log和e组合到了一起,就让事情变得简单起来,我们可以直接省去softmax函数的计算,直接将交叉熵损失函数变为下列形式:
$$
\log(\hat{y}_j) = \log\left( \frac{\exp(o_j - \max(o_k))}{\sum_k \exp(o_k - \max(o_k))} \right)
= \log(\exp(o_j - \max(o_k))) - \log\left( \sum_k \exp(o_k - \max(o_k)) \right)
= o_j - \max(o_k) - \log\left( \sum_k \exp(o_k - \max(o_k)) \right).
$$
我们没有将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的预测,并同时计算softmax及其对数