本篇是《Rust与AI》系列的第二篇,上一篇我们主要介绍了本系列的概览和方向,定下了一个基调。本篇我们将介绍LLM的基本架构,我们会以迄今为止使用最广泛的开源模型LLaMA为例展开介绍。
LLM背景
Rust 本身是不挑 AI 模型的,但是 LLM 是当下最热的方向,我们就从它开始吧,先了解一些非常基础的背景知识。
Token
LLM 中非常重要的一个概念是 Token,我们输入给 LLM 和它输出的都是 Token。Token 在这里可以看做语言的基本单位,中文一般是词或字(其实字也是词)。比如:”我们喜欢 Rust 语言“,Token 化后会变成类似 ”我们/喜欢/Rust/语言“ 这样的四个词,可以理解为四个 Token。
给定一段任意的自然语言文本,我们可以用一个分词器(Tokenizer)将其 Token 化成一个个连续的 Token。这些 Token 接下来就可以映射成一个个数字,其实是在词表中的索引,索引进而可以找到一个稠密向量,用来表示该位置 Token 的语义输入。
我们以刚刚的”我们喜欢 Rust 语言“为例,假定已有词表如下。
1 2 3 4 5 6 7 …… 1000 Rust …… 2000 我们 2001 喜欢 2002 语言 ……
注意,前面的数字是行号,并不是词表内容。刚刚那句话其实就是 [2000, 2001, 1000, 2002],这就是 LLM 的输入。LLM 拿到这些 ID 后,会在一个非常大的表里查找对应的稠密向量。这个非常大的表就是词表,大小是:词表大小N × 模型维度,如下所示。
1 2 3 4 5 6 7 …… 1000 0.9146 , 0.066 , 0.4469 , 0.3867 , 0.3221 , 0.6566 , 0.2895 , ...…… 2000 0.5702 , 0.9579 , 0.0992 , 0.9667 , 0.5013 , 0.4752 , 0.1397 , ...2001 0.2896 , 0.7756 , 0.6392 , 0.4034 , 0.3267 , 0.9643 , 0.4311 , ...2002 0.4344 , 0.6662 , 0.3205 , 0.3929 , 0.6418 , 0.6707 , 0.2414 , ...……
也就是说,输入”我们喜欢Rust语言“这句话,我们实际传递给模型的其实是一个 4×Dim 的矩阵,这里的 4 一般也叫 Sequence Length。
我们可以暂时把模型看作一个函数 f(x),输入一个 Sequence Length × Dim 的矩阵,经过模型 f(x) 各种运算后会输出 Sequence Length × Vocabulary Size 大小的一个概率分布。有了概率分布就可以采样一个 Token ID(基于上下文最后一个 Token ID 的分布),这个 ID 也就是给定当前上下文(”我们喜欢Rust语言“)时生成的下一个 Token。接下来就是把这个 ID 拼在刚刚的 4 个 ID 后面(输入变成 5 个 ID),继续重复这个过程。
生成
如上所言,生成过程就是从刚刚的概率分布中 “选择” 出一个 Token ID 作为下一个 Token ID。选择的方法可以很简单,比如直接选择概率最大的,此时就是 Greedy Search,或 Greedy Decoding。
不过我们平时用到大模型时一般都用的是采样的方法,也就是基于概率分布进行采样。抛硬币也是一种采样,按概率分布(0.5,0.5)进行采样,但假设正面比较重,概率分布就可能变成了(0.8,0.2)了。基于 Vocabulary Size 个概率值进行采样也是类似的,只不过括号里的值就是词表大小那么多个。
top_p/top_k 采样是概率值太多了,大部分都是概率很小的 Token,为了避免可能采样到那些概率很低的 Token(此时生成的结果可能很不连贯),干脆就只从前面的 Token 里挑。
top_k 就是把 Token 按概率从大到小排序,然后从前 k 个里面选择(采用)下一个 Token;top_p 也是把 Token 按概率从大到小排序,不过是从累积概率大于 p 的 Token 里选。就是这么简单。
这里有个小细节需要说明,因为选择了 top_p/k,所以这些备选的 Token 需要重新计算概率,让它们的概率和为 1(100%)。
开源代表——LLaMA
接下来,我们把重心放在函数 f(x) 上,以最流行的开源 LLM——LLaMA 为例,简单介绍一下模型的结构和参数。
结构
LLaMA 的结构相对而言比较简单,如果我们忽略其中的很多细节,只考虑推理过程,看起来如下图所示。
图中 [] 中的是该位置的张量 shape,B 表示 Batch Size,一般时候都是批量丢给 GPU 计算的,L 就是 Sequence Length,D 就是上面提到的 Dim。这是一个简化了的架构图,但是足以清晰地表达模型了。
两个 Hidden states(以下简称 HS),外面(之上和之下)的部分我们前面已经提到过了(注意上面部分,[B,L,D] 会先变成 [B,L,VS],然后取最后一个 Token 就得到了 [B,1,VS]),上面的 HS 会传回到 Block 里面,重复 N 次,N 就是模型的层数。接下来我们就把重点放在中间这个 Block 里。
每个 Block 包括两个主要模块,一个 MHA(Multi-Head Attention)模块,一个 FFN(Feedforward Network)模块,每次传给模块之前都需要 Normalization,这个叫 Pre-Normalization,一般用来稳定训练。另外,每个模块结束后会叠加模块之前的输入,这个叫残差连接,一般能加速收敛。
接下来是 MHA 和 FFN,先看 FFN 模块,它的大概流程如下(@ 表示矩阵/张量乘法)。
1 2 3 4 z1 = ns @ up_weights z2 = ns @ gate_weights z3 = z1 * silu (z2) z4 = z3 @ down_weights
整体来看是先将网络扩大再收缩,扩大时增加了一个激活处理。silu 函数大概长这样:
等价于只激活了一部分参数,这个非线性激活非常重要,可以让模型学习到更丰富的知识和表达。
再就是 MHA 模块了,大概流程如下(为了更直观,去掉了 Batch Size 和 Softmax)。
1 2 3 4 5 6 7 8 9 10 11 q = ns @ q_weights # (L, D) @ (D, D) = (L, D) k = ns @ k_weights # (L, D) @ (D, D) = (L, D) v = ns @ v_weights # (L, D) @ (D, D) = (L, D) q = q.reshape (L, NH, HD) k = k.reshape (L, NH, HD) v = v.reshpae (L, NH, HD) attn = q.trans (NH, L, HD) @ k.trans (NH, HD, L) # (NH, L, HD) @ (NH, HD, L) = (NH, L, L) v = attn @ v.trans (NH, L, HD) # (NH, L, L) @ (NH, L, HD) = (NH, L, HD) v = v.reshpe (L, NH*HD) # (L, D)
其中,NH 表示 Attention 的 Head 数,HD 表示 Head 的维度。因为有 NH 个 Head,所以叫 Multi-Head,但其实我们看上面的过程,在实际计算的时候它们是合并一起算的。我们不妨只看一个 Head,如下所示。
1 2 3 4 5 6 q = ns @ hq_weights # (L, D) @ (D, HD) = (L, HD) k = ns @ hk_weights # (L, D) @ (D, HD) = (L, HD) v = ns @ hv_weights # (L, D) @ (D, HD) = (L, HD) attn = q @ k.T # (L, HD) @ (HD, L) = (L, L) v = attn @ v # (L, L) @ (L, HD) = (L, HD)
上面的多个 Head 的 v 就是下面的每个 Head 的 v 拼接起来的。
Multi-Head 是多个注意力头去执行 Attention,其思想是让每个 Head 去捕获不同角度/层面的 Attention,这些角度/层面是什么?不是特别清楚(但一定是某种特征),但我们可以通过 Attention 的权重看出外在 Token 级别的注意力,知道每个注意力 Head,哪些 Token 之间有比较强的连接。
参数
关于 f(x) 我们已经介绍完了,可以发现这个函数其实还是有点复杂的。接下来,我们看看参数情况。
对一个一元一次方程(比如 f(x) = ax + b)来说,参数就两个:a 和 b,但对于 LLM 来说,参数就非常多了,目前常用的是 7B、13B、20B 的级别,也就是 70亿、130亿和 200亿的参数规模。
在神经网络中,可以把矩阵乘法看作是多元一次方程组的计算过程,输入的 Hidden State 维度是 D,就表示未知变量的维度是 D,也就是 D 元一次方程组。
以前面的但 Head Attention 的 q 为例,q_weights 是一个 DxHD 的参数矩阵,我们把 D 和 HD 设置的小一点(假设为4和2),看一个具体的例子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 torch.manual_seed(42 ) w = nn.Linear(4 , 2 , bias=False ) hs = torch.rand((3 , 4 )) q = hs @ w.weight.T """ hq_weights = w.weight.T = tensor([[ 0.3823, -0.1096], [ 0.4150, 0.1009], [-0.1171, -0.2434], [ 0.4593, 0.2936]]) hs = tensor([[0.9408, 0.1332, 0.9346, 0.5936], [0.8694, 0.5677, 0.7411, 0.4294], [0.8854, 0.5739, 0.2666, 0.6274]]) q = tensor([[ 0.5781, -0.1428], [ 0.6784, -0.0923], [ 0.8336, 0.0803]]) """
这个例子除了维度小一点,其他逻辑是一样的。它对应这么一个多元方程组。
1 2 3 4 5 6 w11*x11 + w21*x12 + w31*x13 + w41*x14 = y11 w12*x11 + w22*x12 + w32*x13 + w42*x14 = y12 w11*x21 + w21*x22 + w31*x23 + w41*x24 = y21 w12*x21 + w22*x22 + w32*x23 + w42*x24 = y22 w11*x31 + w21*x32 + w31*x33 + w41*x34 = y31 w12*x31 + w22*x32 + w32*x33 + w42*x34 = y32
其中 x 就是 hs,w 就是 hq_weights,写成数学表达式大概就是下面的这样。
[ x 11 x 12 x 13 x 14 x 21 x 22 x 23 x 24 x 31 x 32 x 33 x 34 ] × [ w 11 w 12 w 21 w 22 w 31 w 32 w 41 w 42 ] = [ y 11 y 12 y 21 y 22 y 31 y 32 ] \left[\begin{array}{llll}
x_{11} & x_{12} & x_{13} & x_{14} \\
x_{21} & x_{22} & x_{23} & x_{24} \\
x_{31} & x_{32} & x_{33} & x_{34}
\end{array}\right] \times\left[\begin{array}{ll}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32} \\
w_{41} & w_{42}
\end{array}\right]=\left[\begin{array}{ll}
y_{11} & y_{12} \\
y_{21} & y_{22} \\
y_{31} & y_{32}
\end{array}\right]
x 11 x 21 x 31 x 12 x 22 x 32 x 13 x 23 x 33 x 14 x 24 x 34 × w 11 w 21 w 31 w 41 w 12 w 22 w 32 w 42 = y 11 y 21 y 31 y 12 y 22 y 32
对于这样的一个 Linear 来说,参数量就是 2×4=8 个。现在让我们看看 LLaMA,就按词表大小=32000,维度=4096来计算。
首先是 Embedding 和 LM Head(就是映射到 32000 个 Token 的那个参数),它们是一样的,都是 32000×4096,有时候这两个地方的参数也可以设计成共享的,LM Head 前面也有一个 Normalization,4096 个参数。
然后是 Block,MHA 的 qkvo 是 4 个 4096×4096 的矩阵,FFN 的 gate、up、down 是 11008×4096 的矩阵,再加上两个 Normalization, 4096×2 个参数。每个 Block 参数量为 4096×(4096×4+11008×3+2)。
这样得到所有的参数总和为:32000*4096*2 + 4096 +(4096*(4096*4+11008*3+2))*32 = 6738415616,67亿多的样子,也就是常说的 7B。
Rust与LLaMA
终于来到了 Rust,之所以前面铺垫那么多,是因为如果我们完全不熟悉模型的基本结构和执行过程,这个代码看起来就会知其然而不知其所以然。当然,即便了解了基本结构,里面也有一些细节需要单独介绍,不过我们会放在后续的内容。
只看上面的内容,我们可以发现 LLM 模型的结构其实不算特别复杂,而且其中涉及到大量的矩阵运算(至少占到 80% 以上)。关于矩阵运算以及相关的优化,我们也会在后面慢慢涉及。
LLaMA 的 Rust 实现有很多个版本,本次选择的是来自 karpathy/llama2.c: Inference Llama 2 in one file of pure C 的 Rust 实现的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust ,而且我们暂时只会涉及模型基础结构部分,其中涉及一些特别的细节会简单解释,不深入展开。
配置
首先是配置,如下所示。
1 2 3 4 5 6 7 8 9 10 11 struct Config { dim: usize , hidden_dim: usize , n_layers: usize , n_heads: usize , head_size: usize , n_kv_heads: usize , shared_weights: bool , vocab_size: usize , seq_len: usize , }
dim 就是上面一直说的 Dim,hidden_dim 仅在 FFN 层,因为 FFN 层需要先扩大再缩小。n_heads 和 n_kv_heads 是 Query 的 Head 数和 KV 的 Head 数,简单起见可以认为它们是相等的。如果我们加载 karpathy 的 15M 的模型,结果如下。
1 Config { dim: 288 , hidden_dim: 768 , n_layers: 6 , n_heads: 6 , head_size: 48 , n_kv_heads: 6 , shared_weights: true , vocab_size: 32000 , seq_len: 256 }
shared_weights 就是上面提到的 Embedding 和 LM Head 是否共享参数。
Tokenizer 的功能我们暂且略过,目前只需知道它负责将文本转为 ID 列表(encode)以及把 ID 列表转为文本(decode)。
参数
接下来看模型参数,如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 struct TransformerWeights { token_embedding_table: Vec <f32 >, rms_att_weight: Vec <f32 >, rms_ffn_weight: Vec <f32 >, wq: Vec <f32 >, wk: Vec <f32 >, wv: Vec <f32 >, wo: Vec <f32 >, w1: Vec <f32 >, w2: Vec <f32 >, w3: Vec <f32 >, rms_final_weights: Vec <f32 >, freq_cis_real: Vec <f32 >, freq_cis_imag: Vec <f32 >, wcls: Vec <f32 >, }
上面的参数应该都比较直观,我们不太熟悉的应该是 freq_ 开头的两个参数,它们是和位置编码有关的参数,也就是说,我们每次生成一个 Token 时,都需要传入当前位置的位置信息。
位置编码在 Transformer 中是比较重要的,因为 Self Attention 本质上是无序的,而语言的先后顺序在有些时候是很重要的,比如 “我喜欢你” 和 “你喜欢我”,“你” 和 “我” 的顺序不同,语义也不同。但时候很多语义又不太响影我们解理语义,不妨再仔细读一下刚刚这半句话。你看文本顺序虽然变了,但你读起来毫无障碍。这也是为什么会有研究说不要位置编码语言模型也可以,但效果应该是不如加了位置编码的。
模型创建好后,接下来就是加载参数和执行推理。加载参数要看模型文件的格式设计,本项目来自 karpathy 的 C 代码,模型文件被安排成了 bin 文件,按规定的格式读取即可,核心代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 fn byte_chunk_to_vec <T>(byte_chunk: &[u8 ], number_elements: usize ) -> Vec <T>where T: Clone , { unsafe { let data = byte_chunk.as_ptr () as *const T; let slice_data : &[T] = std::slice::from_raw_parts (data, number_elements); slice_data.to_vec () } }
byte_chunk 表示原始的字节切片,number_elements 表示结果向量中元素的个数,T 有 Clone 的 Trait 约束,表示 T 必须实现该 Trait,也就是 T 必须能够使用 Clone 方法。其他解释已经在代码中给出了注释,不再赘述。
加载模型就是读取原始的 bin 文件并指定对应的参数大小,我们以 Token Embedding 参数为例,如下所示。
1 2 3 let token_embedding_table_size = config.vocab_size * config.dim;let token_embedding_table : Vec <f32 > = byte_chunk_to_vec (&mmap[offset..], token_embedding_table_size);
类似这样就可以依次把模型参数读取进来了。
模型
接下来就是最复杂的模型部分了。这里最大的不同是 Token by Token 的处理,而不是给定一个上下文生成下一个 Token。我们看一下基本的 Struct,如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 struct LLaMA2 <'a > { x: Vec <f32 >, xb: Vec <f32 >, xb2: Vec <f32 >, hb: Vec <f32 >, hb2: Vec <f32 >, q: Vec <f32 >, k: Vec <f32 >, v: Vec <f32 >, att: Vec <f32 >, logits: Vec <f32 >, key_cache: Vec <f32 >, value_cache: Vec <f32 >, transformer: &'a TransformerWeights, config: &'a Config, }
最后两个参数我们上面已经介绍过了,其他参数都是模型推理过程中需要用到的中间结果和最初的输入,以及最终的结果,它们均被初始化成 0。至于为什么有些值是多个(比如 xb、hb等),是因为 Block 里面涉及到残差连接,需要额外保存一个输入。
现在我们从 forward 开始,方法如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 fn forward (&mut self , token: usize , pos: usize ) { self .x.copy_from_slice ( &self .transformer.token_embedding_table [(token * self .config.dim)..((token + 1 ) * self .config.dim)], ); for l in 0 ..self .config.n_layers { self .layer (l, pos); } rmsnorm ( self .x.as_mut_slice (), self .transformer.rms_final_weights.as_slice (), ); matmul ( self .logits.as_mut_slice (), self .transformer.wcls.as_slice (), self .x.as_slice (), ); }
这块代码是推理的全流程,一共四个步骤:取 Embedding、逐层计算、Normalization、映射到词表大小的 logits(后续会基于此转为概率分布)。
Embedding 是直接从参数里 copy 出对应索引的参数,无序赘述。
Normalization 用的是 RMS(Root Mean Square)Normalization,基本公式如下。
x i ′ = x i ∑ i = 1 N x i ∗ w i x'_i = \frac{x_i} {\sqrt{\sum_{i=1}^N x_i}} * w_i
x i ′ = ∑ i = 1 N x i x i ∗ w i
它是标准 Normalization 的简单形式,但效果尚可,其代码如下。
1 2 3 4 5 6 7 8 9 10 fn rmsnorm (x: &mut [f32 ], weight: &[f32 ]) { let size = x.len (); let squared_sum = x.iter ().fold (0.0 , |acc, x| acc + x * x); let rms = 1 . / (squared_sum / size as f32 ).sqrt (); x.iter_mut () .zip (weight.iter ()) .for_each(|(x, w)| *x *= rms * w); }
代码一目了然,先一个 reduce,然后开方取倒数,接着就是遍历计算更新每个参数值。
最后的矩阵乘法比较标准,输入的 Hidden State(x)因为只有一个 Token,所以可以看成向量,长度为 Dim,与 LM Head 矩阵乘法后就得到一个词表大小的输出值,后续可以归一化成概率值(即概率分布)。矩阵乘法代码如下(准确来说是向量和矩阵乘法)。
1 2 3 4 5 6 7 8 9 10 fn matmul (target: &mut [f32 ], w: &[f32 ], x: &[f32 ]) { let in_dim = x.len (); target.par_iter_mut ().enumerate ().for_each(|(i, t)| { let row_offset = i * in_dim; *t = x .iter () .zip (w[row_offset..].iter ()) .fold (0.0 , |result, (x, w)| result + x * w); }); }
这里需要注意的是 offset,因为参数是一个 Vec 存储的一维数组,要按二维取值,需要每次跳过对应数量的参数。剩下的就很清晰了,最终的结果会存储到 target,也就是 self.logits,进而会转为概率分布。
我们把重心放在中间的逐层计算上,LLM 的核心也在这里。先看 layer 的代码,如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 fn layer (&mut self , layer: usize , pos: usize ) { rmsnorm_with_dest ( self .xb.as_mut_slice (), self .x.as_slice (), &self .transformer.rms_att_weight [layer * self .config.dim..(layer + 1 ) * self .config.dim], ); self .attn (layer, pos); add_vectors (self .x.as_mut_slice (), self .xb2.as_slice ()); rmsnorm_with_dest ( self .xb.as_mut_slice (), self .x.as_slice (), &self .transformer.rms_ffn_weight [layer * self .config.dim..(layer + 1 ) * self .config.dim], ); self .ffn (layer); add_vectors (self .x.as_mut_slice (), self .xb.as_slice ()); }
非常标准的流程(可回看前面的架构图),先归一化,然后 MHA,残差连接,再归一化,FFN,残差连接。归一化的代码刚刚已经看过了,这里唯一的不同是将输出放到第一个参数(即 self.xb)里。add_vectors 就是对应元素值求和,结果放到第一个参数,这个比较简单,我们就不放代码了。重点就是 ffn 和 attn,它们内部涉及大量矩阵乘法,我们开始。
先看 ffn,它比较简单,主要是几个矩阵乘法加非线性激活,代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 fn ffn (&mut self , layer: usize ) { let weight_from = layer * self .config.hidden_dim * self .config.dim; let weight_to = (layer + 1 ) * self .config.hidden_dim * self .config.dim; matmul ( self .hb.as_mut_slice (), &self .transformer.w1[weight_from..weight_to], self .xb.as_slice (), ); matmul ( self .hb2.as_mut_slice (), &self .transformer.w3[weight_from..weight_to], self .xb.as_slice (), ); for i in 0 ..self .config.hidden_dim { self .hb[i] = silu (self .hb[i]) * self .hb2[i]; } matmul ( self .xb.as_mut_slice (), &self .transformer.w2[weight_from..weight_to], self .hb.as_slice (), ); }
这个过程和我们《开源代表——LLaMA 结构》一节中是一一对应的,涉及到的主要是刚刚介绍过的 matmul 和一个 silu,后者我们之前看过它的图像,代码如下。
1 2 3 fn silu (x: f32 ) -> f32 { x / (1.0 + (-x).exp ()) }
表达式如下所示。
SiLU ( x ) = x 1 + e − x \text{SiLU}(x) = \frac{x}{1 + e^{-x}}
SiLU ( x ) = 1 + e − x x
好了,最后我们把重心放在 attn 这个方法上,由于逐 Token 生成时,Query 是当前 Token,这没问题,但 Key 和 Value(Attention 里面的 K和V)是需要历史 Token 的(不然怎么算注意力)。常见的做法就是把历史过程中的 K 和 V 缓存起来,每次生成时顺便更新缓存,这样下次生成时拿到的就是之前的所有 K 和 V。
先看一下基本的代码流程,如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 fn attn (&mut self , layer: usize , pos: usize ) { self .attn_qkv_matmuls (layer); self .attn_rope (layer, pos); self .cache_kv (layer, pos); self .multihead_attn (layer, pos); let weight_from = layer * self .config.dim * self .config.dim; let weight_to = (layer + 1 ) * self .config.dim * self .config.dim; matmul ( self .xb2.as_mut_slice (), &self .transformer.wo[weight_from..weight_to], self .xb.as_slice (), ); }
最后的 wo 比较简单,不再赘述。一开始的 qkv 也比较简单,都是矩阵乘法,如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 fn attn_qkv_matmuls (&mut self , layer: usize ) { let weight_from = layer * self .config.dim * self .config.dim; let weight_to = (layer + 1 ) * self .config.dim * self .config.dim; matmul ( self .q.as_mut_slice (), &self .transformer.wq[weight_from..weight_to], self .xb.as_slice (), ); matmul ( self .k.as_mut_slice (), &self .transformer.wk[weight_from..weight_to], self .xb.as_slice (), ); matmul ( self .v.as_mut_slice (), &self .transformer.wv[weight_from..weight_to], self .xb.as_slice (), ); }
还剩下三个方法:attn_rope、cache_kv 和 multihead_attn,我们分别看一下。
第一个用来加入位置信息,参数是一开始算好的,这里直接取出对应位置的值进行计算。代码如下所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 fn attn_rope (&mut self , layer: usize , pos: usize ) { let freq_cis_real_offset = pos * self .config.head_size / 2 ; let freq_cis_imag_offset = pos * self .config.head_size / 2 ; for i in (0 ..self .config.dim).step_by (2 ) { let q0 = self .q[i]; let q1 = self .q[i + 1 ]; let k0 = self .k[i]; let k1 = self .k[i + 1 ]; let cos = self .transformer.freq_cis_real [freq_cis_real_offset + (i % self .config.head_size) / 2 ]; let sin = self .transformer.freq_cis_imag [freq_cis_imag_offset + (i % self .config.head_size) / 2 ]; self .q[i] = q0 * cos - q1 * sin; self .q[i + 1 ] = q1 * cos + q0 * sin; self .k[i] = k0 * cos - k1 * sin; self .k[i + 1 ] = k1 * cos + k0 * sin; } }
这部分代码就是把位置信息注入到 Q 和 K 中,其理论分析比较复杂,此处不展开。
cache_kv 比较简单,直接把当前的 K 和 V 存起来即可,如下所示。
1 2 3 4 5 6 7 8 9 fn cache_kv (&mut self , layer: usize , pos: usize ) { let layer_offset = layer * self .config.seq_len * self .config.dim; let cache_from = layer_offset + pos * self .config.dim; let cache_to = layer_offset + (pos + 1 ) * self .config.dim; self .key_cache[cache_from..cache_to].copy_from_slice (&self .k.as_slice ()); self .value_cache[cache_from..cache_to].copy_from_slice (&self .v.as_slice ()); }
因为我们不确定用户生成的 Token 长度,所以就把最大长度(seq_len)的所有位置都占上,因为是按层存的,每一层都有计算,所以需要层的 ID。每一层、每个位置都缓存 dim 个中间结果。
最后就是最重要的 multihead_attn 了,这里面的主要逻辑是计算 attention 分数,然后得到 attention 之后的结果,代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 fn multihead_attn (&mut self , layer: usize , pos: usize ) { let layer_offset_for_cache = layer * self .config.seq_len * self .config.dim; let sqrt_d = (self .config.head_size as f32 ).sqrt (); self .att.par_chunks_exact_mut (self .config.seq_len) .zip (self .xb.par_chunks_exact_mut (self .config.head_size)) .enumerate () .for_each(|(h, (attn_scores, xb))| { assert_eq! (attn_scores.len (), self .config.seq_len); assert_eq! (xb.len (), self .config.head_size); let q_from = h * self .config.head_size; let q_to = (h + 1 ) * self .config.head_size; let q = &self .q[q_from..q_to]; for t in 0 ..=pos { let timestep_and_layer_offset = layer_offset_for_cache + t * self .config.dim; let key_vector_from = timestep_and_layer_offset + h * self .config.head_size; let key_vector_to = timestep_and_layer_offset + (h + 1 ) * self .config.head_size; let key_vector = &self .key_cache[key_vector_from..key_vector_to]; attn_scores[t] = inner_product (q, key_vector) / sqrt_d; } softmax (&mut attn_scores[..(pos + 1 )]); xb.fill (0.0 ); for t in 0 ..=pos { let timestep_and_layer_offset = layer_offset_for_cache + t * self .config.dim; let value_vector_from = timestep_and_layer_offset + h * self .config.head_size; let value_vector_to = timestep_and_layer_offset + (h + 1 ) * self .config.head_size; let value_vector = &self .value_cache[value_vector_from..value_vector_to]; let attention_weight = attn_scores[t]; for i in 0 ..self .config.head_size { xb[i] += attention_weight * value_vector[i]; } } }); }
上面的过程是分 Head 计算的,需要我们深刻理解前面《开源代表——LLaMA 结构》一小节的内容,具体解释可以参考代码里的注释。值得一提的是,分 Head 计算是并行的。
另外,有个新方法 inner_product 是点积,也就是对应元素相乘后求和,代码如下。
1 2 3 fn inner_product (x: &[f32 ], y: &[f32 ]) -> f32 { zip (x, y).fold (0.0 , |acc, (a, b)| acc + a * b) }
比较简单,不再赘述。
生成
最后就是生成(或 Decoding)过程。代码略有不同,我们先看下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 fn generate (&mut self , prompt_tokens: &Vec <usize >, n_tokens: usize , temperature: f32 ) -> Vec <usize > { let mut tokens = vec! []; tokens.reserve (n_tokens); let mut token = BOS_TOKEN; tokens.push (token); for (pos, prompt_token) in prompt_tokens.iter ().enumerate () { self .forward(token, pos); token = *prompt_token; tokens.push (token); } for pos in prompt_tokens.len ()..(n_tokens - 1 ) { self .forward(token, pos); if temperature == 0.0 { token = argmax (self .logits.as_slice ()); } else { self .logits.iter_mut ().for_each(|p| *p = *p / temperature); softmax (&mut self .logits.as_mut_slice ()); token = sample (self .logits.as_slice ()); } tokens.push (token); } tokens }
这里有两个值得注意的地方。
第一个是推理 Prompt(即第一次输入时的 Context),此时给定的 Context 是多个 Token 组成的,执行该过程目的是填充 KV Cache。
第二个是采样过程,temperature=0.0 时,就是 Greedy Search,每次返回概率最大位置的 Token;否则,会先应用 temperature,然后按照概率分布进行采样。temperature 参数会平滑概率分布,值越大,平滑力度越大,更有可能生成多样的结果。softmax 用来把一系列值归一化成概率分布(所有值加起来和为 1.0)。我们重点看看这个 sample 方法,它的主要思想是根据概率分布进行采样,也就是高概率的位置更容易被采样到,低概率的位置更不容易被采样到。代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 fn sample (probs: &[f32 ]) -> usize { let mut rng = rand::thread_rng (); let mut cdf = 0.0 ; let r = rng.gen_range (0.0 ..1.0 ); for (i, p) in probs.iter ().enumerate () { cdf += p; if cdf > r { return i; } } probs.len () - 1 }
随机生成 0-1 之间的一个值(均匀分布),计算累积概率,当累积概率大于刚刚生成的值时,返回此时的位置。这样就可以保证是按照概率分布进行采样的。我们举个具体的例子,如下所示。
1 2 3 4 probs = [0.1 , 0.2 , 0.1 , 0.5 , 0.1 ] accu_probs = [0.1 , 0.3 , 0.4 , 0.9 , 1.0 ]
假设随机值为 r,因为它是均匀分布的,所以落在不同区间的概率与该区间的长度成正比。我们看上面的累积概率,可以得出如下结果。
r落在区间
返回 Index
[0, 0.1)
0
[0.1, 0.3)
1
[0.3, 0.4)
2
[0.4, 0.9)
3
[0.9, 1.0)
4
也就是说返回 Index=3 的概率为 0.5,其他同理。
拿到 Token 向量后只要用 Tokenizer 解码即可得到生成的文本。
小结
本文我们首先简单介绍了 LLM 相关的背景,着重讨论了关于 Token 和生成过程,这是应用 LLM 时非常重要的两个知识点。然后我们介绍了开源 LLM 的代表——LLaMA 的模型结构和参数,给大家一个整体的感知和认识。最后就是 Rust 的实现,主要包括配置、参数、模型和生成四个方面,其中最重要的就是模型部分,模型部分最重要、也最难理解的是 Multi-Head Attention 的计算。主要是因为具体的计算过程都是把矩阵运算给展开了,这需要对模型有一定程度的理解。
这种展开的写法其实是比较底层的实现,如果能在上面抽象一层,直接操纵矩阵或张量,那计算起来应该会简单很多。事实上,大部分框架都是这么做的,比如 Python 的 NumPy 、PyTorch等,当然 Rust 也有类似的框架,比如 NumPy 对应的 ndarray,以及 Rust 版本的深度学习框架。使用这些框架时,我们使用的是矩阵/张量(或者叫多维数组)这个对象,所有的操作也都在这个粒度进行,这无疑极大地提高了编程效率。同时,还可以利用这些框架底层的性能优化。
不过,有时候当我们需要框架暂未支持的更细致的优化、或在一个框架不支持的设备上运行时,这种 Pure X(此处为 Rust)的方式就比较方便灵活了。
总的来说,算法是多样的,实现更是多样的,优化更更是无止境的,吾辈唯有不断前行,持续向上。