【深度学习】常见优化算法的NumPy和PyTorch实现

以下是NumPy和PyTorch实现的几种常见优化算法:

其中参数含义如下:

  • w:待优化参数。

  • grad:参数的梯度。

  • lr:学习率。

  • mu:动量系数(仅对Momentum算法有用)。

  • eps:防止除0操作的小量。

  • cache:参数的暂存值,在不同算法中有不同的含义。

  • decay_rate:衰减率,仅在RMSprop和Adam算法中使用。

  • t:迭代步骤数,仅在Adam算法中使用。

  • m:动量梯度的暂存值,仅在Adam算法中使用。

  • v:平方梯度的暂存值,仅在Adam算法中使用。

    cb0ed52b3a2e6d3738cd37873c68127c.png

1. 随机梯度下降(SGD)

随机梯度下降算法是最基本的优化算法之一,每次更新参数时,使用一个样本的梯度来更新参数,这样可以避免在大数据集中计算整个数据集的梯度。

算法公式:

其中,表示次迭代时的参数值,表示参数的梯度,为学习率。这个算法每次迭代只计算一个样本的梯度,速度快,但可能会引起算法的震荡。

Numpy实现代码:

import numpy as np

def sgd(w, grad, lr):
    w -= lr * grad
    return w

PyTorch实现代码:

import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

2. 动量(Momentum)

动量算法可以加速梯度下降,并减少梯度下降的震荡。它引入一个额外的动量参数,用于记住之前梯度下降的方向,从而减少在各个方向上波动的情况。

算法公式:

其中,表示次迭代时的参数值,表示参数的梯度,为学习率,是动量参数。这个算法将之前梯度下降方向的信息与当前梯度下降方向结合起来,可以更好地适应数据集。

Numpy实现代码:

import numpy as np

def momentum(v, grad, lr, mu):
    v = mu * v - lr * grad
    return v

PyTorch实现代码:

import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

3. Adagrad

Adagrad是一种自适应学习率算法,它根据每个参数的梯度值来适应学习率的调整,并对使用频率高的参数进行更快的学习率更新。

算法公式:

其中,表示次迭代时的参数值,表示参数的梯度,为学习率,为所有梯度的平方和的累加量,是一个很小的数,用来避免除以0。

Adagrad的缺点是在处理大规模数据集时,学习率会变得过小,导致算法收敛缓慢,并且不能区分参数的重要性。

Numpy实现代码:

import numpy as np

def adagrad(w, grad, lr, eps, cache):
    cache += grad ** 2
    w -= lr * grad / (np.sqrt(cache) + eps)
    return w, cache

PyTorch实现代码:

import torch.optim as optim

optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)

4. RMSprop

RMSprop是一种自适应学习率算法,它根据梯度的有限滑动平均值来调整每个参数的学习率。

算法公式:

其中,表示次迭代时的参数值,表示参数的梯度,为学习率,为所有梯度的平方的滑动平均值,是一个很小的数,用来避免除法时出现除以0的情况,是衰减率。

RMSprop是Adagrad的扩展,它使用了滑动平均,解决了Adagrad不能区分参数重要性的问题。

Numpy实现代码:

import numpy as np

def rmsprop(w, grad, lr, eps, decay_rate, cache):
    cache = decay_rate * cache + (1 - decay_rate) * grad ** 2
    w -= lr * grad / (np.sqrt(cache) + eps)
    return w, cache

PyTorch实现代码:

import torch.optim as optim

optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)

5. Adam

Adam是一种自适应学习率算法,它结合了动量算法和自适应学习率算法,可以对不同梯度的参数调整学习率,对不同方向上的梯度进行更好的调整。

算法公式:

其中,表示次迭代时的参数值,表示参数的梯度,为学习率,为动量估计值,为带平方梯度的指数加权平均数,是一个很小的数,用来避免除以0,是动量系数,是偏差校正数。

Adam算法相对于其他优化算法,具有更好的适应性和计算效率,被广泛应用在深度学习中。

Numpy实现代码:

import numpy as np

def adam(w, grad, lr, eps, decay_rate_1, decay_rate_2, t, m, v):
    m = decay_rate_1 * m + (1 - decay_rate_1) * grad
    v = decay_rate_2 * v + (1 - decay_rate_2) * (grad ** 2)
    mb = m / (1 - decay_rate_1 ** t)
    vb = v / (1 - decay_rate_2 ** t)
    w -= lr * mb / (np.sqrt(vb) + eps)
    return w, m, v

PyTorch实现代码:

import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

以上就是numpy和PyTorch实现的几种优化算法的代码,大家可以根据实际需求选择合适的优化算法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/4426.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C语言蓝桥杯每日一题】——跑步锻炼

【C语言蓝桥杯每日一题】—— 跑步锻炼😎前言🙌排序🙌总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭:全神贯注的上吧!!! 😊作者简介…

SpringBoot启动流程源码分析一、入口参数研究和创建对象

个人搭建的博客唐小码,欢迎大家指教哦 引言 这不最近到金三银四的季节了么,有个朋友去参加了一个面试,回来的时候给我说其它还可以,但是问到SpringBoot的启动原理了,说了解的不深,我仔细转过头来也想了一…

python基础篇:什么是装饰器?装饰器有什么用?

上一篇介绍了python的函数,本文将介绍Python的装饰器,装饰器应用非常广泛,一定要好好掌握啊 什么是装饰器 装饰器是一种Python语言的特性,它允许在不修改已有函数的情况下,向函数添加额外的功能。装饰器本质上是一个函…

【设计模式】单例模式

一,定义单例模式:创建型模式之一,是指在内存中只会创建且仅创建一次对象的设计模式。在程序中多次使用同一个对象且作用相同时,为了防止频繁地创建对象使得内存飙升,单例模式可以让程序仅在内存中创建一个对象&#xf…

蓝桥杯冲刺 - week1

文章目录💬前言🌲day192. 递归实现指数型枚举843. n-皇后问题🌲day2日志统计1209. 带分数🌲day3844. 走迷宫1101. 献给阿尔吉侬的花束🌲day41113. 红与黑🌲day51236. 递增三元组🌲day63491. 完全…

C语言基础——运算符(定义变量、转义字符、输入输出语句、运算符、32个关键字)

文章目录一、定义变量1.如何定义?2.如何调用?二、转义字符二、输入输出语句1.输出语句2.输入语句三、运算符3.1 赋值运算符:3.2 算数运算符:3.3条件运算符3.4 逻辑运算符3.5 赋值复合运算符3.6 自增自减运算符3.7 位运算符3.8 分隔…

用户态--fork函数创建进程

我们一般使用Shell命令行来启动一个程序&#xff0c;其中首先是创建一个子进程。但是由于Shell命令行程序比较复杂&#xff0c;为了便于理解&#xff0c;我们简化了Shell命令行程序&#xff0c;用如下一小段代码来看怎样在用户态创建一个子进程。 #include <stdio.h> #i…

vue 监听器及计算属性高阶用法

文章目录监听器的高级用法深度监听立即触发计算属性的高级用法Getter 和 Setter缓存策略计算属性的依赖总结监听器的高级用法 深度监听 默认情况下&#xff0c;Vue.js 的监听器只会监听对象或数组的第一层属性变化&#xff0c;而不会深度监听其嵌套的属性变化。但是&#xff…

vue Teleport和ref结合复用弹框组件

1、首先新建conform.vue组件&#xff0c;其内容为&#xff1a; <template><div v-if"fade"><div class"xtx-confirm" :class"{fade}"><div class"wrapper" :class"{fade}"><div class"hea…

C语言基础——流程控制语句

文章目录一、流程控制语句 -- 控制程序的运行过程 9条&#xff08;一&#xff09;、条件选择流程控制语句&#xff1a;if语句if……else……语句if……else if……语句switch语句&#xff08;二&#xff09;、循环流程控制语句&#xff1a;for语句while语句do while……语句co…

深度学习的面试小记

随机梯度下降&#xff08;SGD&#xff09; 一种迭代方法&#xff0c;用于优化可微分目标函数。SGD有一个训练速度的问题&#xff0c;学习率过大&#xff0c;无法获得理想的结果&#xff0c;而学习率过小&#xff0c;训练可能会非常耗时。 在微积分里面&#xff0c;对多元函数的…

VUE3 学习笔记(五)UI框架Element Plus

目录 一、安装&#xff1a; 1. 环境支持 2. 版本 3. 安装&#xff08;包管理器npm安装&#xff09; 二、使用 1. 完整引入 2. Volar 支持 3. 国际化 三、国际中文化时错误解决 一、安装&#xff1a; 官网&#xff1a;一个 Vue 3 UI 框架 | Element Plus (gitee.io) 1…

C/C++开发,编译环境搭建

目录 一、 MinGW&#xff08;win&#xff09; 二、 Cygwin(win) 三、纯粹的linux环境 一、 MinGW&#xff08;win&#xff09; 进入Downloads - MinGW-w64下载页面&#xff0c;选择MinGW-w64-builds跳转下载&#xff0c; 再次进行跳转&#xff1a; 然后进入下载页面&#x…

HDFS概述

在现代的企业环境中&#xff0c;单机容量往往无法存储大量数据&#xff0c;需要跨机器存储。统一管理分布式在集群上的文件系统为分布式文件系统。HDFS&#xff08;Hadoop Distributed File System&#xff09;是Apache Hadoop项目的一个子项目&#xff0c;Hadoop非常适合存储大…

查看mysql InnoDB引擎 线程模型信息

InnoDB 引擎 底层采用大量的NIO 进行实现的。 1. 查看 IO Thread show engine innodb status; 2023-03-24 21:53:10 0x2cec INNODB MONITOR OUTPUTPer second averages calculated from the last 42 seconds ----------------- BACKGROUND THREAD ----------------- srv_mast…

Modelsim仿真使用教程

最近写了个设计《基于FPGA的汉明码编译码器设计》之前用QuartusIImodelsim联合仿真&#xff0c;没有出现任何问题&#xff0c;后面在别的电脑上也安装了两个软件&#xff0c;结果QuartusII中无法正常的启动modelsim软件&#xff0c;没有找到很好的解决办法&#xff0c;干脆直接…

数据结构之小端和大端之谜

这些是什么&#xff1f; 小端和大端是存储多字节数据类型&#xff08;int、float 等&#xff09;的两种方式。在小端机器中&#xff0c;多字节数据类型的二进制表示的最后一个字节首先存储。另一方面&#xff0c;在大端机器中&#xff0c;首先存储多字节数据类型的二进制表示的…

Vue 点击图片放大显示功能

1 方式一&#xff1a;列表中感应鼠标显示大图 我管理后台使用的是 element , 列表使用的是 el-tabe <el-table-columnprop"identifImg"header-align"center"align"center"label"证件照"width"100"><template slot-…

信息打点-主机架构蜜罐识别WAF识别端口扫描协议识别服务安全

文章目录识别-Web服务器-请求返回包识别-应用服务器-端口扫描技术web服务器与应用服务器的区别拓展其他类型服务器识别-其他服务协议-端口扫描技术常见端口及潜在威胁识别-WAF防火墙-看图&项目&指纹1、WAF解释&#xff1a;2、WAF分类&#xff1a;3、识别看图&#xff1…

测试开发进阶系列课程

测试开发系列课程1.完善程序思维--------案列&#xff1a;图书管理系统的创建**&#xff08;一&#xff09;图书管理系统的创建**1.完善程序思维--------案列&#xff1a;图书管理系统的创建 &#xff08;一&#xff09;图书管理系统的创建 1.在main中写入主函数&#xff0c;…
最新文章