LLM之Llama模型

概览

出于LLM基建需求, 最近以hugging face的Llama实现作为baseline, 通过调用基础算子对Llama进行了C++组网。在组网前, 阅读了一下Llama的论文和hugging face的组网逻辑(modelling_llama.py), 总结了Llama网络的几个主要特点:

  • 和GLM类似, 也是使用了RoPE, 且实现和原始RoPE是一致的;
  • 没有采用通常的Layer Normalization, 而是使用RMSNorm;
  • GLU单元实现不一样, 像GPT、GLM、BLOOM模型的GLU都是两个参数矩阵做FC, Llama的GLU单元有3个参数矩阵, 采用Swish激活函数;
  • 采用BPE(byte-pair encoding) Tokenizer方案
  • Llama的模型参数重没有bias;
    后面会具体介绍一下RMSNorm、GLU和BPE Tokenizer这三个feature。

RMSNorm

paper: 《Root Mean Square Layer Normalization》
一般的layer normalization的计算可以表示成。对于基于Transformer结构的LLM, normalization的输入一般都是shape为[seq_len, batch_size, hidden_dim](hidden_dim = size_per_head * head_num), 其中均值和标准差是在hidden_dim所在的维度计算出来的。
RMSNorm的计算略有不同, 可以表示成,其中。和layer normalization一样, 也是在hidden_dim这个维度计算得到的。

由于在搭建GLM网络时,layer normalization使用float16数据类型进行计算, 所以一开始RMSNorm也是用float16。后来发现虽然第一层的精度可以对齐, 但是经过多层block后, 精度偏差越来越大。仔细检查了一下, 发现由于的计算方式不同, 可能会很大, 导致为0, 使用float计算后精度可以完全对齐。

SwiGLU

  • SiLU激活函数:
  • Swish激活函数:

SiLU可以看成Swish的一个特例。
可以看看hugging face上对Llama的SwiGLU的实现, self.act_fn就是Swish激活函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]

def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

拓展阅读:

BPE

BPE最初是作为一种文本压缩算法,后来被很多的Transformer模型用作tokenization, 比如GPT系列、RoBERTa, BART, DeBERTa。

tokenization(标记化)是将自然语言转换成模型可以处理的标记的过程。自然语言处理的基本对象可以分成3类:1)单词(word); 2)字符(character); 3)子词(subword)。这是一种介于单词和字符之间的处理粒度。BPE就是一种基于subword的处理方法, 还有其它一些subword tokenization方法, 比如WordPiece, Unigram, SentencePiece等。
BPE的一个典型例子可以查看wiki:

TJ1EUs

具体到Llama, 论文里说使用的是Google的SentencePiece方法。感兴趣的可以了解一下WordPiece、Unigram和SentencePiece算法。

拓展阅读:

Comments

Unable to load Disqus, please make sure your network can access.