精华内容
下载资源
问答
  • 文章目录基于transformer翻译系统1. 数据处理1.1 英文分词1.2 中文分词1.3 生成字典1.4 数据生成器...

    基于transformer 的翻译系统

    论文:https://arxiv.org/abs/1706.03762
    项目地址:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

    本文实现了一个基于自注意力机制的翻译系统。注意力机制是机制是这两年比较火的方向,其中去年提出的自注意力机制更是各位大神的宠儿,网上可读性较高的代码有一点点不完美的地方就是mask没有发挥作用,最近也在做翻译系统,于是整理本文分享思路。
    本文代码参考网上可读性较好的项目:https://github.com/Kyubyong/transformer
    但是作者在key_mask和queries_mask中有一定的失误,本文修改了对应的模型和multihead层,使该功能正常。

    转载请注明出处:https://blog.csdn.net/chinatelecom08

    1. 数据处理

    本文使用数据:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

    • 读取数据
    • 分别保存为inputs,outputs
    with open('cmn.txt', 'r', encoding='utf8') as f:
        data = f.readlines()
    
    • 1
    • 2
    • 3
    from tqdm import tqdm
    

    inputs = []
    outputs = []
    for line in tqdm(data[:10000]):
    [en, ch] = line.strip(’\n’).split(’\t’)
    inputs.append(en.replace(’,’,’ ,’)[:-1].lower())
    outputs.append(ch[:-1])

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    100%|██████████| 10000/10000 [00:00<00:00, 473991.57it/s]
    
    • 1
    • 查看数据格式
    print(inputs[:10])
    
    • 1
    ['hi', 'hi', 'run', 'wait', 'hello', 'i try', 'i won', 'oh no', 'cheers', 'he ran']
    
    • 1
    print(outputs[:10])
    
    • 1
    ['嗨', '你好', '你用跑的', '等等', '你好', '让我来', '我赢了', '不会吧', '乾杯', '他跑了']
    
    • 1

    1.1 英文分词

    我们将英文用空格隔开即可,但是需要稍微修改一下,将大写字母全部用小写字母代替。在上文中使用.lower进行了替代。

    for line in tqdm(data):
        [en, ch] = line.strip('\n').split('\t')
        inputs.append(en[:-1].lower())
        outputs.append(ch[:-1])
    
    • 1
    • 2
    • 3
    • 4
    • 5

    此处我们只需要将英文用空格分开即可。

    inputs = [en.split(' ') for en in inputs]
    
    • 1
    print(inputs[:10])
    
    • 1
    [['hi'], ['hi'], ['run'], ['wait'], ['hello'], ['i', 'try'], ['i', 'won'], ['oh', 'no'], ['cheers'], ['he', 'ran']]
    
    • 1

    1.2 中文分词

    • 中文分词选择结巴分词工具。
    import jieba
    outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs]
    
    • 1
    • 2
    • 也可以用hanlp。
    from pyhanlp import *
    outputs = [[term.word for term in HanLP.segment(line) if term.word != ' '] for line in outputs]
    
    • 1
    • 2
    • 或者按字分词?

    • 最终我选择了结巴分词

    import jieba
    jieba_outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs[-10:]]
    print(jieba_outputs)
    
    • 1
    • 2
    • 3
    [['你', '不應', '該', '去', '那裡', '的'], ['你', '以前', '吸煙', ',', '不是', '嗎'], ['你現', '在', '最好', '回家'], ['你', '今天', '最好', '不要', '出門'], ['你', '滑雪', '比', '我', '好'], ['你', '正在', '把', '我', '杯子', '里', '的', '东西', '喝掉'], ['你', '并', '不', '满意', ',', '对', '吧'], ['你', '病', '了', ',', '该', '休息', '了'], ['你', '很', '勇敢', ',', '不是', '嗎'], ['你', '的', '意志力', '很強']]
    
    • 1
    outputs = [[char for char in jieba.cut(line) if char != ' '] for line in tqdm(outputs)]
    
    • 1
    100%|██████████| 10000/10000 [00:00<00:00, 11981.68it/s]
    
    • 1

    1.3 生成字典

    将英文和中文映射为id

    def get_vocab(data, init=['<PAD>']):
        vocab = init
        for line in tqdm(data):
            for word in line:
                if word not in vocab:
                    vocab.append(word)
        return vocab
    

    SOURCE_CODES = [’<PAD>’]
    TARGET_CODES = [’<PAD>’, ‘<GO>’, ‘<EOS>’]
    encoder_vocab = get_vocab(inputs, init=SOURCE_CODES)
    decoder_vocab = get_vocab(outputs, init=TARGET_CODES)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    100%|██████████| 10000/10000 [00:00<00:00, 20585.73it/s]
    100%|██████████| 10000/10000 [00:01<00:00, 7808.17it/s]
    
    • 1
    • 2
    print(encoder_vocab[:10])
    print(decoder_vocab[:10])
    
    • 1
    • 2
    ['<PAD>', 'hi', 'run', 'wait', 'hello', 'i', 'try', 'won', 'oh', 'no']
    ['<PAD>', '<GO>', '<EOS>', '嗨', '你好', '你', '用', '跑', '的', '等等']
    
    • 1
    • 2

    1.4 数据生成器

    翻译系统训练所需要的数据形式,跟谷歌gnmt输入致,gnmt的原理可以参考:https://github.com/tensorflow/nmt
    大概是:

    • 编码器输入:I am a student
    • 解码器输入:(go) Je suis étudiant
    • 解码器输出:Je suis étudiant (end)

    即解码器输入起始部分有个开始符号,输出句尾有个结束符号。

    encoder_inputs = [[encoder_vocab.index(word) for word in line] for line in inputs]
    decoder_inputs = [[decoder_vocab.index('<GO>')] + [decoder_vocab.index(word) for word in line] for line in outputs]
    decoder_targets = [[decoder_vocab.index(word) for word in line] + [decoder_vocab.index('<EOS>')] for line in outputs]
    
    • 1
    • 2
    • 3
    print(decoder_inputs[:4])
    print(decoder_targets[:4])
    
    • 1
    • 2
    [[1, 3], [1, 4], [1, 5, 6, 7, 8], [1, 9]]
    [[3, 2], [4, 2], [5, 6, 7, 8, 2], [9, 2]]
    
    • 1
    • 2
    import numpy as np
    

    def get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4):
    batch_num = len(encoder_inputs) // batch_size
    for k in range(batch_num):
    begin = k batch_size
    end = begin + batch_size
    en_input_batch = encoder_inputs[begin:end]
    de_input_batch = decoder_inputs[begin:end]
    de_target_batch = decoder_targets[begin:end]
    max_en_len = max([len(line) for line in en_input_batch])
    max_de_len = max([len(line) for line in de_input_batch])
    en_input_batch = np.array([line + [0] (max_en_len-len(line)) for line in en_input_batch])
    de_input_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_input_batch])
    de_target_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_target_batch])
    yield en_input_batch, de_input_batch, de_target_batch

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4)
    next(batch)
    
    • 1
    • 2
    (array([[1],
            [1],
            [2],
            [3]]), array([[1, 3, 0, 0, 0],
            [1, 4, 0, 0, 0],
            [1, 5, 6, 7, 8],
            [1, 9, 0, 0, 0]]), array([[3, 2, 0, 0, 0],
            [4, 2, 0, 0, 0],
            [5, 6, 7, 8, 2],
            [9, 2, 0, 0, 0]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2. 构建模型

    模型结构如下:

    在这里插入图片描述
    其中主要建模组件下面都会给出。

    论文:https://arxiv.org/abs/1706.03762
    关于论文讲解:百度即可,对着原论文代码一起看。
    我个人觉得结合代码就会很好理解。

    import tensorflow as tf
    
    • 1

    2.1 构造建模组件

    下面代码实现了图片结构中的各个功能组件。

    layer norm层

    在框框的位置。
    在这里插入图片描述

    def normalize(inputs, 
                  epsilon = 1e-8,
                  scope="ln",
                  reuse=None):
        '''Applies layer normalization.
    
    Args:
      inputs: A tensor with 2 or more dimensions, where the first dimension has
        `batch_size`.
      epsilon: A floating number. A very small number for preventing ZeroDivision Error.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
    
    Returns:
      A tensor with the same shape and data dtype as `inputs`.
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        inputs_shape <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span>
        params_shape <span class="token operator">=</span> inputs_shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span>
    
        mean<span class="token punctuation">,</span> variance <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>moments<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> keep_dims<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
        beta<span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
        gamma <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
        normalized <span class="token operator">=</span> <span class="token punctuation">(</span>inputs <span class="token operator">-</span> mean<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span> <span class="token punctuation">(</span>variance <span class="token operator">+</span> epsilon<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token punctuation">(</span><span class="token number">.5</span><span class="token punctuation">)</span> <span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> gamma <span class="token operator">*</span> normalized <span class="token operator">+</span> beta
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    embedding层

    这里值得一提的是本文的position encoding也是用embedding层表示,原论文中说用公式或者embedding层自己训练都可以。
    在这里插入图片描述

    def embedding(inputs, 
                  vocab_size, 
                  num_units, 
                  zero_pad=True, 
                  scale=True,
                  scope="embedding", 
                  reuse=None):
        '''Embeds a given tensor.
        Args:
          inputs: A `Tensor` with type `int32` or `int64` containing the ids
             to be looked up in `lookup table`.
          vocab_size: An int. Vocabulary size.
          num_units: An int. Number of embedding hidden units.
          zero_pad: A boolean. If True, all the values of the fist row (id 0)
            should be constant zeros.
          scale: A boolean. If True. the outputs is multiplied by sqrt num_units.
          scope: Optional scope for `variable_scope`.
          reuse: Boolean, whether to reuse the weights of a previous layer
            by the same name.
        Returns:
          A `Tensor` with one more rank than inputs's. The last dimensionality
            should be `num_units`.
    
    For example,
    
    ```
    import tensorflow as tf
    
    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
    outputs = embedding(inputs, 6, 2, zero_pad=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(outputs)
    &gt;&gt;
    [[[ 0.          0.        ]
      [ 0.09754146  0.67385566]
      [ 0.37864095 -0.35689294]]
     [[-1.01329422 -1.09939694]
      [ 0.7521342   0.38203377]
      [-0.04973143 -0.06210355]]]
    ```
    
    ```
    import tensorflow as tf
    
    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
    outputs = embedding(inputs, 6, 2, zero_pad=False)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(outputs)
    &gt;&gt;
    [[[-0.19172323 -0.39159766]
      [-0.43212751 -0.66207761]
      [ 1.03452027 -0.26704335]]
     [[-0.11634696 -0.35983452]
      [ 0.50208133  0.53509563]
      [ 1.22204471 -0.96587461]]]
    ```
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'lookup_table'</span><span class="token punctuation">,</span>
                                       dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">,</span>
                                       shape<span class="token operator">=</span><span class="token punctuation">[</span>vocab_size<span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">,</span>
                                       initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>contrib<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>xavier_initializer<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> zero_pad<span class="token punctuation">:</span>
            lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>shape<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                      lookup_table<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>embedding_lookup<span class="token punctuation">(</span>lookup_table<span class="token punctuation">,</span> inputs<span class="token punctuation">)</span>
    
        <span class="token keyword">if</span> scale<span class="token punctuation">:</span>
            outputs <span class="token operator">=</span> outputs <span class="token operator">*</span> <span class="token punctuation">(</span>num_units <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span> 
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73

    multihead层

    是self-attention的核心思想,务必把原理搞清楚。
    在这里插入图片描述
    意思是自己跟自己做注意力机制,但是在这之前通过线性变换,将原来的输入映射到8个不同的空间去计算,最后再接到一起。
    在这里插入图片描述
    该层实现了下面功能,给谷歌鼓掌:
    在这里插入图片描述

    def multihead_attention(key_emb,
                            que_emb,
                            queries, 
                            keys, 
                            num_units=None, 
                            num_heads=8, 
                            dropout_rate=0,
                            is_training=True,
                            causality=False,
                            scope="multihead_attention", 
                            reuse=None):
        '''Applies multihead attention.
    
    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q].
      keys: A 3d tensor with shape of [N, T_k, C_k].
      num_units: A scalar. Attention size.
      dropout_rate: A floating point number.
      is_training: Boolean. Controller of mechanism for dropout.
      causality: Boolean. If true, units that reference the future are masked. 
      num_heads: An int. Number of heads.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
        
    Returns
      A 3d tensor with shape of (N, T_q, C)  
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># Set the fall back option for num_units</span>
        <span class="token keyword">if</span> num_units <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
            num_units <span class="token operator">=</span> queries<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>
        
        <span class="token comment"># Linear projections</span>
        Q <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>queries<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
        K <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
        V <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
        
        <span class="token comment"># Split and concat</span>
        Q_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>Q<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, C/h) </span>
        K_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>K<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>
        V_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>V<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>
    
        <span class="token comment"># Multiplication</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>Q_<span class="token punctuation">,</span> tf<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span>K_<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        
        <span class="token comment"># Scale</span>
        outputs <span class="token operator">=</span> outputs <span class="token operator">/</span> <span class="token punctuation">(</span>K_<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Key Masking</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>key_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_k)</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k)</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>queries<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        
        paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
        <span class="token comment"># Causality = Future blinding</span>
        <span class="token keyword">if</span> causality<span class="token punctuation">:</span>
            diag_vals <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
            tril <span class="token operator">=</span> tf<span class="token punctuation">.</span>linalg<span class="token punctuation">.</span>LinearOperatorLowerTriangular<span class="token punctuation">(</span>diag_vals<span class="token punctuation">)</span><span class="token punctuation">.</span>to_dense<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
            masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tril<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
            paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>masks<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
            outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
        <span class="token comment"># Activation</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
         
        <span class="token comment"># Query Masking</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>que_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_q)</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q)</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>keys<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        outputs <span class="token operator">*=</span> query_masks <span class="token comment"># broadcasting. (N, T_q, C)</span>
          
        <span class="token comment"># Dropouts</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> rate<span class="token operator">=</span>dropout_rate<span class="token punctuation">,</span> training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>
               
        <span class="token comment"># Weighted sum</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> V_<span class="token punctuation">)</span> <span class="token comment"># ( h*N, T_q, C/h)</span>
        
        <span class="token comment"># Restore shape</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span> <span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
              
        <span class="token comment"># Residual connection</span>
        outputs <span class="token operator">+=</span> queries
              
        <span class="token comment"># Normalize</span>
        outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91

    feedforward

    两层全连接,用卷积模拟加速运算,也可以使用dense层。你会发现这个框架所需组件全部凑齐了,可以召唤神龙了。
    在这里插入图片描述

    def feedforward(inputs, 
                    num_units=[2048, 512],
                    scope="multihead_attention", 
                    reuse=None):
        '''Point-wise feed forward net.
    
    Args:
      inputs: A 3d tensor with shape of [N, T, C].
      num_units: A list of two integers.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
        
    Returns:
      A 3d tensor with the same shape and dtype as inputs
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># Inner layer</span>
        params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> inputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
                  <span class="token string">"activation"</span><span class="token punctuation">:</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
        
        <span class="token comment"># Readout layer</span>
        params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> outputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
                  <span class="token string">"activation"</span><span class="token punctuation">:</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
        
        <span class="token comment"># Residual connection</span>
        outputs <span class="token operator">+=</span> inputs
        
        <span class="token comment"># Normalize</span>
        outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span>
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    label_smoothing.

    对于训练有好处,将0变为接近零的小数,1变为接近1的数,下面注释很清楚。

    def label_smoothing(inputs, epsilon=0.1):
        '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
    
    Args:
      inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
      epsilon: Smoothing rate.
    
    For example,
    
    ```
    import tensorflow as tf
    inputs = tf.convert_to_tensor([[[0, 0, 1], 
       [0, 1, 0],
       [1, 0, 0]],
      [[1, 0, 0],
       [1, 0, 0],
       [0, 1, 0]]], tf.float32)
       
    outputs = label_smoothing(inputs)
    
    with tf.Session() as sess:
        print(sess.run([outputs]))
    
    &gt;&gt;
    [array([[[ 0.03333334,  0.03333334,  0.93333334],
        [ 0.03333334,  0.93333334,  0.03333334],
        [ 0.93333334,  0.03333334,  0.03333334]],
       [[ 0.93333334,  0.03333334,  0.03333334],
        [ 0.93333334,  0.03333334,  0.03333334],
        [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]   
    ```
    '''</span>
    K <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token comment"># number of channels</span>
    <span class="token keyword">return</span> <span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token operator">-</span>epsilon<span class="token punctuation">)</span> <span class="token operator">*</span> inputs<span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token punctuation">(</span>epsilon <span class="token operator">/</span> K<span class="token punctuation">)</span>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    2.2 搭建模型

    再看一次模型,我们发现里面的组件我们都已经构建好了。
    按照这个结构搭建模型就可以啦!
    在这里插入图片描述
    代码如下:

    class Graph():
        def __init__(self, is_training=True):
            tf.reset_default_graph()
            self.is_training = arg.is_training
            self.hidden_units = arg.hidden_units
            self.input_vocab_size = arg.input_vocab_size
            self.label_vocab_size = arg.label_vocab_size
            self.num_heads = arg.num_heads
            self.num_blocks = arg.num_blocks
            self.max_length = arg.max_length
            self.lr = arg.lr
            self.dropout_rate = arg.dropout_rate
    
        <span class="token comment"># input placeholder</span>
        self<span class="token punctuation">.</span>x <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>y <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>de_inp <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Encoder</span>
        <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"encoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token comment"># embedding</span>
            self<span class="token punctuation">.</span>en_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>input_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"enc_embed"</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                          vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"enc_pe"</span><span class="token punctuation">)</span>
            <span class="token comment">## Dropout</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                        rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                        training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>
    
            <span class="token comment">## Blocks</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>enc <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    
            <span class="token comment">### Feed Forward</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Decoder</span>
        <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"decoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token comment"># embedding</span>
            self<span class="token punctuation">.</span>de_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"dec_embed"</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>dec <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                          vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"dec_pe"</span><span class="token punctuation">)</span>
            <span class="token comment">## Dropout</span>
            self<span class="token punctuation">.</span>dec <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                        rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                        training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>        
    
            <span class="token comment">## Multihead Attention ( self-attention)</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                    scope<span class="token operator">=</span><span class="token string">'self_attention'</span><span class="token punctuation">)</span>
    
            <span class="token comment">## Multihead Attention ( vanilla attention)</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                    scope<span class="token operator">=</span><span class="token string">'vanilla_attention'</span><span class="token punctuation">)</span> 
    
            <span class="token comment">### Feed Forward</span>
            self<span class="token punctuation">.</span>outputs <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
                
        <span class="token comment"># Final linear projection</span>
        self<span class="token punctuation">.</span>logits <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>self<span class="token punctuation">.</span>outputs<span class="token punctuation">,</span> self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>preds <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_int32<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>istarget <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>not_equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>acc <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>preds<span class="token punctuation">,</span> self<span class="token punctuation">.</span>y<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
        tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'acc'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>acc<span class="token punctuation">)</span>
                
        <span class="token keyword">if</span> is_training<span class="token punctuation">:</span>  
            <span class="token comment"># Loss</span>
            self<span class="token punctuation">.</span>y_smoothed <span class="token operator">=</span> label_smoothing<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> depth<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax_cross_entropy_with_logits_v2<span class="token punctuation">(</span>logits<span class="token operator">=</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> labels<span class="token operator">=</span>self<span class="token punctuation">.</span>y_smoothed<span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>mean_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>loss<span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
               
            <span class="token comment"># Training Scheme</span>
            self<span class="token punctuation">.</span>global_step <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'global_step'</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>optimizer <span class="token operator">=</span> tf<span class="token punctuation">.</span>train<span class="token punctuation">.</span>AdamOptimizer<span class="token punctuation">(</span>learning_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>lr<span class="token punctuation">,</span> beta1<span class="token operator">=</span><span class="token number">0.9</span><span class="token punctuation">,</span> beta2<span class="token operator">=</span><span class="token number">0.98</span><span class="token punctuation">,</span> epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>train_op <span class="token operator">=</span> self<span class="token punctuation">.</span>optimizer<span class="token punctuation">.</span>minimize<span class="token punctuation">(</span>self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">,</span> global_step<span class="token operator">=</span>self<span class="token punctuation">.</span>global_step<span class="token punctuation">)</span>
                   
            <span class="token comment"># Summary </span>
            tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'mean_loss'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>merged <span class="token operator">=</span> tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>merge_all<span class="token punctuation">(</span><span class="token punctuation">)</span>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111

    3. 训练模型

    用我们搭建好的模型,和准备好的数据进行训练!

    3.1 参数设定

    def create_hparams():
        params = tf.contrib.training.HParams(
            num_heads = 8,
            num_blocks = 6,
            # vocab
            input_vocab_size = 50,
            label_vocab_size = 50,
            # embedding size
            max_length = 100,
            hidden_units = 512,
            dropout_rate = 0.2,
            lr = 0.0003,
            is_training = True)
        return params
    

    arg = create_hparams()
    arg.input_vocab_size = len(encoder_vocab)
    arg.label_vocab_size = len(decoder_vocab)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    3.2 模型训练

    import os
    

    epochs = 25
    batch_size = 64

    g = Graph(arg)

    saver =tf.train.Saver()
    with tf.Session() as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    if os.path.exists(‘logs/model.meta’):
    saver.restore(sess, ‘logs/model’)
    writer = tf.summary.FileWriter(‘tensorboard/lm’, tf.get_default_graph())
    for k in range(epochs):
    total_loss = 0
    batch_num = len(encoder_inputs) // batch_size
    batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size)
    for i in tqdm(range(batch_num)):
    encoder_input, decoder_input, decoder_target = next(batch)
    feed = {g.x: encoder_input, g.y: decoder_target, g.de_inp:decoder_input}
    cost,_ = sess.run([g.mean_loss,g.train_op], feed_dict=feed)
    total_loss += cost
    if (k batch_num + i) % 10 == 0:
    rs=sess.run(merged, feed_dict=feed)
    writer.add_summary(rs, k batch_num + i)
    if (k+1) % 5 == 0:
    print(‘epochs’, k+1, ': average loss = ', total_loss/batch_num)
    saver.save(sess, ‘logs/model’)
    writer.close()

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    100%|██████████| 156/156 [00:31<00:00,  6.19it/s]
    100%|██████████| 156/156 [00:24<00:00,  5.83it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.23it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.11it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.14it/s]
    

    epochs 5 : average loss = 3.3463863134384155

    100%|██████████| 156/156 [00:23<00:00, 6.27it/s]
    100%|██████████| 156/156 [00:23<00:00, 5.86it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.33it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.08it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.29it/s]

    epochs 10 : average loss = 2.0142565186207113

    100%|██████████| 156/156 [00:24<00:00, 6.18it/s]
    100%|██████████| 156/156 [00:24<00:00, 5.84it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.38it/s]

    epochs 15 : average loss = 1.5278632457439716

    100%|██████████| 156/156 [00:24<00:00, 6.15it/s]
    100%|██████████| 156/156 [00:24<00:00, 5.86it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.23it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.13it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.32it/s]

    epochs 20 : average loss = 1.4216684783116365

    100%|██████████| 156/156 [00:23<00:00, 6.26it/s]
    100%|██████████| 156/156 [00:23<00:00, 5.89it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.26it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.35it/s]

    epochs 25 : average loss = 1.3833287457625072

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    3.3 模型推断

    输入几条拼音测试一下效果如何:

    arg.is_training = False
    

    g = Graph(arg)

    saver =tf.train.Saver()

    with tf.Session() as sess:
    saver.restore(sess, ‘logs/model’)
    while True:
    line = input(‘输入测试拼音: ‘)
    if line ‘exit’: break
    line = line.lower().replace(’,’, ’ ,’).strip(’\n’).split(’ ‘)
    x = np.array([encoder_vocab.index(pny) for pny in line])
    x = x.reshape(1, -1)
    de_inp = [[decoder_vocab.index(’<GO>’)]]
    while True:
    y = np.array(de_inp)
    preds = sess.run(g.preds, {g.x: x, g.de_inp: y})
    if preds[0][-1] decoder_vocab.index(’<EOS>’):
    break
    de_inp[0].append(preds[0][-1])
    got = ‘’.join(decoder_vocab[idx] for idx in de_inp[0][1:])
    print(got)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    INFO:tensorflow:Restoring parameters from logs/model
    输入测试拼音: You could be right, I suppose
    我猜想你可能是对的
    输入测试拼音: You don't believe Tom, do you
    你不信任汤姆,对吗
    输入测试拼音: Tom has lived here since 2003
    汤姆自从2003年就住在这里
    输入测试拼音: Tom asked if I'd found my key
    湯姆問我找到我的鑰匙了吗
    输入测试拼音: They have a very nice veranda
    他们有一个非常漂亮的暖房
    输入测试拼音: She was married to a rich man
    她嫁給了一個有錢的男人
    输入测试拼音: My parents sent me a postcard
    我父母給我寄了一張明信片
    输入测试拼音: Just put yourself in my shoes
    你站在我的立場上考慮看看
    输入测试拼音: It was a very stupid decision
    这是一个十分愚蠢的决定
    输入测试拼音: I'm really sorry to hear that
    听到这样的消息我真的很难过
    输入测试拼音: His wife is one of my friends
    他的妻子是我的一個朋友
    输入测试拼音: He thought of a good solution
    他想到了一個解決的好辦法
    输入测试拼音: exit
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    结果果然不错,训练速度也是比基于rnn的encoder decoder结构快很多,不得不说谷歌真棒啊。

    转载请注明出处:https://blog.csdn.net/chinatelecom08

    同学们喜欢的话给我项目点个星吧!
    https://github.com/audier

                                    </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/chinatelecom08/article/details/85068059&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div>
                <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
                                </div>
            </article>
    
    展开全文
  • 对于机器翻译而言,分为源语言与目的语言(如英文与中文) 对源语言建立词典,大小为src_vocab_size (包括padding) 对目标语言建立词典,大小为tgt_vocab_size (包括padding) 在词典中对词进行排序(放好位置就行...

    transformer的核心是self-attention,self-attention可参考:

    1. 图解transformer
    2. 李宏毅老师的transform

    一、机器翻译

    对于机器翻译而言,分为源语言与目的语言(如英文与中文)
    对源语言建立词典,大小为src_vocab_size (包括padding)
    对目标语言建立词典,大小为tgt_vocab_size (包括padding)

    在词典中对词进行排序(放好位置就行),按照下标给词一个数字表示(0一定表示padding,得规定好)

    训练集为源语言与目标语言的一一对应关系,但是

    1. 源语言与目标语言对应的句子的长度很可能不一致
    2. 源语言与目标语言中所有句子可能有长有短

    二、数据处理

    1、batch

    模型中要通过矩阵加速运算,如果是一对句子一对句子的输入模型进行训练就行,但是训练是需要分批输入到模型中的。
    所以一个bacth中的句子需要是等长度的所以需要padding,即填充0。
    以下为一个batch的处理过程:
    在这里插入图片描述
    最终一个batch的输入格式为(B,L)
    B为batch_size, L为本次batch的输入的句子的单词个数

    2、embedding

    源语言与目标语言词典中的每一个词都可以转换为一个dmodeld_{model}长度的向量,词向量可以预训练得来,也可以加入到此网络中进行训练。词向量类似与一个矩阵:
    在这里插入图片描述

    源语言的词向量矩阵如上所示,每一行为一个词向量,可以用每一个词的代表数字表示当成行标直接去取对应的词向量

    最终将一个batch的输入转换为(B,L,d_model)
    B为batch_size, L为本次batch的输入的句子的单词个数,d_model为词向量长度

    3、pad Mask

    加入了padding,但是padding是不能加入到self-attention的计算中的, 这意味着什么呢?
    思考一下self-attention中会产生一个单词与单词之间的关系矩阵,每个单词的query向量与所有单词的key向量点乘再除以向量长度取根号,之后经过softmax层归一化将其变为都是正数并且加起来等于 1。 (Scaled Dot-Product attention)
    在这里插入图片描述
    最终形成的矩阵格式为(L,L),表示一个句子的各个词之间的关系分数。
    第一行就表示第一个单词与其他所有单词的关系分数,每一行的和为1
    在这里插入图片描述

    但是我们说单词与padding之间是没有关系的,即单词与padding之间的关系分为为0,那把softmax之后分数直接置为0
    在这里插入图片描述
    不行,因为行的和不再为1,这样就不合理了,所以我们在softmax之前做操作。

    回忆一下softmax的公式:
    在这里插入图片描述
    要让某个值为0,只需要让ziz_i为负无穷即可,所以只需要让softmax之前的矩阵的padding列全为无穷即可,softmax之前的矩阵如下所示:
    在这里插入图片描述
    注:
    为什么不考虑padding行呢?
    我的理解是padding行最终也会输出一个向量,但是这个向量不会参数到self-attention的计算中,所以不用管它

    三、encoder

    采用论文中的描述
    “编码器由N = 6个相同的层组成。每层有两个子层。
    第一个是一个多头自注意机制
    第二个是一个简单的,位置上完全连接的前馈网络。
    我们在两个子层周围使用一个残差连接,然后是层规范化。
    也就是说,每个子层的输出是LayerNorm(x +子层(x)),其中子层(x)是由子层本身实现的函数。
    为了方便这些残差连接,模型中的所有子层以及嵌入层产生的输出维数dmodel = 512。”
    如下所示:
    在这里插入图片描述

    一层encoder(图中标识为N层)

    1、 输入

    (1)输入矩阵:(B,L)

    (2)单词embedding:(B,L,d_model)

    (3)加上position_embedding:(B,L,d_model)

    2、multi-Head attention

    注:论文中dq=dk=dv=dmodel/nheadsd_q = d_k = d_v = d_{model} / n_{heads}

    (1)转换为Q,K,V矩阵

    (B,L,dmodel)>(B,L,dqnheads)(B,L,d_{model}) -> (B,L,d_q * n_{heads}) Q矩阵
    (B,L,dmodel)>(B,L,dknheads)(B,L,d_{model}) -> (B,L,d_k * n_{heads}) K矩阵
    (B,L,dmodel)>(B,L,dvnheads)(B,L,d_{model}) -> (B,L,d_v * n_{heads}) V矩阵
    一个线性层nn.Linear就可以完成

    (2)对每个注意力头进行soft-atttention计算

    (1)把头提出来
    (B,L,dqnheads)>(B,nheads,L,dq)(B,L,d_q * n_{heads}) -> (B,n_{heads},L,d_q) Q矩阵
    (B,L,dknheads)>(B,nheads,L,dk)(B,L,d_k * n_{heads}) -> (B,n_{heads},L,d_k) K矩阵
    (B,L,dqnheads)>(B,nheads,L,dv)(B,L,d_q * n_{heads}) -> (B,n_{heads},L,d_v ) V矩阵

    (2)Scaled Dot-Product Attention
    <1>(QKT)/dk(Q * K^T)/ \sqrt{d_k} -> (B,nheads,L,L))(B,n_{heads},L,L))
    (KTK^T 对最后两位进行转置,变为(B,nheads,dk,L))(B,n_{heads},d_k,L))
    <2> mask结果,将padding列变为-inf
    <3>对每一行进行softmax
    <4>乘以V矩阵 -> (B,nheads,L,dv))(B,n_{heads},L,d_v))

    (3)多头融合

    <1> (B,nheads,L,dv))(B,n_{heads},L,d_v)) -> (B,L,nheadsdv))(B,L,n_{heads} * d_v))
    <2> (B,L,nheadsdv))(B,L,n_{heads} * d_v)) -> (B,L,dmodel))(B,L,d_{model})) (线性层)

    3、残差连接+layerNorm

    (1)模型输入与多头注意力层的输出格式相同: (B,L,dmodel))(B,L,d_{model})) 直接相加

    (2)layerNorm

    4、Position-wise Feed-Forward

    (1)全连接层 + relu

    (B,L,dmodel))(B,L,d_{model})) -> (B,L,dff))(B,L,d_{ff}))

    (2)全连接层

    (B,L,dff))(B,L,d_{ff})) -> (B,L,dmodel))(B,L,d_{model}))

    5、残差连接+layerNorm

    (1)模型输入与多头注意力层的输出格式相同: (B,L,dmodel))(B,L,d_{model})) 直接相加

    (2)layerNorm

    最终encoder输出格式为(B,L,dmodel)(B,L,d_{model})

    四、decoder

    在这里插入图片描述
    decoder与encoder主要有两个部分不同

    1、subsequence_mask

    之前描述了pad_mask,是由于真正的单词与padding之间不能有关系,关系系数为0。
    subsequence_mask是针对目标语言的输入embedding而言的,在self-attention时我们要算各个词之间的关系分数。考虑一下利用RNN的翻译模型,目标语言是按顺序一个一个输入进网络中,即前面的词不知道后面的词,即前面的词不能算跟后面的词之间的关系,但是后面的词可以跟前面算关系。mask矩阵如下,
    在这里插入图片描述
    而且padding的mask也要考虑,最终变为
    在这里插入图片描述

    2、encoder-decoder attention

    在这里插入图片描述
    这一部分输入与之前不同,之前是这样的
    (B,L,dmodel)>(B,L,dqnheads)(B,L,d_{model}) -> (B,L,d_q * n_{heads}) Q矩阵
    (B,L,dmodel)>(B,L,dknheads)(B,L,d_{model}) -> (B,L,d_k * n_{heads}) K矩阵
    (B,L,dmodel)>(B,L,dvnheads)(B,L,d_{model}) -> (B,L,d_v * n_{heads}) V矩阵

    同一个输入,转换为Q,K,V矩阵

    现在变为这样
    (B,Loutputs,dmodel)(outputs)>(B,Loutputs,dqnheads)(B,L_{outputs},d_{model}) (来自outputs) -> (B,L_{outputs},d_q * n_{heads}) Q矩阵
    (B,Linputs,dmodel)(intputs)>(B,Linputs,dknheads)(B,L_{inputs},d_{model}) (来自intputs) -> (B,L_{inputs},d_k * n_{heads}) K矩阵
    (B,Linputs,dmodel)(intputs)>(B,Linputs,dvnheads)(B,L_{inputs},d_{model}) (来自intputs) -> (B,L_{inputs},d_v * n_{heads}) V矩阵

    query矩阵来自outputs,key矩阵、value矩阵来自inputs,相当于decoder从encoder的输出中选取数据来完成自己的输出,与sequence to sequence模型有异曲同工之妙。

    最终Linear层之前的输出为 (B,Loutputs,dmodel)(B,L_{outputs},d_{model})

    3、翻译输出

    linear之后softmax
    (B,Loutputs,dmodel)(B,L_{outputs},d_{model}) -> (B,Loutputs,tgt_vocab_size)(B,L_{outputs},tgt\_vocab\_size)
    选取每个位置概率最大的单词输出,即为翻译

    展开全文
  • 机翻后手动修改,排版按照原论文格式,格式上有些地方有点问题。
  • Transformer详解

    2020-01-06 15:45:27
    1.The Illustrated Transformer中文翻译版 2. The Illustrated Transformer(配合李宏毅老师的视频看,很直观明了) 2.草稿纸上的Transformer 3.放弃幻想,全面拥抱Transformer:自然语言处理三大特征抽取器(CNN...
    展开全文
  • transformer

    2020-03-14 22:08:54
    参考文章链接:transformer介绍 这篇文章介绍的比较详细,里面也有中文翻译链接,自己在这做一下链接记录。

    参考文章链接:transformer介绍
    这篇文章介绍的比较详细,里面也有中文的翻译链接,自己在这做一下链接记录。

    展开全文
  • A Transformer-based Approach for Source Code Summarization 全文翻译 本文最佳阅读方式:读完一段中文内容快速阅读对应的英文部分 欢迎关注我的公众号:NLP小讲堂,扫码第一时间获取更多最新的文章。 本文来自ACL...
  • BERT中文翻译

    2020-07-24 10:42:40
    摘要:介绍了一种新的语言模型表示BERT,它代表transformer的双向编码器表示。与最近的表示模型不同(Peters等人,2018;Radford等人,2018),BERT被用来设计成通过所有层中联合调节左右上下文来预训练来自未标记...
  • Bert 论文中文翻译

    2019-07-26 15:29:37
    BERT:预训练的深度双向 Transformer 语言模型 Jacob Devlin;Ming-Wei Chang;Kenton Lee;Kristina Toutanova Google AI Language {jacobdevlin,mingweichang,kentonl,kristout}@google.com 摘要 我们提出了一...
  • transformer理解

    2019-03-07 18:22:01
    原论文:https://jalammar.github.io/illustrated-transformer/ 中文翻译:https://blog.csdn.net/qq_41664845/article/details/84969266
  • 图解 Transformer

    2020-08-25 12:59:44
    由于看过一些中文翻译的文章,感觉不够好,所以我自己翻译了一个版本,在一些难以直译的地方,我加入了一些原文没有的文字说明,来更好地解释概念。另外,我添加了一些简单的代码,实现了一个基本的 self-attention ...
  • 文章目录1 机器翻译及相关技术1.1 机器翻译基本原理1.2 Encoder-Decoder1.3 Sequence to Sequence模型1.4 Beam Search2 注意力机制与Seq2seq模型2.1 注意力机制2.2 注意力机制的计算函数介绍2.3 引入注意力机制的Seq...
  • 摘要 1.介绍 2.中文语法错误检测 3.方法 3.1 错误检测 3.2 错误纠正
  • Transformer全面详解

    千次阅读 2020-07-12 20:49:41
    Transformer输入是一个序列数据,以"Tom chase Jerry" 翻译中文"汤姆追逐杰瑞"为例: Encoder 的 inputs就是"Tom chase Jerry" 分词后的词向量。可以是任意形式的词向量,如word2vec,GloVe,one-hot编码。 假设...
  • CSDN上又博主(于建民)对其进行了很好的中文翻译: 中文版:The Illustrated Transformer【译】 Google AI blog写的一篇简述可以作为科普文: Transformer: A Novel Neural Network Architecture for Language ...
  • Transformer的学习

    2019-06-01 10:13:42
    入门Transformer的可以参考以下...中文翻译版本参考:https://zhuanlan.zhihu.com/p/54356280 进阶一下,参考哈佛大学NLP研究组写的:http://nlp.seas.harvard.edu/2018/04/03/attention.html 代码原理双管齐下。 ...
  • Transformer -- 笔记

    2019-08-21 15:45:22
    BERT 就是 UNsupervised trained ...比如训练一个中英翻译模型,只收集中文语料就可以了,并不需要输入对应翻译之后的英文文本。 Transformer 就是 用到了大量 Self-attention layer 的 Seq2seq model...
  • 审稿人:Jepson,Datawhale成员,毕业于中国科学院,目前在...由于看过一些中文翻译的文章,感觉不够好,所以我自己翻译了一个版本,在一些难以直译的地方,我加入了一些原文没有的文字说明,来更好地解释概念。另外,
  • 使用Excel通俗易懂理解Transformer! 文哥的学习日记关注 ...假设我们在做一个从中文翻译到英文的过程,我们的词表很简单如下: 中文词表:[机、器、学、习] 英文词表[deep、machine、learning、chinese] 先...
  • 假设我们在做一个从中文翻译到英文的过程,我们的词表很简单如下: 中文词表:[机、器、学、习] 英文词表[deep、machine、learning、chinese] 先来看一下Transformer的整个过程: 接下来,我们将按顺序来讲解...
  • 本文旨在通过最通俗易懂的过程来详解Transformer的每个步骤!假设我们在做一个从中文翻译到英文的过程,我们的词表很简单如下:中文词表:[机、器、学、习] 英文词表[...
  • Attention用于计算"相关程度", 例如在翻译过程中,不同的英文对中文的依赖程度不同,Attention通常可以进行如下描述,表示为将query(Q)和key-value pairs映射到输出上,其中query、每个key、每个value都是向量,输出...
  • 自定义训练数据训练transformer,实现中文到英文的翻译 环境: tensorflow 1.14 python 3.6.x tensor2tensor 2.自定义数据训练Tensor2Tensor 2.1 自定义一个用户目录(参数--t2t_usr_dir的值) 该目录下主要存放以下...
  • 文章主要内容概览:1. Seq2Seq以及注意力机制Seq2Seq 任务指的是输入和输出都是序列的任务。例如说英语翻译中文。1.1 Encoder-Decoder模型与Seq2Seq的...
  •   鉴于最近BERT在人工智能领域特别火,但相关中文资料却很少,因此将BERT论文理论部分(1-3节)翻译中文以方便大家后续研究。 · 摘要  本文主要介绍一个名为BERT的模型。与现有语言模型不同的是,BERT...

空空如也

空空如也

1 2 3 4
收藏数 62
精华内容 24
关键字:

transformer中文翻译