[PyTorch][chapter 53][Auto Encoder 实战]

前言:

     结合手写数字识别的例子,实现以下AutoEncoder

     ae.py:  实现autoEncoder 网络

     main.py: 加载手写数字数据集,以及训练,验证,测试网络。

左图:原图像

右图:重构图像

 ----main-----

 每轮训练时间 : 91
0 loss: 0.02758789248764515

 每轮训练时间 : 95
1 loss: 0.024654878303408623

 每轮训练时间 : 149
2 loss: 0.018874473869800568

目录:

     1: AE 实现

     2: main 实现


一  ae(AutoEncoder) 实现

  文件名: ae.py

               模型的搭建

   注意点:

            手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

它的标签就是输入的x, 需要重构本身

自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  。

自编码器包含编码器(encoder)和解码器(decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

   

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023

@author: chengxf2
"""

import torch
from torch import nn

#ae: AutoEncoder

class AE(nn.Module):
    
    def __init__(self,hidden_size=10):
        
        super(AE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=hidden_size),
            nn.ReLU()
            )
         # hidden [batch_size, 10]

        self.decoder = nn.Sequential(
             nn.Linear(in_features=hidden_size, out_features=64),
             nn.ReLU(),
             nn.Linear(in_features=64, out_features=128),
             nn.ReLU(),
             nn.Linear(in_features=128, out_features=256),
             nn.ReLU(),
             nn.Linear(in_features=256, out_features=784),
             nn.Sigmoid()
             )
        
        
    def forward(self, x):
            '''
            param x:[batch, 1,28,28]
            return 
        
            '''
      
            m= x.size(0)
            
            x = x.view(m, 784)
            
            hidden= self.encoder(x)
            x =   self.decoder(hidden)
            
            #reshape
            x = x.view(m,1,28,28)
            
            return x
        
    


二 main 实现

  文件名: main.py

  作用:

      加载数据集

     训练模型

     测试模型泛化能力

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023

@author: chengxf2
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from ae import AE
import visdom





def main():
   
   batchNum = 32
   lr = 1e-3
   epochs = 20
   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   torch.manual_seed(1234)
   viz = visdom.Visdom()
   viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))

    
   

   tf= transforms.Compose([ transforms.ToTensor()])
   mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
   train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
   
   mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
   test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
   global_step =0

   
   

   
  
   model =AE().to(device)
   criteon = nn.MSELoss().to(device) #损失函数
   optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
   
   print("\n ----main-----")
   for epoch in range(epochs):
       
       start = time.perf_counter()
       for step ,(x,y) in enumerate(train_data):
           #[b,1,28,28]
           x = x.to(device)
           x_hat = model(x)
           
           loss = criteon(x_hat, x)
           
           #backprop
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
           global_step +=1



    
       end = time.perf_counter()    
       interval = end - start
       print("\n 每轮训练时间 :",int(interval))
       print(epoch, 'loss:',loss.item())
       
       x,target = iter(test_data).next()
       x = x.to(device)
       with torch.no_grad():
           x_hat = model(x)
       
       tip = 'hat'+str(epoch)
       viz.images(x,nrow=8, win='x',opts=dict(title='x'))
       viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
           
           
           
           
   

if __name__ == '__main__':
    
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/96939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【哈士奇赠书活动 - 37期】- 〖深入浅出SSD:固态存储核心技术、原理与实战 第2版〗

文章目录 ⭐️ 赠书 - 《深入浅出SSD:固态存储核心技术、原理与实战 第2版》⭐️ 内容简介⭐️ 作者简介⭐️ 编辑推荐⭐️ 赠书活动 → 获奖名单 ⭐️ 赠书 - 《深入浅出SSD:固态存储核心技术、原理与实战 第2版》 ⭐️ 内容简介 本书从基础认知、核心技…

快速制作餐厅签到抽奖营销活动,吸引更多顾客

在如今竞争激烈的市场中,吸引用户参与活动是企业获取关注和提升转化率的重要手段。而签到抽奖活动无疑是一种简单而又有效的方式。本文将教你如何利用乔拓云平台后台制作一个快速而有效的签到抽奖活动。 首先,登录乔拓云平台后台,进入【营销活…

【云原生进阶之PaaS中间件】第一章Redis-2.3.3集群模式

1 集群模式 Redis集群是一个提供在多个Redis节点之间共享数据的程序集。它并不像Redis主从复制模式那样只提供一个master节点提供写服务,而是会提供多个master节点提供写服务,每个master节点中存储的数据都不一样,这些数据通过数据分片的方式被自动分割到不同的master节点上…

香橙派OrangePi zero H2+ 驱动移远4G/5G模块

目录 1 安装系统和内核文件: 1.1 下载镜像 1.2 内核头安装 1.2.1 下载内核 1.2.2 将内核头文件导入开发板中 1.2.3 安装内核头 2 安装依赖工具: 2.1 Installing Required Host Utilities 3 驱动步骤: 3.1 下载模块驱动文件…

IO模型:阻塞和非阻塞

一、五种IO模型------读写外设数据的方式 阻塞: 不能操作就睡觉 非阻塞:不能操作就返回错误 多路复用:委托中介监控 信号驱动:让内核如果能操作时发信号,在信号处理函数中操作 异步IO:向内核注册操作请求&…

5G NR:PRACH时域资源

PRACH occasion时域位置由高层参数RACH-ConfigGeneric->prach-ConfigurationIndex指示,根据小区不同的频域和模式,38.211的第6.3.3节中给出了prach-ConfigurationIndex所对应的表格。 小区频段为FR1,FDD模式(paired频谱)/SUL,…

CSRF(跨站请求伪造)和SSRF(服务端请求伪造)漏洞复现:风险与防护方法

这篇文章旨在用于网络安全学习,请勿进行任何非法行为,否则后果自负。 环境准备 一、CSRF(跨站请求伪造) 示例:假设用户在银行网站A上登录并保持会话活动,同时他也在浏览其他网站。攻击者在一个不可信任…

AMBA_AXI Protocol_基本读写事务

基本读写事务 1. 握手的过程 2. 信道信令要求 3. 通道之间的关系1. 握手的过程 当地址、数据或控制信息可用时,源端(source)产生VALID信号。终端(destination)生成READY信号,表示它可以接受该信息。传输只…

微前端:重塑大型项目的前沿技术

引言 随着互联网技术的飞速发展,前端开发已经从简单的页面制作逐渐转变为复杂的应用开发。在这个过程中,传统的前端开发模式已经难以满足大型项目的需求。微前端作为一种新的前端架构模式,应运而生,它旨在解决大型项目中的前端开…

Docker从认识到实践再到底层原理(一)|技术架构

前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量博客汇总 然后就是博主最近最花时间的一个专栏…

适应高速率网络设备的-2.5G/5G/10G网络变压器/网络滤波器介绍

Hqst盈盛(华强盛)电子导读:在高速发展的互联网/物联网时代,为满足高网速的网络数据传输需求,网络设备在制造中也要选用合适的网络变压器/滤波器产品,有哪些可供选择的高速率网络变压器产品也是广大采购人员…

javaee spring 自动注入,如果满足条件的类有多个如何区别

如图IDrinkDao有两个实现类 方法一 方法二 Resource(name“对象名”) Resource(name"oracleDrinkDao") private IDrinkDao drinkDao;

异步迭代器

一、什么是异步迭代器? 实现了 __aiter__() 和 __anext__() 方法的对象。__anext__ 必须返回一个 awaitable对象。async for 会处理异步迭代器的 __anext__() 方法所返回的可等待对象,直到其引发一个 StopAsyncIteration 异常。 二、实例 class Async…

LeetCode239.滑动窗口最大值

看到这道题我就有印象, 我在剑指offer里面做过这道题,我记得当时用的是优先队列,然后我脑子里一下子就有了想法,拿优先队列作为窗口,每往右移动一步,把左边的数remove掉,把右边的数add进来&…

SpringAOP详解(下)

proxyFactory代理对象创建方式和代理对象调用方法过程: springaop创建动态代理对象和代理对象调用方法过程: 一、TargetSource的使用 Lazy注解,当加在属性上时,会产生一个代理对象赋值给这个属性,产生代理对象的代码为…

Linux系统下vim常用命令

一、基础命令: v:可视模式 i:插入模式 esc:命令模式下 :q :退出 :wq :保存并退出 ZZ:保存并退出 :q! :不保存并强制退出二、在Esc下: dd : 删除当前行 yy:复制当前行 p:复制已粘贴的文本 u:撤销上一步 U:…

IC芯片 trustzone学习

搭建Airplay TA环境需要在IC的TrustZone中进行。TrustZone是一种安全技术,用于隔离安全和非安全环境,并保护敏感文件。在TrustZone中,我们需要编写一个叫做TA(Trusted Application)的应用程序来控制这些私密文档。 &am…

重磅!TikTok将于8月底关闭半闭环 切断外链意在电商业务发展?

自2019年开始,TikTok电商业务逐渐走进人们的视线,并引起了市场的广泛关注。作为一家短视频平台,TikTok能够依靠其强大的用户基数与精准的推广策略,将流量成功转化为商业价值。截至目前,TikTok电商业务已经初步形成完整…

Nacos安装

一、下载Nacos1.4.1二、单机版本安装 2.1 将下载的nacos安装包传输到服务器2.2 解压文件2.3 进入bin目录下 单机版本启动2.4 关闭nacos2.5 访问Nacos地址 IP:8848/nacos三、集群版本的安装 3.1 复制nacos安装包,修改为nacos8849,nacos8850&am…

案例实操-获取员工数据

案例:获取员工数据,返回统一响应结果,在页面渲染展示 package com.bignyi.controller;import com.bignyi.pojo.Emp; import com.bignyi.pojo.Result; import com.bignyi.utils.XmlParserUtils; import org.springframework.web.bind.annotat…
最新文章