相比于传统的 RNN 方法构建的翻译模型,使用 Google 提出的 multi-head self attention 方法可以更好的捕捉长距离的词与语词之间的关系。北大裴剑锋课题组最近在 arXiv 发布了其利用该方法应用于逆合成的研究,本问尝试利用该论文提供的先做进行重复试验。虽然谷歌已经提供了“非官方”的 TensorFlow 扩展 tensor2tensor 以实现这种 Transformer 模型的构建,本文则基于一个 NLP 社区贡献者提供的 Pytorch 版本。

原理介绍

Self-attention 介绍

要了解 self-attention,首先先要清楚 attention 的意思,该 博文 已经详细的介绍了 attention 以及其一些变种。

这个词 attention 即注意力,直观的感受是当我们判断一张图片进行分类时,会把注意力集中到部分区域,这些区域中的信息极大程度地决定了最终分类的结果。所以其本质可以认为是对输入向量进行权重判断,然后根据加权后的结果进行后续的计算。

建议读者先了解 attention,这里仅仅给出最重要的几个概念,即我们经常在相关文献中看到的几个字母 Q, K, V,分别对应 Query、Key 和 Value。简单的理解是 attention 层能够根据不同的 Query,对每个 value 进行加权,该权重又跟每个 Value 所对应的 Key 息息相关。
可以用如下公式表示:

$$ a = softmax(f(key, query)) $$

$$ c = \sum_{i=1}^{N}(a \times value) $$

self-attention 指的是 attention 得到的途径来源于编码器自身。假设有输入 X1 到 Xn,为了得到第 t 个输入对应的 attention,此时 query 就是 Xt(其实不是,还需要乘以一个权重W),K和V就是 X1 到 Xn(其实也不是,也要乘以一个权重)。
因此才会看到公式

$$ Attention (Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

$$ head = Attention(QW_i^Q,KW_i^K,VW_i^V) $$

需要注意的是第一个公式里的 QKV 三个值都是不同的,但是第二个公式里的 QKV 却是相同的,都是编码器中原始的输入,只是它们乘以了不同的权重参数 attention 计算(公式一)中的值不同。而这三个权重正是神经网络需要学习的参数。

Multi-head self attention

博客 已经详细的描述了 Multi-head self attention 这个算法的过程,英文原版可 点击此处 查看。这里仅仅是简单描述下 mulit-head 指的是什么。

其实根据前面的公式(head=xxx)已经能发现,一个 head 其实就是一套权重,即 WQ, WK 以及 WV并得到一个 attention,因此 multi-head 就是多套这样的权重得到多个 attention,这样的 attention 再拼接(concate)到一起,再通过另一个权重 WO 合并到一个 attention。

下图就是一个由 8 个 head 拼接而成,再经过 8d * d 的 WO 变回 维度为 d 的 attention。

实现过程

程序安装

首先使用anaconda下载最新的Pytorch

1
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch

下载模型实现源码
1
git clone https://github.com/jadore801120/attention-is-all-you-need-pytorch

源码解析参见:
(1)https://blog.csdn.net/weixin_42744102/article/details/87006081
(2)https://blog.csdn.net/weixin_42744102/article/details/87076089
(3)https://blog.csdn.net/weixin_42744102/article/details/87088748

化学反应数据收集与处理

比如一个化学反应BrCCO.ClS(=O)(=O)CC>>BrCCOS(=O)(=O)CC,我们可以认为这是一个翻译(或者应答)的过程,提问的是生成物BrCCOS(=O)(=O)CC,回答的是反应物BrCCO.ClS(=O)(=O)CC,然后要将他们拆成一个个单词,在这里我们可以认为每个字母代表一个单词,除了溴这种两个字母连在一起表示一个特定的原子,因此我们可以经过简单的预处理,变成
Br C C O S ( = O ) ( = O ) C C -> Br C C O . Cl S ( = O) ( = O ) C C,由于括号、等号都有有特殊的含义,因此也都算作一个单词。最后把问题整理到一个文件,如 train_src,把回答整理到另一个文件,如train_tgt,就完成了数据的准备。

构建模型

参考源码中的介绍,我们可以经过如下处理

1
2
3
4
5
6
7
8
9
10
11
12
# 预处理
python preprocess.py -train_src data/reaction_dev/train_src -train_tgt data/reaction_dev/train_tgt -valid_src data/reaction_dev/valid_src -valid_tgt data/reaction_dev/valid_tgt -save_data data/reaction_dev.low.pt -max_len 150 -share_vocab

# 训练
python train.py -data data/reaction_dev.low.pt -save_model trained -save_mode best -proj_share_weight -label_smoothing

# 预测
python translate.py -model trained.chkpt -vocab data/reaction_dev.low.pt -src data/reaction_dev/test_src

# 以下是笔者自己设定的训练参数,仅供参考
python train.py -data data/reaction_dev.low.pt -save_model trained -save_mode best -proj_share_weight -label_smoothing -d_model 64 -d_k 16 -d_v 16 -d_inner_hid 1024

以下是使用一万条数据得到的结果,平均每个 epoch 需要 1 分钟

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
[ Epoch 0 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 17.45903, accuracy: 36.509 %, elapse: 1.069 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 10.20915, accuracy: 39.760 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 1 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 11.90654, accuracy: 40.024 %, elapse: 1.058 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 6.40426, accuracy: 45.837 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 2 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 8.44371, accuracy: 46.783 %, elapse: 1.075 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 5.06044, accuracy: 50.435 %, elapse: 0.069 min
- [Info] The checkpoint file has been updated.
[ Epoch 3 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 7.18980, accuracy: 50.097 %, elapse: 1.085 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 4.30800, accuracy: 52.780 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 4 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 6.58546, accuracy: 52.369 %, elapse: 1.069 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 3.94653, accuracy: 54.824 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 5 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 6.19902, accuracy: 54.195 %, elapse: 1.066 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 3.64838, accuracy: 56.507 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 6 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 5.95536, accuracy: 55.393 %, elapse: 1.066 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 3.54062, accuracy: 57.446 %, elapse: 0.069 min
- [Info] The checkpoint file has been updated.
[ Epoch 7 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 5.75498, accuracy: 56.546 %, elapse: 1.084 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 3.34297, accuracy: 59.352 %, elapse: 0.068 min
- [Info] The checkpoint file has been updated.
[ Epoch 8 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 5.57317, accuracy: 57.624 %, elapse: 1.062 min
- (Training) : 0%| | 0/157 [00:00<?, ?it/s] - (Validation) ppl: 3.17742, accuracy: 61.413 %, elapse: 0.069 min
- [Info] The checkpoint file has been updated.
[ Epoch 9 ]
- (Validation) : 0%| | 0/32 [00:00<?, ?it/s] - (Training) ppl: 5.40228, accuracy: 58.825 %, elapse: 1.074 min
- (Validation) ppl: 3.16756, accuracy: 60.991 %, elapse: 0.068 min

参考文献