[NLP]LLM高效微调(PEFT)--LoRA

LoRA

背景

神经网络包含很多全连接层,其借助于矩阵乘法得以实现,然而,很多全连接层的权重矩阵都是满秩的。当针对特定任务进行微调后,模型中权重矩阵其实具有很低的本征秩(intrinsic rank),因此,论文的作者认为权重更新的那部分参数矩阵尽管随机投影到较小的子空间,仍然可以有效的学习,可以理解为针对特定的下游任务这些权重矩阵就不要求满秩。

技术原理

LoRA(论文:LoRA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS),该方法的核心思想就是通过低秩分解来模拟参数的改变量,从而以极小的参数量来实现大模型的间接训练。

在涉及到矩阵相乘的模块,在原始的PLM旁边增加一个新的通路,通过前后两个矩阵A,B相乘,第一个矩阵A负责降维,第二个矩阵B负责升维,中间层维度为r,从而来模拟所谓的本征秩(intrinsic rank)。

 可训练层维度和预训练模型层维度一致为d,先将维度d通过全连接层降维至r,再从r通过全连接层映射回d维度,其中,r<<d,r是矩阵的秩,这样矩阵计算就从d x d变为d x r + r x d,参数量减少很多。

在下游任务训练时,固定模型的其他参数,只优化新增的两个矩阵的权重参数,将W跟新增的通路W1两部分的结果加起来作为最终的结果(两边通路的输入跟输出维度是一致的),即h=Wx+BAx。第一个矩阵的A的权重参数会通过高斯函数初始化,而第二个矩阵的B的权重参数则会初始化为零矩阵,这样能保证训练开始时新增的通路BA=0从而对模型结果没有影响。

在推理时,将左右两部分的结果加到一起即可,h=Wx+BAx=(W+BA)x,所以只要将训练完成的矩阵乘积BA跟原本的权重矩阵W加到一起作为新权重参数替换原本PLM的W即可,对于推理来说,不会增加额外的计算资源。

为什么更新ΔW只需要更新较少的参数呢?

现在,让我们解决房间里的大问题:如果我们引入新的权重矩阵,这个参数的效率如何?新矩阵WAWB可以非常小。例如,假设A=100B=500 ,则ΔW的大小为100 × 500 = 50,000。现在,如果我们将其分解为两个较小的矩阵,一个100×5维矩阵WA和一个5×500维矩阵WB。这两个矩阵总共只有5×100+5×500=3000个参数。

作者也在摘要中明确表示,他们采用lora方法微调,相比于GPT-3全量参数微调,可训练参数下降了10000倍,GPU显存需求下降了3倍,而lora微调后的效果,在特定任务上甚至可以媲美全量微调的模型。

为什么一直强调特定任务呢?因为lora基于的假设就是在特定任务上微调时,更新的参数矩阵具有较低的内在维度。可以把lora想象成一个特定能力的装备,而预训练模型是游戏角色本身。在预训练模型(游戏角色)的基础上,特定lora(装备)可以增强对于某一特定任务的表现,但是在其他不相关任务上该lora模块并不会起到作用。如果想同时在多个任务上有媲美全量参数微调模型的表现的话,就得需要针对不同的任务训练不同的ΔW模块(多个装备),最后整合在一起。但是,如果想模型(游戏角色)本身整体变强大,还是全量参数微调更合适。

至于是否适合作为通用指令微调的解决方案,有个问题我也没有搞懂,就是通用的指令样本是否真的有统一的低秩空间表征?这个表征又是什么含义?因为指令微调阶段的样本其实是混合的多任务指令样本,这种情况下lora是否合适,感觉需要更全面的评估.

## 初始化低秩矩阵A和B
self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
self.scaling[adapter_name] = lora_alpha / r

## 向前计算
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result += (
    self.lora_B[self.active_adapter](
        self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
    )
    * self.scaling[self.active_adapter]
)

此外,Transformer的权重矩阵包括Attention模块里用于计算query, key, value的Wq,Wk,Wv以及多头attention的Wo,以及MLP层的权重矩阵,LoRA只应用于Attention模块中的4种权重矩阵,而且通过消融实验发现同时调整 Wq 和 Wv 会产生最佳结果。

input_dim = 768  # e.g., the hidden size of the pre-trained model
output_dim = 768  # e.g., the output size of the layer
rank = 8  # The rank 'r' for the low-rank adaptation

W = ... # from pretrained network with shape input_dim x output_dim

W_A = nn.Parameter(torch.empty(input_dim, rank)) # LoRA weight A
W_B = nn.Parameter(torch.empty(rank, output_dim)) # LoRA weight B

# Initialization of LoRA weights
nn.init.kaiming_uniform_(W_A, a=math.sqrt(5))
nn.init.zeros_(W_B)

def regular_forward_matmul(x, W):
    h = x @ W
return h

def lora_forward_matmul(x, W, W_A, W_B):
    h = x @ W  # regular matrix multiplication
    h += x @ (W_A @ W_B)*alpha # use scaled LoRA weights
return h

在上面的伪代码中,alpha是一个缩放因子,用于调整组合结果(原始模型输出加上低秩自适应)的大小。这平衡了预训练模型的知识和新的特定于任务的适应——默认情况下,alpha通常设置为 1。另请注意,虽然W A被初始化为小的随机权重,但W B被初始化为 0,因此
训练开始时ΔW = W AW B = 0 ,这意味着我们以原始权重开始训练。

实验还发现,保证权重矩阵的种类的数量比起增加隐藏层维度r更为重要,增加r并不一定能覆盖更加有意义的子空间。

Rank r 的设置

一个很直接的问题就是:在实践中,rank 应该设为多少比较合适呢?

作者做了几组实验进行比较,结果发现 rank 可以很低,不超过8就很 OK 了,甚至是1也挺好..

关于秩的选择,通常情况下,rank为4,8,16即可。

通过实验也发现,在众多数据集上LoRA在只训练极少量参数的前提下,最终在性能上能和全量微调匹配,甚至在某些任务上优于全量微调。

减少推理开销

请注意,在实践中,如果我们在训练后保持原始权重W和矩阵W AW B分开,如上所示,我们将在推理过程中产生小的效率损失,因为这引入了额外的计算步骤。相反,我们可以在训练后通过W' = W + WA WB更新权重,这类似于前面提到的W' = W + ΔW


然而,将权重矩阵W AW B分开可能具有实际优势。例如,假设我们希望将我们的预训练模型作为各种客户的基础模型,并且我们希望从基础模型开始为每个客户创建一个经过微调的 LLM。在这种情况下,我们不需要为每个客户存储完整的权重矩阵W',其中存储模型的所有权重W' = W + WA WB对于 LLM 来说可能非常大,因为 LLM 通常有数十亿到数万亿个权重参数。因此,我们可以保留原始模型W,只需要存储新的轻量级矩阵WAWB


为了用具体数字说明这一点,一个完整的 7B LLaMA 检查点需要 23GB 的存储容量,而如果我们选择r=8的等级,LoRA 权重可以小到 8MB 。

利用 LoRA可以如下优点:

  1. 在面对不同的下游任务时,仅需训练参数量很少的低秩矩阵,而预训练权重可以在这些任务之间共享
  2. 省去了预训练权重的梯度和相关的 optimizer states,大大增加了训练效率降低了硬件要求
  3. 训练好的低秩矩阵可以合并(merge)到预训练权重中,多分支结构变为单分支,从而达到没有推理延时的效果;
  4. 与之前的一些参数高效的微调方法(如 Adapter, Prefix-Tuning 等)互不影响,并且可以相互结合

QLoRA和AdaLoRA

当红炸子鸡 LoRA,是当代微调 LLMs 的正确姿势? - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/51051.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ajax axios json

目录 一、ajax概述 1. 概念 2. 实现方式 &#xff08;1&#xff09;原生的JS实现方式&#xff08;了解&#xff09; &#xff08;2&#xff09; JQeury实现方式 二、axios 介绍 三、axios使用 1. axios 发送get/post请求 2. axios验证用户名称是否存在 四、json 1. …

2023牛客暑期多校-J-Qu‘est-ce Que C‘est?(DP)

题意&#xff1a; 给定长度为n的数列,要求每个数都在的范围&#xff0c;且任意长度大于等于2的区间和都大于等于0&#xff0c;问方案数。。 思路&#xff1a; 首先要看出是dp题&#xff0c;用来表示遍历到第i位且后缀和最小为x的可行方案数&#xff08;此时的后缀可以只有最…

[golang gin框架] 42.Gin商城项目-微服务实战之后台Rbac微服务角色增删改查微服务

一.重构后台Rbac用户登录微服务功能 上一节讲解了后台Rbac微服务用户登录功能以及Gorm数据库配置单独抽离&#xff0c;Consul配置单独抽离&#xff0c;这一节讲解后台Rbac微服务角色增删改查微服务功能&#xff0c;Rbac微服务角色增删改查微服务和后台Rbac用户登录微服务是属于…

苍穹外卖day09——历史订单模块(用户端)+订单管理模块(管理端)

查询历史订单——需求分析与设计 产品原型 业务规则 分页查询历史订单 可以根据订单状态查询 展示订单数据时&#xff0c;需要展示的数据包括&#xff1a;下单时间、订单状态、订单金额、订单明细&#xff08;商品名称、图片&#xff09; 接口设计 查询历史订单——代码开…

AI聊天GPT三步上篮!

1、是什么&#xff1f; CHATGPT是OpenAI开发的基于GPT&#xff08;Generative Pre-trained Transformer&#xff09;架构的聊天型人工智能模型。也就是你问它答&#xff0c;根据网络抓去训练 2、怎么用&#xff1f; 清晰表达自己诉求&#xff0c;因为它就是一个AI助手&#…

Java书签 #解锁MyBatis的4种批量插入方式及ID返回姿势

1. 今日书签 项目开发中&#xff0c;我们经常会用到单条插入和批量插入。但是实际情况可能是&#xff0c;项目初期由于种种原因&#xff0c;在业务各处直接使用单条插入SQL进行开发&#xff08;未开启批处理&#xff09;&#xff0c;在后面的迭代中&#xff0c;系统性能问题渐…

无涯教程-jQuery - Ajax Tutorial函数

AJAX是用于创建交互式Web应用程序的Web开发技术。如果您了解JavaScript,HTML,CSS和XML,则只需花费一个小时即可开始使用AJAX。 为什么要学习Ajax? AJAX代表 A 同步 Ja vaScript和 X ML。 AJAX是一项新技术,可借助XML,HTML,CSS和Java Script创建更好,更快,更具交互性的Web应用…

解决Font family [‘sans-serif’] not found问题

序言 以下测试环境都是在 anaconda3 虚拟环境下执行。 激活虚拟环境 conda activate test_python_env 或 source activate test_python_env工具&#xff1a; WinSCP Visual Studio Code 这里笔者使用 WinSCP 工具连接&#xff0c;编辑工具是 Visual Studio Code 一、字体…

基于fpga_EP4CE6F17C8实现的呼吸灯

文章目录 前言实验手册&#xff08;EP4CE6F17C8&#xff09;一、实验目的二、实验原理理论原理 三、系统架构设计四、模块说明1&#xff0e;模块端口信号列表2&#xff0e;状态转移图3&#xff0e;时序图 五、仿真波形图六、引脚分配七、代码实现八、仿真代码九、板级验证效果 …

【论文阅读】Feature Inference Attack on Shapley Values

摘要 研究背景 近年来&#xff0c;解释性机器学习逐渐成为一个热门的研究领域。解释性机器学习可以帮助我们理解机器学习模型是如何进行预测的&#xff0c;它可以提高模型的可信度和可解释性。Shapley值是一种解释机器学习模型预测结果的方法&#xff0c;它可以计算每个特征对…

视频标注是什么?和图像数据标注的区别?

视频数据标注是对视频剪辑进行标注的过程。进行标注后的视频数据将作为训练数据集用于训练深度学习和机器学习模型。这些预先训练的神经网络之后会被用于计算机视觉领域。 自动化视频标注对训练AI模型有哪些优势 与图像数据标注类似&#xff0c;视频标注是教计算机识别对象…

springboot整合myabtis+mysql

一、pom.xml <!--mysql驱动包--><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId></dependency><!--springboot与JDBC整合包--><dependency><groupId>org.springframework.b…

hcip——路由策略

要求&#xff1a; 基础配置 AR1 [R1]int g 0/0/0 [R1-GigabitEthernet0/0/0]ip add 12.0.0.1 24[R1-GigabitEthernet0/0/0]int g 0/0/1 [R1-GigabitEthernet0/0/1]ip add 14.0.0.1 24[R1]int loop0 [R1-LoopBack0]ip add 1.1.1.1 24[R1]rip 1 [R1-rip-1]vers 2 [R1-rip-1]net…

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

k8s中强制删除pv

K8s 集群内有一个已经不再使用的 PV&#xff0c;虽然已经删除了与其关联的 Pod 及 PVC&#xff0c;并对其执行了删除命令&#xff0c;但仍无法正常删除&#xff0c;一直处于 Terminating 状态&#xff1a; 解决办法&#xff1a; 1. 获取pv信息 kubectl get pv 2. 解除pv锁定 …

2023/7/29总结

项目&#xff1a; 这几天主要实现了评论的功能点: 还是有点小bug&#xff0c;还在更改中…… 修改个人中心的界面 接下来是把收藏完善&#xff0c;因为收藏需要用户自己创建一个新的收藏夹

JAVA 正则表达式(heima)

JAVA 正则表达式&#xff08;heima&#xff09; public class RegexDemo01 {/** 正则表达式介绍&#xff1a;本质来说就是一个字符串&#xff0c;字符串中可以指定规则&#xff0c;来对其他字符串进行校验。* public boolean matches(String regex):根据传入的正则表达式&#…

matplotlib绘图中可选标记

文章目录 简介所有可用的绘图标记绘图函数标记绘制 简介 前面的博客简要介绍了matplotlib中的绘图标记&#xff0c;并列举出了部分可用标记点的类型&#xff0c;并画了个图作为示例&#xff0c;如下图下表所示。本文则将所有标记点的类型均绘制一遍 字符类型字符类型字符类型…

基于springboot+mybatis+thymeleaf+html产品销售与分析系统

基于springbootmybatisthymeleafhtml产品销售与分析系统 一、系统介绍二、功能展示1.下单(批发商)2.订单管理&#xff08;批发商&#xff09;3.首页(厂家管理员)4.订单管理&#xff08;厂家管理员&#xff09;5.商品管理&#xff08;厂家管理员&#xff09;6.统计分析&#xff…

【深度学习】InST,Inversion-Based Style Transfer with Diffusion Models,论文

代码&#xff1a;https://github.com/zyxElsa/InST 论文&#xff1a;https://arxiv.org/abs/2211.13203 文章目录 AbstractIntroductionRelated WorkImage style transferText-to-image synthesisInversion of diffusion models MethodOverview ExperimentsComparison with Sty…
最新文章