论文介绍
论文的地址在此,作者使用了 Lua 语言实现,代码地址在这,然而我不会 Lua 语言,于是找了找是否有 Python 的实现版本。还真有,Python 版本代码地址在这。 但是 Python 版本的代码篇幅太长,且几乎没有注释,于是我将其重写了一遍,一些工具类是直接复制别人的,但是核心代码我改写了一下,并添加了一些注释。
我将里面的数据获取模块移除了。
论文实现
论文共用了两个办法:1)普通 seq2seq 模型;2)作者自创的 seq2tree 模型。其中每个模型又分别有 lstm 实现和 lstm + attention 实现两种版本。虽然两个版本使用的技术不同,但是说到底也只是同一个模型。以下讲解原理。
seq2seq 模型
RNN
论文中使用 LSTM 实现 seq2seq 模型,训练之后,accuracy 大约在 70%。
Transformer
我自己用了 Transformer 改写了一下,并且调了几天的参数,发现效果出奇的差,accuracy 最好只有 18%。然后我还发现,对于短句子几乎是百分比预测正确,对于长句子百分比预测错误。所以我怀疑是否是位置编码那产生的问题,考虑到一个逻辑形式它并不是纯粹的线性结构,它的内部是由很多括号的。 经调参后得到最好的一组参数如下: 1. learning rate: 0.001 2. dim_feedforward: 随意(我设置为 256) 3. h_model: 256 4. nhead: 4 5. encoder_layer/decoder_layer: 1 6. dropout: 0.4 7. batch_size: 32(16 的效果可能更好) 8. epoch: 95(epoch 可以进一步修改) 9. src_mask: False 10. tgt_mask: True 11. memory_mask: False