Abstract
SpecInfer 是一个基于树的推测推理和验证的加速 LLM 服务的系统。利用小型投机模型来预测 LLM 输出;这些预测被组织成一个 token 树,每个节点代表一个候选 token 序列。使用一种新颖的基于树的并行解码机制,在LLM上并行验证由 token 树表示的所有候选 token 序列的正确性。SpecInfer使用LLM作为 token 树验证器,而不是增量解码器,大大减少了服务生成LLM的端到端延迟和计算需求,同时保持模型质量。
1. Introduction
LLM生成文本的方式是逐个 token 生成的。每个新生成的 token 都依赖于输入的 prompt和所有先前已生成的 token. 之前的工作使用一个模型参数量远小于LLM的小型投机模型 (Small Speculative Model, SSM)来预先生成一个 token 序列,然后让LLM并行的验证这个序列的正确性。这种方法只考虑了由单个SSM生成的一个 token 序列。由于SSM和LLM之间存在巨大的模型能力差距 (SSM通常比LLM小100-1000倍),SSM的预测序列很难与LLM的真实输出完全对齐,导致投机成功率不高。
SpecInfer 同时考虑多种投机候选,以最大化投机性能。这些候选标记被组织成一个标记树,每个节点代表一系列推测的标记。所有候选 token 序列的正确性都是针对LLM并行验证的,这使得SpecInfer可以在LLM解码步骤中显著增加生成 token 的数量。实现这个技术有两个难点。
- 巨大的搜索空间: 要最大化投机性能,SpecInfer必须在一个极其庞大的候选序列空间中进行探索。这主要是因为LLM的词汇表非常大,并且需要预测未来的多个 token. SpecInfer不是使用单个SSM进行基于序列的推测,而是通过同时考虑为给定输入提示以树结构组织的各种 token 序列来最大化推测性能。SpecInfer引入了一种基于扩展和合并的机制,分别通过利用单个SSM和多个SSM中的多样性来构建 token 树。
- 保证验证的正确性: 许多LLM应用采用随机解码 (stochastic decoding),即从一个概率分布中采样下一个 token ,而非总是选择概率最高的那个。SpecInfer必须确保其复杂的树状验证机制最终生成的 token,遵循与原始LLM完全相同的概率分布。为此,论文提出了新的多步投机采样 (multi-step speculative sampling)并结合基于树的并行解码机制,以最低的验证成本,最大化可被验证的 token 数量,同时保证了数学上的等价性。
2. SpecInfer’s Overview
SpecInfer的系统设计可以清晰地划分为两个主要部分:
学习型投机器 (Learning-based Speculator): 负责猜测,即生成一个包含多种可能性预测的 token 树。SSM 通常使用与目标LLM同系列但尺寸小得多的预训练模型,因为他们是在同一数据集上训练的。为了最大化投机树能够覆盖LLM真实输出的概率,SpecInfer提出了两种构建方法 (在Figure 2顶部展示):
- 基于扩展的构建 (Expansion-based): 利用单个SSM在某些步骤生成多个候选 token ,形成树的分支。
- 基于合并的构建 (Merge-based): 并行使用多个不同的SSM,然后将它们各自的预测结果合并成一棵更多样化的树。
token 树验证器 (Token Tree Verifier): 负责验证,即利用 LLM 高效地判断投机器给出的猜测哪些是正确的。在传统系统中,LLM是一个增量解码器,一次只预测下一个 token 。在SpecInfer中,LLM的角色转变为一个 token 树验证器,负责一次性地对整个投机树进行评估。为了验证树中某个特定 token 的正确性,系统会将其所有祖先节点构成的序列作为它的上下文。
SpecInfer相较于现有LLM服务系统实现了两大关键优势:
- 减少对LLM参数的内存访问: LLM推理的性能瓶颈主要在于从GPU内存中读取庞大的模型参数,而非计算本身 。在传统的增量解码中,每生成一个 token ,都需要完整地访问一次LLM的所有参数。只要投机树与LLM的实际输出存在重叠 (即猜对了至少一个 token ),SpecInfer就能显著减少访问LLM参数的次数。
- 降低端到端推理延迟: LLM服务面临着很长的端到端延迟 。这主要是因为增量解码中, token 的生成是串行依赖的,系统必须等待前一个 token 生成完毕才能开始计算下一个。Spec 将 LLM作为验证器,一次性接收整个投机树作为输入,并能同时检查树中的所有 token 。
3. Learning-based Speculator
现有的投机解码方法只预测一个单一的 token 序列。这种方法的成功概率会随着期望对齐的序列长度呈指数级衰减。SpecInfer的目标是通过在每一步提供更多样化的候选 token 来提升匹配成功的概率 。它将这些由一个或多个SSM预测出的多样化 token 组织成一个 token 树结构,以最大化投机性能,同时保持较低的内存和延迟开销。
定义 3.1 (Token Tree): token 树 N 是一个树形结构,其中每个节点 $u\in N$ 都被标记为一个 token $t_u$. 对于每个节点 u,它所代表的 token 序列 $S_u$ 由其父节点所代表的序列 $S_{p_u}$ 与其自身的 token ${t_u}$ 拼接而成。
Expansion-based token tree construction. 该方法基于一个非常重要的观察: 当SSM的预测与LLM不一致时 (即它们的top-1选择不同),LLM实际选择的 token 通常也存在于SSM预测的 top-k token 列表中,且k值通常很小。
如果在每一步都进行top-k扩展,会导致序列数量呈指数级爆炸,带来巨大的延迟和内存开销。因此,SpecInfer采用了一种静态策略,即遵循一个预设的扩展配置。这个配置是一个整数向量 $
Merge-based token tree construction. 该方法利用多个SSM来共同预测LLM的输出,以获得更强的多样性。借鉴了机器学习中的集成学习思想。通过组合多个弱分类器 (SSMs),来形成一个更强大的强分类器,使其联合输出能更好地逼近LLM. 在一个通用文本语料库上,首先用LLM为每个prompt生成标准答案。然后,逐个训练SSM。第一个 SSM 被充分微调后,标记出所有它能猜对LLM输出的样本。接着,用剩下的、第一个 SSM 猜错的样本来训练第二个SSM,以此类推。通过这个过程得到了一组多样化的SSM,每个SSM可能擅长处理不同类型的文本模式,它们的聚合输出能够最大程度地覆盖LLM的输出。
当使用多个SSM时,每个SSM都会生成自己的 token 树 (或序列),SpecInfer会执行 token 树合并 (Token Tree Merge) 操作,将所有这些树聚合成一个单一的、更大的树结构。
定义 3.2 (Token Tree Merge): 合并多个 token 树会产生一棵新树,这棵新树包含了原始所有树中的全部 token 序列。
4. Token Tree Verifier
通过一次对 LLM 参数的访问,并行地验证 token 树中的所有候选序列 。为了实现这一目标,作者设计了一整套精巧的机制,可以分解为以下三个关键步骤:
- 树注意力 (Tree Attention): 将传统作用于线性序列的注意力机制,从概念上推广到树形结构。
- 基于树的并行解码 (Tree-based Parallel Decoding): 实现“树注意力”的高效计算方法,这是本章的技术核心。
- token 验证 (Token Verification): 在计算完成后,根据贪心或随机解码规则,最终确定哪些投机 token 被接受。
定义 4.1 (Tree Attention): 对于 token 树 N 中的任意一个节点 u,它的树注意力被定义为: 对该节点所代表的 token 序列 $S_u$ (即从树根到节点u的路径)进行一次标准的 Transformer 序列注意力计算。
如果对树中的每条路径都独立计算一次,效率会非常低下。在树状结构中,不同的分支在同一深度 (位置) 上会有不同的 token ,这意味着它们的 KV 缓存是相互冲突的。SpecInfer 引入了两项关键技术来解决上述问题,如 Figure 4 右侧所示:
- 深度优先搜索 (Depth-first search) 来更新共享的 KV 缓存: SpecInfer 并不为每个序列创建独立的缓存,而是让所有序列共享同一个 KV 缓存区。系统按照深度优先的顺序遍历 token 树,并依次更新这个共享的缓存。这种遍历方式保证了在计算任何一个节点的注意力时,其所有祖先节点的 KV 值都已经被正确计算并存储在缓存中。
- 拓扑感知的因果掩码 (Topology-aware causal mask): 为了避免为树中每个节点单独启动计算核,SpecInfer 将树中的所有 token 打包成一个批次,送入一个单一的、融合的计算核中进行处理。SpecInfer 对标准的因果掩码进行了巧妙的修改,使其能够感知树的拓扑结构。这个特殊的掩码会精确地屏蔽掉所有不具有祖先-后代关系的 token 对之间的注意力计算。
在通过并行解码计算出树中每个节点的 LLM 输出 (即每个位置最可能的下一个 token )后,最后一步就是根据具体的解码策略来决定接受哪些投机 token 。SpecInfer 同时支持贪心解码和随机解码
- 贪心解码 (Greedy decoding): 系统从 token 树的根节点开始,沿着一条路径向下遍历 。在每个节点 u,它会检查 u 的子节点中,是否有某个子节点 v 的 token $t_v$ 与 LLM 在 u 位置预测出的最优 token O(u) 相匹配。如果匹配成功,就将 $t_v$ 视为已验证的 token ,并继续从节点 v 向下验证。如果所有子节点都不匹配,则验证终止,系统接受 LLM 预测的 token O(u) 作为最后一个正确的 token ,并将这条已验证的路径附加到最终的生成结果中。
- 随机解码 (Stochastic decoding): 在投机式推理的框架下,随机解码面临一个严峻的挑战: 如何确保经过复杂的“猜测-验证”流程后,最终输出的文本仍然严格遵循原始 LLM 的概率分布?作者提出的 多步投机采样 (MSS) 流程如下:
假设我们当前已经验证到了 token 树的节点 u,需要决定下一个 token 是什么。u 的子节点集合为 H,代表了所有投机的候选。
- 进入验证循环: 只要候选集合 H 不为空,就持续进行尝试。
- 随机挑选候选: 从 H 中随机挑选一个候选子节点 s,其对应的 token 为 $x_s$.
- 进行概率接受测试: 系统会以一定的概率接受 token $x_s$. 这个接受概率 $P_accept$ 是基于 LLM 和 SSM 对该 token 的预测概率之比计算
- 根据测试结果进行分支处理:
- 如果接受: 将 token $x_s$ 添加到已验证序列 V 中。将当前节点更新为 $s (u = s)$,然后跳出当前节点的验证循环,回到步骤1,对新的 u 的子节点进行验证。
- 如果拒绝: 系统会从 LLM 的原始概率分布中减去刚刚被拒绝的 SSM 候选所对应的概率质量,然后对剩余的分布进行归一化,形成一个新的、修正后的残差分布 (Normalized residual distribution).
- 所有候选均被拒绝: 系统会从最终修正后的 LLM 概率分布中进行一次标准的采样加入到已验证序列中,结束整个过程。
5. System Design and Implementation
请求管理器 (Request Manager) 通常在 CPU 上运行,主要负责以下功能:
- 请求调度 (Request Scheduling): 接收用户的 LLM 服务请求,并决定处理哪些请求。
- token 树合并 (Token Tree Merge): 从多个 SSM 收集生成的 token ,并将它们合并成每条请求对应的投机树。
- token 树验证 (Token Tree Verification): 从 LLM 处获取验证结果,并执行最终的验证逻辑 (贪心或 MSS 算法).
工作流程如下:
- 请求管理器将待处理的请求分发给多个并行的 SSM 实例。
- 使用数据并行的方式来服务 SSM. 因为 SSM 模型较小,可以在单张 GPU 上运行,因此系统会将不同的请求分配到不同的 GPU 上,让多个 SSM 并行地生成候选 token .
- SSM 生成的候选 token 被送回请求管理器,管理器为每条请求构建一棵投机树,然后将这些树分发给 LLM 进行验证。
- 使用混合并行策略,所有参与 LLM 计算的 GPU 都会执行基于树的并行解码来计算树注意力分数。
- LLM 生成的验证结果被送回请求管理器,完成最后的 token 接受/拒绝判断 。
请求管理器和 GPU 工作节点之间只传输轻量级的 token ID,而非高维度的向量表示,因此通信开销可以忽略不计。SpecInfer 采用了 Orca 系统中提出的连续批处理技术。
SpecInfer 在 FlexFlow 之上构建。为了避免直接调用 cuBLAS 和 cuDNN 库计算注意力时产生的高昂的核函数启动开销,SpecInfer 采用了基于 FasterTransformer 的定制化 CUDA 核函数。每个 GPU 线程块负责计算单个请求的单个注意力头。通过利用 GPU 的共享内存来缓存 query tensor 并进行中间结果的广播。
论文提出了一个关键的观点,即传统的增量解码模式在很多时候未能充分利用 GPU 的计算资源,导致 GPU 处于半闲置状态。SpecInfer 利用了这些闲置的计算资源 (under-utilized resources) 来执行树的并行验证,因此,尽管总计算量增加了,但并不会显著增加每次迭代的实际运行时间,引入的运行时开销可以忽略不计。
最后,论文指出了 SpecInfer 的技术能够显著受益的两个实际应用场景。
- 分布式 LLM 推理 (Distributed LLM inference): SpecInfer 通过一次性验证多个 token ,增加了通信的粒度,减少了通信的频率。虽然不能减少单次通信的数据量,但通过显著减少总的解码步骤 (和通信次数),从而降低了端到端延迟 。
- 基于卸载的 LLM 推理 (Offloading-based LLM inference): SpecInfer 的优势在这里被进一步放大。通过机会性地一次验证多个 token ,它直接减少了 LLM 解码的总步数,从而显著减少了参数在 CPU 和 GPU 之间的传输次数。