首页 > Python资料 博客日记
【多模态大模型】LLaMA in arXiv 2023
2024-08-17 15:00:11Python资料围观105次
一、引言
论文: LLaMA: Open and Efficient Foundation Language Models
作者: Meta AI
代码: LLaMA
特点: 该方法在Transformer的基础上增加了Pre-normalization (RMSNorm)、SwiGLU activation function (SwiGLU)、Rotary Embeddings (RoPE)、FlashAttention。
⚠️ 在学习该方法前,建议补充BatchNorm、LayerNorm、位置编码、Attention的相关知识。
二、详情
Transformer和LLaMA的结构图如下:
可见,其结构差异主要体现在如下方面:
- Transformer采用了左编码器+右解码器(Encoder+Decoder)的结构,LLaMA采用了仅解码器(Decoder-only)的结构。由于仅包含解码器不需要与编码器输出交互,故LLaMA去掉了Transformer中Decoder中间的交叉Multi-Head Attention和Add & Norm。
- LLaMA采用了归一化前置(Pre-normalization)的策略,将归一化操作放在了注意力、FFN前并在线性映射前增加了一个归一化。此外,LLaMA还将LayerNorm替换为了
RMSNorm
。- LLaMA将绝对位置编码替换为了旋转位置编码,即
RoPE
,这是一种只对Q和K进行位置编码的方式。- 为加速训练,LLaMA引入了
FlashAttention
。- LLaMA将ReLU替换为了
SwiGLU
。
2.1 RMSNorm
均方根归一化RMSNorm
简化了LayerNorm
的计算。
要了解RMSNorm
,首先需回顾LayerNorm
的公式:
其中, x \boldsymbol{x} x为输入的token序列, E [ x ] = 1 n ∑ i = 1 n x i {\bf E}\boldsymbol{[x]}=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i E[x]=n1∑i=1nxi和 V a r [ x ] = 1 n ∑ i = 1 n ( x i − E [ x ] ) 2 {\bf Var}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^n(\boldsymbol{x}_i-{\bf E}\boldsymbol{[x]})^2} Var[x]=n1∑i=1n(xi−E[x])2为 x \boldsymbol{x} x的均值和有偏方差, ϵ \boldsymbol{\epsilon} ϵ用来防止分母为0, γ \boldsymbol{\gamma} γ和 β \boldsymbol{\beta} β是可学习的参数用来缩放和平移。
RMSNorm
简化了LayerNorm
的计算,其公式如下:
其中, R M S [ x ] = 1 n ∑ i = 1 n x i 2 {\bf RMS}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i^2} RMS[x]=n1∑i=1nxi2是均方根。
可见,RMSNorm
与LayerNorm
主要有如下差别:
RMSNorm
无需计算均值 E [ x ] {\bf E}[\boldsymbol{x}] E[x]。RMSNorm
将有偏方差 V a r [ x ] {\bf Var[\boldsymbol{x}]} Var[x]替换为了均方根 R M S [ x ] {\bf RMS[\boldsymbol{x}]} RMS[x]。RMSNorm
无需平移项 γ \boldsymbol{\gamma} γ。
与LayerNorm
一样,RMSNorm
也能以句子或单词(token)为单位进行归一化,如下给出了以token为单位的代码示例。
import torch
import torch.nn as nn
class MyRMSNorm(nn.Module):
def __init__(self, hidden_dim, eps=1e-8):
super().__init__()
# 防止分母计算为0
self._eps = eps
# 仿射变换参数,缩放norm后的数据分布
self._gamma = nn.Parameter(torch.ones(hidden_dim))
def forward(self, input):
# input(N,L,C)
ms = input.pow(2).mean(dim=-1, keepdim=True) # 计算均方,token-wise
input = input / torch.sqrt(ms + self._eps) # 执行标准化
return input * self._gamma # 仿射变换
if __name__ == '__main__':
batch_size = 4
length = 2
hidden_dim = 3
input = torch.rand(4, 2, 3)
myRMSN = MyRMSNorm(hidden_dim=hidden_dim)
MyO = myRMSN(input)
pytorchRMSN = nn.RMSNorm(normalized_shape=hidden_dim, elementwise_affine=False) # 不使用可学习的gamma和beta
pytorchO = pytorchRMSN(input)
print(MyO == pytorchO)
2.2 RoPE
旋转位置编码RoPE
使用绝对位置信息设计旋转规则,使旋转后的数据能够表达相对位置信息。
要了解RoPE
,首先我们来了解一下二维空间的旋转。如下图:
其中, X = [ ρ cos ϕ , ρ sin ϕ ] X=[\rho\cos\phi,\rho\sin\phi] X=[ρcosϕ,ρsinϕ]是一个二维向量,逆时针旋转 θ \theta θ度变成 X R ( θ ) XR(\theta) XR(θ)。此时 R ( θ ) = [ cos θ , sin θ − sin θ , cos θ ] R(\theta)=\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right] R(θ)=[cosθ, sinθ−sinθ, cosθ],证明如下:
X R ( θ ) = [ ρ cos ϕ , ρ sin ϕ ] [ cos θ , sin θ − sin θ , cos θ ] = ρ [ cos ϕ cos θ − sin ϕ sin θ , cos ϕ sin θ + sin ϕ cos θ ] = [ ρ cos ( ϕ + θ ) , ρ sin ( ϕ + θ ) ] XR(\theta)=[\rho\cos\phi,\rho\sin\phi]\left[\begin{matrix}\cos\theta,~\sin\theta\\-\sin\theta,~\cos\theta\end{matrix}\right]\\=\rho[\cos\phi\cos\theta-\sin\phi\sin\theta,\cos\phi\sin\theta+\sin\phi\cos\theta]=[\rho\cos(\phi+\theta),\rho\sin(\phi+\theta)] XR(θ)=[ρcosϕ,ρsinϕ][cosθ, sinθ−sinθ, cosθ]=ρ[cosϕcosθ−sinϕsinθ,cosϕsinθ+sinϕcosθ]=[ρcos(ϕ+θ),ρsin(ϕ+θ)]
可见, X X X与 X R ( θ ) XR(\theta) XR(θ)仅差一个 θ \theta θ,所以二维空间逆时针旋转 θ \theta θ度可通过 R ( θ ) R(\theta) R(θ)实现。
旋转只改变角度,不改变长度。
RoPE
将旋转应用在了注意力模块的查询
Q
Q
Q和
K
K
K上。它将第
i
i
i个查询
Q
i
Q_i
Qi旋转
i
θ
i\theta
iθ的角度,再将第
j
j
j个键
K
j
K_j
Kj旋转
j
θ
j\theta
jθ的角度,那么
Q
i
K
j
T
Q_iK_j^T
QiKjT就会变成一个与相对位置
i
−
j
i-j
i−j相关的值。推导过程如下:
i i i和 j j j是查询 Q i Q_i Qi和 K j K_j Kj的绝对位置, i − j i-j i−j是它们的相对位置。
然而, Q i Q_i Qi和 K j K_j Kj的维度通常都是大于2的,我们假设它是 D D D且 D D D是2的整数倍,于是我们可以将 Q i Q_i Qi和 K j K_j Kj分别划分为 d = D 2 d=\frac{D}{2} d=2D个子空间,每个子空间都是二维的。
下图给出了一个 D = 10 D=10 D=10的例子,我们将 Q i Q_i Qi和 K j K_j Kj分为5个子空间并分配1个包括5个角度的旋转序列 Θ = ( θ 1 , θ 2 , ⋯ , θ 5 ) \Theta=(\theta_1,\theta_2,\cdots,\theta_5) Θ=(θ1,θ2,⋯,θ5),每个子空间的旋转角度是在对应旋转序列的基础上乘以 i i i或 j j j。
将其扩展到 d d d个子空间,可以得到如下信息:
其中, X i X_i Xi代指 Q i Q_i Qi或 K j K_j Kj。此时,这种旋转仍然具有相对位置的表达能力,证明如下:
显然,上面的 R ( i Θ ) R(i\Theta) R(iΘ)过于稀疏,为了提升计算效率,通常 d d d个子空间的旋转使用下式表达:
为避免token数过多,
i
θ
k
i\theta_k
iθk和
j
θ
k
j\theta_k
jθk重叠导致相对位置得不到表达(同一个子空间
k
k
k,绝对位置
i
i
i和
j
j
j不同,
i
θ
k
−
j
θ
k
=
2
m
π
i\theta_k-j\theta_k=2m\pi
iθk−jθk=2mπ时重叠,
m
m
m是一个整数),RoPE
使用了一个递减的等比数列作为
θ
\theta
θ序列,如下:
θ k \theta_k θk是递减的,这表示token中前几个子空间的旋转角度较大,越往后旋转角度越小。
事实上,为了方便我们通常不是将相邻的两个值划分至同一子空间,而是将D分为前后两个部分,前后各取一个依次组成子空间,例如[q0,q1,q2,q3]被划分为[q0,q2], [q1,q3]而不是[q0,q1], [q2,q3]。以下为使用这种方式进行子空间划分的RoPE
代码:
from torch.nn import functional as F
import torch.nn as nn
import torch
import math
class Rotator:
"""根据hidden_dim,和position_ids 生成对应的旋转位置编码, 和论文中定义略有不同,一个个二维的子空间被
分割到了前后两部分,分别进行旋转,然后拼接起来
"""
def __init__(self, D, position_ids):
""" position_ids: [seq_len], D 和单个头的hidden_dim对应 """
base = 10000
d = D / 2
B = base ** (1/d)
theta_base = 1.0 / (B ** (torch.arange(0, d))) # 等比数列, $\Theta$
thetas = position_ids.outer(theta_base) # [seq_len, D/2]
# 这里的子空间划分与讲解不同,[q0,q1,q2,q3] -> [q0,q2],[q1,q3]是两个子空间而不是[q0,q1],[q2,q3]
full_thetas = torch.cat((thetas, thetas), dim=-1) # [seq_len, D]
self.cos = full_thetas.cos()
self.sin = full_thetas.sin()
def rotate(self, x):
"""
x: [bs, num_attention_heads, seq_len, D]
q: [bs, num_attention_heads, seq_len, D]
cos: [seq_len, D]
[x,y] @ [[cos, sin], [-sin, cos]] = [x*cos-y*sin, ycos+x*sin] =[x,y]*cos+[-y, x]*sin
"""
return x * self.cos + Rotator.reverse_half(x) * self.sin
@staticmethod
def reverse_half(q):
""" q: [bs, num_attention_heads, seq_len, D] trick2 """
u = q[..., :q.shape[-1] // 2] # 认为是各个二维子空间的第一维的向量集结
v = q[..., q.shape[-1] // 2:] # 认为是各个二维子空间的第二维的向量集结
return torch.cat((-v, u), dim=-1)
if __name__ == "__main__":
batch_size = 2
num_heads = 3
D = 6 # 单个头的token向量长度
hidden_dim = D * num_heads
seq_len = 4
position_ids = torch.arange(seq_len)
rotator = Rotator(D, position_ids)
x = torch.randn((batch_size, seq_len, hidden_dim))
# 对每个头分别进行旋转,[batch_size,seq_len,hidden_dim] -> [batch_size,seq_len,num_heads,D] -> [batch_size,num_heads,seq_len,D]
x = x.view(batch_size, seq_len, num_heads, D).transpose(1, 2)
x = rotator.rotate(x)
2.3 FlashAttention
FlashAttention
以分块的形式进行注意力计算,避免了SRAM和HBM之间频繁读写导致的时间浪费。
详情请参考我之前的博客FlashAttention in NeurIPS 2022。
2.4 SwiGLU
激活函数SwiGLU
是门控线性单元(Gated Linear Units, GLU
)的变体,下图红框中表达了GLU
的计算过程:
可见,GLU
会先使用两个带偏执的线性层映射输入
x
\boldsymbol{x}
x,分别记为
x
W
1
+
b
1
\boldsymbol{xW_1+b_1}
xW1+b1和
x
W
2
+
b
2
\boldsymbol{xW_2+b_2}
xW2+b2;其中一个线性映射后会跟一个非线性激活函数sigmoid
,记为
σ
(
x
W
1
+
b
1
)
\sigma(\boldsymbol{xW_1+b_1})
σ(xW1+b1);然后将左右两边的结果对应元素相乘即完成了GLU
,记为
σ
(
x
W
1
+
b
1
)
⊗
(
x
W
2
+
b
2
)
\sigma(\boldsymbol{xW_1+b_1})\otimes(\boldsymbol{xW_2+b_2})
σ(xW1+b1)⊗(xW2+b2)。
SwiGLU
对GLU
做了两点改进:
- 去掉了两个线性映射的偏执项,此时公式变成 σ ( x W 1 ) ⊗ ( x W 2 ) \sigma(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) σ(xW1)⊗(xW2)。
- 将
sigmoid
替换为了Swish
,此时公式变成 Swish β ( x W 1 ) ⊗ ( x W 2 ) \text{Swish}_{\beta}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) Swishβ(xW1)⊗(xW2)。
Swish
的公式为
Swish
β
(
a
)
=
a
σ
(
β
a
)
=
a
1
+
e
−
β
a
\text{Swish}_{\beta}(a)=a\sigma(\beta a)=\frac{a}{1+e^{-\beta a}}
Swishβ(a)=aσ(βa)=1+e−βaa,在不同的
β
\beta
β下该非线性激活函数的曲线如下:
可见,当
β
\beta
β较大时,该曲线与ReLU
十分接近;当
β
=
1
\beta=1
β=1时,小于0但接近0的曲线变得更光滑且非单调。
SwiGLU
则选用了
β
=
1
\beta=1
β=1的Swish
,于是我们得到SwiGLU
的公式如下:
Swish
(
x
W
1
)
⊗
(
x
W
2
)
=
x
W
1
1
+
e
−
x
W
1
⊗
x
W
2
\text{Swish}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2})=\frac{\boldsymbol{xW_1}}{1+e^{-\boldsymbol{xW_1}}}\otimes\boldsymbol{xW_2}
Swish(xW1)⊗(xW2)=1+e−xW1xW1⊗xW2
致谢:
本博客仅做记录使用,无任何商业用途,参考内容如下:
解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索
一文为你深度解析 LLaMA2 模型架构
Llama改进之——SwiGLU激活函数
标签:
相关文章
最新发布
- 光流法结合深度学习神经网络的原理及应用(完整代码都有Python opencv)
- Python 图像处理进阶:特征提取与图像分类
- 大数据可视化分析-基于python的电影数据分析及可视化系统_9532dr50
- 【Python】入门(运算、输出、数据类型)
- 【Python】第一弹---解锁编程新世界:深入理解计算机基础与Python入门指南
- 华为OD机试E卷 --第k个排列 --24年OD统一考试(Java & JS & Python & C & C++)
- Python已安装包在import时报错未找到的解决方法
- 【Python】自动化神器PyAutoGUI —告别手动操作,一键模拟鼠标键盘,玩转微信及各种软件自动化
- Pycharm连接SQL Sever(详细教程)
- Python编程练习题及解析(49题)
点击排行
- 版本匹配指南:Numpy版本和Python版本的对应关系
- 版本匹配指南:PyTorch版本、torchvision 版本和Python版本的对应关系
- Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO
- 相关性分析——Pearson相关系数+热力图(附data和Python完整代码)
- Anaconda版本和Python版本对应关系(持续更新...)
- Python与PyTorch的版本对应
- Windows上安装 Python 环境并配置环境变量 (超详细教程)
- Python pyinstaller打包exe最完整教程