基于llam2学习推理过程
基于llam2学习推理过程
在上一节学习了Transformer的搭建过程,但是对于实际推理的流程,比如参数如何加载?batch中的多个sequence如何并行地推理还不甚清楚。由于vLLM内容繁杂,故使用llama提供的generator来做推理过程的学习,顺便再重温一下llama2的架构。
分布式并行策略
- 数据并行(DP):在数据并行训练中,数据集被分割成几个碎片,每个碎片被分配到一个设备上。每个设备将持有一个完整的模型副
- 模型并行(MP):包括流水线并行(PP)和张量并行(TP)。流水线为层间并行,对模型不同的 Transformer 层间进行分割,张量并行为层内并行,对模型参数层内进行分割
推理流程
在example_chat_completion.py
的main
函数中,dialog
为要推理的句子。generator
的初始化调用build
函数
1 | tokenizer = Tokenizer(model_path=tokenizer_path) |
使用load_state_dict
来加载模型参数,算是解决了上边的疑惑之一。model
是Transformer
对象,其实就是Llama2结构。tokenizer
的作用是将句子转换成不同token对应的数字。
然后就是generate
函数开始推理。
首先是在文本长度上的处理,输入的句子长度不一致,并且推理结束时的generate长度也不一致
1 | min_prompt_len = min(len(t) for t in prompt_tokens) |
params.max_seq_len
是prompt+generate的总长度的最大值,tokens
首先按total_len
的长度初始化成全是pad_id,假设 bsz=2, total_len=6,则 tokens 为 [[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]
。
然后使用prompt_tokens对其赋值。
1 | for k, t in enumerate(prompt_tokens): |
假设 prompt_tokens 为 [[22172, 11148, 3304], [[1058, 526, 366, 29973]]
,则 tokens 变成[[22172, 11148, 3304, -1, -1, -1], [1058, 526, 366, 29973, -1, -1]]
。
注意:这里的填充方式是右填充(也有左填充,区别见https://zhuanlan.zhihu.com/p/675273498)
其实这才是正式推理,它从min_promt_len
开始,继续使用上边示例,min_promt_len
=3
1 | prev_pos = 0 #上次推理的位置,初始肯定是0啦 |
for循环第一轮迭代,cur_pos = 3
,prev_pos = 0
,则把只把前三列[[22172, 11148, 3304], [[1058, 526, 366]]传给 model.forward 方法,返回值logits
的shape是(batch_size, 3, vocab_size)
在使用softmax
或argmax
时只取最后一列
假设该方法返回的 logits 中第一个句子的next token 为 29892,第二个句子的 next token 为 115,此时只更新第一个句子的第四个token,因为第二个句子的第四个token已经存在。经过第一次迭代后,tokens 变成了 [[22172, 11148, 3304, 29892, -1, -1], [1058, 526, 366, 29973, -1, -1]]
。
在此有个疑问,如果batch_size很大并且一个batch中文本长度差异很大,那么prompt_len大的句子要经过很久才开始真正的推理,但是每次都把它参与运算了,这会存在较为严重的浪费算力的情况?
上边的想法应该是错误的!!虽然很长的句子很久开始真正推理,但每次将它纳入批次进行forward生成的KV cache是有用的!不能算作算力浪费
for循环第二轮迭代时,prev_pos = 3
,cur_pos = 4
,则只把第四列(pos=3)[[29892], [29973]]传给model.forward。
在class Transformer.forward()
中,着重关注一下mask
1 | mask = None |
在prefill阶段即for循环第一轮迭代时,mask是一个大小(3, 3)的上三角矩阵,长这个样子
1 | ( [0, -inf, -inf], |
而decode阶段的seq_len=1,只计算当前输入与前边token的Attention-Score,因此不需要mask