精华内容
下载资源
问答
  • transformer翻译(英–>中) transformer结构 具体原理可以参考这篇文章Transformer 数据格式 因为我们的中文数据是繁体字,因此需将其转换为简体: import copy import math import matplotlib.pyplot as ...

    Transformer结构

    在这里插入图片描述
    具体原理可以参考这篇文章Transformer

    数据格式

    在这里插入图片描述
    因为我们的中文数据是繁体字,因此需将其转换为简体:

    import copy
    import math
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    import seaborn as sns
    import time
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    from collections import Counter
    from langconv import Converter
    from nltk import word_tokenize
    from torch.autograd import Variable
    
    def cht_to_chs(sent):
        sent = Converter("zh-hans").convert(sent)
        sent.encode("utf-8")
        return sent
    
    

    转换完数据之后,我们需要将每句话按字粒度切分开,并构建词表,然后将单词映射为索引,并按照数据长度划分批次

    class PrepareData:
        def __init__(self, train_file, dev_file):
            # 读取数据、分词
            self.train_en, self.train_cn = self.load_data(train_file)
            self.dev_en, self.dev_cn = self.load_data(dev_file)
            # 构建词表
            self.en_word_dict, self.en_total_words, self.en_index_dict = \
                self.build_dict(self.train_en)
            self.cn_word_dict, self.cn_total_words, self.cn_index_dict = \
                self.build_dict(self.train_cn)
            # 单词映射为索引
            self.train_en, self.train_cn = self.word2id(self.train_en, self.train_cn, self.en_word_dict, self.cn_word_dict)
            self.dev_en, self.dev_cn = self.word2id(self.dev_en, self.dev_cn, self.en_word_dict, self.cn_word_dict)
            # 划分批次、填充、掩码
            self.train_data = self.split_batch(self.train_en, self.train_cn, BATCH_SIZE)
            self.dev_data = self.split_batch(self.dev_en, self.dev_cn, BATCH_SIZE)
    
        def load_data(self, path):
            """
            读取英文、中文数据
            对每条样本分词并构建包含起始符和终止符的单词列表
            形式如:en = [['BOS', 'i', 'love', 'you', 'EOS'], ['BOS', 'me', 'too', 'EOS'], ...]
                    cn = [['BOS', '我', '爱', '你', 'EOS'], ['BOS', '我', '也', '是', 'EOS'], ...]
            """
            en = []
            cn = []
            with open(path, mode="r", encoding="utf-8") as f:
                for line in f.readlines():
                    sent_en, sent_cn = line.strip().split("\t")
                    sent_en = sent_en.lower()
                    sent_cn = cht_to_chs(sent_cn)
                    sent_en = ["BOS"] + word_tokenize(sent_en) + ["EOS"]
                    # 中文按字符切分
                    sent_cn = ["BOS"] + [char for char in sent_cn] + ["EOS"]
                    en.append(sent_en)
                    cn.append(sent_cn)
            return en, cn
    
        def build_dict(self, sentences, max_words=5e4):
            """
            构造分词后的列表数据
            构建单词-索引映射(key为单词,value为id值)
            """
            # 统计数据集中单词词频
            word_count = Counter([word for sent in sentences for word in sent])
            # 按词频保留前max_words个单词构建词典
            # 添加UNK和PAD两个单词
            ls = word_count.most_common(int(max_words))
            total_words = len(ls) + 2
            word_dict = {w[0]: index + 2 for index, w in enumerate(ls)}
            word_dict['UNK'] = UNK
            word_dict['PAD'] = PAD
            # 构建id2word映射
            index_dict = {v: k for k, v in word_dict.items()}
            return word_dict, total_words, index_dict
    
        def word2id(self, en, cn, en_dict, cn_dict, sort=True):
            """
            将英文、中文单词列表转为单词索引列表
            `sort=True`表示以英文语句长度排序,以便按批次填充时,同批次语句填充尽量少
            """
            length = len(en)
            # 单词映射为索引
            out_en_ids = [[en_dict.get(word, UNK) for word in sent] for sent in en]
            out_cn_ids = [[cn_dict.get(word, UNK) for word in sent] for sent in cn]
    
            # 按照语句长度排序
            def len_argsort(seq):
                """
                传入一系列语句数据(分好词的列表形式),
                按照语句长度排序后,返回排序后原来各语句在数据中的索引下标
                """
                return sorted(range(len(seq)), key=lambda x: len(seq[x]))
    
            # 按相同顺序对中文、英文样本排序
            if sort:
                # 以英文语句长度排序
                sorted_index = len_argsort(out_en_ids)
                out_en_ids = [out_en_ids[idx] for idx in sorted_index]
                out_cn_ids = [out_cn_ids[idx] for idx in sorted_index]
            return out_en_ids, out_cn_ids
    
        def split_batch(self, en, cn, batch_size, shuffle=True):
            """
            划分批次
            `shuffle=True`表示对各批次顺序随机打乱
            """
            # 每隔batch_size取一个索引作为后续batch的起始索引
            idx_list = np.arange(0, len(en), batch_size)
            # 起始索引随机打乱
            if shuffle:
                np.random.shuffle(idx_list)
            # 存放所有批次的语句索引
            batch_indexs = []
            for idx in idx_list:
                """
                形如[array([4, 5, 6, 7]), 
                     array([0, 1, 2, 3]), 
                     array([8, 9, 10, 11]),
                     ...]
                """
                # 起始索引最大的批次可能发生越界,要限定其索引
                batch_indexs.append(np.arange(idx, min(idx + batch_size, len(en))))
            # 构建批次列表
            batches = []
            for batch_index in batch_indexs:
                # 按当前批次的样本索引采样
                batch_en = [en[index] for index in batch_index]
                batch_cn = [cn[index] for index in batch_index]
                # 对当前批次中所有语句填充、对齐长度
                # 维度为:batch_size * 当前批次中语句的最大长度
                batch_cn = seq_padding(batch_cn)
                batch_en = seq_padding(batch_en)
                # 将当前批次添加到批次列表
                # Batch类用于实现注意力掩码
                batches.append(Batch(batch_en, batch_cn))
            return batches
    

    模型结构

    好了,接下来正式进入transformer部分
    在这里插入图片描述
    在这里插入图片描述

    首先我们把输入的单词转为词向量,它包括token embedding和position embedding两层,编码之后的词向量再分别的流向encoder里面的两层网络。

    Embedding

    class Embeddings(nn.Module):
        def __init__(self, d_model, vocab):
            super(Embeddings, self).__init__()
            # Embedding层
            self.lut = nn.Embedding(vocab, d_model)
            # Embedding维数
            self.d_model = d_model
    
        def forward(self, x):
            # 返回x的词向量(需要乘以math.sqrt(d_model))
            return self.lut(x) * math.sqrt(self.d_model)
    

    位置编码

    首先一个问题,为啥要进行位置编码呢。原因在于self-attention,将任意两个字之间距离缩小为1,丢失了字的位置信息,故我们需要加上这一信息。我们也可以想到两种方法

    1.固定编码。

    Transformer采用了这一方式,通过奇数列cos函数,偶数列sin函数方式,利用三角函数对位置进行固定编码。

    固定编码方式简洁,不需要训练。且不受embedding table维度影响,理论上可以支持任意长度文本。(但要尽量避免预测文本很长,但训练集文本较短的case)

    2.动态训练。

    BERT采用了这种方式。先随机初始化一个embedding table,然后训练得到table 参数值。predict时通过embedding_lookup找到每个位置的embedding。这种方式和token embedding类似。

    动态训练方式,在语料比较大时,准确度比较好。但需要训练,且最致命的是,限制了输入文本长度。当文本长度大于position embedding table维度时,超出的position无法查表得到embedding(可以理解为OOV了)。这也是为什么BERT模型文本长度最大512的原因。

    Position Encoding

    position encoding直接采用了三角函数。对偶数列采用sin,奇数列采用cos。
    在这里插入图片描述

    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, dropout, max_len=5000):
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(p=dropout)
            # 位置编码矩阵,维度[max_len, embedding_dim]
            pe = torch.zeros(max_len, d_model, device=DEVICE)
            # 单词位置
            position = torch.arange(0.0, max_len, device=DEVICE)
            position.unsqueeze_(1)
            # 使用exp和log实现幂运算
            div_term = torch.exp(torch.arange(0.0, d_model, 2, device=DEVICE) * (- math.log(1e4) / d_model))
            div_term.unsqueeze_(0)
            # 计算单词位置沿词向量维度的纹理值
            pe[:, 0 : : 2] = torch.sin(torch.mm(position, div_term))
            pe[:, 1 : : 2] = torch.cos(torch.mm(position, div_term))
            # 增加批次维度,[1, max_len, embedding_dim]
            pe.unsqueeze_(0)
            # 将位置编码矩阵注册为buffer(不参加训练)
            self.register_buffer('pe', pe)
    
        def forward(self, x):
            # 将一个批次中语句所有词向量与位置编码相加
            # 注意,位置编码不参与训练,因此设置requires_grad=False
            x += Variable(self.pe[:, : x.size(1), :], requires_grad=False)
            return self.dropout(x)
    

    Encoder 结构

    Self-Attention

    在这里插入图片描述

    def attention(query, key, value, mask=None, dropout=None):
        """
        Scaled Dot-Product Attention
        """
        # q、k、v向量长度为d_k
        d_k = query.size(-1)
        # 矩阵乘法实现q、k点积注意力,sqrt(d_k)归一化
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        # 注意力掩码机制
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        # 注意力矩阵softmax归一化
        p_attn = F.softmax(scores, dim=-1)
        # dropout
        if dropout is not None:
            p_attn = dropout(p_attn)
        # 注意力对v加权
        return torch.matmul(p_attn, value), p_attn
    

    Multi-Head Attention

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    MultiHeadedAttention采用多头self-attention。它先将隐向量切分为h个头,然后每个头内部进行self-attention计算,最后再concat再一起。
    这样做是为了获取语义的多层信息,最后再拼接到一起,得到的输出就包含了输入的多层信息。

    
    def clones(module, N):
        """
        克隆基本单元,克隆的单元之间参数不共享
        """
        return nn.ModuleList([
            copy.deepcopy(module) for _ in range(N)
        ])
    
    
    class MultiHeadedAttention(nn.Module):
        """
        Multi-Head Attention
        """
        def __init__(self, h, d_model, dropout=0.1):
            super(MultiHeadedAttention, self).__init__()
            """
            `h`:注意力头的数量
            `d_model`:词向量维数
            """
            # 确保整除
            assert d_model % h == 0
            # q、k、v向量维数
            self.d_k = d_model // h
            # 头的数量
            self.h = h
            # WQ、WK、WV矩阵及多头注意力拼接变换矩阵WO
            self.linears = clones(nn.Linear(d_model, d_model), 4)
            self.attn = None
            self.dropout = nn.Dropout(p=dropout)
    
        def forward(self, query, key, value, mask=None):
            if mask is not None:
                mask = mask.unsqueeze(1)
            # 批次大小
            nbatches = query.size(0)
            # WQ、WK、WV分别对词向量线性变换,并将结果拆成h块
            query, key, value = [
                l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                for l, x in zip(self.linears, (query, key, value))
            ]
            # 注意力加权
            x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
            # 多头注意力加权拼接
            x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
            # 对多头注意力加权拼接结果线性变换
            return self.linears[-1](x)
    
    

    Add & Norm

    在这里插入图片描述

    Add & Norm 层由 Add 和 Norm 两部分组成,其计算公式如下:
    在这里插入图片描述
    其中 X表示 Multi-Head Attention 或者 Feed Forward 的输入,MultiHeadAttention(X)FeedForward(X) 表示输出 (输出与输入 X 维度是一样的,所以可以相加)。

    AddX+MultiHeadAttention(X),是一种残差连接,通常用于解决多层网络训练的问题,可以让网络只关注当前差异的部分,在 ResNet 中经常用到。
    残差连接
    NormLayer Normalization,通常用于 RNN 结构,Layer Normalization 会将每一层神经元的输入都转成均值方差都一样的,这样可以加快收敛。

    class SublayerConnection(nn.Module):
        """
        通过层归一化和残差连接,连接Multi-Head Attention和Feed Forward
        """
        def __init__(self, size, dropout):
            super(SublayerConnection, self).__init__()
            self.norm = LayerNorm(size)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x, sublayer):
            # 层归一化
            x_ = self.norm(x)
            x_ = sublayer(x_)
            x_ = self.dropout(x_)
            # 残差连接
            return x + x_
    

    Feed Forward

    Feed Forward 层比较简单,是一个两层的全连接层,第一层的激活函数为 Relu,第二层不使用激活函数,对应的公式如下:
    在这里插入图片描述
    X是输入,Feed Forward 最终得到的输出矩阵的维度与 X 一致。

    class PositionwiseFeedForward(nn.Module):
        def __init__(self, d_model, d_ff, dropout=0.1):
            super(PositionwiseFeedForward, self).__init__()
            self.w_1 = nn.Linear(d_model, d_ff)
            self.w_2 = nn.Linear(d_ff, d_model)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x):
            x = self.w_1(x)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.w_2(x)
            return x
    
    

    Encoder Layer

    通过上面描述的 Multi-Head Attention, Feed Forward, Add & Norm 就可以构造出一个 Encoder block,Encoder block 接收输入矩阵 X(n×d),并输出一个矩阵 O(n×d)。通过多个 Encoder block 叠加就可以组成 Encoder。
    第一个 Encoder block 的输入为句子单词的表示向量矩阵,后续 Encoder block 的输入是前一个 Encoder block 的输出,最后一个 Encoder block 输出的矩阵就是 编码信息矩阵 C,这一矩阵后续会用到 Decoder 中。

    class EncoderLayer(nn.Module):
        def __init__(self, size, self_attn, feed_forward, dropout):
            super(EncoderLayer, self).__init__()
            self.self_attn = self_attn
            self.feed_forward = feed_forward
            # SublayerConnection作用连接multi和ffn
            self.sublayer = clones(SublayerConnection(size, dropout), 2)
            # d_model
            self.size = size
    
        def forward(self, x, mask):
            # 将embedding层进行Multi head Attention
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
            # attn的结果直接作为下一层输入
            return self.sublayer[1](x, self.feed_forward)
    
    

    Encoder

    class Encoder(nn.Module):
        def __init__(self, layer, N):
            """
            layer = EncoderLayer
            """
            super(Encoder, self).__init__()
            # 复制N个编码器基本单元
            self.layers = clones(layer, N)
            # 层归一化
            self.norm = LayerNorm(layer.size)
    
        def forward(self, x, mask):
            """
            循环编码器基本单元N次
            """
            for layer in self.layers:
                x = layer(x, mask)
            return self.norm(x)
            
    

    Decoder 结构

    在这里插入图片描述
    上图红色部分为 Transformer 的 Decoder block 结构,与 Encoder block 相似,但是存在一些区别:

    • 包含两个 Multi-Head Attention 层。
    • 第一个Masked Multi-Head Attention 层采用了 Masked 操作。
    • 第二个 Multi-Head Attention 层的 K, V 矩阵使用 Encoder 的编码信息矩阵 C 进行计算,而 Q 使用上一个 Decoder block 的输出计算。
    • 最后有一个 Softmax 层计算下一个翻译单词的概率。

    第一个Masked Multi-Head Self-Attention

    Decoder block 的第一个 Masked Multi-Head Self-Attention 采用了 Masked 操作,因为在翻译的过程中是顺序翻译的,即翻译完第 i 个单词,才可以翻译第 i+1 个单词。通过 Masked 操作可以防止第 i 个单词知道 i+1 个单词之后的信息。下面以 “我有一只猫” 翻译成 “I have a cat” 为例,了解一下 Masked 操作。

    下面的描述中使用了类似 Teacher Forcing 的概念,不熟悉 Teacher Forcing 的童鞋可以参考以下上一篇文章Seq2Seq 模型详解。在 Decoder 的时候,是需要根据之前的翻译,求解当前最有可能的翻译,如下图所示。首先根据输入 “” 预测出第一个单词为 “I”,然后根据输入 " I" 预测下一个单词 “have”。

    在这里插入图片描述
    Decoder 可以在训练的过程中使用 Teacher Forcing 并且并行化训练,即将正确的单词序列 ( I have a cat) 和对应输出 (I have a cat ) 传递到 Decoder。那么在预测第 i 个输出时,就要将第 i+1 之后的单词掩盖住,注意 Mask 操作是在 Self-Attention 的 Softmax 之前使用的,下面用 0 1 2 3 4 5 分别表示 < Begin > I have a cat < end >。

    第一步:是 Decoder 的输入矩阵和 Mask 矩阵,输入矩阵包含 " I have a cat" (0, 1, 2, 3, 4) 五个单词的表示向量,Mask 是一个 5×5 的矩阵。在 Mask 可以发现单词 0 只能使用单词 0 的信息,而单词 1 可以使用单词 0, 1 的信息,即只能使用之前的信息。
    在这里插入图片描述
    第二步:接下来的操作和之前的 Self-Attention 一样,通过输入矩阵 X计算得到 Q, K, V 矩阵。然后计算 Q 和 KT 的乘积 QKT。
    在这里插入图片描述
    第三步:在得到 QKT 之后需要进行 Softmax,计算 attention score,我们在 Softmax 之前需要使用 Mask矩阵遮挡住每一个单词之后的信息,遮挡操作如下:

    在这里插入图片描述
    得到 Mask QKT 之后在 Mask QKT 上进行 Softmax,每一行的和都为 1。但是单词 0 在单词 1, 2, 3, 4 上的 attention score 都为 0。

    第四步:使用 Mask QKT 与矩阵 V相乘,得到输出 Z,则单词 1 的输出向量 Z1 是只包含单词 1 信息的。
    Mask 之后的输出

    第五步:通过上述步骤就可以得到一个 Mask Multi-Head Self-Attention 的输出矩阵 Zi,然后和 Encoder 类似,通过 Multi-Head Self-Attention 拼接多个输出 Zi 然后计算得到第一个 Mask Multi-Head Self-Attention 的输出 Z,Z与输入 X 维度一样。

    def subsequent_mask(size):
        "Mask out subsequent positions."
        # 设定subsequent_mask矩阵的shape
        attn_shape = (1, size, size)
        # 生成一个右上角(不含主对角线)为全1,左下角(含主对角线)为全0的subsequent_mask矩阵
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
        # 返回一个右上角(不含主对角线)为全False,左下角(含主对角线)为全True的subsequent_mask矩阵
        return torch.from_numpy(subsequent_mask) == 0
    
    

    第二个 Multi-Head Self-Attention

    Decoder block 第二个 Multi-Head Attention 变化不大, 主要的区别在于其中 Self-Attention 的 K, V矩阵不是使用 上一个 Decoder block 的输出计算的,而是使用 Encoder 的编码信息矩阵 C 计算的。

    根据 Encoder 的输出 C计算得到 K, V,根据上一个 Decoder block 的输出 Z 计算 Q (如果是第一个 Decoder block 则使用输入矩阵 X 进行计算),后续的计算方法与之前描述的一致。

    这样做的好处是在 Decoder 的时候,每一位单词都可以利用到 Encoder 所有单词的信息 。

    Decoder Layer

    class DecoderLayer(nn.Module):
        def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
            super(DecoderLayer, self).__init__()
            self.size = size
            # 自注意力机制
            self.self_attn = self_attn
            # 上下文注意力机制
            self.src_attn = src_attn
            self.feed_forward = feed_forward
            self.sublayer = clones(SublayerConnection(size, dropout), 3)
    
        def forward(self, x, memory, src_mask, tgt_mask):
            # memory为编码器输出隐表示
            m = memory
            # 自注意力机制,q、k、v均来自解码器隐表示
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
            # 上下文注意力机制:q为来自解码器隐表示,而k、v为编码器隐表示
            x = self.sublayer[1](x, lambda x: self.self_attn(x, m, m, src_mask))
            return self.sublayer[2](x, self.feed_forward)
    

    Decoder

    class Decoder(nn.Module):
        def __init__(self, layer, N):
            super(Decoder, self).__init__()
            self.layers = clones(layer, N)
            self.norm = LayerNorm(layer.size)
    
        def forward(self, x, memory, src_mask, tgt_mask):
            """
            循环解码器基本单元N次
            """
            for layer in self.layers:
                x = layer(x, memory, src_mask, tgt_mask)
            return self.norm(x)
    

    Linear 与 Softmax

    class Generator(nn.Module):
        """
        解码器输出经线性变换和softmax函数映射为下一时刻预测单词的概率分布
        """
        def __init__(self, d_model, vocab):
            super(Generator, self).__init__()
            # decode后的结果,先进入一个全连接层变为词典大小的向量
            self.proj = nn.Linear(d_model, vocab)
    
        def forward(self, x):
            # 然后再进行log_softmax操作(在softmax结果上再做多一次log运算)
            return F.log_softmax(self.proj(x), dim=-1)
    
    

    Transformer

    class Transformer(nn.Module):
        def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
            super(Transformer, self).__init__()
            self.encoder = encoder
            self.decoder = decoder
            self.src_embed = src_embed
            self.tgt_embed = tgt_embed
            self.generator = generator
    
        def encode(self, src, src_mask):
            return self.encoder(self.src_embed(src), src_mask)
    
        def decode(self, memory, src_mask, tgt, tgt_mask):
            return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    
        def forward(self, src, tgt, src_mask, tgt_mask):
            # encoder的结果作为decoder的memory参数传入,进行decode
            return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    
    def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
        c = copy.deepcopy
        # 实例化Attention对象
        attn = MultiHeadedAttention(h, d_model).to(DEVICE)
        # 实例化FeedForward对象
        ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(DEVICE)
        # 实例化PositionalEncoding对象
        position = PositionalEncoding(d_model, dropout).to(DEVICE)
        # 实例化Transformer模型对象
        model = Transformer(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
            Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
            nn.Sequential(Embeddings(d_model, src_vocab).to(DEVICE), c(position)),
            nn.Sequential(Embeddings(d_model, tgt_vocab).to(DEVICE), c(position)),
            Generator(d_model, tgt_vocab)).to(DEVICE)
    
        # This was important from their code.
        # Initialize parameters with Glorot / fan_avg.
        for p in model.parameters():
            if p.dim() > 1:
                # 这里初始化采用的是nn.init.xavier_uniform
                nn.init.xavier_uniform_(p)
        return model.to(DEVICE)
    

    Label Smoothing

    为了防止模型在训练时过于自信地预测标签,改善模型的泛化能力,我们可以增加一个label smoothing的操作

    class LabelSmoothing(nn.Module):
        """
        标签平滑
        """
    
        def __init__(self, size, padding_idx, smoothing=0.0):
            super(LabelSmoothing, self).__init__()
            self.criterion = nn.KLDivLoss(reduction='sum')
            self.padding_idx = padding_idx
            self.confidence = 1.0 - smoothing
            self.smoothing = smoothing
            self.size = size
            self.true_dist = None
    
        def forward(self, x, target):
            assert x.size(1) == self.size
            true_dist = x.data.clone()
            true_dist.fill_(self.smoothing / (self.size - 2))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            true_dist[:, self.padding_idx] = 0
            mask = torch.nonzero(target.data == self.padding_idx)
            if mask.dim() > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0)
            self.true_dist = true_dist
            return self.criterion(x, Variable(true_dist, requires_grad=False))
    

    总结

    1. Transformer 与 RNN 不同,可以比较好地并行训练。
    2. Transformer 中 Multi-Head Attention 中有多个 Self-Attention,可以捕获单词之间多种维度上的相关系数 attention score。
    3. 由于 self-attention 没有循环结构,Transformer 需要一种方式来表示序列中元素的相对或绝对位置关系。Position Embedding (PE) 就是该文提出的方案。但在一些研究中,模型加上 PE 和不加上 PE 并不见得有明显的差异

    全部代码和数据都已经上传github transformer-english2Chinese

    展开全文
  • transformer_news:基于transformer的中英文平行语料翻译系统
  • 对于机器翻译而言,分为源语言与目的语言(如英文与中文) 对源语言建立词典,大小为src_vocab_size (包括padding) 对目标语言建立词典,大小为tgt_vocab_size (包括padding) 在词典中对词进行排序(放好位置就行...

    transformer的核心是self-attention,self-attention可参考:

    1. 图解transformer
    2. 李宏毅老师的transform

    一、机器翻译

    对于机器翻译而言,分为源语言与目的语言(如英文与中文)
    对源语言建立词典,大小为src_vocab_size (包括padding)
    对目标语言建立词典,大小为tgt_vocab_size (包括padding)

    在词典中对词进行排序(放好位置就行),按照下标给词一个数字表示(0一定表示padding,得规定好)

    训练集为源语言与目标语言的一一对应关系,但是

    1. 源语言与目标语言对应的句子的长度很可能不一致
    2. 源语言与目标语言中所有句子可能有长有短

    二、数据处理

    1、batch

    模型中要通过矩阵加速运算,如果是一对句子一对句子的输入模型进行训练就行,但是训练是需要分批输入到模型中的。
    所以一个bacth中的句子需要是等长度的所以需要padding,即填充0。
    以下为一个batch的处理过程:
    在这里插入图片描述
    最终一个batch的输入格式为(B,L)
    B为batch_size, L为本次batch的输入的句子的单词个数

    2、embedding

    源语言与目标语言词典中的每一个词都可以转换为一个 d m o d e l d_{model} dmodel长度的向量,词向量可以预训练得来,也可以加入到此网络中进行训练。词向量类似与一个矩阵:
    在这里插入图片描述

    源语言的词向量矩阵如上所示,每一行为一个词向量,可以用每一个词的代表数字表示当成行标直接去取对应的词向量

    最终将一个batch的输入转换为(B,L,d_model)
    B为batch_size, L为本次batch的输入的句子的单词个数,d_model为词向量长度

    3、pad Mask

    加入了padding,但是padding是不能加入到self-attention的计算中的, 这意味着什么呢?
    思考一下self-attention中会产生一个单词与单词之间的关系矩阵,每个单词的query向量与所有单词的key向量点乘再除以向量长度取根号,之后经过softmax层归一化将其变为都是正数并且加起来等于 1。 (Scaled Dot-Product attention)
    在这里插入图片描述
    最终形成的矩阵格式为(L,L),表示一个句子的各个词之间的关系分数。
    第一行就表示第一个单词与其他所有单词的关系分数,每一行的和为1
    在这里插入图片描述

    但是我们说单词与padding之间是没有关系的,即单词与padding之间的关系分为为0,那把softmax之后分数直接置为0
    在这里插入图片描述
    不行,因为行的和不再为1,这样就不合理了,所以我们在softmax之前做操作。

    回忆一下softmax的公式:
    在这里插入图片描述
    要让某个值为0,只需要让 z i z_i zi为负无穷即可,所以只需要让softmax之前的矩阵的padding列全为无穷即可,softmax之前的矩阵如下所示:
    在这里插入图片描述
    注:
    为什么不考虑padding的行呢?
    我的理解是padding行最终也会输出一个向量,但是这个向量不会参数到self-attention的计算中,所以不用管它

    三、encoder

    采用论文中的描述
    “编码器由N = 6个相同的层组成。每层有两个子层。
    第一个是一个多头自注意机制
    第二个是一个简单的,位置上完全连接的前馈网络。
    我们在两个子层周围使用一个残差连接,然后是层规范化。
    也就是说,每个子层的输出是LayerNorm(x +子层(x)),其中子层(x)是由子层本身实现的函数。
    为了方便这些残差连接,模型中的所有子层以及嵌入层产生的输出维数dmodel = 512。”
    如下所示:
    在这里插入图片描述

    一层encoder(图中标识为N层)

    1、 输入

    (1)输入矩阵:(B,L)

    (2)单词embedding:(B,L,d_model)

    (3)加上position_embedding:(B,L,d_model)

    2、multi-Head attention

    注:论文中 d q = d k = d v = d m o d e l / n h e a d s d_q = d_k = d_v = d_{model} / n_{heads} dq=dk=dv=dmodel/nheads

    (1)转换为Q,K,V矩阵

    ( B , L , d m o d e l ) − > ( B , L , d q ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_q * n_{heads}) (B,L,dmodel)>(B,L,dqnheads) Q矩阵
    ( B , L , d m o d e l ) − > ( B , L , d k ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_k * n_{heads}) (B,L,dmodel)>(B,L,dknheads) K矩阵
    ( B , L , d m o d e l ) − > ( B , L , d v ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_v * n_{heads}) (B,L,dmodel)>(B,L,dvnheads) V矩阵
    一个线性层nn.Linear就可以完成

    (2)对每个注意力头进行soft-atttention计算

    (1)把头提出来
    ( B , L , d q ∗ n h e a d s ) − > ( B , n h e a d s , L , d q ) (B,L,d_q * n_{heads}) -> (B,n_{heads},L,d_q) (B,L,dqnheads)>(B,nheads,L,dq) Q矩阵
    ( B , L , d k ∗ n h e a d s ) − > ( B , n h e a d s , L , d k ) (B,L,d_k * n_{heads}) -> (B,n_{heads},L,d_k) (B,L,dknheads)>(B,nheads,L,dk) K矩阵
    ( B , L , d q ∗ n h e a d s ) − > ( B , n h e a d s , L , d v ) (B,L,d_q * n_{heads}) -> (B,n_{heads},L,d_v ) (B,L,dqnheads)>(B,nheads,L,dv) V矩阵

    (2)Scaled Dot-Product Attention
    <1> ( Q ∗ K T ) / d k (Q * K^T)/ \sqrt{d_k} (QKT)/dk -> ( B , n h e a d s , L , L ) ) (B,n_{heads},L,L)) (B,nheads,L,L))
    ( K T K^T KT 对最后两位进行转置,变为 ( B , n h e a d s , d k , L ) ) (B,n_{heads},d_k,L)) (B,nheads,dk,L))
    <2> mask结果,将padding列变为-inf
    <3>对每一行进行softmax
    <4>乘以V矩阵 -> ( B , n h e a d s , L , d v ) ) (B,n_{heads},L,d_v)) (B,nheads,L,dv))

    (3)多头融合

    <1> ( B , n h e a d s , L , d v ) ) (B,n_{heads},L,d_v)) (B,nheads,L,dv)) -> ( B , L , n h e a d s ∗ d v ) ) (B,L,n_{heads} * d_v)) (B,L,nheadsdv))
    <2> ( B , L , n h e a d s ∗ d v ) ) (B,L,n_{heads} * d_v)) (B,L,nheadsdv)) -> ( B , L , d m o d e l ) ) (B,L,d_{model})) (B,L,dmodel)) (线性层)

    3、残差连接+layerNorm

    (1)模型输入与多头注意力层的输出格式相同: ( B , L , d m o d e l ) ) (B,L,d_{model})) (B,L,dmodel)) 直接相加

    (2)layerNorm

    4、Position-wise Feed-Forward

    (1)全连接层 + relu

    ( B , L , d m o d e l ) ) (B,L,d_{model})) (B,L,dmodel)) -> ( B , L , d f f ) ) (B,L,d_{ff})) (B,L,dff))

    (2)全连接层

    ( B , L , d f f ) ) (B,L,d_{ff})) (B,L,dff)) -> ( B , L , d m o d e l ) ) (B,L,d_{model})) (B,L,dmodel))

    5、残差连接+layerNorm

    (1)模型输入与多头注意力层的输出格式相同: ( B , L , d m o d e l ) ) (B,L,d_{model})) (B,L,dmodel)) 直接相加

    (2)layerNorm

    最终encoder输出格式为 ( B , L , d m o d e l ) (B,L,d_{model}) (B,L,dmodel)

    四、decoder

    在这里插入图片描述
    decoder与encoder主要有两个部分不同

    1、subsequence_mask

    之前描述了pad_mask,是由于真正的单词与padding之间不能有关系,关系系数为0。
    subsequence_mask是针对目标语言的输入embedding而言的,在self-attention时我们要算各个词之间的关系分数。考虑一下利用RNN的翻译模型,目标语言是按顺序一个一个输入进网络中,即前面的词不知道后面的词,即前面的词不能算跟后面的词之间的关系,但是后面的词可以跟前面算关系。mask矩阵如下,
    在这里插入图片描述
    而且padding的mask也要考虑,最终变为
    在这里插入图片描述

    2、encoder-decoder attention

    在这里插入图片描述
    这一部分输入与之前不同,之前是这样的
    ( B , L , d m o d e l ) − > ( B , L , d q ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_q * n_{heads}) (B,L,dmodel)>(B,L,dqnheads) Q矩阵
    ( B , L , d m o d e l ) − > ( B , L , d k ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_k * n_{heads}) (B,L,dmodel)>(B,L,dknheads) K矩阵
    ( B , L , d m o d e l ) − > ( B , L , d v ∗ n h e a d s ) (B,L,d_{model}) -> (B,L,d_v * n_{heads}) (B,L,dmodel)>(B,L,dvnheads) V矩阵

    同一个输入,转换为Q,K,V矩阵

    现在变为这样
    ( B , L o u t p u t s , d m o d e l ) ( 来 自 o u t p u t s ) − > ( B , L o u t p u t s , d q ∗ n h e a d s ) (B,L_{outputs},d_{model}) (来自outputs) -> (B,L_{outputs},d_q * n_{heads}) (B,Loutputs,dmodel)(outputs)>(B,Loutputs,dqnheads) Q矩阵
    ( B , L i n p u t s , d m o d e l ) ( 来 自 i n t p u t s ) − > ( B , L i n p u t s , d k ∗ n h e a d s ) (B,L_{inputs},d_{model}) (来自intputs) -> (B,L_{inputs},d_k * n_{heads}) (B,Linputs,dmodel)(intputs)>(B,Linputs,dknheads) K矩阵
    ( B , L i n p u t s , d m o d e l ) ( 来 自 i n t p u t s ) − > ( B , L i n p u t s , d v ∗ n h e a d s ) (B,L_{inputs},d_{model}) (来自intputs) -> (B,L_{inputs},d_v * n_{heads}) (B,Linputs,dmodel)(intputs)>(B,Linputs,dvnheads) V矩阵

    query矩阵来自outputs,key矩阵、value矩阵来自inputs,相当于decoder从encoder的输出中选取数据来完成自己的输出,与sequence to sequence模型有异曲同工之妙。

    最终Linear层之前的输出为 ( B , L o u t p u t s , d m o d e l ) (B,L_{outputs},d_{model}) (B,Loutputs,dmodel)

    3、翻译输出

    linear之后softmax
    ( B , L o u t p u t s , d m o d e l ) (B,L_{outputs},d_{model}) (B,Loutputs,dmodel) -> ( B , L o u t p u t s , t g t _ v o c a b _ s i z e ) (B,L_{outputs},tgt\_vocab\_size) (B,Loutputs,tgt_vocab_size)
    选取每个位置概率最大的单词输出,即为翻译

    展开全文
  • 双壁合一 卷积神经网络(CNNS) Fundamentals of Convolutional Neural Networks LeNet && ModernCNN ...整合CNN和RNN的优势,Vaswani et al., 2017 创新性地使用注意力机制设计了 Transformer 模型。 该模型利用 at
  •   在Transformer的代码实现的基础上,采用机器翻译数据,进行基于Transformer的机器翻译实战—数据集代码链接。 如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信! ...

    文章目录



      在Transformer的代码实现的基础上,采用机器翻译数据,进行基于Transformer的机器翻译实战—数据集代码链接


    如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
    在这里插入图片描述


    展开全文
  • 文章目录基于transformer翻译系统1. 数据处理1.1 英文分词1.2 中文分词1.3 生成字典1.4 数据生成器...

    基于transformer 的翻译系统

    论文:https://arxiv.org/abs/1706.03762
    项目地址:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

    本文实现了一个基于自注意力机制的翻译系统。注意力机制是机制是这两年比较火的方向,其中去年提出的自注意力机制更是各位大神的宠儿,网上可读性较高的代码有一点点不完美的地方就是mask没有发挥作用,最近也在做翻译系统,于是整理本文分享思路。
    本文代码参考网上可读性较好的项目:https://github.com/Kyubyong/transformer
    但是作者在key_mask和queries_mask中有一定的失误,本文修改了对应的模型和multihead层,使该功能正常。

    转载请注明出处:https://blog.csdn.net/chinatelecom08

    1. 数据处理

    本文使用数据:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

    • 读取数据
    • 分别保存为inputs,outputs
    with open('cmn.txt', 'r', encoding='utf8') as f:
        data = f.readlines()
    
    • 1
    • 2
    • 3
    from tqdm import tqdm
    

    inputs = []
    outputs = []
    for line in tqdm(data[:10000]):
    [en, ch] = line.strip(’\n’).split(’\t’)
    inputs.append(en.replace(’,’,’ ,’)[:-1].lower())
    outputs.append(ch[:-1])

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    100%|██████████| 10000/10000 [00:00<00:00, 473991.57it/s]
    
     
     
    • 1
    • 查看数据格式
    print(inputs[:10])
    
     
     
    • 1
    ['hi', 'hi', 'run', 'wait', 'hello', 'i try', 'i won', 'oh no', 'cheers', 'he ran']
    
     
     
    • 1
    print(outputs[:10])
    
     
     
    • 1
    ['嗨', '你好', '你用跑的', '等等', '你好', '让我来', '我赢了', '不会吧', '乾杯', '他跑了']
    
     
     
    • 1

    1.1 英文分词

    我们将英文用空格隔开即可,但是需要稍微修改一下,将大写字母全部用小写字母代替。在上文中使用.lower进行了替代。

    for line in tqdm(data):
        [en, ch] = line.strip('\n').split('\t')
        inputs.append(en[:-1].lower())
        outputs.append(ch[:-1])
    
    • 1
    • 2
    • 3
    • 4
    • 5

    此处我们只需要将英文用空格分开即可。

    inputs = [en.split(' ') for en in inputs]
    
     
     
    • 1
    print(inputs[:10])
    
     
     
    • 1
    [['hi'], ['hi'], ['run'], ['wait'], ['hello'], ['i', 'try'], ['i', 'won'], ['oh', 'no'], ['cheers'], ['he', 'ran']]
    
     
     
    • 1

    1.2 中文分词

    • 中文分词选择结巴分词工具。
    import jieba
    outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs]
    
     
     
    • 1
    • 2
    • 也可以用hanlp。
    from pyhanlp import *
    outputs = [[term.word for term in HanLP.segment(line) if term.word != ' '] for line in outputs]
    
     
     
    • 1
    • 2
    • 或者按字分词?

    • 最终我选择了结巴分词

    import jieba
    jieba_outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs[-10:]]
    print(jieba_outputs)
    
     
     
    • 1
    • 2
    • 3
    [['你', '不應', '該', '去', '那裡', '的'], ['你', '以前', '吸煙', ',', '不是', '嗎'], ['你現', '在', '最好', '回家'], ['你', '今天', '最好', '不要', '出門'], ['你', '滑雪', '比', '我', '好'], ['你', '正在', '把', '我', '杯子', '里', '的', '东西', '喝掉'], ['你', '并', '不', '满意', ',', '对', '吧'], ['你', '病', '了', ',', '该', '休息', '了'], ['你', '很', '勇敢', ',', '不是', '嗎'], ['你', '的', '意志力', '很強']]
    
     
     
    • 1
    outputs = [[char for char in jieba.cut(line) if char != ' '] for line in tqdm(outputs)]
    
     
     
    • 1
    100%|██████████| 10000/10000 [00:00<00:00, 11981.68it/s]
    
     
     
    • 1

    1.3 生成字典

    将英文和中文映射为id

    def get_vocab(data, init=['<PAD>']):
        vocab = init
        for line in tqdm(data):
            for word in line:
                if word not in vocab:
                    vocab.append(word)
        return vocab
    

    SOURCE_CODES = [’<PAD>’]
    TARGET_CODES = [’<PAD>’, ‘<GO>’, ‘<EOS>’]
    encoder_vocab = get_vocab(inputs, init=SOURCE_CODES)
    decoder_vocab = get_vocab(outputs, init=TARGET_CODES)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    100%|██████████| 10000/10000 [00:00<00:00, 20585.73it/s]
    100%|██████████| 10000/10000 [00:01<00:00, 7808.17it/s]
    
     
     
    • 1
    • 2
    print(encoder_vocab[:10])
    print(decoder_vocab[:10])
    
     
     
    • 1
    • 2
    ['<PAD>', 'hi', 'run', 'wait', 'hello', 'i', 'try', 'won', 'oh', 'no']
    ['<PAD>', '<GO>', '<EOS>', '嗨', '你好', '你', '用', '跑', '的', '等等']
    
     
     
    • 1
    • 2

    1.4 数据生成器

    翻译系统训练所需要的数据形式,跟谷歌gnmt输入致,gnmt的原理可以参考:https://github.com/tensorflow/nmt
    大概是:

    • 编码器输入:I am a student
    • 解码器输入:(go) Je suis étudiant
    • 解码器输出:Je suis étudiant (end)

    即解码器输入起始部分有个开始符号,输出句尾有个结束符号。

    encoder_inputs = [[encoder_vocab.index(word) for word in line] for line in inputs]
    decoder_inputs = [[decoder_vocab.index('<GO>')] + [decoder_vocab.index(word) for word in line] for line in outputs]
    decoder_targets = [[decoder_vocab.index(word) for word in line] + [decoder_vocab.index('<EOS>')] for line in outputs]
    
     
     
    • 1
    • 2
    • 3
    print(decoder_inputs[:4])
    print(decoder_targets[:4])
    
     
     
    • 1
    • 2
    [[1, 3], [1, 4], [1, 5, 6, 7, 8], [1, 9]]
    [[3, 2], [4, 2], [5, 6, 7, 8, 2], [9, 2]]
    
     
     
    • 1
    • 2
    import numpy as np
    

    def get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4):
    batch_num = len(encoder_inputs) // batch_size
    for k in range(batch_num):
    begin = k batch_size
    end = begin + batch_size
    en_input_batch = encoder_inputs[begin:end]
    de_input_batch = decoder_inputs[begin:end]
    de_target_batch = decoder_targets[begin:end]
    max_en_len = max([len(line) for line in en_input_batch])
    max_de_len = max([len(line) for line in de_input_batch])
    en_input_batch = np.array([line + [0] (max_en_len-len(line)) for line in en_input_batch])
    de_input_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_input_batch])
    de_target_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_target_batch])
    yield en_input_batch, de_input_batch, de_target_batch

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4)
    next(batch)
    
     
     
    • 1
    • 2
    (array([[1],
            [1],
            [2],
            [3]]), array([[1, 3, 0, 0, 0],
            [1, 4, 0, 0, 0],
            [1, 5, 6, 7, 8],
            [1, 9, 0, 0, 0]]), array([[3, 2, 0, 0, 0],
            [4, 2, 0, 0, 0],
            [5, 6, 7, 8, 2],
            [9, 2, 0, 0, 0]]))
    
     
     
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2. 构建模型

    模型结构如下:

    在这里插入图片描述
    其中主要建模组件下面都会给出。

    论文:https://arxiv.org/abs/1706.03762
    关于论文讲解:百度即可,对着原论文代码一起看。
    我个人觉得结合代码就会很好理解。

    import tensorflow as tf
    
     
     
    • 1

    2.1 构造建模组件

    下面代码实现了图片结构中的各个功能组件。

    layer norm层

    在框框的位置。
    在这里插入图片描述

    def normalize(inputs, 
                  epsilon = 1e-8,
                  scope="ln",
                  reuse=None):
        '''Applies layer normalization.
    
    Args:
      inputs: A tensor with 2 or more dimensions, where the first dimension has
        `batch_size`.
      epsilon: A floating number. A very small number for preventing ZeroDivision Error.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
    
    Returns:
      A tensor with the same shape and data dtype as `inputs`.
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        inputs_shape <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span>
        params_shape <span class="token operator">=</span> inputs_shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span>
    
        mean<span class="token punctuation">,</span> variance <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>moments<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> keep_dims<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
        beta<span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
        gamma <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
        normalized <span class="token operator">=</span> <span class="token punctuation">(</span>inputs <span class="token operator">-</span> mean<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span> <span class="token punctuation">(</span>variance <span class="token operator">+</span> epsilon<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token punctuation">(</span><span class="token number">.5</span><span class="token punctuation">)</span> <span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> gamma <span class="token operator">*</span> normalized <span class="token operator">+</span> beta
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    embedding层

    这里值得一提的是本文的position encoding也是用embedding层表示,原论文中说用公式或者embedding层自己训练都可以。
    在这里插入图片描述

    def embedding(inputs, 
                  vocab_size, 
                  num_units, 
                  zero_pad=True, 
                  scale=True,
                  scope="embedding", 
                  reuse=None):
        '''Embeds a given tensor.
        Args:
          inputs: A `Tensor` with type `int32` or `int64` containing the ids
             to be looked up in `lookup table`.
          vocab_size: An int. Vocabulary size.
          num_units: An int. Number of embedding hidden units.
          zero_pad: A boolean. If True, all the values of the fist row (id 0)
            should be constant zeros.
          scale: A boolean. If True. the outputs is multiplied by sqrt num_units.
          scope: Optional scope for `variable_scope`.
          reuse: Boolean, whether to reuse the weights of a previous layer
            by the same name.
        Returns:
          A `Tensor` with one more rank than inputs's. The last dimensionality
            should be `num_units`.
    
    For example,
    
    ```
    import tensorflow as tf
    
    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
    outputs = embedding(inputs, 6, 2, zero_pad=True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(outputs)
    &gt;&gt;
    [[[ 0.          0.        ]
      [ 0.09754146  0.67385566]
      [ 0.37864095 -0.35689294]]
     [[-1.01329422 -1.09939694]
      [ 0.7521342   0.38203377]
      [-0.04973143 -0.06210355]]]
    ```
    
    ```
    import tensorflow as tf
    
    inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
    outputs = embedding(inputs, 6, 2, zero_pad=False)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(outputs)
    &gt;&gt;
    [[[-0.19172323 -0.39159766]
      [-0.43212751 -0.66207761]
      [ 1.03452027 -0.26704335]]
     [[-0.11634696 -0.35983452]
      [ 0.50208133  0.53509563]
      [ 1.22204471 -0.96587461]]]
    ```
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'lookup_table'</span><span class="token punctuation">,</span>
                                       dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">,</span>
                                       shape<span class="token operator">=</span><span class="token punctuation">[</span>vocab_size<span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">,</span>
                                       initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>contrib<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>xavier_initializer<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> zero_pad<span class="token punctuation">:</span>
            lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>shape<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                      lookup_table<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>embedding_lookup<span class="token punctuation">(</span>lookup_table<span class="token punctuation">,</span> inputs<span class="token punctuation">)</span>
    
        <span class="token keyword">if</span> scale<span class="token punctuation">:</span>
            outputs <span class="token operator">=</span> outputs <span class="token operator">*</span> <span class="token punctuation">(</span>num_units <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span> 
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73

    multihead层

    是self-attention的核心思想,务必把原理搞清楚。
    在这里插入图片描述
    意思是自己跟自己做注意力机制,但是在这之前通过线性变换,将原来的输入映射到8个不同的空间去计算,最后再接到一起。
    在这里插入图片描述
    该层实现了下面功能,给谷歌鼓掌:
    在这里插入图片描述

    def multihead_attention(key_emb,
                            que_emb,
                            queries, 
                            keys, 
                            num_units=None, 
                            num_heads=8, 
                            dropout_rate=0,
                            is_training=True,
                            causality=False,
                            scope="multihead_attention", 
                            reuse=None):
        '''Applies multihead attention.
    
    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q].
      keys: A 3d tensor with shape of [N, T_k, C_k].
      num_units: A scalar. Attention size.
      dropout_rate: A floating point number.
      is_training: Boolean. Controller of mechanism for dropout.
      causality: Boolean. If true, units that reference the future are masked. 
      num_heads: An int. Number of heads.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
        
    Returns
      A 3d tensor with shape of (N, T_q, C)  
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># Set the fall back option for num_units</span>
        <span class="token keyword">if</span> num_units <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
            num_units <span class="token operator">=</span> queries<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>
        
        <span class="token comment"># Linear projections</span>
        Q <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>queries<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
        K <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
        V <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
        
        <span class="token comment"># Split and concat</span>
        Q_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>Q<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, C/h) </span>
        K_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>K<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>
        V_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>V<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>
    
        <span class="token comment"># Multiplication</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>Q_<span class="token punctuation">,</span> tf<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span>K_<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        
        <span class="token comment"># Scale</span>
        outputs <span class="token operator">=</span> outputs <span class="token operator">/</span> <span class="token punctuation">(</span>K_<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Key Masking</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>key_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_k)</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k)</span>
        key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>queries<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        
        paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
        <span class="token comment"># Causality = Future blinding</span>
        <span class="token keyword">if</span> causality<span class="token punctuation">:</span>
            diag_vals <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
            tril <span class="token operator">=</span> tf<span class="token punctuation">.</span>linalg<span class="token punctuation">.</span>LinearOperatorLowerTriangular<span class="token punctuation">(</span>diag_vals<span class="token punctuation">)</span><span class="token punctuation">.</span>to_dense<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
            masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tril<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
            paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>masks<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
            outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
        <span class="token comment"># Activation</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
         
        <span class="token comment"># Query Masking</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>que_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_q)</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q)</span>
        query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>keys<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
        outputs <span class="token operator">*=</span> query_masks <span class="token comment"># broadcasting. (N, T_q, C)</span>
          
        <span class="token comment"># Dropouts</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> rate<span class="token operator">=</span>dropout_rate<span class="token punctuation">,</span> training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>
               
        <span class="token comment"># Weighted sum</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> V_<span class="token punctuation">)</span> <span class="token comment"># ( h*N, T_q, C/h)</span>
        
        <span class="token comment"># Restore shape</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span> <span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
              
        <span class="token comment"># Residual connection</span>
        outputs <span class="token operator">+=</span> queries
              
        <span class="token comment"># Normalize</span>
        outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91

    feedforward

    两层全连接,用卷积模拟加速运算,也可以使用dense层。你会发现这个框架所需组件全部凑齐了,可以召唤神龙了。
    在这里插入图片描述

    def feedforward(inputs, 
                    num_units=[2048, 512],
                    scope="multihead_attention", 
                    reuse=None):
        '''Point-wise feed forward net.
    
    Args:
      inputs: A 3d tensor with shape of [N, T, C].
      num_units: A list of two integers.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
        
    Returns:
      A 3d tensor with the same shape and dtype as inputs
    '''</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># Inner layer</span>
        params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> inputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
                  <span class="token string">"activation"</span><span class="token punctuation">:</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
        
        <span class="token comment"># Readout layer</span>
        params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> outputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
                  <span class="token string">"activation"</span><span class="token punctuation">:</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
        
        <span class="token comment"># Residual connection</span>
        outputs <span class="token operator">+=</span> inputs
        
        <span class="token comment"># Normalize</span>
        outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span>
    
    <span class="token keyword">return</span> outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    label_smoothing.

    对于训练有好处,将0变为接近零的小数,1变为接近1的数,下面注释很清楚。

    def label_smoothing(inputs, epsilon=0.1):
        '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
    
    Args:
      inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
      epsilon: Smoothing rate.
    
    For example,
    
    ```
    import tensorflow as tf
    inputs = tf.convert_to_tensor([[[0, 0, 1], 
       [0, 1, 0],
       [1, 0, 0]],
      [[1, 0, 0],
       [1, 0, 0],
       [0, 1, 0]]], tf.float32)
       
    outputs = label_smoothing(inputs)
    
    with tf.Session() as sess:
        print(sess.run([outputs]))
    
    &gt;&gt;
    [array([[[ 0.03333334,  0.03333334,  0.93333334],
        [ 0.03333334,  0.93333334,  0.03333334],
        [ 0.93333334,  0.03333334,  0.03333334]],
       [[ 0.93333334,  0.03333334,  0.03333334],
        [ 0.93333334,  0.03333334,  0.03333334],
        [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]   
    ```
    '''</span>
    K <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token comment"># number of channels</span>
    <span class="token keyword">return</span> <span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token operator">-</span>epsilon<span class="token punctuation">)</span> <span class="token operator">*</span> inputs<span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token punctuation">(</span>epsilon <span class="token operator">/</span> K<span class="token punctuation">)</span>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    2.2 搭建模型

    再看一次模型,我们发现里面的组件我们都已经构建好了。
    按照这个结构搭建模型就可以啦!
    在这里插入图片描述
    代码如下:

    class Graph():
        def __init__(self, is_training=True):
            tf.reset_default_graph()
            self.is_training = arg.is_training
            self.hidden_units = arg.hidden_units
            self.input_vocab_size = arg.input_vocab_size
            self.label_vocab_size = arg.label_vocab_size
            self.num_heads = arg.num_heads
            self.num_blocks = arg.num_blocks
            self.max_length = arg.max_length
            self.lr = arg.lr
            self.dropout_rate = arg.dropout_rate
    
        <span class="token comment"># input placeholder</span>
        self<span class="token punctuation">.</span>x <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>y <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>de_inp <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Encoder</span>
        <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"encoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token comment"># embedding</span>
            self<span class="token punctuation">.</span>en_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>input_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"enc_embed"</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                          vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"enc_pe"</span><span class="token punctuation">)</span>
            <span class="token comment">## Dropout</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                        rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                        training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>
    
            <span class="token comment">## Blocks</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>enc <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    
            <span class="token comment">### Feed Forward</span>
            self<span class="token punctuation">.</span>enc <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
        
        <span class="token comment"># Decoder</span>
        <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"decoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token comment"># embedding</span>
            self<span class="token punctuation">.</span>de_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"dec_embed"</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>dec <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                          vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"dec_pe"</span><span class="token punctuation">)</span>
            <span class="token comment">## Dropout</span>
            self<span class="token punctuation">.</span>dec <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                        rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                        training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>        
    
            <span class="token comment">## Multihead Attention ( self-attention)</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                    scope<span class="token operator">=</span><span class="token string">'self_attention'</span><span class="token punctuation">)</span>
    
            <span class="token comment">## Multihead Attention ( vanilla attention)</span>
            <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                    <span class="token comment">### Multihead Attention</span>
                    self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                                   que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                                   queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                    keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                    num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                    num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                    dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                    is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                    causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                    scope<span class="token operator">=</span><span class="token string">'vanilla_attention'</span><span class="token punctuation">)</span> 
    
            <span class="token comment">### Feed Forward</span>
            self<span class="token punctuation">.</span>outputs <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
                
        <span class="token comment"># Final linear projection</span>
        self<span class="token punctuation">.</span>logits <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>self<span class="token punctuation">.</span>outputs<span class="token punctuation">,</span> self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>preds <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_int32<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>istarget <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>not_equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>acc <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>preds<span class="token punctuation">,</span> self<span class="token punctuation">.</span>y<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
        tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'acc'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>acc<span class="token punctuation">)</span>
                
        <span class="token keyword">if</span> is_training<span class="token punctuation">:</span>  
            <span class="token comment"># Loss</span>
            self<span class="token punctuation">.</span>y_smoothed <span class="token operator">=</span> label_smoothing<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> depth<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax_cross_entropy_with_logits_v2<span class="token punctuation">(</span>logits<span class="token operator">=</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> labels<span class="token operator">=</span>self<span class="token punctuation">.</span>y_smoothed<span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>mean_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>loss<span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
               
            <span class="token comment"># Training Scheme</span>
            self<span class="token punctuation">.</span>global_step <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'global_step'</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>optimizer <span class="token operator">=</span> tf<span class="token punctuation">.</span>train<span class="token punctuation">.</span>AdamOptimizer<span class="token punctuation">(</span>learning_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>lr<span class="token punctuation">,</span> beta1<span class="token operator">=</span><span class="token number">0.9</span><span class="token punctuation">,</span> beta2<span class="token operator">=</span><span class="token number">0.98</span><span class="token punctuation">,</span> epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>train_op <span class="token operator">=</span> self<span class="token punctuation">.</span>optimizer<span class="token punctuation">.</span>minimize<span class="token punctuation">(</span>self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">,</span> global_step<span class="token operator">=</span>self<span class="token punctuation">.</span>global_step<span class="token punctuation">)</span>
                   
            <span class="token comment"># Summary </span>
            tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'mean_loss'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">)</span>
            self<span class="token punctuation">.</span>merged <span class="token operator">=</span> tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>merge_all<span class="token punctuation">(</span><span class="token punctuation">)</span>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111

    3. 训练模型

    用我们搭建好的模型,和准备好的数据进行训练!

    3.1 参数设定

    def create_hparams():
        params = tf.contrib.training.HParams(
            num_heads = 8,
            num_blocks = 6,
            # vocab
            input_vocab_size = 50,
            label_vocab_size = 50,
            # embedding size
            max_length = 100,
            hidden_units = 512,
            dropout_rate = 0.2,
            lr = 0.0003,
            is_training = True)
        return params
    

    arg = create_hparams()
    arg.input_vocab_size = len(encoder_vocab)
    arg.label_vocab_size = len(decoder_vocab)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    3.2 模型训练

    import os
    

    epochs = 25
    batch_size = 64

    g = Graph(arg)

    saver =tf.train.Saver()
    with tf.Session() as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    if os.path.exists(‘logs/model.meta’):
    saver.restore(sess, ‘logs/model’)
    writer = tf.summary.FileWriter(‘tensorboard/lm’, tf.get_default_graph())
    for k in range(epochs):
    total_loss = 0
    batch_num = len(encoder_inputs) // batch_size
    batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size)
    for i in tqdm(range(batch_num)):
    encoder_input, decoder_input, decoder_target = next(batch)
    feed = {g.x: encoder_input, g.y: decoder_target, g.de_inp:decoder_input}
    cost,_ = sess.run([g.mean_loss,g.train_op], feed_dict=feed)
    total_loss += cost
    if (k batch_num + i) % 10 == 0:
    rs=sess.run(merged, feed_dict=feed)
    writer.add_summary(rs, k batch_num + i)
    if (k+1) % 5 == 0:
    print(‘epochs’, k+1, ': average loss = ', total_loss/batch_num)
    saver.save(sess, ‘logs/model’)
    writer.close()

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    100%|██████████| 156/156 [00:31<00:00,  6.19it/s]
    100%|██████████| 156/156 [00:24<00:00,  5.83it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.23it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.11it/s]
    100%|██████████| 156/156 [00:24<00:00,  6.14it/s]
    

    epochs 5 : average loss = 3.3463863134384155

    100%|██████████| 156/156 [00:23<00:00, 6.27it/s]
    100%|██████████| 156/156 [00:23<00:00, 5.86it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.33it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.08it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.29it/s]

    epochs 10 : average loss = 2.0142565186207113

    100%|██████████| 156/156 [00:24<00:00, 6.18it/s]
    100%|██████████| 156/156 [00:24<00:00, 5.84it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.38it/s]

    epochs 15 : average loss = 1.5278632457439716

    100%|██████████| 156/156 [00:24<00:00, 6.15it/s]
    100%|██████████| 156/156 [00:24<00:00, 5.86it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.23it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.13it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.32it/s]

    epochs 20 : average loss = 1.4216684783116365

    100%|██████████| 156/156 [00:23<00:00, 6.26it/s]
    100%|██████████| 156/156 [00:23<00:00, 5.89it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.26it/s]
    100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
    100%|██████████| 156/156 [00:23<00:00, 6.35it/s]

    epochs 25 : average loss = 1.3833287457625072

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    3.3 模型推断

    输入几条拼音测试一下效果如何:

    arg.is_training = False
    

    g = Graph(arg)

    saver =tf.train.Saver()

    with tf.Session() as sess:
    saver.restore(sess, ‘logs/model’)
    while True:
    line = input(‘输入测试拼音: ‘)
    if line ‘exit’: break
    line = line.lower().replace(’,’, ’ ,’).strip(’\n’).split(’ ‘)
    x = np.array([encoder_vocab.index(pny) for pny in line])
    x = x.reshape(1, -1)
    de_inp = [[decoder_vocab.index(’<GO>’)]]
    while True:
    y = np.array(de_inp)
    preds = sess.run(g.preds, {g.x: x, g.de_inp: y})
    if preds[0][-1] decoder_vocab.index(’<EOS>’):
    break
    de_inp[0].append(preds[0][-1])
    got = ‘’.join(decoder_vocab[idx] for idx in de_inp[0][1:])
    print(got)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    INFO:tensorflow:Restoring parameters from logs/model
    输入测试拼音: You could be right, I suppose
    我猜想你可能是对的
    输入测试拼音: You don't believe Tom, do you
    你不信任汤姆,对吗
    输入测试拼音: Tom has lived here since 2003
    汤姆自从2003年就住在这里
    输入测试拼音: Tom asked if I'd found my key
    湯姆問我找到我的鑰匙了吗
    输入测试拼音: They have a very nice veranda
    他们有一个非常漂亮的暖房
    输入测试拼音: She was married to a rich man
    她嫁給了一個有錢的男人
    输入测试拼音: My parents sent me a postcard
    我父母給我寄了一張明信片
    输入测试拼音: Just put yourself in my shoes
    你站在我的立場上考慮看看
    输入测试拼音: It was a very stupid decision
    这是一个十分愚蠢的决定
    输入测试拼音: I'm really sorry to hear that
    听到这样的消息我真的很难过
    输入测试拼音: His wife is one of my friends
    他的妻子是我的一個朋友
    输入测试拼音: He thought of a good solution
    他想到了一個解決的好辦法
    输入测试拼音: exit
    
     
     
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    结果果然不错,训练速度也是比基于rnn的encoder decoder结构快很多,不得不说谷歌真棒啊。

    转载请注明出处:https://blog.csdn.net/chinatelecom08

    同学们喜欢的话给我项目点个星吧!
    https://github.com/audier

                                    </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/chinatelecom08/article/details/85068059&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div>
                <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
                                </div>
            </article>
    
    展开全文
  • 或者可以使用双线性抽样核: 为了允许采样机制反向传播,可以定义关于U和g的梯度,对于双线性抽样,其偏导数为: Spatial Transformer Networks 局部化网络、网格生成器和采样器的组合形成了一个空间转换器。...
  • 用PyTorch玩转Transformer英译中翻译

    千次阅读 2021-02-04 14:58:13
    因此,我们在本文中也具体分享Transformer翻译任务中的使用,教你从零开始玩转Transformer机器翻译模型! 02 数据处理 2.1 原始数据格式 如下所示,原始的数据集是en-cn的sentence pair。 [ ['Some analysts argue...
  • 2.Formulation of Transformer Transformer [123]首次应用于神经语言处理中的机器翻译任务。如图2所示,它由一个编码器模块和一个解码器模块组成,具有几个相同结构的编码器/解码器。每个编码器由自注意力层和前馈...
  • 详解谷歌机器翻译模型:Transformer1. 模型框/架2. 具体的步骤2.1 Embedding algorithm2.2 使用单词进行具体说明2.3 三个向量Query vector,Key vector和Value vectorQuery\,vector,Key\,vector和Value\, ...
  • Transformer论文详解,论文完整翻译(一) 概要 重要的序列转换模型基于复杂的CNN或者RNN进行encoder和decoder。同时最好的模型也使用了attention连接encoder和decoder。我们提出一个新的网络结构,Transformer,...
  • 中英翻译对数据集的预处理。 Transorflow模型代码 这个模型是从GitHub获取的官网代码,不需要怎么改动。接下来我就简单的讲下每部分代码都起到了什么作用,至于为什么要用到这些代码和这些代码主要是干什么的?大家...
  • 本课程将详细讲解 Transformer 新型神经网络及其在机器翻译中的应用,并从工业实践和评测竞赛的角度更全面的展现其实用价值。 本场 Chat 主要内容: 神经网络翻译架构; Transformer 新型网络结构解析; 基于 ...
  • 机翻后手动修改,排版按照原论文格式,格式上有些地方有点问题。
  • 深度学习入门-4(机器翻译,注意力机制和Seq2seq模型,Transformer)一、机器翻译1、机器翻译概念2、数据的处理3、机器翻译组成模块(1)Encoder-Decoder框架(编码器-解码器)(2)Sequence to Sequence模型(3)集...
  • Transformer--论文翻译:Attention Is All You Need 中文

    千次阅读 多人点赞 2019-11-15 11:42:29
    Attention Is All Your Need 摘要   主流的序列转换模型都是基于复杂的循环神经...我们提出了一个新的、简单的网络架构,Transformer. 它只基于单独的attention机制,完全避免使用循环和卷积。在两个翻译任务上表...
  • ▌一、对网络结构进行解析 Transformer这个网络命名,Transformer在英文翻译过来有变形金刚的意思,我想当时作者也是希望借助强大的变形金刚赋予这个网络更强的力量。 如果我们用放大镜高维度解析这个网络,拆开...
  • Transformer

    2019-03-23 21:40:01
    目录 重要性 简介 模型概览 加入输入embedding Self-Attention Multi-head Attention ...Transformer的优势 ...最新的OpenAI GPT和BERT模型,都是以Transformer为基础的。 在“放弃幻想,全...
  • CSDN上又博主(于建民)对其进行了很好的中文翻译中文版:The Illustrated Transformer【译】 Google AI blog写的一篇简述可以作为科普文: Transformer: A Novel Neural Network Architecture for Language ...
  • 本文从Seq2Seq模型的思想开始,一步一步剖析Transformer的原理。
  • Vision Transformer详解

    万次阅读 多人点赞 2021-06-26 10:20:11
    论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale ...原论文对应源码:...Pytorch实现代码: pytorch_classification/vision_transformer 文章目录前言模型详解Vision
  • Transformer这个网络命名,Transformer在英文翻译过来有变形金刚的意思,我想当时作者也是希望借助强大的变形金刚赋予这个网络更强的力量。 如果我们用放大镜高维度解析这个网络,拆开之后它仍然是“Sequence to ...
  • transformer模型详解

    千次阅读 2018-12-20 21:10:43
    综合分析了现有的主流的nlp翻译模型的即基于CNN的可并行对其文本翻译和基于RNN的LSTM门控长短期记忆时序翻译模型,总结了两个模型的优缺点并在此基础上提出了基于自注意力机制的翻译模型transformertransformer...
  • A Survey on Vision Transformer 视觉Transformer综述 Abstract: Transformer是一种基于自注意力机制的深度神经网络,最初被用于NLP,由于其强大的特征表征能力,也逐步被应用于计算机视觉任务中去。在一些计算机...
  • Transformer论文详解,论文完整翻译(七) 3.3 位置相关的前馈神经网络 除了子层的attention之外,每个encoder和decoder层包括了一个全连接前馈网络,每个网络在每个位置中是单独并且相同的。网络包括了两个线性转换...
  • swin Transformer论文解析

    2021-05-21 14:08:53
    详见:https://zhuanlan.zhihu.com/p/367111046
  • The Illustrated Transformer【译】

    万次阅读 多人点赞 2018-12-23 11:28:08
    翻译一篇非常赞的解释Transformer的文章,原文链接。 在之前的文章中,Attention成了深度学习模型中无处不在的方法,它是种帮助提升NMT(Neural Machine Translation)的翻译效果的思想。在本篇博客中,我们解析下...
  • 通俗理解swin transformer

    2021-07-14 15:23:09
    https://zhuanlan.zhihu.com/p/362672090 简单易懂
  • Transformer是机器翻译领域的一个经典模型,一经问世,便取得了SOTA效果。本文将带领大家一同探秘Transformer,并送上基于飞桨实现 Transformer的实战教程。 1. 机器翻译概述 1.1.机器翻译是什么 机器翻译...
  • 论文介绍论文名称:基于Transformer增强架构的中文语法纠错论文作者:王辰成,杨麟儿,王莹莹,杜永萍,杨尔弘发表于:第十八届中国计算语言学大会(CCL 2019)-ONE-简介语法...

空空如也

空空如也

1 2 3 4 5 ... 20
收藏数 3,169
精华内容 1,267
关键字:

transformer中文翻译