1 Introduction
1.1 paper
- 《Accelerating Large Language Model Decoding with Speculative Sampling
》 [From DeepMind] - 《Fast Inference from Transformers via Speculative Decoding》[From Google Research]
1.2 code
https://github.com/jaymody/speculative-sampling/blob/main/main.py
1.3 core ideas
在整个方法中,会涉及到两个模型:
- 更小、推理更快的
draft model
, 比如7B的Chinchilla模型 - 更大、推理更慢的
target model
, 比如70B的Chinchilla模型draft model
会推测接下来的K
个token是什么, 而target model
会决定采用多少个draft model
预测的token。Speculative sampling的想法很简单, 有一些token的预测非常容易, 使用一个更小、表达能力更差的模型(draft model
)一样可以得到很好的预测结果, 而无需target model
的参与。
形式上看, 该方法和知识蒸馏很相似, 但是两个模型的interaction还是很不一样的,毕竟任务也不同。
1.4 limitations
Our method is easy to employ in actual production settings,
doesn’t require training new models, and doesn’t change the
outputs. Therefore, in common situations where memory
bandwidth is the bottleneck, and compute resources are
available, it may be a good default to accelerate sampling
from autoregressive models like Transformers.
- 带宽瓶颈
- 计算资源有富余
根据在我司几张卡上跑主流开源大模型(e.g. Llama, GLM, Opt)的经验, 大模型的推理都是带宽瓶颈。
最后想额外提及的一点是: speculative sampling推理得到的结果和原来只使用target model
推理得到结果是完全一样的
,也就是说speculative sampling
算法是一种确定性的算法,而非对target model
结果的近似或者逼近。
2 Methods
2.1 core algorithm
target model
预测的token概率;draft model
预测的token概率。
draft model
以自回归的方式解码K个token。假设prompt context有N个token, 经过第一步后, 得到一个N+K个token的序列。(此时draft model
前向推理K次)- 对新的到的序列(N+K个token), 使用
target model
前向推理一遍, 可以得到target model
的K个预测概率 - 比较
draft model
和target model
的概率值, 来决定对draft model
的K个预测token的取舍。如果某个token被接受了, 那么继续比较下一个token; 否则从一个新的数据分布
进行采样,然后返回第一步 - 如果K个token都被
target model
接受了,那么target model
会预测下一个token, 然后继续返回第一步, 迭代进行
Q: 什么情况接受draft model
的结果,什么时候拒绝?
A: draft model
的结果一定被采纳,否则有
Q: draft model
给出的预测token被拒绝后,应该如何确定下一个token?
A: 按$(q(x)-p(x)){+}=\frac{max(0,q(x)-p(x))}{\Sigma{x}(0,q(x)-p(x))}$的概率进行采样。
2.2 code explanation
1 | def speculative_sampling(x, draft_model, target_model, N, K): |
2.3 proof
证明:按照2.1节的算法运作, 得到的结果和采用target model
的结果
假设:
- 事件
表示 draft model
预测的token被target model
采纳 - 事件
表示 draft model
预测的token被target model
拒绝
根据SpS算法,有:
- $P(x=x’|A_2) = (q(x’)-p(x’)){+}=\frac{max(0, q(x’)-p(x’))}{\Sigma{x}max(0,q(x)-p(x))}$
根据全概率公式, 有:
上述公式的前半部分很好计算:
由于
所以后半部分的公式有:
综合前后部分公式,可得:
Reference
- https://www.youtube.com/watch?v=E1YFY6Ag70s
- Speculative Sampling
- Probability Review
- 条件概率公式:
- 全概率公式:
- 贝叶斯公式:
Comments