CTC的Decode算法-Prefix Beam Search
CTC几种常见的解码方式
- greedy decode,每帧输出最大值,然后规整。
- 在ctc字符串上做beam search,输出的n个结果规整,并合并相同序列,然后再应用语言模型。(secondpass LM)
- 在规整字符串上做beam search, 可额外在应用语言模型。该算法叫Prefix Beam Search,可以在解码过程中直接应用LM。(firstpass LM)
- 使用fst静态解码。可引入语言模型和字典模型。
prefix beam search的笔记
本文为prefix beam search的笔记。
Grave最早提出的prefix search不好理解,现在也没有人用。可直接参考Awni Hannun提出的prefix beam search。
定义两个概念
- ctc字符串:模型在每个时间点输出的字符组成的字符串。
- 规整字符串:即ctc字符串去除连续重复和blank后的字符串。
prefix beam search基本思想:
- 在每个时刻t,计算所有当前可能输出的规整字符串(即去除连续重复和blank的字符串)的概率。
- 因为t时刻的规整字符串长度最长为t,最短为0(全都是blank,所以规整完是0),所有t时刻所有可能的候选规整字符串个数是\(\sum_{0}^{t}{C^t}\),其中C是字符词典大小。 该值随着t的增大而增大,穷举搜索空间太大。
- 用beam search的方法,在每个时刻选取最好的N个路径,从而将每个时间点t上的搜索空间变为常数。
下图a表示一个规整字符串,表示其前缀,a表示其最后一个字符是a,箭头表示t+1时刻的字符串可能由哪些t时刻的字符串产生。
比如t+1时刻的*ab来自于两种情况:
- t时刻输出规整字符串*a,t+1时刻输出字符b
- t时刻输出规整字符串*ab,t+1时刻输出字符b
一个直接的想法是,对t时刻处在beam中的每个规整字符串,更新其对应的t+1时刻规整字符串的概率值。
但是这里并不能直接使用t时刻规整字符串的概率乘上t+1时刻输出字符概率得到t+1时刻规整字符串的概率。
我们看一个具体例子,假设第3时刻输出规整字符串a,第4时刻输出字符b。该规整字符串a的概率为-a-,-aa,aa-,aaa等不同的ctc字符串的概率和。
- 如果是在aa-这个ctc字符串基础上,在t+1时刻再输出a,得到的ctc字符串为aa-a,其规整字符串为aa.
- 如果是在aaa这个ctc字符串基础上,在t+1时刻再输出a,得到的ctc字符串为aaaa,其规整字符串为a.
可以看到两者使用同样的t+1的输出字符,却产生了不同的规整字符串。因此,需要区分对待blank和非blank结尾的ctc字符的规整概率:
- \(p_b(L)\) 表示所有以blank结尾且规整后是L的各ctc字符串的概率之和
- \(p_nb(L)\) 表示所有以非blank结尾且规整后是L的各ctc字符串的概率之和
比如,假设T=3,则 \(\begin{aligned} & p_b(a) = p(aa-) + p(-a-) + p(a--) \\ & p_{nb}(a) = p(aaa)+p(-aa) + p(--a) \\ \end{aligned}\)
这里只随便看其中一个,比如\(*a\), 其t+1时刻可以产生规整字符串有四种情况。
- 当t+1输出是blank时,产生规整字符串*a
- 当t+1输出是a时,可以产生规整字符串*a
- 当t+1输出是a时,也可产生规整字符串*aa
- 当t+1输出是b时(或其他不等于a和blank的字符),产生规整字符串*ab
四种情况对应的需要更新的统计量的公式如下
\[\begin{aligned} & p_{b}^{t+1}(*a)\! & \Leftarrow & \;[ p_{b}^{t}(*a) + p_{nb}^{t}(*a) ] p_{ctc}^{t+1}(-) \\ & p_{nb}^{t+1}(*a)\! & \Leftarrow & \;p_{nb}^{t}(*a)p_{ctc}^{t+1}(a) \\ & p_{nb}^{t+1}(*aa)\! & \Leftarrow & \;p_{b}^{t}(*a)p_{ctc}^{t+1}(a) \\ & p_{nb}^{t+1}(*ab)\! & \Leftarrow & \;[p_{b}^{t}(*a) + p_{nb}^{t}(*a)]p_{ctc}^{t+1}(b) \\ \end{aligned}\]- \(\Leftarrow\) 不是赋值,而是C语言中的
+=
操作。 - 注意, \(p_{nb}^{t+1}(*ab)\)的值不仅来源于
t时刻beam里的*a路径
, 在对t时刻beam里的*b路径
进行更新t+1的路径时,也会贡献概率。
说明
ctc字符串上的beam search和规整字符串上的beam search的区别:
- 同样的beam size下ctc字符串上的beam search,其丢掉的ctc路径比在规整字符串上做beam search的更多,所以最终的结果就更差一些。
参考Awni在Distill上的文章中的图片
直接做beam search,可以看到beam size=3时,每个时刻只有三个路径v
在规整字符串上做beam search,可以看到beam size=3时,每个时刻可以保留更多路径
Prefix beam search仍然会丢失一些ctc序列的概率,比如上例中,字符串ba里没有包含blank blank blank blank b a这条CTC序列。
- 在规整字符串上做beam search允许在fisrt pass引入LM得分,因为解码过程中就知道规整后的形式和space的位置,在ctc字符串上做beam search则不行。
Hannun论文《First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs 》中给出的算法。 注意两点:
- 可以在first pass里引入word LM的得分,只要在输出space的时候加入语言模型的得分。
- if $l^{+}$ not in $A_{prev}$ 的作用.当t时刻的beam里只有a而没有ab时,在t+1时刻计算ab,只使用了t时刻a的扩展,会丢失来自t时刻*ab的得分。
Hannun给出的一个python实现。
参考资料:
- https://distill.pub/2017/ctc/
- https://arxiv.org/abs/1408.2873
- https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7