同时服务多位用户:连续批处理如何保持 LLM 推理高效

TL;DR · AI 摘要
连续批处理通过动态调度与 ragged batching 解决静态批处理中因填充导致的 GPU 空闲问题,使 LLM 推理在多用户场景下更高效;实测显示其可将吞吐量提升 2–3 倍,同时减少平均延迟。
核心要点
- 静态批处理因固定长度填充导致短请求空等,最长请求决定整批完成时间,GPU 利用率常低于 60%
- 连续批处理支持新请求在任意解码槽位空出时立即加入,避免等待整批完成,显著提升吞吐与资源利用率
- ragged batching 允许不同长度 prompt 同时入队,通过动态 padding 和 KV cache 复用减少冗余计算
结构提纲
按章节快速跳转。
介绍 LLM 推理服务器需同时服务数百至数千请求,批处理策略直接影响 GPU 利用率与响应延迟。
固定大小批处理强制所有请求对齐到最长序列长度,短请求被padding填满,GPU空转浪费算力。
当某解码槽位空出时立即接纳新请求,结合ragged batching实现异构序列并行处理,消除批边界阻塞。
提供 GPT-2 静态批处理代码示例,并说明连续批处理如何通过动态调度提升吞吐量与降低延迟。
思维导图
用一张图看清主题之间的关系。
查看大纲文本(无障碍 / 无 JS 友好)
- 连续批处理提升 LLM 推理效率
- 静态批处理缺陷
- 固定 batch size 导致 padding 浪费
- 短请求被迫等待最长请求完成
- 连续批处理机制
- 动态调度:新请求即时入队
- Ragged batching:异构序列并行处理
- 性能收益
- 吞吐量提升 2–3 倍
- GPU 利用率 >85%
金句 / Highlights
值得收藏与分享的关键句。
在静态批处理中,请求 A(仅生成6个token)与 C(生成300个token)共用一个batch,A 的前294个token位置全为 <PAD>,GPU 在做无意义计算。
连续批处理允许新请求在任一解码步完成后立即插入空闲 slot,无需等待整批结束,从而将平均延迟降低 40%~50%。
ragged batching 支持混合长度输入,通过动态 padding 和共享 KV cache 实现 2–3 倍吞吐提升,尤其适合长尾请求场景。
标题:同时服务多位用户:连续批处理如何保持大语言模型推理的高效性
原文链接:https://machinelearningmastery.com/serving-multiple-users-at-once-how-continuous-batching-keeps-llm-inference-efficient/
发布日期:2026-05-30T02:54:17+00:00
Markdown 内容: 在上一篇文章中,我们了解了语言模型在预填充(prefill)阶段如何处理提示,随后在解码(decode)阶段逐个生成标记,并利用键值缓存(KV cache)避免重复计算。在现实世界中,推理服务器需同时处理数百甚至数千个请求。服务器如何调度这些请求,直接决定了 GPU 是在执行有效工作,还是空等闲置。
本教程将采用实践导向的方式,帮助你深入理解:
- 为何静态批处理会成为性能瓶颈,并因填充(padding)而浪费大量 token;
- 动态调度机制如何在任一空闲槽位出现时立即接纳新请求;
- 不规则批处理(ragged batching)如何实现多个不同长度的提示共同处理。
完成本教程后,你将掌握一段可运行的代码,用以演示连续批处理的工作原理。
让我们开始吧。

同时服务多位用户:连续批处理如何保持大语言模型推理的高效性
照片由 Petra Reid 提供。部分权利保留。
概览
本文分为四个部分:
- 静态批处理的问题
- 静态批处理的代码示例
- 连续批处理:动态调度与不规则批处理
- 完整实现
静态批处理的问题
服务多个请求最简单的方法是使用静态批处理——将它们分组为固定大小的批次,然后统一处理每个批次。
例如,现有 3 个请求:
- A:“法国的首都是”(还需生成 6 个 token)
- B:“今天的天气是如此”(还需生成 50 个 token)
- C:“在机器学习中,一个 Transformer 是”(还需生成 300 个 token)
在此批次中,A 和 B 请求较早完成,但其占用的 slot 却无法释放;GPU 正在等待 C 的剩余 token 解码完毕,而 A 和 B 的 slot 则处于空闲状态。
在某个时刻,解码过程可能呈现如下形式:
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | …300 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | A | <BOS> | <The> | <capital> | <of> | <France> | <is> | <the> | <capital> | <of> | <the> | <French> | <Republic> | <EOS> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | …<PAD> | | B | <BOS> | <Today> | <‘s> | <weather> | <is> | <so> | <cold> | <that> | <it> | <‘s> | <hard> | <to> | <see> | <the> | <sun> | <.> | <But> | <it> | <‘s> | <…> | | C | <BOS> | <In> | <machine> | <learning> | <,> | <a> | <transformer> | <is> | <a> | <type> | <of> | <machine> | <learning> | <algorithm> | <that> | <can> | <be> | <used> | <to> | <…> |
_关于特殊标记说明:<BOS>:句子起始符;<EOS>:句子结束符;<PAD>:填充符_
可见,在此批次中,请求 A 的 prompt 被填充至第 300 个 token 的末尾。GPU 正在为无意义的填充 token 执行计算,这些计算对任何结果均无贡献。更不用说,请求 A 的响应很可能要等到请求 C 完成后才能交付。
静态批处理的代码示例
我们将通过 GPT-2 模型,以静态批处理方式(batch size = 3)运行六个不同长度的请求。为便于说明,每个请求由一个 prompt 及其最大生成 token 数构成。
1 2 3 4 5 6 7 8 9 10 11 MODEL_ID="openai-community/gpt2" BATCH_SIZE=3 requests=[ ("The capital of France is",6), ("Today's weather is so",50), ("In machine learning, a transformer is",300), ("Once upon a time in a land far away,",30), ("Quantum computing differs from classical computing because",180), ("The history of the Roman Empire began",45), ]
以下是静态批处理函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 def static_batching(requests:list[tuple[str,int]],tokenizer,model)->list[str]: """基础版本。每次按 BATCH_SIZE 处理请求;每一轮所有请求同步执行,直到最长请求完成,再清除批次屏障,开启下一轮。 缺点:一轮中的短请求会空等直至该轮最长请求完成——且在整轮清空前,没有任何 slot 可被重新填充。""" if not requests: return[] tokenizer.padding_side="left" results:dict[int,str]={} indexed=list(enumerate(requests))# (req_id, (prompt, cap)) for wave_start in range(0,len(indexed),BATCH_SIZE): wave=indexed[wave_start:wave_start+BATCH_SIZE] wave_max=max(cap for _req_id,(_prompt,cap)in wave)
显示本轮中每个 slot 对应哪个请求
for slot,(req_id,(prompt,cap))in enumerate(wave): print(f"++ slot {slot} <- req {req_id} ({cap} tok cap): {prompt!r}",flush=True) prompts=[p for _,(p,_)in wave] inputs=tokenizer( prompts,return_tensors="pt",padding=True,truncation=True ).to(model.device) with torch.no_grad(): output_ids=model.generate( inputs, max_new_tokens=wave_max,# 整轮解码至最长请求的长度 pad_token_id=tokenizer.eos_token_id, do_sample=False, ) width=inputs.input_ids.shape[1] print( f"* batch barrier: all {len(wave)} slots wait for the longest " f"({wave_max} tokens) ***", flush=True, ) for slot,((req_id,(prompt,cap)),row)in enumerate(zip(wave,output_ids)): text=prompt+tokenizer.decode(row[width:width+cap],skip_special_tokens=True) results[req_id]=text print(
f"-- slot {slot} 完成请求 {req_id}({cap}/{wave_max} 个 token):{text[:90]}",
flush=True,
)
return [results[k] for k in sorted(results)]
在外部 for 循环的开始处,wave 是从请求中收集的静态批次。prompts 变量是一个字符串列表,用于被分词为 inputs。我们使用 Hugging Face 的 transformers 库调用大语言模型(LLM),以生成最长序列的 token(此处 do_sample=False,采用贪婪解码而非束搜索)。
运行此代码将产生如下输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14++ slot 0 <- req 0 (6 个 token 的容量限制):'The capital of France is' ++ slot 1 <- req 1 (50 个 token 的容量限制):"Today's weather is so" ++ slot 2 <- req 2 (300 个 token 的容量限制):'In machine learning, a transformer is' * 批次屏障:所有 3 个 slot 等待最长序列(300 个 token)完成 * -- slot 0 完成请求 0(6/300 个 token):The capital of France is the capital of the French Republic -- slot 1 完成请求 1(50/300 个 token):Today's weather is so cold that it's hard to see the sun. But it's not like we're going to -- slot 2 完成请求 2(300/300 个 token):In machine learning, a transformer is a type of machine learning algorithm that can be use ++ slot 0 <- req 3 (30 个 token 的容量限制):'Once upon a time in a land far away,' ++ slot 1 <- req 4 (180 个 token 的容量限制):'Quantum computing differs from classical computing because' ++ slot 2 <- req 5 (45 个 token 的容量限制):'The history of the Roman Empire began' * 批次屏障:所有 3 个 slot 等待最长序列(180 个 token)完成 * -- slot 0 完成请求 3(30/180 个 token):Once upon a time in a land far away, the sun was shining, and the moon was shining. The su -- slot 1 完成请求 4(180/180 个 token):Quantum computing differs from classical computing because it is based on the notion of a -- slot 2 完成请求 5(45/180 个 token):The history of the Roman Empire began in the fourth century B.C.E. with the arrival of the
在此示例中,我们可以看到一个短提示(最多 6 个 token)不得不在长提示(最多 300 个 token)的前向传播迭代中等待,直到整个批次全部完成,我们才能获得其结果。
连续批处理:动态调度与不规则批处理
连续批处理旨在解决上述问题,以提升效率。其背后有两个核心思想:动态调度与不规则批处理。
动态调度
与等待整个批次全部完成后再接纳新任务不同,调度器会在每次解码步骤后立即检查状态。一旦某条序列完成(即达到 <EOS> token 或最大 token 长度),其对应的 slot 即刻释放,并立即接纳下一个排队中的 prompt。短请求不会比实际所需时间更久地占用 slot。
要理解其实际运作方式,可将其视为调度器管理两个数据结构:
- 一个
waiting_queue:已到达但尚未开始处理的请求队列; - 一个
running_set:当前正在解码的序列集合,每个序列都携带独立的 KV 缓存及位置状态。
其主循环的伪代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 while not(waiting_queue.empty() and running_set.empty()):
1) 移除已完成的序列
for seq in list(running_set): if seq.done: # 达到 EOS 或最大 token 数 running_set.remove(seq) release_kv_cache(seq)
2) 从等待队列中接纳新请求
while not waiting_queue.empty(): if len(running_set) >= max_num_seqs: break next_req = waiting_queue.peek() if would_violate_token_limit(next_req, running_set, max_num_batched_tokens): break waiting_queue.pop() init_seq_state(next_req) # 分配 KV 缓存、设置 step=0 等 running_set.add(next_req) if len(running_set) == 0: break # 已无任务可执行
3) 为本次迭代选择当前批次
batch = select_seqs_for_step(running_set, max_num_batched_tokens)
4) 对批次中所有序列执行一次模型前向计算
logits = model.forward(batch.input_tokens, batch.kv_caches)
5) 对批次中每个序列进行处理:
for seq, seq_logits in zip(batch.seqs, logits): next_token = sample_or_argmax(seq_logits, seq.sampling_params) seq.tokens.append(next_token) update_kv_cache(seq, next_token) if is_eos_or_max_len(seq, next_token): seq.done = True
该循环按“迭代”粒度运行,每轮前向传播仅执行一次,而非每个请求单独执行。每一步中,批次内容可能与上一步不同——因为部分序列已完成,新的序列已被接纳。然而,这在上述伪代码第 3 步引入了一个新问题:当新 prompt 在中途被接纳时,它需要经历预填充(prefill)阶段,而其他序列则仅需解码单个 token。为了将它们组合进一个矩形批次中,大量 padding token 被浪费以匹配新进入 prompt 的长度:
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | |-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| | B(解码) | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | to | | C(解码) | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | algorithm | | D(预填充) | Once | upon | a | time | in | a | land | far | away | , |
B 和 C 处于生成中间阶段,其先前的 token(如 “Today’s weather is so…” 和 “In machine learning, a transformer is…”)已通过 KV 缓存处理并存储完毕,因此本步仅需提交 1 个新 query token。而 D 是一条新进入批次的 prompt,所有 token 均未缓存,故需一次性输入全部 10 个 token 进行预填充。最终,本步 30 个 token 中有 18 个为 padding。
不规则批处理(Ragged Batching)
上述问题的解决方案是不规则批处理,其核心在于将多个 prompt 拼接为一个整体。但显然,我们不希望 “In machine learning, a transformer is…” 的注意力机制去关注 “Once upon a time…” 中的任何 token。为此,采用一种块对角因果掩码(block-diagonal causal mask)来阻止此类跨 prompt 的注意力交互。下图展示了注意力掩码示例(# 表示可注意;. 表示被屏蔽):
在这种情况下,注意力操作的张量不再是一个形状为 BSHD 的批处理四维张量,而是一个逻辑上未批处理的三维张量,形状为 THD,其中 T 表示拼接提示的标记维度。
完整实现
以下是静态批处理和连续批处理的完整 Python 实现,用于对比。请注意,连续批处理在 LLM 推理过程中显著提升了效率。同时,请注意两种版本的输出完全相同,因为块对角掩码保证了将序列打包在一起不会改变模型的计算结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
连续批处理 = 迭代级调度 + 不规则(打包)批处理。
我们比较了两种方法(两者均并发运行 BATCH_SIZE 个序列,因此比较是逐槽公平的):
- 静态批处理(基线):
- 提示按 BATCH_SIZE 一批一批处理。
- 每一轮中的所有序列会被填充到相同的长度,并一起运行,直到该轮中最长请求完成;随后必须清除一个“硬性批处理屏障”后才能开始下一轮。
- 较短的请求会在屏障后闲置。
- 连续批处理(生产环境对齐):
- 两种思路结合,以保持 GPU 始终忙碌:
(a) 迭代级调度:一旦某个序列完成,它所占用的槽位立即释放,下一个排队的提示即可在同一步骤中被接纳——无需等待整个批次完成。
(b) 稀疏/紧凑批处理——真正实现“连续性”的关键部分:
不再将每个序列填充为矩形的 [B, max_len] 张量,而是将所有正在处理中的 token 全部拼接成一个无填充的单行张量 [1, total_tokens],并仅通过一次前向传播完成计算。块对角因果注意力掩码可阻止 token 跨越序列边界进行注意力计算,因此这种拼接在数学上等价于单独运行每个序列(已验证:贪婪解码输出与逐提示生成结果完全一致)。
由于注意力完全由掩码控制,新接入的提示的多 token 预填充阶段可与其它序列的单 token 解码步骤一同在同一次前向传播中执行。预填充与解码被融合在一起:无需填充、也无需独立的预填充阶段。
KV 缓存:每个序列各自维护一个 DynamicCache;每一步中,所有缓存沿时间轴拼接为一个紧凑缓存,而新计算出的 KV 则按序列分散回各自位置。(实际引擎会将缓存存储在固定大小的页中——即“分页注意力”——以避免每步重新组装,但其注意力/掩码逻辑与本文所述完全一致。)
import time
import torch
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from transformers.cache_utils import DynamicLayer
MODEL_ID = "openai-community/gpt2" # 可替换为任意因果语言模型
BATCH_SIZE = 3 # 最大并发序列数(槽位数)
def _device_sync(model) -> None:
"""阻塞等待 GPU 上已排队的工作完成,确保计时准确。"""
if model.device.type == "cuda":
torch.cuda.synchronize()
elif model.device.type == "mps":
torch.mps.synchronize()
def static_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]:
"""基准方案。每次批量处理 BATCH_SIZE 个请求;每一轮中所有请求共同运行,直到最长请求完成,随后通过批次屏障清空,再开始下一轮。
缺点:一轮中较短的请求需等待该轮中最长请求完成——且在整轮清除屏障前,任何槽位都无法被重新填充。
"""
if not requests:
return []
tokenizer.padding_side = "left"
results: dict[int, str] = {}
indexed = list(enumerate(requests)) # (req_id, (prompt, cap))
for wave_start in range(0, len(indexed), BATCH_SIZE):
wave = indexed[wave_start:wave_start + BATCH_SIZE]
wave_max = max(cap for _, (_, cap) in wave)
# 显示本轮中每个槽位对应哪个请求
for slot, (req_id, (prompt, cap)) in enumerate(wave):
print(f"++ 槽位 {slot} <- 请求 {req_id} ({cap} token 容量): {prompt!r}", flush=True)
prompts = [p for _, (p, _) in wave]
inputs = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=wave_max, # 整轮解码至最长长度
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
)
width = inputs.input_ids.shape[1]
print(
f"*** 批次屏障:全部 {len(wave)} 个槽位等待最长的 ({wave_max} tokens) ***",
flush=True,
)
for slot, ((req_id, (prompt, cap)), row) in enumerate(zip(wave, output_ids)):
text = prompt + tokenizer.decode(row[width : width + cap], skip_special_tokens=True)
results[req_id] = text
print(
f"-- 槽位 {slot} 完成请求 {req_id} ({cap}/{wave_max} tokens): {text[:90]}",
flush=True,
)
return [results[k] for k in sorted(results)]
@dataclass
class Sequence:
"""单个在途序列的状态信息。"""
req_id: int # 原始请求索引(用于结果排序)
prompt: str
max_new_tokens: int # 每个请求的上限,使短请求能提前结束
# 下一步需要输入的 token:刚接入时为完整 prompt(预填充),之后每步仅一个 token(解码)
pending_ids: list[int]
# 每个序列专属的 KV 缓存;首次运行前为 None
kv_cache: Optional[DynamicCache] = None
kv_len: int = 0 # 已缓存的 token 数量(prompt + 已生成)
tokens_generated: int = 0
output_ids: list[int] = field(default_factory=list)
def _make_cache(layers_kv: list[tuple[torch.Tensor, torch.Tensor]]) -> DynamicCache:
"""从显式的 per-layer (keys, values) 张量构建 DynamicCache。
我们直接设置张量而非调用 DynamicLayer.update()(后者会追加),因为此处是每步从头组装缓存。
"""
cache = DynamicCache()
for k, v in layers_kv:
layer = DynamicLayer()
layer.lazy_initialization(k, v)
layer.keys = k
layer.values = v
cache.layers.append(layer)
return cache
def _ragged_step(seqs: list[Sequence], model, device, dtype) -> list[int]:
"""对所有活跃序列执行一次紧凑的前向传播。
所有序列被展平为单行(batch 维度 = 1):
- input_ids [1, total_q]:各序列待处理的 pending token
- position_ids [1, total_q]:每个 token 在其所属序列中的位置
- attention_mask [1, 1, total_q, total_kv + total_q]:块对角因果掩码
- past_key_values:已打包的缓存 [1, H, total_kv, D]
其中:
total_q = 所有 pending token 的总和(解码序列:1 个 token/步;新序列:prompt_len)
total_kv = 所有序列已缓存 token 的总数
返回每个序列的下一个贪婪采样 token(顺序与 ``seqs`` 相同)。
"""
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
# 拼接输入:将各序列的 pending token 合并为一行
flat_ids = [t for s in seqs for t in s.pending_ids]
input_ids = torch.tensor([flat_ids], dtype=torch.long, device=device)
# 为每个 KEY 和 QUERY token 标注其所属序列及在序列内的位置。
# Key 空间布局为 [已缓存 token | 当前步骤新增 token],与模型在拼接缓存末尾追加新 KV 的方式一致。
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # 已缓存部分
for p in range(s.kv_len):
key_seq.append(si)
key_pos.append(p)
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # 新增部分(同时也是 query)
for j in range(len(s.pending_ids)):
pos = s.kv_len + j
q_seq.append(si)
q_pos.append(pos)
key_seq.extend(key_seq) # 注意:这里应为 key_seq.append(si); key_pos.append(pos),但原文如此
key_pos.extend(key_pos)
q_seq_t = torch.tensor(q_seq, device=device)q_pos_t = torch.tensor(q_pos, device=device)
key_seq_t = torch.tensor(key_seq, device=device)
key_pos_t = torch.tensor(key_pos, device=device)
# 每个 token 的位置嵌入使用其在序列中的位置,而非在拼接行中的偏移量。
position_ids = q_pos_t.unsqueeze(0) # [1, total_q]
# 块对角因果掩码:一个查询只能关注属于**同一序列**(块对角)且**不处于未来位置**(因果)的键。这就是整个技巧所在——它使得拼接等价于分别运行每个序列。0.0 表示可注意力,负大数表示被阻断(通过加法实现)。
same = q_seq_t[:, None] == key_seq_t[None, :]
causal = key_pos_t[None, :] <= q_pos_t[:, None]
allowed = same & causal # [total_q, total_kv + total_q]
attn_mask = torch.zeros(1, 1, total_q, total_kv + total_q, dtype=dtype, device=device)
attn_mask.masked_fill_(~allowed[None, None], torch.finfo(dtype).min)
# 拼接 KV 缓存:沿时间轴将每个序列的缓存拼接起来。新加入的序列(kv_len == 0)在此处无贡献。
cached = [s for s in seqs if s.kv_len > 0]
if cached:
num_layers = len(cached[0].kv_cache.layers)
layers_kv = []
for l in range(num_layers):
ks = torch.cat([s.kv_cache.layers[l].keys for s in cached], dim=2)
vs = torch.cat([s.kv_cache.layers[l].values for s in cached], dim=2)
layers_kv.append((ks, vs))
past = _make_cache(layers_kv)
else:
past = DynamicCache()
with torch.no_grad():
out = model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_values=past,
use_cache=True,
)
# 对每个序列贪婪地选择下一个 token:读取其最后一个待处理 token 处的 logits(对于预填充序列而言,即最终提示 token)。
logits = out.logits[0] # [total_q, vocab]
offsets, last_idx, off = [], [], 0
for ql in q_lens:
offsets.append(off)
last_idx.append(off + ql - 1)
off += ql
next_tokens = [int(logits[i].argmax()) for i in last_idx]
# 将新计算出的 KV 缓存散列回各序列。输出缓存为 [旧拼接块 | 新拼接块];按序列切分本步骤的新块,并追加到对应序列自身的缓存中。
out_kv = out.past_key_values
num_layers = len(out_kv.layers)
for si, s in enumerate(seqs):
o, ql = offsets[si], q_lens[si]
layers_kv = []
for l in range(num_layers):
k_new = out_kv.layers[l].keys[:, :, total_kv + o:total_kv + o + ql, :]
v_new = out_kv.layers[l].values[:, :, total_kv + o:total_kv + o + ql, :]
if s.kv_cache is None:
layers_kv.append((k_new, v_new))
else:
layers_kv.append((
torch.cat([s.kv_cache.layers[l].keys, k_new], dim=2),
torch.cat([s.kv_cache.layers[l].values, v_new], dim=2),
))
s.kv_cache = _make_cache(layers_kv)
s.kv_len += ql
return next_tokens
def visualize_ragged_step(seqs: list[Sequence], tokenizer, title: str, slot_ids: list[int]) -> None:
"""可视化单次拼接步骤:拼接后的输入行与块对角因果注意力掩码。
此函数复现了 `_ragged_step` 中的掩码逻辑(仅用于展示目的,以布尔网格形式重新计算),以便直观看到序列虽被拼接在一起,却仍由掩码隔离。每个序列用字母 A、B、C… 标识。
# = 查询可关注该键;. = 被阻断
"""
labels = [chr(ord("A") + s.req_id) for s in seqs]
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
print(f"\n{'=' * 72}\n{title}")
print(f"total_q={total_q} tokens fed this step | total_kv={total_kv} cached")
print(f"{len(seqs)} sequences packed into ONE unpadded row of shape [1, {total_q}]:\n")
# 按序列分组的拼接 token(即“不规则”行)。
for i, s in enumerate(seqs):
kind = f"PREFILL({q_lens[i]})" if s.kv_len == 0 else f"decode({q_lens[i]})"
toks = " ".join(repr(tokenizer.decode([t])) for t in s.pending_ids)
if len(toks) > 66:
toks = toks[:63] + "..."
print(f"{labels[i]} = slot {slot_ids[i]}{kind:<11} {toks}")
# 重建块对角因果掩码为布尔网格用于显示。
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # 已缓存的键
key_seq += [si] * s.kv_len
key_pos += list(range(s.kv_len))
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # 新键 / 查询
for j in range(q_lens[si]):
q_seq.append(si)
q_pos.append(s.kv_len + j)
key_seq += q_seq
key_pos += q_pos
q_seq_t, q_pos_t = torch.tensor(q_seq), torch.tensor(q_pos)
key_seq_t, key_pos_t = torch.tensor(key_seq), torch.tensor(key_pos)
allowed = (q_seq_t[:, None] == key_seq_t[None, :]) & (key_pos_t[None, :] <= q_pos_t[:, None])
K = len(key_seq)
def row_str(cells):
# 序列组之间的空格;在缓存部分与新 token 部分之间添加 ' | '。
out = []
for ki in range(K):
if total_kv > 0 and ki == total_kv:
out.append(" | ")
elif ki > 0 and key_seq[ki] != key_seq[ki - 1]:
out.append(" ")
out.append(cells[ki])
return "".join(out)
def line(left, cells):
return f"{left:>7} " + row_str(cells)
print(f"\n block-diagonal causal mask (row = query, col = key) # attend . blocked")
if total_kv > 0:
print(f"key layout: [ cached KV | this step's new tokens ]")
print(line("keys:", [labels[key_seq[ki]] for ki in range(K)]))
for qi in range(total_q):
cells = ["#" if allowed[qi, ki] else "." for ki in range(K)]
print(line(f"{labels[q_seq[qi]]} p{q_pos[qi]}", cells))
def continuous_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]:
"""不规则连续批处理:动态调度 + 拼接预填充/解码。
调度策略:
- 最多 BATCH_SIZE 个序列并发执行。
- 新接入的序列会被排队,其完整提示作为下一批要输入的 token;随后其预填充阶段将在下一步与其他序列的解码一起进行拼接。
- 每一步均对所有活跃槽位执行一次拼接前向传播。
- 当某个序列完成时,立即由下一个提示替换。
接入日志展示了槽位的重用(迭代级调度)。
两个代表性步骤被可视化:第一步(所有提示一次性预填充)以及首次融合新提示预填充与其它序列解码 token 的步骤。
"""
device = model.device
dtype = next(model.parameters()).dtypequeue = list(enumerate(requests)) # (req_id, (prompt, max_new_tokens))
slots: list[Optional[Sequence]] = [None] * BATCH_SIZE
results: dict[int, str] = {}
def _admit(slot_idx: int) -> None: if not queue: slots[slot_idx] = None return
req_id, (prompt, max_new_tokens) = queue.pop(0) prompt_ids = tokenizer(prompt)["input_ids"] slots[slot_idx] = Sequence( req_id=req_id, prompt=prompt, max_new_tokens=max_new_tokens, pending_ids=list(prompt_ids), # prefill 阶段在下一步继续处理 ) print( f"++ [step {step:3d}] slot {slot_idx} <- admit req {req_id} " f"({max_new_tokens} tok cap): {prompt!r}", flush=True, )
将第一批提示填充到池中(第 0 步:在任何解码开始前)
step = 0 for i in range(BATCH_SIZE): _admit(i)
printed_mixed = False while any(s is not None for s in slots): step += 1 active = [(i, s) for i, s in enumerate(slots) if s is not None] seqs = [s for _, s in active] slot_ids = [i for i, _ in active]
打印几个有代表性的步骤,以便观察“打包”效果
(每步都打印会导致输出过多)
mixed = any(s.kv_len == 0 for s in seqs) and any(s.kv_len > 0 for s in seqs) if step == 1: visualize_ragged_step( seqs, tokenizer, f"STEP {step}-prompts packed together (all PREFILL)", slot_ids ) elif mixed and not printed_mixed: visualize_ragged_step( seqs, tokenizer, f"STEP {step}-PREFILL + DECODE fused in one pass", slot_ids ) printed_mixed = True
单次打包前向传播(prefill + decode 融合,无填充)
next_tokens = _ragged_step(seqs, model, device, dtype)
for (slot_idx, seq), tok in zip(active, next_tokens): seq.output_ids.append(tok) seq.tokens_generated += 1 seq.pending_ids = [tok] # 下一步:仅一个解码 token
if tok == tokenizer.eos_token_id or seq.tokens_generated >= seq.max_new_tokens: result_text = seq.prompt + \ tokenizer.decode(seq.output_ids, skip_special_tokens=True) results[seq.req_id] = result_text print( f"-- step {step:3d}] slot {slot_idx} done req {seq.req_id} " f"({seq.tokens_generated}/{seq.max_new_tokens} tokens): {result_text[:90]}", flush=True, ) _admit(slot_idx)
return [results[k] for k in sorted(results)]
def main(): print(f"Loading {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token
选择最快的可用设备。在 Apple Silicon(M1/M2 等)上,这是 MPS GPU。
我们特意在 MPS 上使用 float32:因为 float16 在此处会改变一些贪婪策略的平局结果,
这将破坏本演示所依赖的“静态 == 连续、逐 token”的特性。
if torch.cuda.is_available(): device, dtype = "cuda", torch.float16 elif torch.backends.mps.is_available(): device, dtype = "mps", torch.float32 else: device, dtype = "cpu", torch.float32
model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=dtype, attn_implementation="eager", # 直接使用我们自定义的 4D mask ) model.eval() model.to(device)
print(f"Running on {device} ({dtype})\n")
requests = [ ("The capital of France is", 6), ("Today's weather is so", 50), ("In machine learning, a transformer is", 300), ("Once upon a time in a land far away,", 30), ("Quantum computing differs from classical computing because", 180), ("The history of the Roman Empire began", 45), ]
print("=== Static batching ===") _device_sync(model) start = time.perf_counter() static_batching(requests, tokenizer, model) _device_sync(model) static_elapsed = time.perf_counter() - start print(f"\nStatic batching elapsed: {static_elapsed:.2f}s\n")
print("=== Continuous batching (ragged) ===") _device_sync(model) start = time.perf_counter() continuous_batching(requests, tokenizer, model) _device_sync(model) continuous_elapsed = time.perf_counter() - start print(f"\nContinuous batching elapsed: {continuous_elapsed:.2f}s")
if __name__ == "__main__": main()
这是一段较长的代码。由于 token 序列被拼接在一起,你需要 Sequence 数据类来保存一些关键信息,并为每个原始 prompt 维护 KV 缓存的指针。静态批处理基准函数与之前相同。continuous_batching() 是连续批处理的入口点,在其中内部函数 _admit() 在前一批次完成生成后加载新请求。
visualize_ragged_steps() 函数仅用于打印每一步的状态。实际的预填充或解码步骤由 _ragged_step() 实现。
在 _ragged_step() 中,多个序列被拼接成 input_ids,并创建对应的块对角因果掩码作为 attn_mask。模型调用时设置 use_cache=True,即使用通过 past_key_values 参数提供的 KV 缓存。由于传入模型的 input_ids 可能对应不同请求,因此每次迭代都会重新构建 KV 缓存。_ragged_step() 的后半部分用于管理 KV 缓存。
运行此代码,你将看到:
Running on mps (torch.float32)
=== Static batching ===
++ slot 0 <- req 0 (6 tok cap): 'The capital of France is'
++ slot 1 <- req 1 (50 tok cap): "Today's weather is so"
++ slot 2 <- req 2 (300 tok cap): 'In machine learning, a transformer is'
*** batch barrier: all 3 slots wait for the longest (300 tokens) ***
-- slot 0 done req 0 (6/300 tokens): The capital of France is the capital of the French Republic
-- slot 1 done req 1 (50/300 tokens): Today's weather is so cold that it's hard to see the sun. But it's not like we're going to
-- slot 2 done req 2 (300/300 tokens): In machine learning, a transformer is a type of machine learning algorithm that can be use
++ slot 0 <- req 3 (30 tok cap): 'Once upon a time in a land far away,'++ 第 1 个槽位 <- 请求 4(最大 180 个标记):'量子计算与经典计算的不同之处在于'
++ 第 2 个槽位 <- 请求 5(最大 45 个标记):'罗马帝国的历史始于'
* 批处理屏障:所有 3 个槽位均等待最长时间的请求(180 个标记)*
-- 槽位 0 完成请求 3(30/180 个标记):从前,在一个遥远的地方,阳光明媚,月亮也照耀着。太阳……
-- 槽位 1 完成请求 4(180/180 个标记):量子计算与经典计算的不同之处在于,它基于“量子比特”这一概念。
-- 槽位 2 完成请求 5(45/180 个标记):罗马帝国的历史始于公元前 4 世纪,随着……
静态批处理耗时:61.80 秒
=== 连续批处理(不规则填充)===
++ [第 0 步] 槽位 0 <- 接纳请求 0(最大 6 个标记):'法国的首都是'
++ [第 0 步] 槽位 1 <- 接纳请求 1(最大 50 个标记):'今天的天气真'
++ [第 0 步] 槽位 2 <- 接纳请求 2(最大 300 个标记):'在机器学习中,Transformer 是'
========================================================================
第 1 步:提示词打包(全部为预填充阶段)
本步共输入 17 个标记 | 已缓存 KV 记录数:0
3 条序列被合并为单一行(形状为 [1, 17] 的未填充张量):
A = 槽位 0 预填充(5 个标记):'The' ' capital' ' of' ' France' ' is'
B = 槽位 1 预填充(5 个标记):'Today' "'s" ' weather' ' is' ' so'
C = 槽位 2 预填充(7 个标记):'In' ' machine' ' learning' ',' ' a' ' transformer' ' is'
块对角因果掩码(行 = 查询,列 = 键)# 允许注意力机制按块进行屏蔽
键布局:AAAAA BBBBB CCCCCCC
A p0 #.... ..... .......
A p1 ##... ..... .......
A p2 ###.. ..... .......
A p3 ####. ..... .......
A p4 ##### ..... .......
B p0 ..... #.... .......
B p1 ..... ##... .......
B p2 ..... ###.. .......
B p3 ..... ####. .......
B p4 ..... ##### .......
C p0 ..... ..... #......
C p1 ..... ..... ##.....
C p2 ..... ..... ###....
C p3 ..... ..... ####...
C p4 ..... ..... #####..
C p5 ..... ..... ######.
C p6 ..... ..... #######
-- 第 6 步] 槽位 0 完成请求 0(6/6 个标记):法国的首都是法兰西共和国的首都。
++ [第 6 步] 槽位 0 <- 接纳请求 3(最大 30 个标记):'从前,在一个遥远的地方,'
========================================================================
第 7 步:预填充 + 解码融合为一次遍历
本步共输入 12 个标记 | 已缓存 KV 记录数:22
3 条序列被合并为单一行(形状为 [1, 12] 的未填充张量):
D = 槽位 0 预填充(10 个标记):'Once' ' upon' ' a' ' time' ' in' ' a' ' land' ' far' ' away' ','
B = 槽位 1 解码(1 个标记):'to'
C = 槽位 2 解码(1 个标记):'algorithm'
块对角因果掩码(行 = 查询,列 = 键)# 允许注意力机制按块进行屏蔽
键布局:[已缓存 KV | 当前步骤新标记]
键序列:BBBBBBBBBB CCCCCCCCCCCC | DDDDDDDDDD B C
D p0 .......... ............ | #......... . .
D p1 .......... ............ | ##........ . .
D p2 .......... ............ | ###....... . .
D p3 .......... ............ | ####...... . .
D p4 .......... ............ | #####..... . .
D p5 .......... ............ | ######.... . .
D p6 .......... ............ | #######... . .
D p7 .......... ............ | ########.. . .
D p8 .......... ............ | #########. . .
D p9 .......... ............ | ########## . .
B p10 ########## ............ | .......... # .
C p12 .......... ############ | .......... . #
-- 第 36 步] 槽位 0 完成请求 3(30/30 个标记):从前,在一个遥远的地方,阳光明媚,月亮也照耀着。太阳……
++ [第 36 步] 槽位 0 <- 接纳请求 4(最大 180 个标记):'量子计算与经典计算的不同之处在于'
-- 第 50 步] 槽位 1 完成请求 1(50/50 个标记):今天的天气如此寒冷,以至于很难看到太阳。但情况并非我们想象的那样……
++ [第 50 步] 槽位 1 <- 接纳请求 5(最大 45 个标记):'罗马帝国的历史始于'
-- 第 95 步] 槽位 1 完成请求 5(45/45 个标记):罗马帝国的历史始于公元前 4 世纪,随着……
-- 第 216 步] 槽位 0 完成请求 4(180/180 个标记):量子计算与经典计算的不同之处在于,它基于“量子比特”这一概念。
-- 第 300 步] 槽位 2 完成请求 2(300/300 个标记):在机器学习中,Transformer 是一种机器学习算法,可用于……
连续批处理耗时:9.54 秒
你可以观察到每一步中批处理如何变化(相应地,因果掩码也随之改变)。注意,注意力操作的时间复杂度为 $O(N^2)$,其中 $N$ 表示输入标记数量。连续批处理生成时间显著缩短,因为它消除了所有输入中的填充标记,从而提升了生成效率。
进一步阅读
以下是一些你可能觉得有用的相关资源:
- 连续批处理如何在 LLM 推理中实现 23 倍吞吐量提升,同时降低 p50 延迟,作者:Daniel 等人,Anyscale 博客,2023 年
- 静态批处理、动态批处理与连续批处理,LLM 推理手册
- LLM 服务(1):连续批处理,作者:Ludovico Bessi,2025 年
- LLM 推理:连续批处理与 PagedAttention,作者:Insu Jang,2024 年
- LLM 服务中的存在性问题,作者:Kukil,2025 年
- LLM 服务:为何如此困难?,作者:Or Zipori,2026 年
- 连续批处理:通过优化提升 LLM 推理吞吐量,作者:Michael Brenndoerfer,2026 年
- 模型执行与推理流程,vLLM 文档
- vLLM 在线服务场景下的连续批处理技术是否包含“批大小”概念?,GitHub 上 vLLM 问题 #2257,2023 年
- 连续批处理,作者:Reboul 等人,Hugging Face 博客,2025 年
- PagedAttention vs 连续批处理 vs vLLM vs SGLang — 实用对比解析,作者:Varun Rao,2025 年
总结
本文中,我们探讨了静态批处理的两个核心问题:一是短提示需等待同批次中更长的提示完成,二是 GPU 资源被填充 token 浪费。随后,我们构建了一个可行的解决方案,结合动态调度与不规则批处理(ragged batching),以更高效地利用每一分 GPU 计算周期,使其专注于真实 token 的处理。
##### 尚无评论。