关于python:使用混淆矩阵理解多标签分类器

Understanding multi-label classifier using confusion matrix

我有一个包含12个类的多标签分类问题。我正在使用Tensorflowslim来训练在ImageNet上预训练的模型。这是每个班级参加培训的人数的百分比


输出校准

我认为首先要意识到的一件事是,神经网络的输出可能校准不良。我的意思是说,它提供给不同实例的输出可能会导致良好的排名(带有标签L的图像比没有标签L的图像具有更高的得分),但是这些得分不能始终可靠地解释为概率(对于没有标签的实例,它可能给出非常高的分数,如0.9;对于带有标签的实例,它可能给出更高的分数,如0.99)。我想是否会发生这种情况,除其他外,取决于您选择的损失函数。

有关此的更多信息,请参见例如:https://arxiv.org/abs/1706.04599

一对一地学习所有课程

0级:AUC(曲线下面积)= 0.99。多数民众赞成在一个很好的成绩。混淆矩阵中的第0列也看起来不错,所以这里没有错。

第1类:AUC = 0.44。如果我没记错的话,那真是太糟糕了,低于0.5,这意味着您最好还是故意做出与网络预测的相反的效果。

看看您的混淆矩阵中的第1列,它的得分几乎到处都是相同的。对我来说,这表明网络并没有学到太多关于该类的知识,而只是根据训练集中包含该标签的图像的百分比"猜测"(55.6%)。由于此百分比在验证集中下降到50%,因此该策略的确意味着它的效果比随机效果要差一些。尽管第1行在此列中仍然是所有行中数量最多的,所以它似乎至少学到了一点点,但学到的并不多。

第2类:AUC = 0.96。这是非常好的。

您对此类的解释是,根据整个列的浅色阴影,通常会预测它不存在。我不认为这种解释是正确的。看看它在对角线上的得分如何> 0,在列中其他任何地方都只有0。它在该行中的得分可能较低,但很容易与同一列中的其他行分开。您可能只需要设置阈值,以选择该标签是否相对较低。我怀疑这是由于上面提到的校准问题。

这也是为什么AUC非常好的原因;可以选择一个阈值,以使分数高于阈值的大多数实例正确地带有标签,而低于阈值的大多数实例正确地带有标签。但是,该阈值可能不是0.5,这是假设进行良好校准后可能会遇到的阈值。绘制此特定标签的ROC曲线可以帮助您准确确定阈值应在哪里。

第3类:AUC = 0.9,非常好。

您将其解释为始终被检测为存在,并且混淆矩阵的确在列中有很多高数字,但是AUC很好,对角线上的像元确实具有足够高的值,可能是容易与其他人分开。我怀疑这与第2类类似(只是四处翻转,到处都有高预测,因此正确决策需要较高的阈值)。

如果您想确定某个选择良好的阈值是否确实可以正确地将大多数"阳性"(类别3的实例)与大多数"阴性"(类别3的实例)分开,则可以\\将要根据标签3的预测得分对所有实例进行排序,然后遍历整个列表,并在每对连续条目之间计算如果您决定将阈值放置在此处而获得的验证集准确性,以及选择最佳阈值。

第4类:与第0类相同。

5类:AUC = 0.01,显然很糟糕。也同意您对混淆矩阵的解释。很难确定为什么它在这里表现这么差。也许这是一种很难识别的物体?可能还会出现一些过拟合现象(从第二个矩阵的列中判断,训练数据为0误报,尽管这种情况也会发生在其他类上)。

从训练到验证数据增加的标签5图像比例可能也无济于事。这意味着网络在训练过程中在此标签上表现良好的重要性不如在验证过程中重要。

6类:AUC = 0.52,仅比随机数稍好。

根据第一个矩阵中的第6列判断,这实际上可能与第2类类似。如果我们也考虑到AUC,它看起来也没有学会很好地对实例进行排名。类似于第5类,只是没有那么糟糕。同样,培训和验证的分配也大不相同。

第7类:A??UC = 0.65,而不是平均值。例如,显然不如第2类好,但也不如仅从矩阵中解释的那样坏。

8级:AUC = 0.97,非常好,类似于3级。

第9类:AUC = 0.82,虽然不佳,但仍然不错。矩阵中的列中有许多暗单元,并且数量非常接近,以至于我认为AUC令人惊讶地好。训练数据中几乎所有图像中都存在它,因此被预测为经常出现也就不足为奇了。也许其中一些非常暗的单元格仅基于绝对数量很少的图像?这将很有趣。

10级:AUC = 0.09,太糟了。对角线上的0非常令人担忧(您的数据是否正确标记?)。根据第一个矩阵的第10行,对于第3类和第9类似乎很困惑(cotton和primary_incision_knives看起来很像secondary_incision_knives吗?)。也许对训练数据也有些过拟合。

第11类:AUC = 0.5,没有比随机数好。性能不佳(矩阵中的得分明显过高)很可能是因为大多数训练图像中都存在此标签,但只有少数验证图像中存在此标签。

还有什么要绘制/测量的?

为了获得对数据的更多了解,我将首先绘制有关每个班级发生频率的热图(一个用于训练,一个用于验证数据)。单元格(i,j)将根据同时包含标签i和j的图像的比例进行着色。这将是一个对称图,对角线上的单元格将根据问题中的第一个数字列表进行着色。比较这两个热图,看看它们有何不同,是否可以帮助您解释模型的性能。

另外,了解(对于两个数据集)每个图像平均具有多少个不同的标签,以及对于每个单独的标签,平均与一个图像共享多少个其他标签可能很有用。例如,我怀疑带有标签10的图像在训练数据中具有相对较少的其他标签。如果网络识别出其他事物,这可能会使网络无法预测标签10,并且如果标签10确实突然在验证数据中更规律地与其他对象共享图像,则会导致性能下降。由于伪代码比单词更容易理解问题,因此打印类似以下内容可能很有趣:

1
2
3
4
5
6
7
8
9
10
11
12
13
# Do all of the following once for training data, AND once for validation data    
tot_num_labels = 0
for image in images:
    tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image =", avg_labels_per_image)

for label in range(num_labels):
    tot_shared_labels = 0
    for image in images_with_label(label):
        tot_shared_labels += (len(image.get_all_labels()) - 1)
    avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
    print("On average, images with label", label," also have", avg_shared_labels," other labels.")

对于单个数据集,它并不能提供太多有用的信息,但是如果您将其用于训练和验证集,则可以看出,如果数字非常不同,则它们的分布会完全不同

最后,我有点担心您的第一个矩阵中的某些列在许多不同的行上出现的均值预测完全相同。我不太确定是什么原因引起的,但这可能对调查很有帮助。

如何提高?

如果您还没有的话,我建议您对训练数据进行数据扩充。由于您正在使用图像,因此可以尝试将现有图像的旋转版本添加到数据中。

特别是对于多标签情况而言,目标是检测不同类型的对象,尝试简单地将一堆不同的图像(例如,两个或四个图像)连接在一起也可能很有趣。然后,您可以将其缩小到原始图像大小,并在标签分配原始标签集的并集时使用。在合并图像的边缘会出现有趣的不连续点,我不知道这是否有害。也许对您的多对象检测而言,这不是我想要的尝试。