vLLM中的Executor与Worker

executor文件夹下定义不同类型的执行器,如NeuronExecutor、CPUExecutor、RayGPUExecutor、GPUExecutor

class LLMEnginefrom_engine_args()中确定使用的执行器类名称。其中,parallel_config.world_size=pipeline_parallel_size*self.tensor_parallel_size,即当存在PP并行或者TP并行时world_size就大于1,就需要使用RayGPUExecutor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

之后通过engine = cls()方式创建engine,在init()中创建model_executor

1
2
3
4
5
6
7
8
9
10
11
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
)

接下来针对于常用的GPUExector和RayGPUExecutor来学习

ExecutorBase

基类中给出了执行器的定义:将模型在特定的硬件设备上执行,如CPU, GPU, Neuron或者是多设备

着重关注_init_executordetermine_num_available_blocksinitialize_cacheexecute_model的实现

GPUExecutor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def _init_executor(self) -> None:
"""Initialize the worker and load the model.

If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()

def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()

通常情况下是创建非特殊的worker。首先实例化driver_worker,然后初始化设备和加载模型。

1
2
3
4
5
6
7
8
9
10
11
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker

对于单GPU的情况,其rank就是0,同理distributed_init_method也是None。 worker_class_name=”Worker”指定类名。wrapper.init_worker是对wrapper中worker的初始化,其实例化worker为worker_class_name指定的类对象。

至于为什么要搞一个WorkerWrapperBase类包裹worker,vllm给的注释中说是为了懒惰地初始化……

由于Worker类属于另一个重要结构,后续介绍。

GPUExecutor类对、determine_num_available_blocksinitialize_cacheexecute_model的实现也是调用driver_worker中对应函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)

self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output

RayGPUExecutor

通常单卡无法做推理任务,此时就要使用并行方式。VLLM使用Ray来做分布式计算

Ray 是一个开源的统一框架,用于扩展机器学习等人工智能和 Python 应用程序。它为并行处理提供了计算层

首先,在from_engine_args()调用initialize_ray_cluster

27行的ray.init做对ray引擎的初始化,并默认将dashboard运行在8265端口。在placement_group不存在时,直接跳转else。num_gpus_in_cluster是当前集群中可用的GPU数,它不能少于world_size所需。

58、59创建placement group,每一个bundle包含一个GPU

A placement group reserves the resources from the cluster. 从集群中预定资源

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
):
"""Initialize the distributed cluster with Ray.

it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.

Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")

# Connect to a ray cluster. 初始化Ray引擎
if is_hip(): #针对于AMD显卡
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)

if parallel_config.placement_group:
# Placement group is already set.
return

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
# Create a new placement group
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)

# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group

接着就是在实例化RayGPUExecutor对象,它继承了DistributedGPUExecutor,DistributedGPUExecutor又继承了GPUExecutor。

在17行创建并行的GPUworker

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."

assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group

# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

在第6行的循环中,对placement_group中每个bundle,使用PlacementGroupSchedulingStrategy生成资源调度策略;ray.remote传入资源调度策略,生成远程对象,(RayWorkerWrapper).remote实例化远程对象(并未真正初始化)

25行获取work所在node的IP地址,与运行服务的主机IP对比,目的是设置一个driverwork。

这里需要强调两点,

一,self.workers和self.driver_dummy_worker都是RayWorkerWrapper,它继承自WorkerWrapperBase,wrapper意为包装器,正如注释所述,设置这样一个包裹在worker类之外的类是为了延迟初始化worker。因此需要区分初始化包装器和初始化包装器中真正worker。

二,在设置driver时,driver_dummy_worker被直接赋值为remote类型的worker(不过因为IP和主节点IP一致,其实还是本地worker),而driver_worker被设置成普通的RayWorkerWrapper,个人理解是:driver_worker被设置成普通的RayWorkerWrapper,其使用的GPU资源是本地node的;remote worker使用的资源也还是在本地node,因此二者没有实质上区别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
......
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)


# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

node_workers = defaultdict(list) # key:node节点 value:运行在该节点上的worker索引
node_gpus = defaultdict(list) # key:node节点 value:该节点中被使用的GPU编号

for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

......
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank), #获取当前 rank 在 node_workers[node_id] 列表中的索引,表示该 Worker 在本地节点的排名
rank=rank, #全局排名,即 Worker 在整个集群中的索引。
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

最后三行调用统一的执行函数的接口,传入要执行的函数名:init_worker、init_device、load_model。在_run_workers中,会先在ray workers中启动要执行的函数,然后在driver worker中执行函数,最后合并输出

总结

可以理解为三层,上层是Executor层,中层是Wrapper层,下层是Worker层。而对于init_device、load_model等具体的模型操作都是在Worker中控制

Worker

在上一部分中,我们按调用关系分析了两个我认为常用的Executor:GPUExecutor和RayGPUExecutor。他们对于模型的具体执行和调度是通过Worker类实现的。本节我们接着来分析Worker类。

在Executor的init中,对于Worker的操作包括:init_worker、init_device、load_model。

  • init_worker先实例化了一个ModelRunner的对象;接着Uninitialized cache engine,cache engine与KV cache管理相关;并维护gpu_cache管理KV cache的存储
  • init_device没有需要关注的点
  • load_model调用了model_runner.load_model()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
is_driver_worker: bool = False,
) -> None:
......
self.model_runner = ModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CacheEngine
self.gpu_cache: List[torch.Tensor]

对于world_size=1即使用1个GPU推理时,load_model没有值得注意的点。但是我比较好奇对于多卡的TP情况下,模型参数是如何均匀分配到不同的GPU中。这部分的代码在model_loader文件夹,看不太懂,先挖个坑吧

下一节来分析KV相关的初始化,代码起始于class LLMEngine->init()->self._initialize_kv_caches()