@@ -191,7 +191,7 @@ Loss = \alpha \cdot \mathcal L_{CE} + \beta \cdot \tau ^ 2 \cdot \mathcal L_{KL}
191191$$
192192
193193
194- &emsp ;&emsp ; 一般情况下, 要保持 $\alpha + \beta = 1$。 实践中 $\alpha$ 通常取 $0.1$, 而 $\beta$ 通常取 $0.9$。至于为什么要对软损失部分的 $\operatorname {KL}$ 乘一个 $\tau ^ 2$, 简单的解释是为了保持软损失和硬损失在梯度上的平衡, 而具体的推导过程请选择性的阅读 [ 6.7 ] ( #67- 损失函数部分推导选修 ) 节。
194+ &emsp ;&emsp ; 一般情况下, 要保持 $\alpha + \beta = 1$。 实践中 $\alpha$ 通常取 $0.1$, 而 $\beta$ 通常取 $0.9$。至于为什么要对软损失部分的 $\operatorname {KL}$ 乘一个 $\tau ^ 2$, 简单的解释是为了保持软损失和硬损失在梯度上的平衡, 而具体的推导过程请选择性的阅读 [ 损失函数部分推导 ] ( #损失函数部分推导选修 ) 节。
195195
196196
197197
@@ -847,6 +847,129 @@ print('best_Test_Acc = ', best_test_acc)
847847 best_Test_Acc = 95.89
848848```
849849
850+ ## 损失函数部分推导(选修)
851+
852+
853+ ### Softmax函数求导
854+
855+
856+ &emsp ;&emsp ; 假设对于一个任意的 Logits 向量 $\mathbf z = [ z_1, z_2, ..., z_ {K}] \in \mathbb{R}^{1\times K}$, 其中 $K$ 是数据集的类别数。通过带温度的Softmax函数计算后得到向量 $\mathbf s = [ s_1(\tau), s_2(\tau), ..., s_K(\tau)] $ ,其中 $s_i(\tau)$ 的定义为 :
857+ $$
858+ s_i(\tau) = \frac{e^{z_i/ \tau}}{\sum_{j=1}^K e^{z_j / \tau}}
859+ $$
860+
861+
862+ &emsp ;&emsp ; 对于任意 $z_k \in \mathbf z$ , $s_i(\tau)$ 对 $z_k$ 的偏导分为两种情况:
863+
864+
865+ &emsp ;&emsp ; 当 $i = k$ 时, 有
866+ $$
867+ \begin{align*}
868+ \frac{\partial s_i(\tau)}{\partial z_k}
869+ &= \frac{\partial}{\partial z_k}\frac{e^{z_k / \tau}}{\sum_{j=1}^K e^{z_j / \tau}} \\
870+ &= \frac{\frac{\partial}{\partial z_k}e^{z_k / \tau}\ \sum_{j=1}^K e^{z_j / \tau} - e^{z_k / \tau}\ \frac{\partial}{\partial z_k}\sum_{j=1}^K e^{z_j / \tau}}{\left( \sum_{j=1}^K e^{z_j / \tau}\right) ^ 2} \\
871+ &= \frac{\frac{1}{\tau}\ e^{z_k/ \tau}}{\sum_{j=1}^K e^{z_j / \tau}} - \frac{e^{z_k/ \tau }\ \frac{1}{\tau}\ e^{z_k/ \tau }}{\left( \sum_{j=1}^K e^{z_j / \tau}\right) ^ 2}\\
872+ &= \frac{1}{\tau} (s_k(\tau) - s_k(\tau)\ s_k(\tau)) \\
873+ &= \frac{1}{\tau}\ s_k(\tau)\ (1 - s_k(\tau))
874+ \end{align*}
875+ $$
876+
877+ &emsp ;&emsp ; 当 $i \neq k$ 时, 有
878+ $$
879+ \begin{align*}
880+ \frac{\partial s_i(\tau)}{\partial z_k}
881+ &= \frac{\partial}{\partial z_k}\frac{e^{z_i / \tau}}{\sum_{j=1}^K e^{z_j / \tau}} \\
882+ &= \frac{\frac{\partial}{\partial z_k}e^{z_i / \tau}\ \sum_{j=1}^K e^{z_j / \tau} - e^{z_i / \tau}\ \frac{\partial}{\partial z_k}\sum_{j=1}^K e^{z_j / \tau}}{\left( \sum_{j=1}^K e^{z_j / \tau}\right) ^ 2} \\
883+ &= 0 - \frac{e^{z_i/ \tau }\frac{1}{\tau}e^{z_k/ \tau }}{\left(\sum_{j=1}^K e^{z_j / \tau}\right) ^ 2}\\
884+ &= -\frac{1}{\tau}\ s_i(\tau)\ s_k(\tau)
885+ \end{align*}
886+ $$
887+
888+ &emsp ;&emsp ; 因此, 对于 $\varphi (z_i)$ 对 $z_k$ 偏导有
889+ $$
890+ \frac{\partial s_i(\tau)}{\partial z_k} =
891+ \left\{
892+ \begin{matrix}
893+ \frac{1}{\tau}\ s_k(\tau)\ (1 - s_k(\tau))& \text{if } i = k \\
894+ -\frac{1}{\tau}\ s_i(\tau)\ s_k(\tau) & \text{if } i \neq k
895+ \end{matrix}
896+ \right.
897+ $$
898+
899+
900+ ### 硬损失CE求导
901+
902+ &emsp ;&emsp ; 对于学生模型的输出logits $\mathbf{v} = [ v_1, v_2, ..., v_K] \in \mathbb{R}^{1\times K}$ 中任意一个 $v_k$ , 硬损失 $ \mathcal L_ {CE}$ 对 $v_k$ 的梯度为 :
903+ $$
904+ \begin{align*}
905+ \mathcal L_{CE} &= \operatorname {CE}(\mathbf q(\tau = 1), \mathbf y) \\
906+ &= \sum _ {j=1}^K - y_j \log q_j(\tau = 1) \\
907+ \frac{\partial \mathcal L_{CE}}{\partial v_k} &= \frac{\partial}{\partial v_k}\sum _ {j=1}^K - y_j \log q_j(\tau = 1)\\
908+ &= \frac{\partial}{\partial v_k}\sum _ {j=1, j\neq k}^K - y_{j} \log q_j(\tau = 1) + \frac{\partial}{\partial v_k} - y_k \log q_k(\tau = 1)\\
909+ &= \sum _ {j=1, j\neq k}^K - y_{j} \frac{1}{q_j(\tau=1)}\ \frac{-1}{\tau=1}q_j(\tau=1)q_k(\tau=1) \\ &\ \ \ \ \ \ - y_k \frac{1}{q_k(\tau=1)}\frac{1}{\tau=1}q_k(\tau=1)(1-q_k(\tau=1))\\
910+ &= \frac{1}{1}(1-y_k)q_k(\tau=1) - \frac{1}{1}y_k(1-q_k(\tau=1))\\
911+ &= q_k(\tau=1) - y_k \\
912+ \end{align*}
913+ $$
914+
915+
916+ ### 软损失KL求导
917+
918+ &emsp ;&emsp ; 软损失$\mathcal L_ {KL}$为教师模型的软标签$\mathbf p(\tau)$ 和学生模型的软标签 $\mathbf q(\tau)$ 的KL散度。$\mathcal L_ {KL}$ 对于学生模型的输出中任意一个 $v_k$ 的梯度为
919+ $$
920+ \begin{align*}
921+ \mathcal{L}_{\operatorname{KL}} &= \operatorname {KL}(\mathbf{q}(\tau), \mathbf{p}(\tau)) \\
922+ &= \sum_{j=1}^{K} \ p_j(\tau) \log \frac{p_j(\tau)}{q_j(\tau)} \\
923+ \frac{\partial \mathcal{L}_{\operatorname{KL}}}{\partial v_k} &= \frac{\partial}{\partial v_k} \sum_{j=1}^{K} \ p_j(\tau) \log \frac{p_j(\tau)}{q_j(\tau)} \\
924+ &= \frac{\partial}{\partial v_k} \sum_{j=1}^{K} \left( \ p_j(\tau) \log {p_j(\tau)} - p_j(\tau)\log{q_j(\tau)} \right) \\
925+ &= \frac{\partial}{\partial v_k} \left(\sum_{j=1}^{K} - p_j(\tau)\log{q_j(\tau)}\right) \\
926+ &= \frac{\partial}{\partial v_k} \left( \sum_{j=1, j\neq k} ^ K -p_j(\tau) \log q_j(\tau) -p_k(\tau) \log q_k(\tau) \right) \\
927+ &= \sum_{j=1, j\neq k}^K \left( -p_j(\tau) \frac{\partial}{\partial v_k} \log q_j(\tau)\right) - \frac{\partial}{\partial v_k} p_k(\tau) \log q_k(\tau)\\
928+ &= \sum_{j=1, j\neq k}^K -\frac{p_j(\tau)}{q_j(\tau)}\left[ -\frac{1}{\tau}q_j(\tau)q_k(\tau) \right] - \frac{p_k(\tau)}{q_k(\tau)}\left[ \frac{1}{\tau} q_k(\tau)(1 - q_k(\tau))\right]\\
929+ &\approx \frac{1}{\tau} \sum_{j=1, j\neq k}^K p_j(\tau) q_k(\tau) - \frac{1}{\tau} p_k(\tau)(1 - q_k(\tau)) \ \ \ \text{where} \sum_{j=1}^K p_j(\tau)\approx 1\\
930+ &= \frac{1}{\tau} (1 - p_k(\tau))q_k(\tau) - \frac{1}{\tau} p_k(\tau)(1 - q_k(\tau)) \\
931+ &= \frac{1}{\tau} \left[ q_k(\tau) - p_k(\tau)q_k(\tau) - p_k(\tau) + p_k(\tau)q_k(\tau)\right] \\
932+ &= \frac{q_k(\tau) - p_k(\tau)}{\tau}
933+ \end{align*}
934+ $$
935+
936+ ### 泰勒逼近
937+
938+ &emsp ;&emsp ; 对于 $e^x$ , 当 $x$ 趋于 0 的时候有 $e^x \approx 1 + x + ...$
939+
940+ &emsp ;&emsp ; 最终, 硬损失CE和软损失KL对于$v_k$ 的梯度为:
941+ $$
942+ \left\{
943+ \begin{matrix}
944+ \frac{\partial \mathcal L_{CE}}{\partial v_k} = q_k(\tau=1) - y_k
945+ \\
946+ \frac{\partial \mathcal{L}_{\operatorname{KL}}}{\partial v_k} = \frac{1}{\tau}(q_k(\tau) - p_k(\tau))
947+ \end{matrix}
948+ \right.
949+ $$
950+
951+ &emsp ;&emsp ; 对于$\frac{\partial \mathcal L_ {CE}}{\partial v_k}$ 展开有 :
952+ $$
953+ \begin{align*}
954+ \frac{\partial \mathcal L_{CE}}{\partial v_k} &= q_k(\tau=1)-y_k\\
955+ &= \frac{e^{v_k}}{\sum_{j=1}^K e^{v_j}} - y_k \\
956+ &\approx \frac{1+v_k}{\sum_{j=1}^K 1 + v_j} - y_k , \text{where} \sum v_j = 0\\
957+ &= \frac{1+v_k}{K} - y_k\\
958+ \end{align*}
959+ $$
960+
961+ &emsp ;&emsp ; 对于 $\frac{\partial \mathcal{L}_ {\operatorname{KL}}}{\partial v_k}$ 展开有 :
962+ $$
963+ \begin{align*}
964+ \frac{\partial \mathcal{L}_{\operatorname{KL}}}{\partial v_k} &= \frac{1}{\tau}(q_k(\tau) - p_k(\tau))\\
965+ &= \frac{1}{\tau} (\frac{e^{v_k/\tau}}{\sum_{j=1}^K e^{v_j / \tau}} - \frac{e^{u_k/ \tau}}{\sum_{j=1}^K e^{u_k/\tau}}) \\
966+ &\approx \frac{1}{\tau}(\frac{1 + v_k/\tau}{\sum_{j=1}^K (1 + v_j/\tau)} - \frac{1+u_k/\tau}{\sum_{j=1}^K (1+u_j/ \tau)})\\
967+ &= \frac{1}{\tau}(\frac{v_k/\tau - u_k}{K}) \\
968+ &= \frac{1}{K \ \tau^2} v_k - \frac{u_k}{K\tau}
969+ \end{align*}
970+ $$
971+
972+ &emsp ;&emsp ; 此时可以发现, 硬损失中梯度对于$v_k$ 的部分时软损失的梯度中对于$v_k$部分的 $\tau ^ 2$ 倍, 所以在最终计算损失函数Loss的时候, 需要给 $\mathcal{L}_ {\operatorname{KL}}$ 乘上一个 $\tau ^ 2$ 以平衡两个损失之间的梯度。
850973
851974
852975## 引用资料
0 commit comments