[pytorch入门] 9. 优化器

介绍

在pytorch的官方文档中,所有的优化器都集中在torch.optim中
在这里插入图片描述
在官方文档中,会告诉你如何去创建一个优化器
在这里插入图片描述
选择一种优化器创建,传入模型的参数(必需的)、学习速率(几乎是每个优化器都有的参数)、优化器算法中特定需要设置的参数

可以在其中选择优化器的算法,设置相应参数,包括一些必备参数以及学习速率等
参数比较多比较复杂,初始阶段先设置params和lr(学习速率)就可以了

使用方法

官方文档中也给出了一些optim的使用

for input, target in dataset:
    optimizer.zero_grad()  # 对之前训练的梯度清零
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

对之前训练的梯度清零这一步一定要写,不然会导致模型出现问题

算法

在这里插入图片描述

实例

import torch
import torchvision
from torch import nn
from torch.utils.data import dataloader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
dataloader = dataloader.DataLoader(dataset=dataset, batch_size = 64)

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    
    def forward(self, x):
        x = self.model1(x)
        return x

loss = nn.CrossEntropyLoss()
net = Test()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01)
# 学习速率的设置不能太大也不能太小 
    # 太大:跨度大,可能会跨过最优值
    # 太小:学习慢
# 一般情况下,学习速率的设置先大后小

for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        img, target = data
        output = net(img)
        result_loss = loss(output, target)

        optimizer.zero_grad()  # 对网络模型中每个可以调节的参数设置为0
        result_loss.backward() # 使用优化器对每个参数进行优化,首先就需要获取每个参数的梯度
        optimizer.step()
        # print(result_loss)
        running_loss += result_loss
    print(running_loss)

在这里插入图片描述
可以发现每轮训练的误差都在减小

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/367268.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【制作100个unity游戏之23】实现类似七日杀、森林一样的生存游戏9(附项目源码)

本节最终效果演示 文章目录 本节最终效果演示系列目录前言回收物品素材绘制UI代码控制垃圾桶回收功能效果 源码完结 系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第23篇中,我们将…

DPVS 多活部署架构部署

一、目标 利用DPVS部署一个基于OSPF/ECMP的提供HTTP服务的多活高可用的测试环境。 本次部署仅用于验证功能,不提供性能验证。 配置两台DPVS组成集群、两台REAL SERVER提供实际HTTP服务。 注:在虚拟环境里面,通过在一台虚拟服务器上面安装FR…

2024牛客寒假算法基础集训营1

文章目录 A DFS搜索M牛客老粉才知道的秘密G why外卖E 本题又主要考察了贪心B 关鸡C 按闹分配 今天的牛客,说是都是基础题,头昏昏的,感觉真不会写,只能赛后补题了 A DFS搜索 写的时候刚开始以为还是比较难的,和dfs有关…

老版本labelme如何不保存imagedata

我的版本是3.16,默认英文且不带取消保存imagedata的选项。 最简单粗暴的方法就是在json文件保存时把传递过来的imagedata数据设定为None,方法如下: 找到labelme的源文件,例如:D:\conda\envs\deeplab\Lib\site-packages…

数据分析基础之《pandas(4)—pandas画图》

1、DataFrame.plot(xNone, yNone, kindline) 说明: x:设置x轴标签 y:设置y轴标签 kind: line 折线图 bar 柱状图 hist 直方图 pie 饼图 scatter 散点图 # 找到p_change和turnover之间的关系 data.plot(xvolume, yturnover, kinds…

dubbo+sentinel最简集成实例

说明 在集成seata后,下面来集成sentinel进行服务链路追踪管理~ 背景 sample-front网关服务已配置好 集成 一、启动sentinel.jar 1、官网下载 选择1:在本地启动 nohup java -Dserver.port8082 -Dcsp.sentinel.dashboard.serverlocalhost:8082 -Dp…

Simulink|光伏阵列模拟多类故障(开路/短路/阴影遮挡/老化)

目录 主要内容 模型研究 1.正常模型 2.断路故障 3.短路故障 4.阴影遮挡 5.老化模型 结果一览 1.U-I曲线 2.P-V曲线 下载链接 主要内容 该模型为光伏阵列模拟故障情况simulink模型,程序实现了多种故障方式下的光伏阵列输出功率-电压-电流关系特…

链表——C语言——day17

链表 链表是一种常见的重要的数据结构。它是动态地进行存储分配的一种结构。在用数组存放数据时,必须事先定义固定的长度(即元素个数)。链表则没有这种缺点,它根据需要开辟内存单元。 链表有一个“头指针“变量,图中…

Docker极速入门掌握基本概念和用法

1、Docker概念 1.1什么是docker Docker是一个快速交付应用、运行应用的技术,具备以下优势 可将程序及其依赖、运行环境一起打包为一个镜像,可以迁移到任意Linux操作系统运行时利用沙箱机制形成隔离容器,各个应用互不干扰启动、移除都可以通…

jmeter-03界面介绍

文章目录 主界面介绍测试计划介绍线程组介绍线程组——选择测试计划,右键-->添加-->线程-->线程组 主界面介绍 测试计划介绍 测试计划:本次测试所需要的所有内容,即父线程 线程组介绍 jmeter讲究一个概念:一个线程一…

如何在docker中访问电脑上的GPU?如何在docker中使用GPU进行模型训练或者加载调用?

如何在docker中访问电脑上的GPU?如何在docker中使用GPU进行模型训练或者加载调用? 其实使用非常简单,只是一行命令的事,最主要的事配置好驱动和权限。 docker run -it --rm --gpus all ycj520/centos:1.0.0 nvidia-smi先看看 st…

AI在线写作软件推荐:5款不可错过的写作工具

现在人工智能(AI)技术已经渗透到了各个领域,包括写作。AI在线写作软件的出现,为我们提供了更加高效、准确的写作工具。在本文中,我将向大家推荐5款功能强大的AI在线写作软件,这些软件可以帮助我们提高写作效…

一文掌握SpringBoot注解之@Configuration知识文集(3)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

QEMU源码全解析 —— 内存虚拟化(2)

接前一篇文章: 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM》源码解析与应用 —— 李强,机械工业出版社 QEMU内存管理模型 特此致谢! QEMU内存初始化 1. 基本结构 在开始介绍内存初始化…

对象克隆Objects

对象克隆 把A对象的属性值完全拷贝给B对象,也叫对象拷贝,对象复制。 package MyApi.a04objectdemo;public class ObjectDemo03 {public static void main(String[] args) throws CloneNotSupportedException {//1.先创建一个对象int []data{1,2,3,4,5,…

深度学习(12)--Mnist分类任务

一.Mnist分类任务流程详解 1.1.引入数据集 Mnist数据集是官方的数据集,比较特殊,可以直接通过%matplotlib inline自动下载,博主此处已经完成下载,从本地文件中引入数据集。 设置数据路径 from pathlib import Path# 设置数据路…

C# Onnx GroundingDINO 开放世界目标检测

目录 介绍 效果 模型信息 项目 代码 下载 介绍 地址:https://github.com/IDEA-Research/GroundingDINO Official implementation of the paper "Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection" 效果 …

RT-Thread 28. Nano实现MSH及CPU利用率显示

Nano版源码官网下载 https://github.com/RT-Thread/rtthread-nano/archive/refs/heads/master.zip 1. 代码结构 2.代码 //main.c #include "gd32f3x0.h" #include <rthw.h> #include <rtthread.h> #include "cpuusage.h"#define delay_ms(x…

树型结构构建,模糊查询,过滤

一、前言 1、最近在做甘特图&#xff0c;有些需求和树型结构要求很大&#xff0c;看的是 pingCode&#xff0c;有搜索 2、还有抽取一部分树型结构的&#xff0c;如下是抽取上面的结构类型为需求的&#xff0c;重新组成树型 二、构建多颗树型结构 1、某些业务下&#xff0c;从…

【Nginx】Ubuntu如何安装使用Nginx反向代理?

文章目录 使用Nginx反向代理2个web接口服务步骤 1&#xff1a;安装 Nginx步骤 2&#xff1a;启动 Nginx 服务步骤 3&#xff1a;配置 Nginx步骤 4&#xff1a;启用配置步骤 5&#xff1a;检查配置步骤 6&#xff1a;重启 Nginx步骤 7&#xff1a;访问网站 proxy_set_header 含义…
最新文章