目录
-
-
-
- 引言
- 问题分析及解决方案
- 总结
- 参考资料
-
-
引言
- 最近在使用CTC Loss来训练文本识别的相关模型,在制作好数据集,开始训练时,不出几个batch,就出现本文题目的错误
- 经过几经查找,有些眉目
问题分析及解决方案
- 问题出现原因:输入CTC Loss的
input_length 和对应label_length 对应问题 - 问题分析:
- 经过阅读CTC Loss发表论文中,可以得知如下:

- 标黄之处说明在原始标签字母L上长度小于或等于T的序列集,而这里的L的长度指的就是
label_length , T指的就是input_length - 那么
input_length 和label_length 之间的长度关系是怎样的呢?- 我曾经试图通过解析相关深度学习框架下实现CTC Loss的源码来分析具体计算过程,但是没有找到直观的,比较相信参考资料[1]中所述,尽管没有找到对应源码来支撑他的观点。
ctc_loss在计算预测结果和真值的loss的时候,会在你真值label中重复的字符之间插入空符,所以必须将label_length加上空符个数大于input_length的图片删除掉。
- 借用参考资料[1]中的示例:
举个例子,你图片高度为32,宽度为160,那么
input_length=40 (160 // 4这里取决于网络结构,输入图像长度为160,除以4是因为有两个pool层)。
label='abbbccddddcccaa' ,label_length=15 ,经过计算repreat_number=2(bbb)+1(cc)+3(dddd)+2(ccc)+1(aa) ,然后再加上开头结果的空符数2 ,最终等于11。也就是说必须满足label_length(15)+repreat_number(11)<=input_length(40) 的图片才是合格的图片。- 实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17# 读取图像
Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
# 统计Label中重复元素
l = [len(list(g)) for k, g in itertools.groupby(Label)]
repeat_number = 0
for n in l:
if n > 1:
repeat_number += (n - 1)
# 获得输入CTC Loss时的input_length,这主要取决于输入图像的尺寸
input_length = ResizedImg.shape[1] // 4
# 最终判断是否为合格图像
if len(Label) + repeat_number + 2 > input_length:
continue
- 经过阅读CTC Loss发表论文中,可以得知如下:
总结
- 通过以上对数据集进行过滤处理后,一般都能解决该文章题目问题。
参考资料
[1] 训练CRNN时,关于ctc_loss的几点注意事项
[2] CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
