首页 > Python资料 博客日记
【多模态大模型】LLaMA in arXiv 2023
2024-08-17 15:00:11Python资料围观73次
一、引言
论文: LLaMA: Open and Efficient Foundation Language Models
作者: Meta AI
代码: LLaMA
特点: 该方法在Transformer的基础上增加了Pre-normalization (RMSNorm)、SwiGLU activation function (SwiGLU)、Rotary Embeddings (RoPE)、FlashAttention。
⚠️ 在学习该方法前,建议补充BatchNorm、LayerNorm、位置编码、Attention的相关知识。
二、详情
Transformer和LLaMA的结构图如下:
可见,其结构差异主要体现在如下方面:
- Transformer采用了左编码器+右解码器(Encoder+Decoder)的结构,LLaMA采用了仅解码器(Decoder-only)的结构。由于仅包含解码器不需要与编码器输出交互,故LLaMA去掉了Transformer中Decoder中间的交叉Multi-Head Attention和Add & Norm。
- LLaMA采用了归一化前置(Pre-normalization)的策略,将归一化操作放在了注意力、FFN前并在线性映射前增加了一个归一化。此外,LLaMA还将LayerNorm替换为了
RMSNorm
。- LLaMA将绝对位置编码替换为了旋转位置编码,即
RoPE
,这是一种只对Q和K进行位置编码的方式。- 为加速训练,LLaMA引入了
FlashAttention
。- LLaMA将ReLU替换为了
SwiGLU
。
2.1 RMSNorm
均方根归一化RMSNorm
简化了LayerNorm
的计算。
要了解RMSNorm
,首先需回顾LayerNorm
的公式:
其中, x \boldsymbol{x} x为输入的token序列, E [ x ] = 1 n ∑ i = 1 n x i {\bf E}\boldsymbol{[x]}=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i E[x]=n1∑i=1nxi和 V a r [ x ] = 1 n ∑ i = 1 n ( x i − E [ x ] ) 2 {\bf Var}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^n(\boldsymbol{x}_i-{\bf E}\boldsymbol{[x]})^2} Var[x]=n1∑i=1n(xi−E[x])2为 x \boldsymbol{x} x的均值和有偏方差, ϵ \boldsymbol{\epsilon} ϵ用来防止分母为0, γ \boldsymbol{\gamma} γ和 β \boldsymbol{\beta} β是可学习的参数用来缩放和平移。
RMSNorm
简化了LayerNorm
的计算,其公式如下:
其中, R M S [ x ] = 1 n ∑ i = 1 n x i 2 {\bf RMS}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i^2} RMS[x]=n1∑i=1nxi2是均方根。
可见,RMSNorm
与LayerNorm
主要有如下差别:
RMSNorm
无需计算均值 E [ x ] {\bf E}[\boldsymbol{x}] E[x]。RMSNorm
将有偏方差 V a r [ x ] {\bf Var[\boldsymbol{x}]} Var[x]替换为了均方根 R M S [ x ] {\bf RMS[\boldsymbol{x}]} RMS[x]。RMSNorm
无需平移项 γ \boldsymbol{\gamma} γ。
与LayerNorm
一样,RMSNorm
也能以句子或单词(token)为单位进行归一化,如下给出了以token为单位的代码示例。
import torch
import torch.nn as nn
class MyRMSNorm(nn.Module):
def __init__(self, hidden_dim, eps=1e-8):
super().__init__()
# 防止分母计算为0
self._eps = eps
# 仿射变换参数,缩放norm后的数据分布
self._gamma = nn.Parameter(torch.ones(hidden_dim))
def forward(self, input):
# input(N,L,C)
ms = input.pow(2).mean(dim=-1, keepdim=True) # 计算均方,token-wise
input = input / torch.sqrt(ms + self._eps) # 执行标准化
return input * self._gamma # 仿射变换
if __name__ == '__main__':
batch_size = 4
length = 2
hidden_dim = 3
input = torch.rand(4, 2, 3)
myRMSN = MyRMSNorm(hidden_dim=hidden_dim)
MyO = myRMSN(input)
pytorchRMSN = nn.RMSNorm(normalized_shape=hidden_dim, elementwise_affine=False) # 不使用可学习的gamma和beta
pytorchO = pytorchRMSN(input)
print(MyO == pytorchO)
2.2 RoPE
旋转位置编码RoPE
使用绝对位置信息设计旋转规则,使旋转后的数据能够表达相对位置信息。
要了解RoPE
,首先我们来了解一下二维空间的旋转。如下图:
其中, X = [ ρ cos ϕ , ρ sin ϕ ] X=[\rho\cos\phi,\rho\sin\phi] X=[ρcosϕ,ρsinϕ]是一个二维向量,逆时针旋转 θ \theta θ度变成 X R ( θ ) XR(\theta) XR(θ)。此时 R ( θ ) = [ cos θ , sin θ − sin θ , cos θ ] R(\theta)=\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right] R(θ)=[cosθ, sinθ−sinθ, cosθ],证明如下:
X R ( θ ) = [ ρ cos ϕ , ρ sin ϕ ] [ cos θ , sin θ − sin θ , cos θ ] = ρ [ cos ϕ cos θ − sin ϕ sin θ , cos ϕ sin θ + sin ϕ cos θ ] = [ ρ cos ( ϕ + θ ) , ρ sin ( ϕ + θ ) ] XR(\theta)=[\rho\cos\phi,\rho\sin\phi]\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right]\\=\rho[\cos\phi\cos\theta-\sin\phi\sin\theta,\cos\phi\sin\theta+\sin\phi\cos\theta]=[\rho\cos(\phi+\theta),\rho\sin(\phi+\theta)] XR(θ)=[ρcosϕ,ρsinϕ][cosθ, sinθ−sinθ, cosθ]=ρ[cosϕcosθ−sinϕsinθ,cosϕsinθ+sinϕcosθ]=[ρcos(ϕ+θ),ρsin(ϕ+θ)]
可见, X X X与 X R ( θ ) XR(\theta) XR(θ)仅差一个 θ \theta θ,所以二维空间逆时针旋转 θ \theta θ度可通过 R ( θ ) R(\theta) R(θ)实现。
旋转只改变角度,不改变长度。
RoPE
将旋转应用在了注意力模块的查询
Q
Q
Q和
K
K
K上。它将第
i
i
i个查询
Q
i
Q_i
Qi旋转
i
θ
i\theta
iθ的角度,再将第
j
j
j个键
K
j
K_j
Kj旋转
j
θ
j\theta
jθ的角度,那么
Q
i
K
j
T
Q_iK_j^T
QiKjT就会变成一个与相对位置
i
−
j
i-j
i−j相关的值。推导过程如下:
i i i和 j j j是查询 Q i Q_i Qi和 K j K_j Kj的绝对位置, i − j i-j i−j是它们的相对位置。
然而, Q i Q_i Qi和 K j K_j Kj的维度通常都是大于2的,我们假设它是 D D D且 D D D是2的整数倍,于是我们可以将 Q i Q_i Qi和 K j K_j Kj分别划分为 d = D 2 d=\frac{D}{2} d=2D个子空间,每个子空间都是二维的。
下图给出了一个 D = 10 D=10 D=10的例子,我们将 Q i Q_i Qi和 K j K_j Kj分为5个子空间并分配1个包括5个角度的旋转序列 Θ = ( θ 1 , θ 2 , ⋯ , θ 5 ) \Theta=(\theta_1,\theta_2,\cdots,\theta_5) Θ=(θ1,θ2,⋯,θ5),每个子空间的旋转角度是在对应旋转序列的基础上乘以 i i i或 j j j。
将其扩展到 d d d个子空间,可以得到如下信息:
其中, X i X_i Xi代指 Q i Q_i Qi或 K j K_j Kj。此时,这种旋转仍然具有相对位置的表达能力,证明如下:
显然,上面的 R ( i Θ ) R(i\Theta) R(iΘ)过于稀疏,为了提升计算效率,通常 d d d个子空间的旋转使用下式表达:
为避免token数过多,
i
θ
k
i\theta_k
iθk和
j
θ
k
j\theta_k
jθk重叠导致相对位置得不到表达(同一个子空间
k
k
k,绝对位置
i
i
i和
j
j
j不同,
i
θ
k
−
j
θ
k
=
2
m
π
i\theta_k-j\theta_k=2m\pi
iθk−jθk=2mπ时重叠,
m
m
m是一个整数),RoPE
使用了一个递减的等比数列作为
θ
\theta
θ序列,如下:
θ k \theta_k θk是递减的,这表示token中前几个子空间的旋转角度较大,越往后旋转角度越小。
事实上,为了方便我们通常不是将相邻的两个值划分至同一子空间,而是将D分为前后两个部分,前后各取一个依次组成子空间,例如[q0,q1,q2,q3]被划分为[q0,q2], [q1,q3]而不是[q0,q1], [q2,q3]。以下为使用这种方式进行子空间划分的RoPE
代码:
from torch.nn import functional as F
import torch.nn as nn
import torch
import math
class Rotator:
"""根据hidden_dim,和position_ids 生成对应的旋转位置编码, 和论文中定义略有不同,一个个二维的子空间被
分割到了前后两部分,分别进行旋转,然后拼接起来
"""
def __init__(self, D, position_ids):
""" position_ids: [seq_len], D 和单个头的hidden_dim对应 """
base = 10000
d = D / 2
B = base ** (1/d)
theta_base = 1.0 / (B ** (torch.arange(0, d))) # 等比数列, $\Theta$
thetas = position_ids.outer(theta_base) # [seq_len, D/2]
# 这里的子空间划分与讲解不同,[q0,q1,q2,q3] -> [q0,q2],[q1,q3]是两个子空间而不是[q0,q1],[q2,q3]
full_thetas = torch.cat((thetas, thetas), dim=-1) # [seq_len, D]
self.cos = full_thetas.cos()
self.sin = full_thetas.sin()
def rotate(self, x):
"""
x: [bs, num_attention_heads, seq_len, D]
q: [bs, num_attention_heads, seq_len, D]
cos: [seq_len, D]
[x,y] @ [[cos, sin], [-sin, cos]] = [x*cos-y*sin, ycos+x*sin] =[x,y]*cos+[-y, x]*sin
"""
return x * self.cos + Rotator.reverse_half(x) * self.sin
@staticmethod
def reverse_half(q):
""" q: [bs, num_attention_heads, seq_len, D] trick2 """
u = q[..., :q.shape[-1] // 2] # 认为是各个二维子空间的第一维的向量集结
v = q[..., q.shape[-1] // 2:] # 认为是各个二维子空间的第二维的向量集结
return torch.cat((-v, u), dim=-1)
if __name__ == "__main__":
batch_size = 2
num_heads = 3
D = 6 # 单个头的token向量长度
hidden_dim = D * num_heads
seq_len = 4
position_ids = torch.arange(seq_len)
rotator = Rotator(D, position_ids)
x = torch.randn((batch_size, seq_len, hidden_dim))
# 对每个头分别进行旋转,[batch_size,seq_len,hidden_dim] -> [batch_size,seq_len,num_heads,D] -> [batch_size,num_heads,seq_len,D]
x = x.view(batch_size, seq_len, num_heads, D).transpose(1, 2)
x = rotator.rotate(x)
2.3 FlashAttention
FlashAttention
以分块的形式进行注意力计算,避免了SRAM和HBM之间频繁读写导致的时间浪费。
详情请参考我之前的博客FlashAttention in NeurIPS 2022。
2.4 SwiGLU
激活函数SwiGLU
是门控线性单元(Gated Linear Units, GLU
)的变体,下图红框中表达了GLU
的计算过程:
可见,GLU
会先使用两个带偏执的线性层映射输入
x
\boldsymbol{x}
x,分别记为
x
W
1
+
b
1
\boldsymbol{xW_1+b_1}
xW1+b1和
x
W
2
+
b
2
\boldsymbol{xW_2+b_2}
xW2+b2;其中一个线性映射后会跟一个非线性激活函数sigmoid
,记为
σ
(
x
W
1
+
b
1
)
\sigma(\boldsymbol{xW_1+b_1})
σ(xW1+b1);然后将左右两边的结果对应元素相乘即完成了GLU
,记为
σ
(
x
W
1
+
b
1
)
⊗
(
x
W
2
+
b
2
)
\sigma(\boldsymbol{xW_1+b_1})\otimes(\boldsymbol{xW_2+b_2})
σ(xW1+b1)⊗(xW2+b2)。
SwiGLU
对GLU
做了两点改进:
- 去掉了两个线性映射的偏执项,此时公式变成 σ ( x W 1 ) ⊗ ( x W 2 ) \sigma(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) σ(xW1)⊗(xW2)。
- 将
sigmoid
替换为了Swish
,此时公式变成 Swish β ( x W 1 ) ⊗ ( x W 2 ) \text{Swish}_{\beta}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) Swishβ(xW1)⊗(xW2)。
Swish
的公式为
Swish
β
(
a
)
=
a
σ
(
β
a
)
=
a
1
+
e
−
β
a
\text{Swish}_{\beta}(a)=a\sigma(\beta a)=\frac{a}{1+e^{-\beta a}}
Swishβ(a)=aσ(βa)=1+e−βaa,在不同的
β
\beta
β下该非线性激活函数的曲线如下:
可见,当
β
\beta
β较大时,该曲线与ReLU
十分接近;当
β
=
1
\beta=1
β=1时,小于0但接近0的曲线变得更光滑且非单调。
SwiGLU
则选用了
β
=
1
\beta=1
β=1的Swish
,于是我们得到SwiGLU
的公式如下:
Swish
(
x
W
1
)
⊗
(
x
W
2
)
=
x
W
1
1
+
e
−
x
W
1
⊗
x
W
2
\text{Swish}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2})=\frac{\boldsymbol{xW_1}}{1+e^{-\boldsymbol{xW_1}}}\otimes\boldsymbol{xW_2}
Swish(xW1)⊗(xW2)=1+e−xW1xW1⊗xW2
致谢:
本博客仅做记录使用,无任何商业用途,参考内容如下:
解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索
一文为你深度解析 LLaMA2 模型架构
Llama改进之——SwiGLU激活函数
标签:
相关文章
最新发布
- 【Python】selenium安装+Microsoft Edge驱动器下载配置流程
- Python 中自动打开网页并点击[自动化脚本],Selenium
- Anaconda基础使用
- 【Python】成功解决 TypeError: ‘<‘ not supported between instances of ‘str’ and ‘int’
- manim边学边做--三维的点和线
- CPython是最常用的Python解释器之一,也是Python官方实现。它是用C语言编写的,旨在提供一个高效且易于使用的Python解释器。
- Anaconda安装配置Jupyter(2024最新版)
- Python中读取Excel最快的几种方法!
- Python某城市美食商家爬虫数据可视化分析和推荐查询系统毕业设计论文开题报告
- 如何使用 Python 批量检测和转换 JSONL 文件编码为 UTF-8
点击排行
- 版本匹配指南:Numpy版本和Python版本的对应关系
- 版本匹配指南:PyTorch版本、torchvision 版本和Python版本的对应关系
- Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO
- 相关性分析——Pearson相关系数+热力图(附data和Python完整代码)
- Python与PyTorch的版本对应
- Anaconda版本和Python版本对应关系(持续更新...)
- Python pyinstaller打包exe最完整教程
- Could not build wheels for llama-cpp-python, which is required to install pyproject.toml-based proj