In the previous article, I explored why vLLM is gaining popularity and the process of setting up an OpenAI-compatible server when using vllm serve. While the first article focused on the architectural foundations and server initialization process, in this article, I want to dive deeper into the runtime behavior and request processing pipeline.
The /v1/chat/completions endpoint has become the de facto standard for conversational AI applications, powering everything from customer service chatbots to sophisticated AI assistants. Unlike the legacy /v1/completions endpoint, which operates on simple text completion, the chat completions endpoint provides structured message handling, role-based conversations, and built-in context management.
Through this deep dive, I’ll walk you through:
Endpoint Comparison: Detailed comparison between /v1/completions and /v1/chat/completions
Request Processing: Step-by-step breakdown of how chat messages are preprocessed and transformed
Chat Template System: How vLLM applies model-specific chat templates to structure conversations
Internal Pipeline: Deep dive into the inference process, from message parsing to response generation
Performance Considerations: Understanding token efficiency and memory management in chat contexts
By examining vLLM’s implementation of the OpenAI-compatible chat completions endpoint, I’ll uncover the sophisticated engineering that enables high-performance conversational AI serving while maintaining full API compatibility.
Theoretical Background
/v1/completions vs. /v1/chat/completions
As seen in the previous article, the OpenAI compatible server provides two endpoints as shown below.
1 2 3 4 5
$ vllm serve Qwen/Qwen3-0.6B --max-model-len 8192 ... INFO 06-09 23:16:17 [launcher.py:36] Route: /v1/chat/completions, Methods: POST INFO 06-09 23:16:17 [launcher.py:36] Route: /v1/completions, Methods: POST ...
Let me walk you through the differences between these two endpoints.
{ "id":"cmpl-bc9fa340e282468eb41d47ea9db57bfd", "object":"text_completion", "created":1750076839, "model":"Qwen/Qwen3-0.6B", "choices":[ { "index":0, "text":" My name is Alex. I am a software engineer with a passion for coding and", "logprobs":null, "finish_reason":"length", "stop_reason":null, "prompt_logprobs":null } ], "usage":{ "prompt_tokens":4, "total_tokens":20, "completion_tokens":16, "prompt_tokens_details":null }, "kv_transfer_params":null }
As a result, it responds with an extended sentence based on the input "prompt", rather than a chat-style response.
In contrast, /v1/chat/completions, as shown in the server log above, applies a chat template according to the user’s input format and feeds that value to the LLM.
{ "id":"chatcmpl-dab79c6ebcb24ff58b4e032f6f83b888", "object":"chat.completion", "created":1750076956, "model":"Qwen/Qwen3-0.6B", "choices":[ { "index":0, "message":{ "role":"assistant", "reasoning_content":null, "content":"<think>\nOkay, the user said \"Hello, World!\" and I need to respond. First, I should acknowledge their message. Since it's a simple greeting, a straightforward response is best. I can say \"Hello, World!\" as well, but maybe add a friendly note to keep it engaging. Let me check if there's any context I'm missing, but the message is pretty basic. Just a greeting. Alright, I'll respond with a friendly message to reinforce the exchange.\n</think>\n\nHello, World! 😊 What's interesting about you?", "tool_calls":[] }, "logprobs":null, "finish_reason":"stop", "stop_reason":null } ], "usage":{ "prompt_tokens":12, "total_tokens":125, "completion_tokens":113, "prompt_tokens_details":null }, "prompt_logprobs":null, "kv_transfer_params":null }
As a result, the response appears in chat format. The chat template applied in the above result uses the chat_template in tokenizer_config.json by default, unless a separate --chat-template option is specified.
Qwen/Qwen3-0.6B/tokenizer_config.jsonlink
1 2 3
... "chat_template":"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}", ...
Chat template testing can be performed as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
>>> import transformers >>> tokenizer=transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages = [ ... { "role": "system", "content": "You are a helpful assistant." }, ... { "role": "user", "content": "What is the capital of France?" }, ... { "role": "assistant", "content": "The capital of France is Paris." }, ... { "role": "user", "content": "Tell me more about it." } ... ] >>> print(tokenizer.apply_chat_template(messages, tokenize=False)) <|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user What is the capital of France?<|im_end|> <|im_start|>assistant The capital of France is Paris.<|im_end|> <|im_start|>user Tell me more about it.<|im_end|>
Request/Response Schema of /v1/chat/completions
Now that I understand the fundamental differences between the endpoints, let me examine the detailed structure of the /v1/chat/completions request and response schemas. Understanding these schemas is crucial for effective API integration and troubleshooting, as they define the contract between client applications and vLLM’s serving infrastructure.
My analysis here is based on vLLM’s source code implementation, providing insights into both OpenAI-compatible fields and vLLM-specific extensions that enhance functionality beyond the standard API specification.
Request Schema
The ChatCompletionRequest class in vLLM implements the complete OpenAI Chat Completions API specification while adding several vLLM-specific extensions for advanced sampling and optimization features.
The schema is carefully organized to match the official OpenAI API documentation order, ensuring maximum compatibility with existing OpenAI client libraries and tools.
Model name to use (vllm-project/vllm#13568 made optional)
frequency_penalty
Optional[float]
❌
0.0
Frequency-based token penalty (-2.0 ~ 2.0)
logit_bias
Optional[dict[str, float]]
❌
None
Bias for specific tokens’ logits
logprobs
Optional[bool]
❌
False
Whether to return log probabilities
top_logprobs
Optional[int]
❌
0
Number of top log probabilities to return (0-20)
max_tokens
Optional[int]
❌
None
Maximum number of tokens to generate
n
Optional[int]
❌
1
Number of completions to generate
presence_penalty
Optional[float]
❌
0.0
Presence-based token penalty (-2.0 ~ 2.0)
response_format
Optional[AnyResponseFormat]
❌
None
Response format specification (JSON mode)
seed
Optional[int]
❌
None
Seed for reproducible output
stop
Optional[Union[str, list[str]]]
❌
[]
Stop strings for generation
stream
Optional[bool]
❌
False
Whether to stream responses
temperature
Optional[float]
❌
None
Sampling temperature (0.0 ~ 2.0)
top_p
Optional[float]
❌
None
Nucleus sampling probability
tools
Optional[list[ChatCompletionToolsParam]]
❌
None
Function call tool definitions
tool_choice
Optional[Union[Literal, NamedToolChoice]]
❌
"none"
Tool selection strategy
user
Optional[str]
❌
None
User identifier
best_of
Optional[int]
❌
None
Number of generations to select best from
use_beam_search
bool
❌
False
Whether to use beam search
top_k
Optional[int]
❌
None
Consider only top k tokens
min_p
Optional[float]
❌
None
Minimum probability threshold
repetition_penalty
Optional[float]
❌
None
Repetition penalty
min_tokens
int
❌
0
Minimum number of tokens to generate
skip_special_tokens
bool
❌
True
Whether to skip special tokens in output
spaces_between_special_tokens
bool
❌
True
Whether to add spaces between special tokens
truncate_prompt_tokens
Optional[int]
❌
None
Truncate prompt to specified token count
prompt_logprobs
Optional[int]
❌
None
Number of prompt log probabilities to return
Message Object
The message object structure supports both simple text conversations and complex multimodal interactions. vLLM extends the standard OpenAI message format to support custom roles and enhanced tool integration.
... classCustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" role: Required[str] """The role of the message's author."""
content: Union[str, list[ChatCompletionContentPartParam]] """The contents of the message."""
name: str """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """
tool_call_id: Optional[str] """Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] """The tool calls generated by the model, such as function calls."""
The response schema follows the OpenAI specification closely while incorporating vLLM-specific enhancements for advanced use cases like KV caching optimization and detailed logging.
Object type (chat.completion or chat.completion.chunk)
created
int
Creation time represented as Unix timestamp
model
str
Model name used
choices
list[ChatCompletionResponseChoice]
Array of generated completion choices
usage
UsageInfo
Token usage information
prompt_logprobs
Optional[list[Optional[dict[int, Logprob]]]]
Prompt log probability information
kv_transfer_params
Optional[dict[str, Any]]
KVTransfer parameters
Choice Object
Each choice represents a single completion generated by the model. The choice object contains the actual generated content along with metadata about the generation process.
vllm/entrypoints/openai/protocol.pylink
1 2 3 4 5 6 7 8 9 10
... classChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None # per OpenAI spec this is the default finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None ...
vLLM legacy field (outside OpenAI spec, provides similar info to finish_reason)
Usage Object
The usage object provides detailed token consumption metrics, essential for billing, monitoring, and optimization purposes.
vllm/entrypoints/openai/protocol.pylink
1 2 3 4 5
classUsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Field
Type
Description
prompt_tokens
int
Number of tokens used in prompt
total_tokens
int
Total tokens (prompt + completion)
completion_tokens
Optional[int]
Number of tokens generated in completion
prompt_tokens_details
Optional[PromptTokenUsageInfo]
Detailed prompt token usage information
Router
vLLM’s OpenAI-compatible server is built on FastAPI, providing a robust and high-performance web framework for serving LLM requests. When a user sends a POST request to /v1/chat/completions, FastAPI’s routing system directs the request to the following function, which serves as the main entry point for chat completion requests.
I can see that the handler is defined through the chat() function. This function retrieves the openai_serving_chat instance that was registered in the app.state during server initialization, as shown below.
The Request object is a class included in the Starlette framework, and it inherits the app property from its parent class HTTPConnection. This design provides access to the application state and configuration throughout the request lifecycle.
starlette/requests.pylink
1 2 3
... classRequest(HTTPConnection): ...
The app property provides access to the FastAPI application instance, while scope contains ASGI (Asynchronous Server Gateway Interface) information about the current request. This architecture follows the ASGI specification, enabling efficient handling of asynchronous web requests.
starlette/requests.pylink
1 2 3 4 5 6 7 8 9 10 11
... classHTTPConnection(Mapping[str, Any]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ ... @property defapp(self) -> Any: returnself.scope["app"] ...
Application State Initialization
Looking at the initialization of state.openai_serving_chat, it occurs in the init_app_state() function as follows. This initialization happens during server startup, ensuring that all necessary components are ready before handling incoming requests.
The app.state mechanism can be tested with the following example. This demonstrates how FastAPI’s application state works in practice and how components are shared across request handlers.
from random import random from typing importOptional
import uvicorn import uvloop from fastapi import FastAPI, Request from fastapi.datastructures import State from loguru import logger from pydantic import BaseModel
Examining the server logs reveals the initialization sequence: the OpenAIServingChat instance is initialized before FastAPI starts running. When a request arrives, the handler is retrieved from request.app.state.openai_serving_chat and executed.
This pattern demonstrates FastAPI’s application lifecycle management, where:
Initialization Phase: Critical components are set up during server startup
Request Phase: Pre-initialized components are accessed through the application state
Processing Phase: The actual request handling occurs with the retrieved handler
1 2 3 4 5 6 7 8
2025-06-16 23:38:46.972 | INFO | __main__:__init__:16 - Init: OpenAIServingChat INFO: Started server process [52024] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) 2025-06-16 23:38:49.021 | INFO | __main__:create_chat_completion:38 - raw_request=<starlette.requests.Request object at 0x105a80a50> 2025-06-16 23:38:49.021 | INFO | __main__:create_chat_completion:19 - Run: OpenAIServingChat.create_chat_completion INFO: 127.0.0.1:61279 - "POST /v1/chat/completions HTTP/1.1" 200 OK
vllm/entrypoints/openai/serving_chat.pylink
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ Chat Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. This API mimics the OpenAI Chat Completion API. """ ...
Chat Completion Processing Pipeline
As I observed in the router’s create_chat_completion() function above, all preprocessing, LLM inference, and postprocessing for /v1/chat/completions requests are performed within the following method.
vllm/entrypoints/openai/serving_chat.pylink
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ Chat Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. This API mimics the OpenAI Chat Completion API. """ ...
How does the complete processing flow work? Let’s examine the step-by-step process:
vllm/entrypoints/openai/serving_chat.pylink
1 2 3 4 5 6 7 8 9 10
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( ... error_check_ret = await self._check_model(request) if error_check_ret isnotNone: logger.error("Error with model %s", error_check_ret) return error_check_ret ...
Model Validation: The OpenAIServing._check_model() method validates that the request’s "model" name is correctly configured.
vllm/entrypoints/openai/serving_chat.pylink
1 2 3 4 5 6 7 8 9 10
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( ... error_check_ret = await self._check_model(request) if error_check_ret isnotNone: logger.error("Error with model %s", error_check_ret) return error_check_ret ...
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( ... # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error ...
Mistral Tokenizer Handling: For v0.9.0.1, there are Pydantic-related issues with MistralTokenizer (vllm-project/vllm#9951, pydantic/pydantic#9467, pydantic/pydantic#9541), which require special handling as shown below.
vllm/entrypoints/openai/serving_chat.pylink
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
... classOpenAIServingChat(OpenAIServing): ... asyncdefcreate_chat_completion( ... try: ... ifisinstance(tokenizer, MistralTokenizer): # because of issues with pydantic we need to potentially # re-serialize the tool_calls field of the request # for more info: see comment in `maybe_serialize_tool_calls` maybe_serialize_tool_calls(request) truncate_tool_call_ids(request) validate_request_params(request) ...
Tool Configuration: When the request’s tool_choice is "auto", it undergoes validation and generates tool_dicts.
Content Format and Conversation Setup: Prepares resolved_content_format (determines the content format for chat templates based on tools and model configuration), conversation (parsed conversation messages with multimodal data handling), and mm_data_future (future object for asynchronous multimodal data processing), then updates the chat_template_kwargs (user-specified chat template settings) into _chat_template_kwargs (internal chat template configuration dictionary).
Obtain the request_prompt based on the tokenizer type: for models using MistralTokenizer, the apply_mistral_chat_template() function is used, while for other models, the apply_hf_chat_template() function is used to generate the request_prompt.
Process tool parsing if enabled: When a tool parser is configured and the tool choice is not "none", the system determines whether tool parsing should be performed. If tools are being used, the request is adjusted through the tool parser to handle function calling capabilities. This step ensures that the model can correctly interpret and respond to tool-related requests.
vllm/entrypoints/openai/serving_engine.pylink
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
... classOpenAIServing: ... asyncdef_preprocess_chat( ... # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM should_parse_tools = tool_parser isnotNoneand (hasattr( request, "tool_choice") and request.tool_choice != "none") if should_parse_tools: ifnotisinstance(request, ChatCompletionRequest): msg = "Tool usage is only supported for Chat Completions API" raise NotImplementedError(msg) request = tool_parser(tokenizer).adjust_request(# type: ignore request=request) ...
Tokenize the request prompt: Convert the string-based prompt into token format for model processing. For string prompts, the system uses asynchronous tokenization with optional prompt truncation and special token handling through the OpenAIServing._tokenize_prompt_input_async() method, which performs tokenization in a thread pool to prevent blocking the main event loop. For MistralTokenizer, token IDs are already provided, so the system creates a TextTokensPrompt object containing both the decoded text and the token IDs.
... classOpenAIServing: ... asyncdef_preprocess_chat( ... ifisinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, tokenizer, request_prompt, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: # For MistralTokenizer assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids") prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt) ...
Create the engine prompt: Construct the final EngineTokensPrompt object that will be passed to the inference engine. This includes the tokenized prompt, multimodal data (if present), multimodal processor kwargs, and cache salt for caching optimization. The function returns the processed conversation, request prompt, and engine prompt for the next stage of processing.
Inference is performed through the OpenAIServingChat(OpenAIServing).engine_client.generate() method. In this document, I’m using AsyncLLM(EngineClient) as the engine_client, so let me examine the AsyncLLM(EngineClient).generate() method.
Engine Client
Initialize output handler: AsyncLLM(EngineClient).output_handler is executed by running the AsyncLLM(EngineClient)._run_output_handler() method.
... classAsyncLLM(EngineClient): ... asyncdefgenerate( self, prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request * 1) Making an AsyncStream corresponding to the Request. * 2) Processing the Input. * 3) Adding the Request to the Detokenizer. * 4) Adding the Request to the EngineCore (separate process). A separate output_handler loop runs in a background AsyncIO task, pulling outputs from EngineCore and putting them into the per-request AsyncStream. The caller of generate() iterates the returned AsyncGenerator, returning the RequestOutput back to the caller. """ try:
# We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() ...
The output_handler executes in the following order:
Pull EngineCoreOutputs from the EngineCore: Continuously polls the engine core for outputs using await engine_core.get_output_async() and processes them in chunks to avoid blocking the event loop.
Process EngineCoreOutputs: Each output chunk is processed through output_processor.process_outputs() which converts raw engine outputs into formatted request outputs and pushes them to appropriate async streams.
Handle request aborts: Processes any requests that need to be aborted due to stop strings or other completion conditions via await engine_core.abort_requests_async().
Performance logging: Records scheduler statistics and iteration metrics for monitoring and debugging purposes.
... classAsyncLLM(EngineClient): ... def_run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams."""
ifself.output_handler isnotNone: return
# Ensure that the task doesn't have a circular ref back to the AsyncLLM # object, or else it won't be garbage collected and cleaned up properly. engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats stat_loggers = self.stat_loggers if log_stats elseNone
asyncdefoutput_handler(): try: whileTrue: # 1) Pull EngineCoreOutputs from the EngineCore. outputs = await engine_core.get_output_async() num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if ( log_stats and num_outputs) elseNone
# Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: slices = (outputs.outputs, ) else: slices = np.array_split( outputs.outputs, cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
for i, outputs_slice inenumerate(slices): # 2) Process EngineCoreOutputs. processed_outputs = output_processor.process_outputs( outputs_slice, outputs.timestamp, iteration_stats) # NOTE: RequestOutputs are pushed to their queues. assertnot processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks if i + 1 < len(slices): await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings. await engine_core.abort_requests_async( processed_outputs.reqs_to_abort)
# 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: assert outputs.scheduler_stats isnotNone AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") output_processor.propagate_error(e)
AsyncLLM(EngineClient).add_request() operates as follows:
Process input and create request: Converts the input prompt and parameters into an internal request object using self.processor.process_inputs(), which handles tokenization, parameter validation, and request formatting.
Send request to core engine: The AsyncLLM(EngineClient)._add_request() method calls the AsyncMPClient(MPClient).add_request_async() method to send an EngineCoreRequestType.ADD request to the core engine, enabling asynchronous communication between the client and the engine process for efficient request queuing and processing.
Process request through busy loop: The request sent in this way is processed through EngineCoreProc via a busy loop as shown below and scheduled in the EngineCoreProc(EngineCore).scheduler.
... classEngineCore: """Inner loop of vLLM's Engine.""" ... defrun_busy_loop(self): """Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM whileTrue: # 1) Poll the input queue until there is work to do. self._process_input_queue() # 2) Step the engine core and return the outputs. self._process_engine_step()
def_process_input_queue(self): """Exits when an engine step needs to be performed."""
waited = False whilenotself.engines_running andnot (self.scheduler.has_requests()): if logger.isEnabledFor(DEBUG) andself.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True req = self.input_queue.get() self._handle_client_request(*req)
if waited: logger.debug("EngineCore loop active.")
# Handle any more client requests. whilenotself.input_queue.empty(): req = self.input_queue.get_nowait() self._handle_client_request(*req)
def_process_engine_step(self): """Called only when there are unfinished local requests."""
# Step the engine core. outputs = self.step_fn() # Put EngineCoreOutputs into the output queue. if outputs isnotNone: self.output_queue.put_nowait(outputs)
... classEngineCore: """Inner loop of vLLM's Engine.""" ... defadd_request(self, request: EngineCoreRequest): """Add request to the scheduler."""
if request.mm_hashes isnotNone: # Here, if hash exists for a multimodal input, then it will be # fetched from the cache, else it will be added to the cache. # Note that the cache here is mirrored with the client cache, so # anything that has a hash must have a HIT cache entry here # as well. assert request.mm_inputs isnotNone request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes)
if req.kv_transfer_params isnotNoneand ( notself.scheduler.get_kv_connector()): logger.warning("Got kv_transfer_params, but no KVConnector found. " "Disabling KVTransfer for this request.")
self.scheduler.add_request(req) ...
The busy loop is created through the following process:
Determine step function and execute scheduling: Based on the EngineCoreProc(EngineCore).model_executor.max_concurrent_batches value, the EngineCoreProc(EngineCore).step_fn is determined as one of the two methods below, and the Scheduler(SchedulerInterface).schedule() method is internally executed by the EngineCoreProc(EngineCore)._process_engine_step() method.
Scheduling logic: The scheduler determines which requests to process next based on factors like priority, available resources, sequence length, and batching constraints. It creates batched sequences for efficient GPU utilization and manages the transition of requests between different states (waiting, running, swapped).
... classEngineCore: ... defstep(self) -> EngineCoreOutputs: """Schedule, execute, and make output."""
# Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. ifnotself.scheduler.has_requests(): return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats(), ) scheduler_output = self.scheduler.schedule() model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore
return engine_core_outputs
defstep_with_batch_queue(self) -> Optional[EngineCoreOutputs]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. The execution flow is as follows: 1. Try to schedule a new batch if the batch queue is not full. If a new batch is scheduled, directly return an empty engine core output. In other words, fulfilling the batch queue has a higher priority than getting model outputs. 2. If there is no new scheduled batch, meaning that the batch queue is full or no other requests can be scheduled, we block until the first batch in the job queue is finished. 3. Update the scheduler from the output. """ assertself.batch_queue isnotNone
engine_core_outputs = None scheduler_output = None # Try to schedule a new batch if the batch queue is not full, but # the scheduler may return an empty batch if all requests are scheduled. # Note that this is not blocking. ifnotself.batch_queue.full(): scheduler_output = self.scheduler.schedule() if scheduler_output.total_num_scheduled_tokens > 0: future = self.model_executor.execute_model(scheduler_output) self.batch_queue.put_nowait( (future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output isnotNone and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty, # block until the first batch in the job queue is finished. # TODO(comaniac): Ideally we should peek the first batch in the # job queue to check if it's finished before scheduling a new batch, # but peeking the first element in a queue is not thread-safe, # so we need more work. ifnot scheduled_batch andnotself.batch_queue.empty(): future, scheduler_output = self.batch_queue.get_nowait() # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output)
return engine_core_outputs ...
Executor
Execute model with scheduler output: The EngineCoreProc(EngineCore).model_executor.execute_model() method is executed using the SchedulerOutput (which contains batched sequences, execution metadata, and resource allocation information) from the Scheduler(SchedulerInterface).schedule() method output.
vllm/v1/engine/core.pylink
1 2 3 4 5 6 7 8 9 10 11 12 13
... classEngineCore: ... defexecute_model(self, scheduler_output: SchedulerOutput): try: returnself.model_executor.execute_model(scheduler_output) except BaseException as err: # NOTE: This method is exception-free dump_engine_exception(self.vllm_config, scheduler_output, self.scheduler.make_stats()) # Re-raise exception raise err ...
Send model inference request: The model inference request is sent through the UniProcExecutor(UniProcExecutorV0, Executor).collective_rpc() method.
Execute model inference: The Worker(WorkerBase) that receives the request executes the execute_model() method and performs actual model inference through the GPUModelRunner(LoRAModelRunnerMixin).execute_model() method.
... classAsyncLLM(EngineClient): ... asyncdefgenerate( ... try: ... # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. finished = False whilenot finished: # Note: drain queue without await if possible (avoids # task switching under load which helps performance). out = q.get_nowait() orawait q.get() # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. finished = out.finished yield out ...
Postprocessing
The process of preparing the response that users will receive is very complex, so the code for this section has been excluded.
Buffered Response
Method Initialization
The method accepts parameters including ChatCompletionRequest, AsyncIterator[RequestOutput], request metadata, etc.
Records the current timestamp with created_time = int(time.time())
Initializes final_res: Optional[RequestOutput] = None to store the final result
Result Generation Loop
Iterates through result_generator using async for res in result_generator:
Continuously updates final_res = res to get the final output
Handles exceptions:
asyncio.CancelledError: Returns error response for client disconnection
ValueError: Returns error response with the exception message
Creates final choice with appropriate finish_reason
Sets finish_reason_sent[i] = True
Chunk Creation and Yielding
Creates ChatCompletionStreamResponse chunk
Adds continuous usage stats if requested
Yields formatted chunk: f"data: {data}\n\n"
Final Usage Statistics
If include_usage is True:
Calculates total completion tokens
Creates UsageInfo with final statistics
Adds prompt token details if enabled
Yields final usage chunk
Metadata and Error Handling
Sets request_metadata.final_usage_info with aggregate usage
Exception Handling: Catches all exceptions and yields error response
Final Response: Yields "data: [DONE]\n\n" to signal completion
Conclusion
This comprehensive analysis of vLLM’s /v1/chat/completions endpoint reveals the sophisticated architecture powering OpenAI-compatible inference serving. The journey from a simple HTTP request to a complete chat response involves multiple layers of abstraction, each meticulously optimized for performance, scalability, and reliability.
Below is a sequence diagram summarizing this article:
sequenceDiagram
participant Client
participant FastAPI
participant OpenAIServingChat as OpenAIServingChat(OpenAIServing)
participant AsyncLLM as AsyncLLM(EngineClient)
participant AsyncMPClient as AsyncMPClient(MPClient)
participant ZMQ as ZeroMQ
participant EngineCoreProc as EngineCoreProc(EngineCore)
participant Scheduler as Scheduler(SchedulerInterface)
participant UniProcExecutor as UniProcExecutor(UniProcExecutorV0|Executor)
participant Worker as Worker(WorkerBase)
participant GPUModelRunner as GPUModelRunner(LoRAModelRunnerMixin)
participant OutputProcessor
EngineCoreProc-->>EngineCoreProc: run_busy_loop()
Client->>FastAPI: POST /v1/chat/completions
FastAPI->>OpenAIServingChat: create_chat_completion(ChatCompletionRequest)
OpenAIServingChat->>OpenAIServingChat: _check_model, _preprocess_chat, etc.
OpenAIServingChat->>AsyncLLM: generate()
AsyncLLM->>AsyncMPClient: add_request(EngineCoreRequest)
AsyncMPClient->>ZMQ: add_request_async(EngineCoreRequest)
EngineCoreProc->>ZMQ: _handle_client_request(EngineCoreRequestType)
ZMQ-->>EngineCoreProc: add_request(EngineCoreRequest)
EngineCoreProc->>Scheduler: add_request(Request)
rect rgb(255,128,128)
note over EngineCoreProc: step_fn()
EngineCoreProc->>Scheduler: schedule()
Scheduler-->>EngineCoreProc: SchedulerOutput
EngineCoreProc->>UniProcExecutor: execute_model(SchedulerOutput)
UniProcExecutor->>Worker: collective_rpc("execute_model")
Worker->>GPUModelRunner: execute_model(SchedulerOutput)
GPUModelRunner-->>Worker: ModelRunnerOutput | IntermediateTensors
Worker-->>UniProcExecutor: ModelRunnerOutput
UniProcExecutor-->>EngineCoreProc: ModelRunnerOutput
EngineCoreProc->>Scheduler: update_from_output(SchedulerOutput, ModelRunnerOutput)
Scheduler->>EngineCoreProc: EngineCoreOutputs
end
EngineCoreProc-->>EngineCoreProc: put_nowait(EngineCoreOutputs)
EngineCoreProc->>ZMQ: process_output_socket()
rect rgb(128,128,255)
note over AsyncLLM: output_handler()
AsyncLLM->>AsyncMPClient: get_output_async()
AsyncMPClient->>ZMQ: process_outputs_socket()
ZMQ-->>AsyncLLM: EngineCoreOutputs
AsyncLLM->>OutputProcessor: process_outputs()
OutputProcessor-->>AsyncLLM: OutputProcessorOutput
end
AsyncLLM-->>OpenAIServingChat: AsyncGenerator[RequestOutput, None]
OpenAIServingChat-->>FastAPI: ChatCompletionResponse / AsyncGenerator
FastAPI-->>Client: JSONResponse / StreamingResponse
The structure turned out to be much more complex than I expected, making this article quite lengthy with many parts omitted. In future articles, I’ll take a closer look at core components like EngineCoreProc(EngineCore), Scheduler(SchedulerInterface), and GPUModelRunner(LoRAModelRunnerMixin).