这是论文《在神经网络中提炼知识》的 PyT orch 实现/教程。
这是一种使用经过训练的大型网络中的知识来训练小型网络的方法;即从大型网络中提炼知识。
直接在数据和标签上训练时,具有正则化或模型集合(使用 dropout)的大型模型比小型模型的概化效果更好。但是,在大型模型的帮助下,可以训练小模型以更好地进行概括。较小的模型在生产中更好:速度更快、计算更少、内存更少。
经过训练的模型的输出概率比标签提供的信息更多,因为它也会为错误的类分配非零概率。这些概率告诉我们,样本有可能属于某些类别。例如,在对数字进行分类时,当给定数字 7 的图像时,广义模型会给出7的高概率,给2的概率很小但不是零,而给其他数字分配几乎为零的概率。蒸馏利用这些信息来更好地训练小型模型。