↑ 点击蓝字 关注视学算法
作者丨汤凯华@知乎
来源丨https://zhuanlan.zhihu.com/p/259569655
编辑丨极市平台
本文主要介绍我们组今年被NeurIPS 2020接收的论文《Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect》。目前代码已经在Github上开源,链接如下:
https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch
论文链接:https://kaihuatang.github.io/Files/long-tail.pdf
这个工作从因果分析的角度,利用一种非常优雅的实现,提出了一种崭新的长尾问题的通用解决思路。而且实现非常简单,能够广泛适用于各种不同类型的任务。之前几期介绍了很多过往的工作,然而有几个问题(缺陷)却一直萦绕在我的脑中没有被解决:
虽然利用数据集分布的re-sampling和re-weighting训练方法可以一定程度上缓解长尾分布的问题。然而这种利用其实是违背现实学习场景的,他们都需要在训练/学习之前,了解“未来”将要看到的数据分布,这显然不符合人类的学习模式,也因此无法适用于各种动态的数据流。
目前长尾分类最优的Decoupling算法依赖于2-stage的分步训练,这显然不符合深度学习end-to-end的训练传统,而论文本身也没有提出让人信服的理由解释为什么特征提取backbone需要在长尾分布下学,而偏偏classifier又需要re-balancing的学。
长尾分布下简单的图片分类问题和其他复杂问题(诸如物体检测和实例分割)研究的割裂,目前长尾分布下图片分类问题的算法日趋复杂,导致很难运用于本来框架就很繁琐的检测分割等任务。而我觉得长尾问题的本质都是相似的,真正的解决方案一定是简洁的,可以通用的。
基于上面这些问题,也就最终诞生了我们的这篇工作。我们提出的De-confound-TDE的优势如下:
我们的训练过程完全不依赖于提前获取的数据分布,只需要在传统训练框架的基础上统计一个特征的移动平均向量,并且这个平均特征在训练中并不会参与梯度计算(只在测试时使用)。这也就解决了传统长尾分类方法依赖“提前获取未来数据分布”的问题。
尽管我们的测试过程和训练过程有所不同,但我们的模型是一次训练到位的,并不需要依赖繁琐的多步训练,这大大简化了拓展至其他任务时的修改成本。
并且,我们成功的将这个方法运用于图片分类(ImageNet-LT,Long-tailed CIFAR-10/-100)和物体检测/实例分割(LVIS dataset)等多个任务,均取得了最优的结果(截止至我们投稿也就是2020年5月)。这证明了我们的方法可以作为继re-balancing之后又一个在长尾数据下通用的Strong Single-Stage baseline。
长尾分布这个问题是什么我已经在往期文章里介绍过了,我一直觉得大家普遍运用的re-balancing不是一种方法而更像是一个trick,当我决定做这个task时,我follow的Decoupling给了我启发。他的2-stage训练模式让我意识到,re-balancing确实是有问题的,因为他会破坏backbone的特征学习,而必须为此额外增加一个stage来预训练所有的特征提取部分,并且在后续re-balancing学习中freeze住backbone。但既然backbone可以在原始长尾数据上直接训练,classifier真的需要再利用额外的一步训练来balance吗?还是只是目前没有找到对的方法而已呢?
在介绍本文复杂的因果图构建和后续推导实现,让同学们失去耐心之前,对于想赶紧快速食用我们方法的同学,我给个4步速成指南吧:
1) 训练时需要De-confound Training,说人话就是classifier需要使用multi-head normalized classifier,即每个类的logits计算如下: ,其中 是超参,K是multi-head的数量。分子部分为正常的无bias项的线性分类器,分母部分可以是任何形式的normalization(公式中是我们自己提出的形式,不过事实上如果分母变成 ,也就是cosine classifier也一样work)。
2) 同时不要忘记在训练时统计一个移动平均特征 , 并将他的单位方向看作是特征对头部类的倾向方向 。
3) 在测试时做counterfactual TDE inference,人话就是从training的logits中剔除我们认为代表对头部类过度倾向的部分,即测试时改用如下公式计算TDE logits:
详细实现可以参考我们的代码文件:
De-confound-TDE
https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/blob/master/classification/models/CausalNormClassifier.py
4) 最后,当运用到诸如物体检测,实例分割的任务中时,还需要对background类做特殊处理,因为background类也是一个头部大类,但是对background的bias却是有益的,因为我们需要依赖他来剔除大量琐碎的细节。其计算方式如下,其中i=0代表background类, 是利用原始training的logits计算出的probability, 是利用TDE logits计算出的softmax后的概率。实现可参考链接:
五. De-confound-TDE 算法
至于4.2的inference时TDE的减法其实还是比较直接的,我这就不细说了。同时考虑到在一些特殊任务中,有些大类是需要保持合理的倾向性的,比如物体检测和实例分割时,就需要合理地倾向于background类这个大类,否则就会检测到过多无意义地细节。因此我们在4.3中介绍了Background-Exempted Inference这种特殊处理。
六. How to Understanding TDE (怎么理解TDE)
七. Experiments (实验结果)