Implement LLM Cross-engine Orchestration Patterns

In this tutorial, we will introduce how to implement LLM cross-engine orchestration patterns, like prefill-decode disaggregation, in MLC-LLM via microserving API. Aiming to make disaggregated serving programmable, MicroServing provides a new RISC-style approach to design LLM serving API at sub-request level. It enables programmable cross-engine serving patterns in a few lines of python code. For more information of microserving API, check out https://blog.mlc.ai/2025/01/07/microserving-llm-engines.

Below is an example of prefill-decode disaggregation implementation. An LLM cross-engine orchestration pattern is implemented in a router, which dispatches original OpenAI-style completion requests to a chain of microserving API calls. In this code example, we create a subclass of Router (which includes wrappers for calling microserving APIs), and override translate_request function. The translate_request function takes in a request and a unique identifier of the request (request_id), and returns an AsyncGenerator of response. We launch the CustomRouter and 2 engines, each of which has tensor parallel degree 2. Engine 0 is prefill engine and engine 1 is decode engine.

from mlc_llm.router import Router
from mlc_llm.protocol import openai_api_protocol
from typing import Any, AsyncGenerator
from mlc_llm.serve.entrypoints import microserving_entrypoints
from mlc_llm.interface.router import serve

import aiohttp

class CustomRouter(Router):
    async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
        pass


serve(
    model="/path/to/model", # replace this with actual path
    model_lib="/path/to/model_lib", # replace this with actual path
    router_host="127.0.0.1",
    router_port=9123,
    endpoint_hosts=["127.0.0.1", "127.0.0.1"],
    endpoint_ports=[9124,9125],
    endpoint_num_gpus=[2,2],
    enable_prefix_cache=False,
    router_type=CustomRouter,
)

In the translate_request function, we first assign request_id to request.user, and later the request id will be passed as an argument to the microserving API.

# we will pass request_id as an argument in microserving API calls
request.user = request_id

Next, call prep_recv on the decode engine to prepare KV entries for receiving from remote. end=-1 means that we will let the prefill engine prefill all except the last token, which makes sure that the prefill engine does not need sampling logic. prep_recv returns address to receive KV from remote and matched prefix length. For simplicity, we do not enable prefix cache in the tutorial, so we only need the kv address here.

async with aiohttp.ClientSession(
    timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True
) as session:
    decode_start = len(request.prompt) -1
    # 1. Ask decode engine to prepare KV entries to receive from prefill engine
    prep_recv_request = microserving_entrypoints.PrepRecvRequest(
        **request.model_dump(), end=decode_start
    )
    (
        kv_addr_info,
        _,
    ) = await self.send_prepare_receive(
        session=session,
        request=prep_recv_request,
        server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine
    )

Then, call remote_send on the prefill engine to compute and send KV to decode engine. recv_rank=self.device_id_starts[1] means that we are sending KV to engine 1 (decode engine).

# 2. Ask prefill engine to send KV to decode engine
remote_send_request = microserving_entrypoints.RemoteSendRequest(
    **request.model_dump(),
    begin=0,
    end=decode_start,
    kv_addr_info=kv_addr_info,
    recv_rank=self.device_id_starts[1], # the rank of decode engine
)
await self.send_remote_send(
    session=session,
    request=remote_send_request,
    server_url=self.server_urls[0], # prefill engine
)

Finally, call start_generate on the decode engine to start generating tokens. begin=decode_start means we will prefill the last token in the prompt and start decoding. Notably, the decode process of the request may be preempted. In such case, we yield None, so that the router will rerun the translate_request function.

# 3. Start decoding
start_generate_request = microserving_entrypoints.StartGenerateRequest(
    **request.model_dump(),
    begin=decode_start,
)
async for response in self.send_start_generate(
    session=session,
    request=start_generate_request,
    server_url=self.server_urls[1],
):
    if len(response.choices) > 0:
        finish_reason = response.choices[0].finish_reason
        if finish_reason == "preempt":
            yield None
    yield response

Bringing everything together, the complete code is as below:

from mlc_llm.router import Router
from mlc_llm.protocol import openai_api_protocol
from typing import Any, AsyncGenerator
from mlc_llm.serve.entrypoints import microserving_entrypoints
from mlc_llm.interface.router import serve

import aiohttp
class CustomRouter(Router):
    async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
        # we will pass request_id as an argument in microserving API calls
        request.user = request_id

        async with aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True
        ) as session:
            decode_start = len(request.prompt) -1
            # 1. Ask decode engine to prepare KV entries to receive from prefill engine
            prep_recv_request = microserving_entrypoints.PrepRecvRequest(
                **request.model_dump(), end=decode_start
            )
            (
                kv_addr_info,
                _,
            ) = await self.send_prepare_receive(
                session=session,
                request=prep_recv_request,
                server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine
            )
            # 2. Ask prefill engine to send KV to decode engine
            remote_send_request = microserving_entrypoints.RemoteSendRequest(
                **request.model_dump(),
                begin=0,
                end=decode_start,
                kv_addr_info=kv_addr_info,
                recv_rank=self.device_id_starts[1], # the rank of decode engine
            )
            await self.send_remote_send(
                session=session,
                request=remote_send_request,
                server_url=self.server_urls[0], # prefill engine
            )
            # 3. Start decoding
            start_generate_request = microserving_entrypoints.StartGenerateRequest(
                **request.model_dump(),
                begin=decode_start,
            )
            async for response in self.send_start_generate(
                session=session,
                request=start_generate_request,
                server_url=self.server_urls[1],
            ):
                if len(response.choices) > 0:
                    finish_reason = response.choices[0].finish_reason
                    if finish_reason == "preempt":
                        yield None
                yield response


serve(
    model="/path/to/model", # replace this with actual path
    model_lib="/path/to/model_lib", # replace this with actual path
    router_host="127.0.0.1",
    router_port=9123,
    endpoint_hosts=["127.0.0.1", "127.0.0.1"],
    endpoint_ports=[9124,9125],
    endpoint_num_gpus=[2,2],
    enable_prefix_cache=False,
    router_type=CustomRouter,
)