精华内容
下载资源
问答
  • 文 / Ankur Parikh 和 Xuezhi Wang,Google Research 研究员在过去几年中,自然语言生成 (Neural Language Generation, N...

    文 / Ankur Parikh 和 Xuezhi Wang,Google Research 研究员

    在过去几年中,自然语言生成 (Neural Language Generation, NLG) 方向取得了很大的进步,相关的研究已被应用在生成文本摘要等任务中。然而,尽管神经网络已可以生成流畅的文本,但仍然容易产生幻觉(Hallucination,如:生成通顺但与原文不相关的内容),导致无法将这些系统部署在对准确性有较高要求的许多场景。

    • 幻觉
      https://arxiv.org/abs/1707.08052

    Wikibio 数据集为例,向神经基线模型分配一个任务,使其总结比利时足球运动员 Constant Vanden Stock 的 Wikipedia 信息框内容,但该模型得出了他是一名美国花样滑冰运动员错误结论

    • Wikibio 数据集
      https://arxiv.org/abs/1603.07771

    • 基线模型
      https://arxiv.org/abs/1704.04368

    评估所生成文本对源内容契合程度非常具有挑战性,但如果将源内容结构化(例如,以表格形式),那么这一过程在一定程度上会变得容易一些。此外,结构化数据还可以测试模型的推理能力和数字推断能力。但现有的大规模结构化数据集往往含有噪声(即无法完全根据表格式数据推断出作为参考的句子),因此无法通过现有数据集对模型开发中的幻觉进行客观测量。

    在“ToTTo:受控的表到文本生成数据集(ToTTo: A Controlled Table-to-Text Generation Dataset) 一文中,我们提出了一个开放域表到文本生成数据集,并使用全新的注释处理方式(通过句子修订)以及一个用于评估模型幻觉的受控文本生成任务构建该数据集。

    • ToTTo:受控的表到文本生成数据集
      https://arxiv.org/abs/2004.14373

    ToTTo 为“Table-To-Text”(表到文本)的缩写,包含训练样本 121,000 个,以及用于开发和测试的样本各 7,500 个。由于注释的准确性,此数据集适合作为高精度文本生成研究中的挑战性 benchmark。数据集和代码已在我们的 GitHub 仓库上开源。

    • GitHub 仓库
      https://github.com/google-research-datasets/totto

    表到文本生成

    ToTTo 引入了一项受控的 (Controlled) 生成任务,在该任务中,将包含一组选定单元格的 Wikipedia 表用作源材料,生成一句话总结表中的单元格内容。下方示例说明了该任务中的挑战,如数字推理、大型开放域词汇表和不同的表结构等。

    例如,在此 ToTTo 数据集中,给定源表及其中突出显示的单元格集(左侧),目标则是生成一句话,如“目标句子”(右侧)。请注意,生成目标句子需要进行数字推断(十一个 NFL 赛季)并理解 NFL(美国国家橄榄球联盟) 领域

    注释策略

    设计一种根据表格数据注释出自然且简明的句子极具挑战性。一种办法是:如 Wikibio 和 RotoWire 等许多数据集会启发式地将自然生成的文本与表配对,但是这个过程会引入噪声,因此很难区分幻觉主要是由数据噪声还是模型缺陷导致的。另一种办法是:让标注者从头开始编写契合表内容的目标句子,但这样得到的目标句子在结构和风格方面往往缺乏多样性。

    相比之下,ToTTo 使用一种全新的数据注释策略,标注者分阶段修订现有的 Wikipedia 句子。这样一来,目标句子就能变得简洁自然,同时表现出有趣多样的语言特性。

    数据收集和注释处理的第一步是从 Wikipedia 收集表。在这一步中,给定表会启发式地与相关页面上下文的总结句配对,如页面文本和表之间单词重叠以及引用表格数据的超链接。此总结句可能包含表上没有的信息,也可能包含所指示先行词仅存在于表中(而非句子本身)的代词。

    然后,标注者会突出显示表中与总结句相关的单元格,并删除该句中与表无关的短语。标注者还会对句子进行去语境化处理(例如,正确替换/定向代词),以便根据需要生成语法正确的独立句子。

    我们的研究表明,标注者在上述任务中能够实现高度一致:突出显示的单元格的 Fleiss Kappa 为 0.856,最终目标句子的 BLEU 为 67.0。

    数据集分析

    我们对 ToTTo 数据集中超过 44 个类别进行了主题分析,发现体育和国家/地区主题包含一系列细粒度主题,例如体育中的足球/奥林匹克主题以及国家/地区中的人口/建筑物主题,总共占数据集的 56.4%。另外的 44% 的主题范围更广泛,包括表演艺术、交通与娱乐。

    此外,我们还随机选取了 100 多个样本,对数据集中不同类型的语言现象展开人工分析。下表汇总了部分需要参考页面和分区标题的样本,以及数据集中可能会对当前系统构成全新挑战的一些语言现象。

    语言现象百分比
    需要参考页面标题82%
    需要参考分区标题19%
    需要参考表说明3%
    推理(逻辑、数字、时间等)21%
    跨行/列/单元格比较13%
    需要背景信息12%

    基线模型结果

    我们列出了文献中三个最先进模型(BERT-to-BERTPointer Generator Puduppully 于 2019 发表的模型)在两个评估指标 BLEU 和 PARENT 上的一些基线结果。除了报告基于整个测试集的分数外,我们还基于一个更具挑战性的子集(由域外样本组成)评估了每个模型。如下表所示,BERT-to-BERT 模型在 BLEU 和 PARENT 这两个指标上的表现都优于其他两个模型。此外,所有模型在挑战性子集上的表现都相当不理想,这表示域外泛化仍具挑战性。


    BLEUPARENTBLEUPARENT
    模型总体总体挑战挑战
    BERT-to-BERT43.952.634.846.7
    Pointer Generator41.651.632.245.2
    Puduppully 等人于 2019 年发表的模型19.229.213.925.8

    • BERT-to-BERT
      https://arxiv.org/abs/1907.12461

    • Pointer Generator 
      https://arxiv.org/abs/1704.04368

    • Puduppully 于 2019 发表的模型
      https://arxiv.org/abs/1809.00582

       

    虽然自动指标可以在一定程度上表明性能,但目前尚不足以评估文本生成系统中的幻象。为更好地理解幻象,我们假设差异表示幻象,并在此前提下手动评估最高性能基线,以确定目标句子对源表内容的忠实度。为计算“专家”性能,我们为多参考测试集中的各个样本设定一个参考模型,并要求标注者比较该样本与其他参考模型的忠实度。如结果所示,最高性能基线下,出现幻象信息的概率约为 20%。

     忠实度忠实度
    模型(总体)(挑战)
    专家93.691.4
    BERT-to-BERT76.274.2

    模型错误与挑战

    下表列出了观察到的模型错误,重点说明 ToTTo 数据集的一些更具挑战性的方面。我们发现,即便使用简洁明了的参考,最先进的模型在应对幻觉、数字推理和罕见主题时仍会遇到困难(以红色标记的错误)。最后一个示例表明,即使模型输出正确,有时也无法像原始参考(包含更多表相关推理,以蓝色显示)那样信息丰富。

    参考模型预测
    在 1939 年库里杯中,西部省在开普敦以 17–6 的比分输给了德兰士瓦。第一届库里杯于 1939 年在新大陆德兰士瓦 1 举行,西部省以 17–6 的比分赢得了比赛。
    IBM 公司在 2000 年发布了第二代微型硬盘,其容量增加到 512 MB 和 1 GB。2000 年有 512 块微型硬盘模型:1 千兆字节。
    1956 年世界摩托车锦标赛赛季包括 6 场大奖赛,每场包含 5 个级别:500cc、350cc、250cc、125cc 和边车 500cc。1956 年世界摩托车锦标赛赛季包括 8 场大奖赛,每场包含 5 个级别:500cc、350cc、250cc、125cc 和边车 500cc。
    在 Travis Kelce 的最后一个大学赛季中,他在接球次数 (45)、接球码数 (722)、单次接球码数 (16.0) 和接球达阵次数 (8) 方面创下了个人的职业生涯记录Travis Kelce 以 45 次,共计 722 码(平均为 16.0 码)的接球和 8 次达阵结束了 2012 年赛季。

    结论

    在这项研究中,我们提出了 ToTTo,这是一个表到文本的大型英语数据集,不仅会提供受控的生成任务,还会提供基于迭代句子修订的数据注释处理。我们还提供了几个最先进的基线,并证明了 ToTTo 数据集有助于研究建模以及开发可更好地检测模型改进情况的评估指标。

    除了提及的任务,我们希望我们的数据集也可以为其他任务提供帮助,如表理解和句子修订等。您可通过我们的 GitHub 仓库获取 ToTTo。

    • GitHub 仓库(或“阅读原文”)
      https://github.com/google-research-datasets/totto

    致谢

    作者要感谢 Ming-Wei Chang、Jonathan H. Clark、Kenton Lee 和 Jennimaria Palomaki 提供的深刻探讨和支持。同时非常感谢 Ashwin Kakarla 及其团队在注释工作中的帮助。

    更多 AI 相关阅读:

     点击屏末  | 即刻访问 GitHub

    展开全文
  • DP-SeqGAN通过生成对抗网络自动提取数据集的重要特征并生成与原数据分布接近的新数据集,基于差分隐私对模型做随机加扰以提高生成数据集的隐私性,并进一步降低鉴别器过拟合。DP-SeqGAN 具有直观通用性,无须对具体...
  • 关于文本生成数据集记录

    千次阅读 2018-07-17 17:18:18
    摘要数据集 cnn/dailymail Gigaword Gigaword corpus [Graff and Cieri, 2003] preprocessed identically to [Rush et al., 2015], which leads to around 3.8M training samples, 190K validation samples and ...

    摘要数据集

    cnn/dailymail

    Gigaword
    Gigaword corpus [Graff and Cieri, 2003] preprocessed identically to [Rush et al., 2015], which leads to around 3.8M training samples, 190K validation samples and 1951 test samples for evaluation. The input summary pairs consist of the head- line and the first sentence of the source articles.

    中文摘要数据集
    a large corpus of Chinese short text summarization (LCSTS) dataset [Hu et al., 2015] collected and constructed from the Chinese microblogging website Sina Weibo.

    散文生成数据集

    数据集和代码地址
    论文:Topic-to-Essay Generation with Neural Networks
    数据集介绍:
    In order to guarantee the quality of the crawled text, we only crawl the compositions which contain some reviews and scores. The process of the data collection is summarized as follows: a) We crawl 228,110 articles, which have high scores. b) We choose paragraphs composed of 50 to 120 words to be our corpus from these articles. c) We follow [Wang et al., 2016b] and also employ TextRank [Mihalcea and Tarau, 2004] to extract keywords as topic words. In the end, we obtain 305,000 paragraph-level essays and randomly select 300,000 as training set and 5,000 as test set. We name this dataset as ESSAY
    ZhiHu:
    In this paper, we also find some articles that conform to our requirements on ZhiHu, a Chinese question-and-answer website, where questions are created, answered, edited and organized by users in the community. In particular, users also give the topic words of each article. Based on the information mentioned above, we crawl a large number of Zhihu’ articles and corresponding topic words. Referring

    展开全文
  • 因此,图像文本识别能够将图像中的文本区域转化成计算机可以读取和编辑的符号,打通了从图像到文本再到信息的通路。 随着计算机算力的提升,基于深度学习方法的本文识别技术逐渐成为主流,而深度学习中数...

    代码地址如下:
    http://www.demodashi.com/demo/14792.html

    一、开发背景

    图像中的文本识别近几年来备受瞩目。通常来说,图片中的文本能够比图片中其他内容提供更加丰富的信息。因此,图像文本识别能够将图像中的文本区域转化成计算机可以读取和编辑的符号,打通了从图像到文本再到信息的通路。

    随着计算机算力的提升,基于深度学习方法的本文识别技术逐渐成为主流,而深度学习中数据集的获取是重中之重。本脚本实现读取语料集中的文本内容,以保存为图像形式的数据集,用于模型训练。

    二、脚本效果

    1、IDE中的运行界面

    (1)选择字体文件

    (2)生成数据集

    2、生成的图像

    不使用数据增强


    使用数据增强



    3、映射表

    存储图像文件名和类别序列标注的对应关系

    三、具体开发

    1、功能需求

    1. 根据用户指定的语料数据生成图像文件及映射表
    2. 用户可自行更改文本长度,图像数量及图像尺寸
    3. 用户可自行选择是否进行增强处理

    2、实际项目

    1. 项目结构

    (1)根目录下的fonts文件夹用于存放ttf字体文件, imageset文件夹用于存放输出图像和映射表
    (2)config中设置相关参数并存放语料文件, dict5990.txt是字典, sentences.txt是语料集

    2. 实现思路

    3. 代码实现

    1. 设置参数
    # 语料集
    corpus = 'config/sentences.txt'
    dict = 'config/dict5990.txt'
    
    # 字体文件路径
    FONT_PATH = 'fonts/'
    
    # 输出路径
    OUTPUT_DIR = 'imageset/'
    
    # 样本总数
    n_samples = 50
    # 每行最大长度
    sentence_lim = 10
    # 画布能容纳的最大序列长度,对应img_w
    canvas_lim = 50
    
    2. 构建生成器

    1. 加载字体文件

    # 选择字体
    root = tk.Tk()
    root.withdraw()
    self.font_path = filedialog.askopenfilename()
    		
    def load_fonts(self, factor, font_path):
    	""" 加载字体文件并设定字体大小
            """
    	self.fonts = []
    	# 加载字体文件
    	font = ImageFont.truetype(font_path, int(self.img_h*factor), 0)
    	self.fonts.append(font)
    

    2. 构建字典

    def build_dict(self):
    	""" 打开字典,加载全部字符到list
                每行是一个字
            """
    	with codecs.open(self.dictfile, mode='r', encoding='utf-8') as f:
    		# 按行读取语料
    		for line in f:
    			# 当前行单词去除结尾,为了正常读取空格,第一行两个空格
    			word = line.strip('\r\n')
    			# 只要没超出上限就继续添加单词
    			self.dict.append(word)
    			# 最后一位作为空白符
    			self.blank_label = len(self.dict)
    

    3. 加载语料

    def build_train_list(self, num_rows, max_row_len=None):
    	# 过滤语料,留下适合的内容组成训练list
    	assert max_row_len <= self.img_lim
    	self.num_rows = num_rows
    	self.max_row_len = max_row_len
    	sentence_list = []
    	self.train_list = []
    	
    	with codecs.open(self.corpus_file, mode='r', encoding='utf-8') as f:
    		# 按行读取语料
    		for line in f:
    			sentence = line.rstrip().replace(' ', '')  # 当前行单词
    			if len(sentence) <= max_row_len and len(sentence_list) < num_rows:
    				# 只要句子长度不超过画布上限且句子数量没超出上限就继续添加
                        sentence_list.append(sentence)
    			elif len(sentence) > max_row_len and len(sentence_list) < num_rows:
                    # 截断句子
                    sentence_list.append(sentence[0:max_row_len])
    
    	if len(sentence_list) < self.num_rows:
    		raise IOError('语料不够')
    
    	for i, sentence in enumerate(sentence_list):
    		# 遍历语料中的每一句(行)
    		# 将单词分成字符,然后找到每个字符对应的整数ID list
    		label_sequence = []
    		for j, word in enumerate(sentence):  # 检查句子中是否包含生僻字
    			try:
    				index = self.dict.index(word)
    				label_sequence.append(index)
    			except ValueError:
    				print("字典不包含:{},已忽略".format(word))
    				sentence_list[i] = sentence_list[i][0:j] + sentence_list[i][j+1:]  # 从该句中删除生僻字
    
    	self.train_list = sentence_list  # 过滤后的训练集
    	np.random.shuffle(self.train_list)  # 打乱顺序
    

    4. 保存映射表

    def mapping_list(self):
    	# 写图像文件名和类别序列的对照表
    	file_path = os.path.join(cfg.OUTPUT_DIR, 'map_list.txt')
    	with codecs.open(file_path, mode='w', encoding='utf-8') as f:
    		for i in range(len(self.train_list)):
    			f.write("{}.png {} \n".format(i, self.train_list[i]))
    

    5. 绘制图像

        def paint_text(self, text, i):
            """ 使用PIL绘制文本图像,传入画布尺寸,返回文本图像
            :param h: 画布高度
            :param w: 画布宽度
            """
            # 创建画布
            canvas = np.zeros(shape=(self.img_h, self.img_w), dtype=np.uint8)
            canvas[0:] = 255
            # 转换图像模式,保证合成的两张图尺寸模式一致
            ndimg = Image.fromarray(canvas).convert('RGBA')
            draw = ImageDraw.Draw(ndimg)
    
            font = self.fonts[-1]
            text_size = font.getsize(text)  # 获取当前字体下的文本区域大小
    
            # 自动调整字体大小避免超出边界, 至少留白水平20%
            margin = [self.img_w - int(0.2*self.img_w), self.img_h - int(0.2*self.img_h)]
            while (text_size[0] > margin[0]) or (text_size[1] > margin[1]):
                self.font_factor -= 0.1
                self.load_fonts(self.font_factor, self.font_path)
                font = self.fonts[-1]
                text_size = font.getsize(text)
    
            # 随机平移
            horizontal_space = self.img_w - text_size[0]
            vertical_space = self.img_h - text_size[1]
            start_x = np.random.randint(2, horizontal_space-2)
            start_y = np.random.randint(2, vertical_space-2)
    
            # 绘制当前文本行
            draw.text((start_x, start_y), text, font=font, fill=(0, 0, 0, 255))
            img_array = np.array(ndimg)
    
            # 转灰度图
            grey_img = img_array[:, :, 0]  # [32, 256, 4]
            if self.aug == True:
                auged = augmentation(grey_img)
                ndimg = Image.fromarray(auged).convert('RGBA')
    
            save_path = os.path.join(cfg.OUTPUT_DIR, '{}.png'.format(i))  # 类别序列即文件名
            ndimg.save(save_path)
    

    6. 数据增强

    def speckle(img):
        severity = np.random.uniform(0, 0.6*255)
        blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
        img_speck = (img + blur)
        img_speck[img_speck > 255] = 255
        img_speck[img_speck <= 0] = 0
        return img_speck
    
    
    def augmentation(img, ):
        # 不能直接在原始image上改动
        image = img.copy()
        img_h, img_w = img.shape
        mode = np.random.randint(0, 9)
        '''添加随机模糊和噪声'''
        # 高斯模糊
        if mode == 0:
            image = cv2.GaussianBlur(image,(5, 5), np.random.randint(1, 10))
    
        # 模糊后二值化,虚化边缘
        if mode == 1:
            image = cv2.GaussianBlur(image, (9, 9), np.random.randint(1, 8))
            ret, th = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
            thresh = image.copy()
            thresh[thresh >= th] = 0
            thresh[thresh < th] = 255
            image = thresh
    
        # 横线干扰
        if mode == 2:
            for i in range(0, img_w, 2):
                cv2.line(image, (0, i), (img_w, i), 0, 1)
    
        # 竖线
        if mode == 3:
            for i in range(0, img_w, 2):
                cv2.line(image, (i, 0), (i, img_h), 0, 1)
    
        # 十字线
        if mode == 4:
            for i in range(0, img_h, 2):
                cv2.line(image, (0, i), (img_w, i), 0, 1)
            for i in range(0, img_w, 2):
                cv2.line(image, (i, 0), (i, img_h), 0, 1)
    
        # 左右运动模糊
        if mode == 5:
            kernel_size = 5
            kernel_motion_blur = np.zeros((kernel_size, kernel_size))
            kernel_motion_blur[int((kernel_size - 1) / 2), :] = np.ones(kernel_size)
            kernel_motion_blur = kernel_motion_blur / kernel_size
            image = cv2.filter2D(image, -1, kernel_motion_blur)
    
        # 上下运动模糊
        if mode == 6:
            kernel_size = 9
            kernel_motion_blur = np.zeros((kernel_size, kernel_size))
            kernel_motion_blur[:, int((kernel_size - 1) / 2)] = np.ones(kernel_size)
            kernel_motion_blur = kernel_motion_blur / kernel_size
            image = cv2.filter2D(image, -1, kernel_motion_blur)
    
        # 高斯噪声
        if mode == 7:
            row, col = [img_h, img_w]
            mean = 0
            sigma = 1
            gauss = np.random.normal(mean, sigma, (row, col))
            gauss = gauss.reshape(row, col)
            noisy = image + gauss
            image = noisy.astype(np.uint8)
    
        # 污迹
        if mode == 8:
            image = speckle(image)
        return image
    

    4. 使用说明

    运行sample_generator.py后会跳出对话框, 选择字体文件即可生成数据集
    从文本到图像——文本识别数据集生成器

    代码地址如下:
    http://www.demodashi.com/demo/14792.html

    注:本文著作权归作者,由demo大师发表,拒绝转载,转载需要作者授权

    展开全文
  • 利用莎士比亚数据集进行RNN文本生成的训练 import tensorflow as tf import numpy as np from tensorflow import keras import pandas as pd import sklearn import sys import os import matplotlib.pyplot as plt ...

    利用莎士比亚数据集进行RNN文本生成的训练

    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    import pandas as pd
    import sklearn
    import sys
    import os
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    from sklearn.preprocessing import StandardScaler
    
    print(tf.__version__)
    print(sys.version_info)
    for module in mpl,np,pd,sklearn,tf,keras:
        print(module.__name__,module.__version__)
    
    #莎士比亚数据集:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
    input_filepath = "./shakespeare.txt"
    text = open(input_filepath,'r').read()
    print(len(text))
    print(text[0:100])
    
    #1.生成词表
    #2.映射 char -->id
    #3.data -->id_data
    #4.abcd -->bcd<eos>:预测下一个字符
    vocab = sorted(set(text))
    print(len(vocab))
    print(vocab)
    
    char2idx = {char:idx for idx, char in enumerate(vocab)}
    print(char2idx)
    idx2char = np.array(vocab)
    print(idx2char)
    
    #对text中每个字符都做一个映射
    text_as_int = np.array([char2idx[c] for c in text])
    print(text_as_int[0:10])
    print(text[0:10])
    
    
    # 定义输入输出函数
    def split_input_target(id_text):
        """
        abcde -->输入abcd,输出bcde
        """
        return id_text[0:-1], id_text[1:]
    
    
    # 将it_text转为dataset
    char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
    seq_length = 100
    seq_dataset = char_dataset.batch(seq_length + 1, drop_remainder=True)  # 当做batch操作时,如果最后一个长度不都,就丢掉
    # 取出ch_id对应的字符
    for ch_id in char_dataset.take(2):
        print(ch_id, idx2char[ch_id.numpy()])
    # 取出seq_id对应的字符
    for seq_id in seq_dataset.take(2):
        print(seq_id)
        print(repr(' '.join(idx2char[seq_id.numpy()])))
    
    seq_dataset = seq_dataset.map(split_input_target)
    for item_input,item_output in seq_dataset:
        print(item_input.numpy())
        print(item_output.numpy())
    
    batch_size = 64
    buffer_size = 10000
    seq_dataset = seq_dataset.shuffle(buffer_size).batch(
        batch_size,drop_remainder=True)
    
    #定义模型
    vocab_size = len(vocab)
    embedding_dim = 256
    rnn_units = 1024
    #模型函数
    def build_model(vocab_size,embedding_dim,rnn_units,batch_size):
        model = keras.models.Sequential([
            keras.layers.Embedding(vocab_size,embedding_dim,
                                  batch_input_shape = [batch_size,None]),
            keras.layers.SimpleRNN(units = rnn_units,
                                  return_sequences=True),
            keras.layers.Dense(vocab_size),])
        return model
    
    model = build_model(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        rnn_units=rnn_units,
        batch_size=batch_size)
    
    model.summary()
    
    
    for input_example_batch,target_example_batch in seq_dataset.take(1):
        example_batch_predictions = model(input_example_batch)
        print(example_batch_predictions.shape)
    
    
    #随机采样
    #在计算分类任务softmax之前的那个值就是logits
    sample_indices = tf.random.categorical(logits=example_batch_predictions[0],
                         num_samples = 1)
    print(sample_indices)
    #将(100,1)转换为(100,)形式
    sample_indices = tf.squeeze(sample_indices,axis=-1)
    print(sample_indices)
    
    #定义模型的损失函数
    def loss(labels,logits):
        return keras.losses.sparse_categorical_crossentropy(
                labels,logits,from_logits=True)
    
    model.compile(optimizer = 'adam', loss = loss)
    example_loss = loss(input_example_batch,example_batch_predictions)
    print(example_loss.shape)
    print(example_loss.numpy().mean())
    
    # 保存模型
    output_dir = "./text_generation_checkpoints"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    
    checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epochs}')
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True, )
    
    epochs = 100
    history = model.fit(seq_dataset, epochs=epochs,
                        callbacks=[checkpoint_callback])
    
    
    
    #导入模型
    model2 = build_model(vocab_size,embedding_dim,
                        rnn_units,
                        batch_size=1)
    model2.load_weights(tf.train.latest_checkpoint(output_dir))
    # 1:指一个样本
    model2.build(tf.TensorShape([1,None]))
    
    
    # 文本生成的流程
    # start ch sequence A,
    # A -->model -->b
    # A.append(b) -->B -->model -->c -->B.appden(c) -->C(abc).....
    def generate_text(model, start_string, num_generate=1000):
        input_eval = [char2idx[ch] for ch in start_string]
        # 维度扩展,因为模型的输入时一个[1,None]的矩阵,而此时是一维的
        input_eval = tf.expand_dims(input_eval, 0)
    
        text_generated = []
        model.reset_states()
    
        for _ in range(num_generate):
            # 1.model inference --> prediction
            # 2.sample --> ch --> text_generated
            # 3.update input_eval
    
            # predictions : [batch_size,input_eval_len,vocab_size]
            predictions = model(input_eval)
            # 去掉第一维: [input_eval_len,vocab_size]
            predictions = tf.squeeze(predictions, 0)
            # predictions : [input_eval_len,1]
            predicted_id = tf.random.categorical(
                predictions, num_samples=1)[-1, 0].numpy()
            text_generated.append(idx2char[predicted_id])
            input_eval = tf.expand_dims([predicted_id], 0)
            return start_string + ' '.join(text_generated)
    
    
    new_text = generate_text(models, "All: ")
    print(new_text)
    
    展开全文
  • 利用SynthText生成自然场景文本检测数据集

    万次阅读 热门讨论 2017-08-23 09:32:52
    二,生成文本检测数据集 1 , 预处理的背景图像 下载本文中使用的8000个背景图像,以及它们的分割和深度模板,下载链接地址如下: `http://zeus.robots.ox.ac.uk/textspot/static/db/ <filename>`,...
  • cnews中文文本分类数据集;由清华大学根据新浪新闻RSS订阅频道2005-2011年间的历史 数据筛选过滤生成,训练过程见我的博客;
  • special_auto.py :自动生成具有所有特殊字符的数据集。 python3 special_auto.py create_lmdb_dataset.py :从生成的图像+ label.txt文件生成Lmdb文件。 *将图像文件夹+ label.txt放入数据/ * python3 create_...
  • 只需几行代码,即可在任何文本数据集上轻松训练您自己的任意大小和复杂度的文本生成神经网络,或者使用预先训练的模型快速训练文本。 textgenrnn是上的顶部一个Python 3模块 / 用于创建 S,与许多凉爽特性: 一种...
  • 朴素贝叶斯 分类算法数据集文本挖掘(Text Mining,从文字中获取信息)是一个比较宽泛的概念,这一技术在如今每天都有海量文本数据生成的时代越来越受到关注。目前,在机器学习模型的帮助下,包括情绪分析,文件分类...
  • 文本生成是NLP的最新应用程序之一。深度学习技术已用于各种文本生成任务,例如写作诗歌,生成电影脚本甚至创作...完成本文之后,您将能够使用所选的数据集执行文本生成。所以,让我们开始吧。导入库和数据集第一步是...
  • 欢迎使用ImageToLatex :waving_hand: 一个能够将笔迹转换为乳胶的神经网络。 该项目还提供了AZ工具,用于生成原始乳胶,生成图像和转换图像,就好像它们是...数据集 输入shape = (64, 64) 。 输入eq_n_b_c格式: eq_n
  • 每个文本有一列数据,将选中的几个文本按要求合并为训练供机器学习算法使用 将单个文本的hdfs路径设置为参数,提高程序的通用性,将所有文本都追加为一个数组,随后按规定切分读写,速度不是很慢。测试效果还可以...
  • 具有递归神经网络的文本生成 使用基于特征的RNN进行文本生成。 我们使用安德烈·卡帕蒂(Andrej Karpathy)的莎士比亚作品。 给定来自此数据的字符序列(“莎士比亚”),训练模型以预测序列中的下一个字符。 通过...
  • 我们提供9种基准文本生成数据集的支持。 用户可以应用我们的库来处理原始数据副本,或者简单地由我们的团队下载处理后的数据集。 图片:TextBox的总体架构 特征 统一和模块化的框架。 TextBox建立在PyTorch的基础上...
  • 文本生成:基于GPT-2的中文新闻文本生成

    千次阅读 热门讨论 2020-03-07 00:42:27
    文本生成一直是NLP领域内研究特别活跃的一个任务,应用前景特别广泛。BERT类预训练模型基于MLM,融合了双向上下文信息,不是天然匹配文本生成类任务(也有针对BERT模型进行改进的多种方式完善了BERT的这个缺点,如...
  • 基于结构化数据生成文本(data-to-text)的任务旨在生成人类可读的文本来直观地描述给定的结构化数据。然而,目前主流任务设定所基于的数据集有较好的对齐 (well-aligned)关系,即输入(i.e. 结构化数据)和输出...
  • 文本整个文本只有一行,无换行,字之间空格隔开 方法一:torchtext对于纯文本数据,通常我们会使用LanguageModelingDataset建立数据集,然后使用BPTTIterator创建迭代器。注意:如果文本数过小,且BPTTIterator中...
  • 线稿上色的数据集: dataset link:https://pan.baidu.com/s/1Abm7V6J2uNOy5U6nvsRSlg key:eepv txt文件生成 import os import glob def Create_Txt(data_name, data_path, data_class,txt_path, ratio = 0.01):...
  • 示例在SuperMap Objects.NET 6R中实现将数据集的属性字段生成一个文本数据集
  • 针对传统的分类算法在处理不均衡样本数据...最后使用原始数据集和新数据集分别训练K最近邻(K nearest neighbor,KNN)及支持向量机(support vector machine,SVM)分类器。实验结果表明此方法有效改善了少数类分类效果。

空空如也

空空如也

1 2 3 4 5 ... 20
收藏数 1,301
精华内容 520
关键字:

文本生成数据集