[oneAPI] 使用字符级 RNN 生成名称

[oneAPI] 使用字符级 RNN 生成名称

  • oneAPI特殊写法
  • 使用字符级 RNN 生成名称
    • Intel® Optimization for PyTorch
    • 数据下载
    • 加载数据并对数据进行处理
    • 创建网络
    • 训练过程
      • 准备训练
      • 训练网络
    • 结果
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

oneAPI特殊写法

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

rnn = RNN(n_letters, 128, n_letters)
optim = torch.optim.SGD(rnn.parameters(), lr=0.01)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
rnn, optim = ipex.optimize(rnn, optimizer=optim)

criterion = nn.NLLLoss()

使用字符级 RNN 生成名称

为了深入探索语言模型在分类和生成方面的卓越能力,我们特意设计了一个独特的任务。此任务的独特之处在于,它旨在综合学习多种语言的词义特征,以确保生成的内容与各种语言的词组相关性一致。

在任务的具体描述中,我们提供了一个多语言数据集,这个数据集包含多种语言的文本。通过这个数据集,我们的目标是使模型能够在生成名称时融合不同语言的特征。具体来说,我们会提供一个词的开头作为提示,然后模型将能够根据这个开头生成对应语言的名称,从而将不同语言的词意和语法特征进行完美融合。

通过这一任务,我们旨在实现一个在多语言环境中具有卓越生成和分类能力的语言模型。通过学习并融合不同语言的词义和语法特征,我们让使模型具备更广泛的应用潜力,能够在不同语境下生成准确、符合语法规则的名称。

> python sample.py Russian RUS
Rovakov
Uantov
Shavakov

> python sample.py German GER
Gerren
Ereng
Rosher

> python sample.py Spanish SPA
Salla
Parer
Allan

> python sample.py Chinese CHI
Chan
Hang
Iun

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

数据下载

从这里下载数据 并将其解压到当前目录。

加载数据并对数据进行处理

简而言之,有一堆data/names/[Language].txt每行都有一个名称的纯文本文件。我们将行分割成一个数组,将 Unicode 转换为 ASCII,最后得到一个字典。{language: [names …]}

from io import open
import glob
import os
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1  # Plus EOS marker


def findFiles(path): return glob.glob(path)


# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )


# Read a file and split into lines
def readLines(filename):
    with open(filename, encoding='utf-8') as some_file:
        return [unicodeToAscii(line.strip()) for line in some_file]


# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('data/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)

if n_categories == 0:
    raise RuntimeError('Data not found. Make sure that you downloaded data '
                       'from https://download.pytorch.org/tutorial/data.zip and extract it to '
                       'the current directory.')

print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

Output:

# categories: 18 ['Arabic', 'Chinese', 'Czech', 'Dutch', 'English', 'French', 'German', 'Greek', 'Irish', 'Italian', 'Japanese', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Scottish', 'Spanish', 'Vietnamese']
O'Neal

创建网络

序列到序列网络,或 seq2seq 网络,或编码器解码器网络,是由两个称为编码器和解码器的 RNN 组成的模型。编码器读取输入序列并输出单个向量,解码器读取该向量以产生输出序列。

我添加了第二个线性层o2o(在组合隐藏层和输出层之后)以赋予其更多的功能。还有一个 dropout 层,它以给定的概率(此处为 0.1)随机将部分输入归零,通常用于模糊输入以防止过度拟合。在这里,我们在网络末端使用它来故意添加一些混乱并增加采样多样性。
在这里插入图片描述

######################################################################
# Creating the Network
# ====================
import torch
import torch.nn as nn

import intel_extension_for_pytorch as ipex


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, category, input, hidden):
        input_combined = torch.cat((category, input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

训练过程

准备训练

import random

# Random item from a list
def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

# Get a random category and random line from that category
def randomTrainingPair():
    category = randomChoice(all_categories)
    line = randomChoice(category_lines[category])
    return category, line
# One-hot vector for category
def categoryTensor(category):
    li = all_categories.index(category)
    tensor = torch.zeros(1, n_categories)
    tensor[0][li] = 1
    return tensor

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return tensor

# ``LongTensor`` of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return torch.LongTensor(letter_indexes)

为了训练过程中的方便,我们将创建一个randomTrainingExample 函数来获取随机(类别、线)对并将它们转换为所需的(类别、输入、目标)张量。

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():
    category, line = randomTrainingPair()
    category_tensor = categoryTensor(category)
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category_tensor, input_line_tensor, target_line_tensor

训练网络

criterion = nn.NLLLoss()

learning_rate = 0.0005

def train(category_tensor, input_line_tensor, target_line_tensor):
    target_line_tensor.unsqueeze_(-1)
    hidden = rnn.initHidden()

    rnn.zero_grad()

    loss = torch.Tensor([0]) # you can also just simply use ``loss = 0``

    for i in range(input_line_tensor.size(0)):
        output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)
        l = criterion(output, target_line_tensor[i])
        loss += l

    loss.backward()

    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item() / input_line_tensor.size(0)

为了跟踪训练需要多长时间,我添加了一个 timeSince(timestamp)返回人类可读字符串的函数:

import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

训练就像平常一样 - 多次调用训练并等待几分钟,打印当前时间和每个print_every 示例的损失,并存储每个plot_every示例的平均损失all_losses以供稍后绘制。

rnn = RNN(n_letters, 128, n_letters)

n_iters = 100000
print_every = 5000
plot_every = 500
all_losses = []
total_loss = 0 # Reset every ``plot_every`` ``iters``

start = time.time()

for iter in range(1, n_iters + 1):
    output, loss = train(*randomTrainingExample())
    total_loss += loss

    if iter % print_every == 0:
        print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))

    if iter % plot_every == 0:
        all_losses.append(total_loss / plot_every)
        total_loss = 0

结果

在这里插入图片描述

参考资料

https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/80769.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

万字长文·通俗易懂·一篇包掌握——输入/输出·文件操作(c语言超详细系列)(二)

前言:Hello,大家好😘,我是心跳sy,上一节我们主要学习了格式化输入输出的基本内容,这一节我们对格式化进行更加深入的了解,对文件概念进行介绍,并且对输入、输出与文件读写的基本概念…

Next.js - Pages and Layouts

Pages 页面是路由独有的用户界面。你可以通过从 page.js 文件导出组件来定义页面。使用嵌套文件夹定义路由&#xff0c;并使用 page.js 文件公开访问路由。 // app/page.tsx is the UI for the / URL export default function Page() {return <h1>Hello, Home page!<…

设计模式之适配器模式(Adapter)的C++实现

1、适配器模式的提出 在软件功能开发中&#xff0c;由于使用环境的改变&#xff0c;之前一些类的旧接口放在新环境的功能模块中不再适用。如何使旧接口能适用于新的环境&#xff1f;适配器可以解决此类问题。适配器模式&#xff1a;通过增加一个适配器类&#xff0c;在适配器接…

QT的mysql(数据库)最佳实践和常见问题解答

涉及到数据库&#xff0c;首先安利一个软件Navicat Premium&#xff0c;用来查询数据库很方便 QMysql驱动是Qt SQL模块使用的插件&#xff0c;用于与MySQL数据库进行通信。要编译QMysql驱动&#xff0c;您需要满足以下条件&#xff1a; 您需要安装MySQL的客户端库和开发头文件…

Maven之JDK编译问题

IDEA Maven 默认使用 JDK 1.5 编译问题 IDEA 在「调用」maven 时&#xff0c;IDEA 默认都会采用 JDK 1.5 编译&#xff0c;不管你安装的 JDK 版本是 JDK 7 还是 JDK 8 或者更高。这样一来非常不方便&#xff0c;尤其是时不时使用 JDK 7/8 的新特性时。如果使用新特性&#xff…

【二叉树前沿篇】树

【二叉树前沿篇】树 1 树的概念2. 树的相关概念3. 树的表示4. 树在实际中的运用&#xff08;表示文件系统的目录树结构&#xff09; 1 树的概念 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做树是…

Python项目实战:基于napari的3D可视化(点云+slice)

文章目录 一、napari 简介二、napari 安装与更新三、napari【巨巨巨大的一个BUG】四、napari 使用指南4.1、菜单栏&#xff08;File View Plugins Window Help&#xff09;4.2、Window&#xff1a;layer list&#xff08;参数详解&#xff09;4.3、Window&#xff1a;layer…

grafana-zabbix基础操作篇------导入数据源

文章目录 一、grafana的安装1.1、下载地址1.2、下载后导入所安装机器1.3、yum安装解决依赖1.4、启动grafana1.5、查看端口是否启用&#xff08;端口默认3000&#xff09;1.6、浏览器访问 二、添加zabbix数据源2.1、导入数据源 **下一篇 我们讲讲构建仪表板的操作** 今天&#x…

SAP ABAP SQL查询语句-内连接与左连接使用

SQL查询语句-内连接与左连接使用 程序: data:begin of itab OCCURS 0,material type /bic/oizmaterial,batch type /bi0/oibatch,brd type /bic/oizzbrd,cpquabu type /bi0/oicpquabu,end of itab. *&**例一: clear: itab[]. select * into corresponding fields…

OpenLayers实战,OpenLayers判断点位是否与多边形有交集,判断车辆是否在电子围栏内

专栏目录: OpenLayers实战进阶专栏目录 前言 OpenLayers实战,OpenLayers判断点位是否与多边形有交集,可以用于判断车辆是否在电子围栏内,船舶是否在锚泊位中等常用案例。 在实际GIS地图业务开发中,一般是不会在前端实现是否在电子围栏这种计算的。 如果有人让你在前端实…

【数据结构】 单链表面试题讲解

文章目录 引言反转单链表题目描述示例&#xff1a;题解思路代码实现&#xff1a; 移除链表元素题目描述&#xff1a;示例思路解析&#xff1a; 链表的中间结点题目描述&#xff1a;示例&#xff1a;思路解析代码实现如下&#xff1a; 链表中倒数第k个结点题目描述示例思路解析&…

SpringBoot统⼀功能处理

前言&#x1f36d; ❤️❤️❤️SSM专栏更新中&#xff0c;各位大佬觉得写得不错&#xff0c;支持一下&#xff0c;感谢了&#xff01;❤️❤️❤️ Spring Spring MVC MyBatis_冷兮雪的博客-CSDN博客 本章是讲Spring Boot 统⼀功能处理模块&#xff0c;也是 AOP 的实战环节&…

Redis 扩展资料

Redis 扩展资料 1.缓存简介2.缓存分类3.常⻅缓存使⽤4.Redis 数据类型和使⽤5.持久化6.常⻅⾯试题7.Redis 集群&#xff08;选学&#xff09; 1.缓存简介 2.缓存分类 3.常⻅缓存使⽤ 4.Redis 数据类型和使⽤ 5.持久化 Redis 和 Memcached 有什么区别&#xff1f; 6.常⻅⾯试…

Java日志框架-JUL

JUL全称Java util logging 入门案例 先来看着入门案例&#xff0c;直接创建logger对象&#xff0c;然后传入日志级别和打印的信息&#xff0c;就能在控制台输出信息。 可以看出只输出了部分的信息&#xff0c;其实默认的日志控制器是有一个默认的日志级别的&#xff0c;默认就…

Linux命令200例:nc非常有用的网络工具(常用)

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌。CSDN专家博主&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &…

PHP-MD5注入

0x00 前言 有些零散的知识未曾关注过&#xff0c;偶然捡起反而更加欢喜。 0x01 md5 注入绕过 md5函数有两个参数&#xff0c;第一个参数是要进行md5的值&#xff0c;第二个值默认为false&#xff0c;如果为true则返回16位原始二进制格式的字符串。意思就是会将md5后的结果当…

MVCC 是否彻底解决了事物的隔离性 ?

目录 1. 什么是 MVCC 2. MVCC 是否彻底解决了事物的隔离性 3. MySQL 中如何实现共享锁和排他锁 4. MySQL 中如何实现悲观锁和乐观锁 1. 什么是 MVCC MVCC&#xff08;Multi-Version Concurrency Control&#xff0c;多版本并发控制&#xff09;是一种多版本并发控制机制&…

vue项目使用qrcodejs2遇到Cannot read property ‘appendChild‘ of null

这个问题是节点还没创建渲染完就读取节点&#xff0c;这个时候应该先让节点渲染完成在生成&#xff0c;解决方法有以下两种 1、使用$nextTick&#xff08;&#xff09;方法进行&#xff0c;这个方法是用来在节点创建渲染完成后进行的操作 that.$nextTick(() > {let qrcode …

Java基础知识小结(内部类、BigInteger、枚举、接口、重写重载和序列化)

一、Java内部类 1、内部类 在Java中&#xff0c;也可以嵌套类&#xff08;类中的类&#xff09;。嵌套类的目的是将属于同一类的类分组&#xff0c;这使代码更具可读性和可维护性。 要访问内部类&#xff0c;请创建外部类的对象&#xff0c;然后创建内部类的对象&#xff1a;…

SpringBoot中优雅的实现隐私数据脱敏(提供Gitee源码)

前言&#xff1a;在实际项目开发中&#xff0c;可能会对一些用户的隐私信息进行脱敏操作&#xff0c;传统的方式很多都是用replace方法进行手动替换&#xff0c;这样会由很多冗余的代码并且后续也不好维护&#xff0c;本期就讲解一下如何在SpringBoot中优雅的通过序列化的方式去…
最新文章