Python API

Note

This page introduces the Python API with MLCEngine in MLC LLM. If you want to check out the old Python API which uses mlc_llm.ChatModule, please go to Python API (Chat Module)

MLC LLM provides Python API through classes mlc_llm.MLCEngine and mlc_llm.AsyncMLCEngine which support full OpenAI API completeness for easy integration into other Python projects.

This page introduces how to use the engines in MLC LLM. The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via the installation page.

Verify Installation

python -c "from mlc_llm import MLCEngine; print(MLCEngine)"

You are expected to see the output of <class 'mlc_llm.serve.engine.MLCEngine'>.

If the command above results in error, follow Install MLC LLM Python Package to install prebuilt pip packages or build MLC LLM from source.

Run MLCEngine

mlc_llm.MLCEngine provides the interface of OpenAI chat completion synchronously. mlc_llm.MLCEngine does not batch concurrent request due to the synchronous design, and please use AsyncMLCEngine for request batching process.

Stream Response. In Quick Start and Introduction to MLC LLM, we introduced the basic use of mlc_llm.MLCEngine.

from mlc_llm import MLCEngine

# Create engine
model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC"
engine = MLCEngine(model)

# Run chat completion in OpenAI API.
for response in engine.chat.completions.create(
    messages=[{"role": "user", "content": "What is the meaning of life?"}],
    model=model,
    stream=True,
):
    for choice in response.choices:
        print(choice.delta.content, end="", flush=True)
print("\n")

engine.terminate()

This code example first creates an mlc_llm.MLCEngine instance with the 8B Llama-3 model. We design the Python API mlc_llm.MLCEngine to align with OpenAI API, which means you can use mlc_llm.MLCEngine in the same way of using OpenAI’s Python package for both synchronous and asynchronous generation.

Non-stream Response. The code example above uses the synchronous chat completion interface and iterate over all the stream responses. If you want to run without streaming, you can run

response = engine.chat.completions.create(
    messages=[{"role": "user", "content": "What is the meaning of life?"}],
    model=model,
    stream=False,
)
print(response)

Please refer to OpenAI’s Python package and OpenAI chat completion API for the complete chat completion interface.

Run AsyncMLCEngine

mlc_llm.AsyncMLCEngine provides the interface of OpenAI chat completion with asynchronous features. We recommend using mlc_llm.AsyncMLCEngine to batch concurrent request for better throughput.

Stream Response. The core use of mlc_llm.AsyncMLCEngine for stream responses is as follows.

async for response in await engine.chat.completions.create(
  messages=[{"role": "user", "content": "What is the meaning of life?"}],
  model=model,
  stream=True,
):
  for choice in response.choices:
      print(choice.delta.content, end="", flush=True)
The collapsed is a complete runnable example of AsyncMLCEngine in Python.
import asyncio
from typing import Dict

from mlc_llm.serve import AsyncMLCEngine

model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC"
prompts = [
    "Write a three-day travel plan to Pittsburgh.",
    "What is the meaning of life?",
]


async def test_completion():
    # Create engine
    async_engine = AsyncMLCEngine(model=model)

    num_requests = len(prompts)
    output_texts: Dict[str, str] = {}

    async def generate_task(prompt: str):
        async for response in await async_engine.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=model,
            stream=True,
        ):
            if response.id not in output_texts:
                output_texts[response.id] = ""
            output_texts[response.id] += response.choices[0].delta.content

    tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)]
    await asyncio.gather(*tasks)

    # Print output.
    for request_id, output in output_texts.items():
        print(f"Output of request {request_id}:\n{output}\n")

    async_engine.terminate()


asyncio.run(test_completion())

Non-stream Response. Similarly, mlc_llm.AsyncEngine provides the non-stream response interface.

response = await engine.chat.completions.create(
  messages=[{"role": "user", "content": "What is the meaning of life?"}],
  model=model,
  stream=False,
)
print(response)

Please refer to OpenAI’s Python package and OpenAI chat completion API for the complete chat completion interface.

Engine Mode

To ease the engine configuration, the constructors of mlc_llm.MLCEngine and mlc_llm.AsyncMLCEngine have an optional argument mode, which falls into one of the three options "local", "interactive" or "server". The default mode is "local".

Each mode denotes a pre-defined configuration of the engine to satisfy different use cases. The choice of the mode controls the request concurrency of the engine, as well as engine’s KV cache token capacity (or in other words, the maximum number of tokens that the engine’s KV cache can hold), and further affects the GPU memory usage of the engine.

In short,

  • mode "local" uses low request concurrency and low KV cache capacity, which is suitable for cases where concurrent requests are not too many, and the user wants to save GPU memory usage.

  • mode "interactive" uses 1 as the request concurrency and low KV cache capacity, which is designed for interactive use cases such as chats and conversations.

  • mode "server" uses as much request concurrency and KV cache capacity as possible. This mode aims to fully utilize the GPU memory for large server scenarios where concurrent requests may be many.

For system benchmark, please select mode "server". Please refer to API Reference for detailed documentation of the engine mode.

Deploy Your Own Model with Python API

The introduction page introduces how we can deploy our own models with MLC LLM. This section introduces how you can use the model weights you convert and the model library you build in mlc_llm.MLCEngine and mlc_llm.AsyncMLCEngine.

We use the Phi-2 as the example model.

Specify Model Weight Path. Assume you have converted the model weights for your own model, you can construct a mlc_llm.MLCEngine as follows:

from mlc_llm import MLCEngine

model = "models/phi-2"  # Assuming the converted phi-2 model weights are under "models/phi-2"
engine = MLCEngine(model)

Specify Model Library Path. Further, if you build the model library on your own, you can use it in mlc_llm.MLCEngine by passing the library path through argument model_lib.

from mlc_llm import MLCEngine

model = "models/phi-2"
model_lib = "models/phi-2/lib.so"  # Assuming the phi-2 model library is built at "models/phi-2/lib.so"
engine = MLCEngine(model, model_lib=model_lib)

The same applies to mlc_llm.AsyncMLCEngine.

API Reference

The mlc_llm.MLCEngine and mlc_llm.AsyncMLCEngine classes provide the following constructors.

The MLCEngine and AsyncMLCEngine have full OpenAI API completeness. Please refer to OpenAI’s Python package and OpenAI chat completion API for the complete chat completion interface.

class mlc_llm.MLCEngine(model: str, device: Union[str, Device] = 'auto', *, model_lib: Optional[str] = None, mode: Literal['local', 'interactive', 'server'] = 'local', additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: Literal['disable', 'small_draft', 'eagle'] = 'disable', spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True)

Bases: MLCEngineBase

The MLCEngine in MLC LLM that provides the synchronous interfaces with regard to OpenAI API.

Parameters:
  • model (str) – A path to mlc-chat-config.json, or an MLC model directory that contains mlc-chat-config.json. It can also be a link to a HF repository pointing to an MLC compiled model.

  • device (Union[str, Device]) – The device used to deploy the model such as “cuda” or “cuda:0”. Will default to “auto” and detect from local available GPUs if not specified.

  • model_lib (Optional[str]) – The full path to the model library file to use (e.g. a .so file). If unspecified, we will use the provided model to search over possible paths. It the model lib is not found, it will be compiled in a JIT manner.

  • mode (Literal["local", "interactive", "server"]) –

    The engine mode in MLC LLM. We provide three preset modes: “local”, “interactive” and “server”. The default mode is “local”. The choice of mode decides the values of “max_batch_size”, “max_total_sequence_length” and “prefill_chunk_size” when they are not explicitly specified. 1. Mode “local” refers to the local server deployment which has low request concurrency. So the max batch size will be set to 4, and max total sequence length and prefill chunk size are set to the context window size (or sliding window size) of the model. 2. Mode “interactive” refers to the interactive use of server, which has at most 1 concurrent request. So the max batch size will be set to 1, and max total sequence length and prefill chunk size are set to the context window size (or sliding window size) of the model. 3. Mode “server” refers to the large server use case which may handle many concurrent request and want to use GPU memory as much as possible. In this mode, we will automatically infer the largest possible max batch size and max total sequence length.

    You can manually specify arguments “max_batch_size”, “max_total_sequence_length” and “prefill_chunk_size” to override the automatic inferred values.

  • additional_models (Optional[List[str]]) – The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. Each string in the list is either in form “model_path” or “model_path:model_lib”. When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically.

  • max_batch_size (Optional[int]) – The maximum allowed batch size set for the KV cache to concurrently support.

  • max_total_sequence_length (Optional[int]) – The KV cache total token capacity, i.e., the maximum total number of tokens that the KV cache support. This decides the GPU memory size that the KV cache consumes. If not specified, system will automatically estimate the maximum capacity based on the vRAM size on GPU.

  • prefill_chunk_size (Optional[int]) – The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config.

  • gpu_memory_utilization (Optional[float]) – A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. When it is unspecified, it defaults to 0.85. Under mode “local” or “interactive”, the actual memory usage may be significantly smaller than this number. Under mode “server”, the actual memory usage may be slightly larger than this number.

  • speculative_mode (Literal["disable", "small_draft", "eagle"]) – The speculative mode. “disable” means speculative decoding is disabled. “small_draft” means the normal speculative decoding (small draft) mode. “eagle” means the eagle-style speculative decoding.

  • spec_draft_length (int) – The number of tokens to generate in speculative proposal (draft).

  • enable_tracing (bool) – A boolean indicating if to enable event logging for requests.

  • verbose (bool) – A boolean indicating whether to print logging info in engine.

__init__(model: str, device: Union[str, Device] = 'auto', *, model_lib: Optional[str] = None, mode: Literal['local', 'interactive', 'server'] = 'local', additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: Literal['disable', 'small_draft', 'eagle'] = 'disable', spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True) None
abort(request_id: str) None

Generation abortion interface.

request_idstr

The id of the request to abort.

class mlc_llm.AsyncMLCEngine(model: str, device: Union[str, Device] = 'auto', *, model_lib: Optional[str] = None, mode: Literal['local', 'interactive', 'server'] = 'local', additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: Literal['disable', 'small_draft', 'eagle'] = 'disable', spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True)

Bases: MLCEngineBase

The AsyncMLCEngine in MLC LLM that provides the asynchronous interfaces with regard to OpenAI API.

Parameters:
  • model (str) – A path to mlc-chat-config.json, or an MLC model directory that contains mlc-chat-config.json. It can also be a link to a HF repository pointing to an MLC compiled model.

  • device (Union[str, Device]) – The device used to deploy the model such as “cuda” or “cuda:0”. Will default to “auto” and detect from local available GPUs if not specified.

  • model_lib (Optional[str]) – The full path to the model library file to use (e.g. a .so file). If unspecified, we will use the provided model to search over possible paths. It the model lib is not found, it will be compiled in a JIT manner.

  • mode (Literal["local", "interactive", "server"]) –

    The engine mode in MLC LLM. We provide three preset modes: “local”, “interactive” and “server”. The default mode is “local”. The choice of mode decides the values of “max_batch_size”, “max_total_sequence_length” and “prefill_chunk_size” when they are not explicitly specified. 1. Mode “local” refers to the local server deployment which has low request concurrency. So the max batch size will be set to 4, and max total sequence length and prefill chunk size are set to the context window size (or sliding window size) of the model. 2. Mode “interactive” refers to the interactive use of server, which has at most 1 concurrent request. So the max batch size will be set to 1, and max total sequence length and prefill chunk size are set to the context window size (or sliding window size) of the model. 3. Mode “server” refers to the large server use case which may handle many concurrent request and want to use GPU memory as much as possible. In this mode, we will automatically infer the largest possible max batch size and max total sequence length.

    You can manually specify arguments “max_batch_size”, “max_total_sequence_length” and “prefill_chunk_size” to override the automatic inferred values.

  • additional_models (Optional[List[str]]) – The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. Each string in the list is either in form “model_path” or “model_path:model_lib”. When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically.

  • max_batch_size (Optional[int]) – The maximum allowed batch size set for the KV cache to concurrently support.

  • max_total_sequence_length (Optional[int]) – The KV cache total token capacity, i.e., the maximum total number of tokens that the KV cache support. This decides the GPU memory size that the KV cache consumes. If not specified, system will automatically estimate the maximum capacity based on the vRAM size on GPU.

  • prefill_chunk_size (Optional[int]) – The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config.

  • max_history_size (Optional[int]) – The maximum history for RNN state.

  • gpu_memory_utilization (Optional[float]) – A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. When it is unspecified, it defaults to 0.85. Under mode “local” or “interactive”, the actual memory usage may be significantly smaller than this number. Under mode “server”, the actual memory usage may be slightly larger than this number.

  • speculative_mode (Literal["disable", "small_draft", "eagle"]) – The speculative mode. “disable” means speculative decoding is disabled. “small_draft” means the normal speculative decoding (small draft) mode. “eagle” means the eagle-style speculative decoding.

  • spec_draft_length (int) – The number of tokens to generate in speculative proposal (draft).

  • enable_tracing (bool) – A boolean indicating if to enable event logging for requests.

  • verbose (bool) – A boolean indicating whether to print logging info in engine.

__init__(model: str, device: Union[str, Device] = 'auto', *, model_lib: Optional[str] = None, mode: Literal['local', 'interactive', 'server'] = 'local', additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: Literal['disable', 'small_draft', 'eagle'] = 'disable', spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True) None
async abort(request_id: str) None

Generation abortion interface.

request_idstr

The id of the request to abort.