LLM之speculative-sampling

1 Introduction

1.1 paper

1.2 code

https://github.com/jaymody/speculative-sampling/blob/main/main.py

1.3 core ideas

在整个方法中,会涉及到两个模型:

  • 更小、推理更快的draft model, 比如7B的Chinchilla模型
  • 更大、推理更慢的target model, 比如70B的Chinchilla模型
    draft model会推测接下来的K个token是什么, 而target model会决定采用多少个draft model预测的token。Speculative sampling的想法很简单, 有一些token的预测非常容易, 使用一个更小、表达能力更差的模型(draft model)一样可以得到很好的预测结果, 而无需target model的参与。

形式上看, 该方法和知识蒸馏很相似, 但是两个模型的interaction还是很不一样的,毕竟任务也不同。

1.4 limitations

Our method is easy to employ in actual production settings,
doesn’t require training new models, and doesn’t change the
outputs
. Therefore, in common situations where memory
bandwidth is the bottleneck
, and compute resources are
available
, it may be a good default to accelerate sampling
from autoregressive models like Transformers.

  • 带宽瓶颈
  • 计算资源有富余

根据在我司几张卡上跑主流开源大模型(e.g. Llama, GLM, Opt)的经验, 大模型的推理都是带宽瓶颈。

最后想额外提及的一点是: speculative sampling推理得到的结果和原来只使用target model推理得到结果是完全一样的,也就是说speculative sampling算法是一种确定性的算法,而非对target model结果的近似或者逼近。

2 Methods

2.1 core algorithm

表示target model预测的token概率;
表示draft model预测的token概率。
Cb2jFv

  • draft model以自回归的方式解码K个token。假设prompt context有N个token, 经过第一步后, 得到一个N+K个token的序列。(此时draft model前向推理K次)
  • 对新的到的序列(N+K个token), 使用target model前向推理一遍, 可以得到target model的K个预测概率
  • 比较draft modeltarget model的概率值, 来决定对draft model的K个预测token的取舍。如果某个token被接受了, 那么继续比较下一个token; 否则从一个新的数据分布进行采样,然后返回第一步
  • 如果K个token都被target model接受了,那么target model会预测下一个token, 然后继续返回第一步, 迭代进行

Q: 什么情况接受draft model的结果,什么时候拒绝?
A: 时, draft model的结果一定被采纳,否则有的概率被采纳。
Q: draft model给出的预测token被拒绝后,应该如何确定下一个token?
A: 按$(q(x)-p(x)){+}=\frac{max(0,q(x)-p(x))}{\Sigma{x}(0,q(x)-p(x))}$的概率进行采样。

2.2 code explanation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def speculative_sampling(x, draft_model, target_model, N, K):
# NOTE: paper indexes arrays starting from 1, python indexes from 0, so
# we have to add an extra -1 term when indexing using n, T, or t
n = len(x)
T = len(x) + N

with tqdm(total=N, desc="speculative sampling") as pbar:
while n < T:
prev_n = n

# Step 1: auto-regressive decode K tokens from draft model and get final p
x_draft = x
for _ in range(K):
# 解释: p的shape是[seq_len, vocab_size], 表示每个token的softmax后的概率值
p = draft_model(x_draft)
x_draft = np.append(x_draft, sample(p[-1]))

# Step 2: target model forward passes on x_draft
# 解释: 此时x_draft是一个n + K个token的序列, 返回的q也是一个shape为[seq_len, vocab_size]的Tensor
q = target_model(x_draft)

# Step 3: append draft tokens based on rejection criterion and resample
# a token on rejection
all_accepted = True
for _ in range(K):
i = n - 1
j = x_draft[i + 1]
if np.random.random() < min(1, q[i][j] / p[i][j]): # accepted
x = np.append(x, j)
n += 1
else: # rejected
x = np.append(x, sample(max_fn(q[i] - p[i]))) # resample
n += 1
all_accepted = False
break

# Step 4: if all draft tokens were accepted, sample a final token
if all_accepted:
x = np.append(x, sample(q[-1]))
n += 1

# just keeping my sanity
pbar.update(n - prev_n)
assert n == len(x), f"{n} {len(x)}"

return x

2.3 proof

证明:按照2.1节的算法运作, 得到的结果和采用target model的结果是一样的。
假设:

  • 事件表示draft model预测的token被target model采纳
  • 事件表示draft model预测的token被target model拒绝

根据SpS算法,有:

  • $P(x=x’|A_2) = (q(x’)-p(x’)){+}=\frac{max(0, q(x’)-p(x’))}{\Sigma{x}max(0,q(x)-p(x))}$

根据全概率公式, 有:

上述公式的前半部分很好计算:

由于, 可得:

所以后半部分的公式有:

综合前后部分公式,可得:

Reference

  1. https://www.youtube.com/watch?v=E1YFY6Ag70s
  2. Speculative Sampling
  3. Probability Review
  • 条件概率公式:
  • 全概率公式:
  • 贝叶斯公式:

Comments

Unable to load Disqus, please make sure your network can access.