ctc_loss_calculator.cc:144] No valid path found.或loss: inf

目录

        • 引言
        • 问题分析及解决方案
        • 总结
        • 参考资料

引言

  • 最近在使用CTC Loss来训练文本识别的相关模型,在制作好数据集,开始训练时,不出几个batch,就出现本文题目的错误
  • 经过几经查找,有些眉目

问题分析及解决方案

  • 问题出现原因:输入CTC Loss的input_length和对应label_length对应问题
  • 问题分析:
    • 经过阅读CTC Loss发表论文中,可以得知如下:
      在这里插入图片描述
    • 标黄之处说明在原始标签字母L上长度小于或等于T的序列集,而这里的L的长度指的就是label_length, T指的就是input_length
    • 那么input_lengthlabel_length之间的长度关系是怎样的呢?
      • 我曾经试图通过解析相关深度学习框架下实现CTC Loss的源码来分析具体计算过程,但是没有找到直观的,比较相信参考资料[1]中所述,尽管没有找到对应源码来支撑他的观点。

      ctc_loss在计算预测结果和真值的loss的时候,会在你真值label中重复的字符之间插入空符,所以必须将label_length加上空符个数大于input_length的图片删除掉。

      • 借用参考资料[1]中的示例:

      举个例子,你图片高度为32,宽度为160,那么input_length=40(160 // 4这里取决于网络结构,输入图像长度为160,除以4是因为有两个pool层)。
      label='abbbccddddcccaa'label_length=15,经过计算repreat_number=2(bbb)+1(cc)+3(dddd)+2(ccc)+1(aa),然后再加上开头结果的空符数2,最终等于11。也就是说必须满足label_length(15)+repreat_number(11)<=input_length(40)的图片才是合格的图片。

      • 实现代码:
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        # 读取图像
        Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
        ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))

        # 统计Label中重复元素
        l = [len(list(g)) for k, g in itertools.groupby(Label)]
        repeat_number = 0
        for n in l:
            if n > 1:
                repeat_number += (n - 1)

        # 获得输入CTC Loss时的input_length,这主要取决于输入图像的尺寸
        input_length = ResizedImg.shape[1] // 4

        # 最终判断是否为合格图像
        if len(Label) + repeat_number + 2 > input_length:
            continue

总结

  • 通过以上对数据集进行过滤处理后,一般都能解决该文章题目问题。

参考资料

[1] 训练CRNN时,关于ctc_loss的几点注意事项
[2] CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)