From 72efdea6846779b943583fc389c5f3aa8776df62 Mon Sep 17 00:00:00 2001 From: activezhao Date: Wed, 22 Nov 2023 17:36:39 +0800 Subject: [PATCH 1/7] change the configuration of engineArgs in config.pbtxt --- Quick_Deploy/vLLM/README.md | 56 ++-- .../model_repository/vllm_model/1/model.py | 283 ++++++++++++++++++ .../model_repository/vllm_model/config.pbtxt | 98 ++++++ 3 files changed, 409 insertions(+), 28 deletions(-) create mode 100644 Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py create mode 100644 Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index ee48f2af..24e59eab 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -41,39 +41,40 @@ backend. ## Step 1: Prepare your model repository -To use Triton, we need to build a model repository. For this tutorial we will -use the model repository, provided in the [samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) -folder of the [vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) -repository. - -The following set of commands will create a `model_repository/vllm_model/1` -directory and copy 2 files: -[`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json) -and -[`config.pbtxt`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/config.pbtxt), -required to serve the [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model. -``` -mkdir -p model_repository/vllm_model/1 -wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/1/model.json -wget -P model_repository/vllm_model/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/config.pbtxt -``` +To use Triton, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is +included with this demo as `model_repository` directory. The model repository should look like this: ``` model_repository/ └── vllm_model ├── 1 - │   └── model.json + │ └── model.json └── config.pbtxt ``` -The content of `model.json` is: +The configuration of engineArgs is in config.pbtxt: + +``` +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} -```json -{ - "model": "facebook/opt-125m", - "disable_log_requests": "true", - "gpu_memory_utilization": 0.5 +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.8" + } } ``` @@ -84,16 +85,15 @@ and for supported key-value pairs. Inflight batching and paged attention is handled by the vLLM engine. -For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified -in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`config.pbtxt`](model_repository/vllm/config.pbtxt). *Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. You can tweak this behavior using fields like `gpu_memory_utilization` and other settings -in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). +in [`config.pbtxt`](model_repository/vllm/config.pbtxt). -Read through the documentation in [`model.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/src/model.py) -to understand how to configure this sample for your use-case. +Read through the documentation in [`model.py`](model_repository/vllm/1/model.py) to understand how +to configure this sample for your use-case. ## Step 2: Launch Triton Inference Server diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py b/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py new file mode 100644 index 00000000..cd77a0b9 --- /dev/null +++ b/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py @@ -0,0 +1,283 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import asyncio +import json +import threading +from typing import AsyncGenerator + +import numpy as np +import triton_python_backend_utils as pb_utils +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.utils import random_uuid + + +class TritonPythonModel: + def initialize(self, args): + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + + # assert are in decoupled mode. Currently, Triton needs to use + # decoupled policy for asynchronously forwarding requests to + # vLLM engine. + self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( + self.model_config + ) + assert ( + self.using_decoupled + ), "vLLM Triton backend must be configured to use decoupled model transaction policy" + + self.model_name = args["model_name"] + assert ( + self.model_name + ), "Parameter of [name] must be configured, and can not be empty in config.pbtxt" + + # Create an AsyncLLMEngine from the config from JSON + self.llm_engine = AsyncLLMEngine.from_engine_args( + AsyncEngineArgs(**self.handle_initializing_config()) + ) + + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + # Counter to keep track of ongoing request counts + self.ongoing_request_count = 0 + + # Starting asyncio event loop to process the received requests asynchronously. + self._loop = asyncio.get_event_loop() + self._loop_thread = threading.Thread( + target=self.engine_loop, args=(self._loop,) + ) + self._shutdown_event = asyncio.Event() + self._loop_thread.start() + + def handle_initializing_config(self): + model_params = self.model_config.get("parameters", {}) + model_engine_args = {} + for key, value in model_params.items(): + model_engine_args[key] = value['string_value'] + + bool_keys = ["trust_remote_code", "use_np_weights", "use_dummy_weights", + "worker_use_ray", "disable_log_stats"] + for k in bool_keys: + if k in model_engine_args: + model_engine_args[k] = bool(model_engine_args[k]) + + float_keys = ["gpu_memory_utilization"] + for k in float_keys: + if k in model_engine_args: + model_engine_args[k] = float(model_engine_args[k]) + + int_keys = ["seed", "pipeline_parallel_size", "tensor_parallel_size", "block_size", + "swap_space", "max_num_batched_tokens", "max_num_seqs"] + for k in int_keys: + if k in model_engine_args: + model_engine_args[k] = int(model_engine_args[k]) + + # Check necessary parameter configuration in model config + model_param = model_engine_args["model"] + assert ( + model_param + ), "Parameter of [model] must be configured, and can not be empty in config.pbtxt" + + self.logger.log_info(f"Initialize engineArgs: {model_engine_args}") + return model_engine_args + + def create_task(self, coro): + """ + Creates a task on the engine's event loop which is running on a separate thread. + """ + assert ( + self._shutdown_event.is_set() is False + ), "Cannot create tasks after shutdown has been requested" + + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def engine_loop(self, loop): + """ + Runs the engine's event loop on a separate thread. + """ + asyncio.set_event_loop(loop) + self._loop.run_until_complete(self.await_shutdown()) + + async def await_shutdown(self): + """ + Primary coroutine running on the engine event loop. This coroutine is responsible for + keeping the engine alive until a shutdown is requested. + """ + # first await the shutdown signal + while self._shutdown_event.is_set() is False: + await asyncio.sleep(5) + + # Wait for the ongoing_requests + while self.ongoing_request_count > 0: + self.logger.log_info( + "[vllm] Awaiting remaining {} requests".format( + self.ongoing_request_count + ) + ) + await asyncio.sleep(5) + + for task in asyncio.all_tasks(loop=self._loop): + if task is not asyncio.current_task(): + task.cancel() + + self.logger.log_info("[vllm] Shutdown complete") + + def get_sampling_params_dict(self, params_json): + """ + This functions parses the dictionary values into their + expected format. + """ + + params_dict = json.loads(params_json) + + # Special parsing for the supported sampling parameters + bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] + for k in bool_keys: + if k in params_dict: + params_dict[k] = bool(params_dict[k]) + + float_keys = [ + "frequency_penalty", + "length_penalty", + "presence_penalty", + "temperature", + "top_p", + ] + for k in float_keys: + if k in params_dict: + params_dict[k] = float(params_dict[k]) + + int_keys = ["best_of", "max_tokens", "n", "top_k"] + for k in int_keys: + if k in params_dict: + params_dict[k] = int(params_dict[k]) + + return params_dict + + def create_response(self, vllm_output): + """ + Parses the output from the vLLM engine into Triton + response. + """ + prompt = vllm_output.prompt + text_outputs = [ + (prompt + output.text).encode("utf-8") for output in vllm_output.outputs + ] + triton_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(text_outputs, dtype=self.output_dtype) + ) + return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + + async def generate(self, request): + """ + Forwards single request to LLM engine and returns responses. + """ + response_sender = request.get_response_sender() + self.ongoing_request_count += 1 + try: + request_id = random_uuid() + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # Request parameters are not yet supported via + # BLS. Provide an optional mechanism to receive serialized + # parameters as an input tensor until support is added + + parameters_input_tensor = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if parameters_input_tensor: + parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") + else: + parameters = request.parameters() + + sampling_params_dict = self.get_sampling_params_dict(parameters) + sampling_params = SamplingParams(**sampling_params_dict) + + last_output = None + async for output in self.llm_engine.generate( + prompt, sampling_params, request_id + ): + if stream: + response_sender.send(self.create_response(output)) + else: + last_output = output + + if not stream: + response_sender.send(self.create_response(last_output)) + + except Exception as e: + self.logger.log_info(f"[vllm] Error generating stream: {e}") + error = pb_utils.TritonError(f"Error generating stream: {e}") + triton_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) + ) + response = pb_utils.InferenceResponse( + output_tensors=[triton_output_tensor], error=error + ) + response_sender.send(response) + raise e + finally: + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.ongoing_request_count -= 1 + + def execute(self, requests): + """ + Triton core issues requests to the backend via this method. + + When this method returns, new requests can be issued to the backend. Blocking + this function would prevent the backend from pulling additional requests from + Triton into the vLLM engine. This can be done if the kv cache within vLLM engine + is too loaded. + We are pushing all the requests on vllm and let it handle the full traffic. + """ + for request in requests: + self.create_task(self.generate(request)) + return None + + def finalize(self): + """ + Triton virtual method; called when the model is unloaded. + """ + self.logger.log_info("[vllm] Issuing finalize to vllm backend") + self._shutdown_event.set() + if self._loop_thread is not None: + self._loop_thread.join() + self._loop_thread = None diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt b/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt new file mode 100644 index 00000000..764df417 --- /dev/null +++ b/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt @@ -0,0 +1,98 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "vllm" +backend: "python" + +# Disabling batching in Triton, let vLLM handle the batching on its own. +max_batch_size: 0 + +# We need to use decoupled transaction policy for saturating +# vLLM engine for max throughtput. +# TODO [DLIS:5233]: Allow asynchronous execution to lift this +# restriction for cases there is exactly a single response to +# a single request. +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "sampling_parameters" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +# The usage of device is deferred to the vLLM engine +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] + +# The configuration of engineArgs +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.8" + } +} \ No newline at end of file From 402720ba8e9bbfb63bc3352dfe438e76a58e56e4 Mon Sep 17 00:00:00 2001 From: activezhao Date: Wed, 22 Nov 2023 17:38:10 +0800 Subject: [PATCH 2/7] change the configuration of engineArgs in config.pbtxt --- Quick_Deploy/vLLM/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index 24e59eab..5e6669cc 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -49,7 +49,7 @@ The model repository should look like this: model_repository/ └── vllm_model ├── 1 - │ └── model.json + │ └── model.py └── config.pbtxt ``` From ad533cda912d4b65517e3d4213f566adecb3179b Mon Sep 17 00:00:00 2001 From: activezhao Date: Wed, 22 Nov 2023 17:43:10 +0800 Subject: [PATCH 3/7] update the model dir name --- Quick_Deploy/vLLM/README.md | 2 +- .../vLLM/model_repository/{vllm_model => vllm}/1/model.py | 0 .../vLLM/model_repository/{vllm_model => vllm}/config.pbtxt | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename Quick_Deploy/vLLM/model_repository/{vllm_model => vllm}/1/model.py (100%) rename Quick_Deploy/vLLM/model_repository/{vllm_model => vllm}/config.pbtxt (100%) diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index 5e6669cc..2affae0a 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -47,7 +47,7 @@ included with this demo as `model_repository` directory. The model repository should look like this: ``` model_repository/ -└── vllm_model +└── vllm ├── 1 │ └── model.py └── config.pbtxt diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py b/Quick_Deploy/vLLM/model_repository/vllm/1/model.py similarity index 100% rename from Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py rename to Quick_Deploy/vLLM/model_repository/vllm/1/model.py diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt b/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt similarity index 100% rename from Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt rename to Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt From a583a65db9ab01fb883429590e349e44bc1982e9 Mon Sep 17 00:00:00 2001 From: activezhao Date: Wed, 22 Nov 2023 18:59:50 +0800 Subject: [PATCH 4/7] update the model dir name --- Quick_Deploy/vLLM/README.md | 2 +- .../vLLM/model_repository/{vllm => vllm_model}/1/model.py | 0 .../vLLM/model_repository/{vllm => vllm_model}/config.pbtxt | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename Quick_Deploy/vLLM/model_repository/{vllm => vllm_model}/1/model.py (100%) rename Quick_Deploy/vLLM/model_repository/{vllm => vllm_model}/config.pbtxt (100%) diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index 2affae0a..5e6669cc 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -47,7 +47,7 @@ included with this demo as `model_repository` directory. The model repository should look like this: ``` model_repository/ -└── vllm +└── vllm_model ├── 1 │ └── model.py └── config.pbtxt diff --git a/Quick_Deploy/vLLM/model_repository/vllm/1/model.py b/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py similarity index 100% rename from Quick_Deploy/vLLM/model_repository/vllm/1/model.py rename to Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py diff --git a/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt b/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt similarity index 100% rename from Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt rename to Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt From c2f1f3f8d891551081d752990efb52157392f91a Mon Sep 17 00:00:00 2001 From: activezhao Date: Mon, 4 Dec 2023 16:41:59 +0800 Subject: [PATCH 5/7] create Customization for customizing vLLM --- Quick_Deploy/Customization/vLLM/.gitignore | 6 + Quick_Deploy/Customization/vLLM/README.md | 217 ++++++++++++++++++ .../model_repository/vllm_model/1/model.py | 0 .../model_repository/vllm_model/config.pbtxt | 2 +- Quick_Deploy/vLLM/README.md | 58 ++--- 5 files changed, 253 insertions(+), 30 deletions(-) create mode 100644 Quick_Deploy/Customization/vLLM/.gitignore create mode 100644 Quick_Deploy/Customization/vLLM/README.md rename Quick_Deploy/{ => Customization}/vLLM/model_repository/vllm_model/1/model.py (100%) rename Quick_Deploy/{ => Customization}/vLLM/model_repository/vllm_model/config.pbtxt (99%) diff --git a/Quick_Deploy/Customization/vLLM/.gitignore b/Quick_Deploy/Customization/vLLM/.gitignore new file mode 100644 index 00000000..82559cc4 --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/.gitignore @@ -0,0 +1,6 @@ +Miniconda* +miniconda +model_repository/vllm/vllm_env.tar.gz +model_repository/vllm/triton_python_backend_stub +python_backend +results.txt diff --git a/Quick_Deploy/Customization/vLLM/README.md b/Quick_Deploy/Customization/vLLM/README.md new file mode 100644 index 00000000..2605d744 --- /dev/null +++ b/Quick_Deploy/Customization/vLLM/README.md @@ -0,0 +1,217 @@ + + + +# Deploying a vLLM model in Triton + +The following tutorial demonstrates how to deploy a simple +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model on +Triton Inference Server using the Triton's +[Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) +[vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) +backend. + +*NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations). + + +## Step 1: Prepare your model repository + +To use Triton, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is +included with this demo as `model_repository` directory. + +The model repository should look like this: +``` +model_repository/ +└── vllm_model + ├── 1 + │ └── model.py + └── config.pbtxt +``` + +The configuration of engineArgs is in config.pbtxt: + +``` +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.5" + } +} +``` + +This file can be modified to provide further settings to the vLLM engine. See vLLM +[AsyncEngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L165) +and +[EngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L11) +for supported key-value pairs. Inflight batching and paged attention is handled +by the vLLM engine. + +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`config.pbtxt`](model_repository/vllm_model/config.pbtxt). + +*Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. +This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. +You can tweak this behavior using fields like `gpu_memory_utilization` and other settings +in [`config.pbtxt`](model_repository/vllm_model/config.pbtxt). + +Read through the documentation in [`model.py`](model_repository/vllm_model/1/model.py) to understand how +to configure this sample for your use-case. + +## Step 2: Launch Triton Inference Server + +Once you have the model repository setup, it is time to launch the triton server. +Starting with 23.10 release, a dedicated container with vLLM pre-installed +is available on [NGC.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) +To use this container to launch Triton, you can use the docker command below. +``` +docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work nvcr.io/nvidia/tritonserver:-vllm-python-py3 tritonserver --model-store ./model_repository +``` +Throughout the tutorial, \ is the version of Triton +that you want to use. Please note, that Triton's vLLM +container was first published in 23.10 release, so any prior version +will not work. + +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. + +``` +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +## Step 3: Use a Triton Client to Send Your First Inference Request + +In this tutorial, we will show how to send an inference request to the +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model in 2 ways: + +* [Using the generate endpoint](#using-generate-endpoint) +* [Using the gRPC asyncio client](#using-grpc-asyncio-client) + +### Using the Generate Endpoint +After you start Triton with the sample model_repository, +you can quickly run your first inference request with the +[generate](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md) +endpoint. + +Start Triton's SDK container with the following command: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` + +Now, let's send an inference request: +``` +curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output":"What is Triton Inference Server?\n\nTriton Inference Server is a server that is used by many"} +``` + +### Using the gRPC Asyncio Client +Now, we will see how to run the client within Triton's SDK container +to issue multiple async requests using the +[gRPC asyncio client](https://github.com/triton-inference-server/client/blob/main/src/python/library/tritonclient/grpc/aio/__init__.py) +library. + +This method requires a +[client.py](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +script and a set of +[prompts](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt), +which are provided in the +[samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. + +Use the following command to download `client.py` and `prompts.txt` to your +current directory: +``` +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client.py +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt +``` + +Now, we are ready to start Triton's SDK container: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` + +Within the container, run +[`client.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +with: +``` +python3 client.py +``` + +The client reads prompts from the +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt) +file, sends them to Triton server for +inference, and stores the results into a file named `results.txt` by default. + +The output of the client should look like below: + +``` +Loading inputs from `prompts.txt`... +Storing results into `results.txt`... +PASS: vLLM example +``` + +You can inspect the contents of the `results.txt` for the response +from the server. The `--iterations` flag can be used with the client +to increase the load on the server by looping through the list of +provided prompts in +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt). + +When you run the client in verbose mode with the `--verbose` flag, +the client will print more details about the request/response transactions. + +## Limitations + +- We use decoupled streaming protocol even if there is exactly 1 response for each request. +- The asyncio implementation is exposed to model.py. +- Does not support providing specific subset of GPUs to be used. +- If you are running multiple instances of Triton server with +a Python-based vLLM backend, you need to specify a different +`shm-region-prefix-name` for each server. See +[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) +for more information. diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/1/model.py similarity index 100% rename from Quick_Deploy/vLLM/model_repository/vllm_model/1/model.py rename to Quick_Deploy/Customization/vLLM/model_repository/vllm_model/1/model.py diff --git a/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt similarity index 99% rename from Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt rename to Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt index 764df417..17c488f6 100644 --- a/Quick_Deploy/vLLM/model_repository/vllm_model/config.pbtxt +++ b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt @@ -93,6 +93,6 @@ parameters { parameters { key: "gpu_memory_utilization" value: { - string_value: "0.8" + string_value: "0.5" } } \ No newline at end of file diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index 5e6669cc..54f39327 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -41,40 +41,39 @@ backend. ## Step 1: Prepare your model repository -To use Triton, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is -included with this demo as `model_repository` directory. +To use Triton, we need to build a model repository. For this tutorial we will +use the model repository, provided in the [samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of the [vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. + +The following set of commands will create a `model_repository/vllm_model/1` +directory and copy 2 files: +[`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json) +and +[`config.pbtxt`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/config.pbtxt), +required to serve the [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model. +``` +mkdir -p model_repository/vllm_model/1 +wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/1/model.json +wget -P model_repository/vllm_model/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/config.pbtxt +``` The model repository should look like this: ``` model_repository/ └── vllm_model ├── 1 - │ └── model.py + │   └── model.json └── config.pbtxt ``` -The configuration of engineArgs is in config.pbtxt: - -``` -parameters { - key: "model" - value: { - string_value: "facebook/opt-125m", - } -} - -parameters { - key: "disable_log_requests" - value: { - string_value: "true" - } -} +The content of `model.json` is: -parameters { - key: "gpu_memory_utilization" - value: { - string_value: "0.8" - } +```json +{ + "model": "facebook/opt-125m", + "disable_log_requests": "true", + "gpu_memory_utilization": 0.5 } ``` @@ -85,15 +84,16 @@ and for supported key-value pairs. Inflight batching and paged attention is handled by the vLLM engine. -For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`config.pbtxt`](model_repository/vllm/config.pbtxt). +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified +in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). *Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. You can tweak this behavior using fields like `gpu_memory_utilization` and other settings -in [`config.pbtxt`](model_repository/vllm/config.pbtxt). +in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). -Read through the documentation in [`model.py`](model_repository/vllm/1/model.py) to understand how -to configure this sample for your use-case. +Read through the documentation in [`model.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/src/model.py) +to understand how to configure this sample for your use-case. ## Step 2: Launch Triton Inference Server @@ -214,4 +214,4 @@ the client will print more details about the request/response transactions. a Python-based vLLM backend, you need to specify a different `shm-region-prefix-name` for each server. See [here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) -for more information. +for more information. \ No newline at end of file From 9aa49d2091cedd135a478791eca9a913877ac8a3 Mon Sep 17 00:00:00 2001 From: activezhao Date: Fri, 5 Jan 2024 02:00:01 +0800 Subject: [PATCH 6/7] update README and adjust directory structure --- Quick_Deploy/Customization/vLLM/README.md | 49 ++++++++++++++++--- .../{vllm_model/1 => }/model.py | 2 - .../model_repository/vllm_model/config.pbtxt | 2 +- 3 files changed, 43 insertions(+), 10 deletions(-) rename Quick_Deploy/Customization/vLLM/model_repository/{vllm_model/1 => }/model.py (99%) diff --git a/Quick_Deploy/Customization/vLLM/README.md b/Quick_Deploy/Customization/vLLM/README.md index 2605d744..739c4a19 100644 --- a/Quick_Deploy/Customization/vLLM/README.md +++ b/Quick_Deploy/Customization/vLLM/README.md @@ -39,9 +39,18 @@ backend. *NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations). -## Step 1: Prepare your model repository -To use Triton, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is +## Step 1: Prepare Triton vllm_backend +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) has been released +as [xx.yy-vllm-python-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags) in Triton NGC, +where is the version of Triton vllm_backend, such as `23.10-vllm-python-py3`. + +You can just get the vllm_backend docker image above. + + +## Step 2: Prepare your model repository + +To use Triton vllm_backend, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using Triton vllm_backend is included with this demo as `model_repository` directory. The model repository should look like this: @@ -49,7 +58,6 @@ The model repository should look like this: model_repository/ └── vllm_model ├── 1 - │ └── model.py └── config.pbtxt ``` @@ -76,6 +84,7 @@ parameters { string_value: "0.5" } } + ``` This file can be modified to provide further settings to the vLLM engine. See vLLM @@ -92,23 +101,49 @@ This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. You can tweak this behavior using fields like `gpu_memory_utilization` and other settings in [`config.pbtxt`](model_repository/vllm_model/config.pbtxt). -Read through the documentation in [`model.py`](model_repository/vllm_model/1/model.py) to understand how +Read through the documentation in [`model.py`](model.py) to understand how to configure this sample for your use-case. -## Step 2: Launch Triton Inference Server +## Step 3: Launch Triton Inference Server Once you have the model repository setup, it is time to launch the triton server. Starting with 23.10 release, a dedicated container with vLLM pre-installed is available on [NGC.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) To use this container to launch Triton, you can use the docker command below. ``` -docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work nvcr.io/nvidia/tritonserver:-vllm-python-py3 tritonserver --model-store ./model_repository +docker run -idt -p 8000:8000 -p 8001:8001 -p 8002:8002 --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v ${PWD}:/model_repository nvcr.io/nvidia/tritonserver:-vllm-python-py3 tritonserver /bin/sh ``` Throughout the tutorial, \ is the version of Triton that you want to use. Please note, that Triton's vLLM container was first published in 23.10 release, so any prior version will not work. +Now, you can get the `CONTAINER ID`, and use the command to enter the container like this: +``` +docker exec -it CONTAINER_ID /bin/bash +``` + +Now, you can see the model repository in the container like this: +``` +model_repository/ +└── vllm_model + ├── 1 + └── config.pbtxt +``` + +And, you can see the vllm_backend in the container which path is `/opt/tritonserver/backends/vllm/model.py`, +you need to use [`model.py`](model.py) to replace the model.py in `/opt/tritonserver/backends/vllm`. + +If you want to get a new docker image, you can commit it like this: +``` +docker commit CONTAINER_ID nvcr.io/nvidia/tritonserver:-vllm-new-python-py3 +``` + +You need to start the Triton with command like this: +``` +/opt/tritonserver/bin/tritonserver --model-store=/model_repository +``` + After you start Triton you will see output on the console showing the server starting up and loading the model. When you see output like the following, Triton is ready to accept inference requests. @@ -119,7 +154,7 @@ I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 ``` -## Step 3: Use a Triton Client to Send Your First Inference Request +## Step 4: Use a Triton Client to Send Your First Inference Request In this tutorial, we will show how to send an inference request to the [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model in 2 ways: diff --git a/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/1/model.py b/Quick_Deploy/Customization/vLLM/model_repository/model.py similarity index 99% rename from Quick_Deploy/Customization/vLLM/model_repository/vllm_model/1/model.py rename to Quick_Deploy/Customization/vLLM/model_repository/model.py index cd77a0b9..fd5dd727 100644 --- a/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/1/model.py +++ b/Quick_Deploy/Customization/vLLM/model_repository/model.py @@ -27,7 +27,6 @@ import asyncio import json import threading -from typing import AsyncGenerator import numpy as np import triton_python_backend_utils as pb_utils @@ -36,7 +35,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.utils import random_uuid - class TritonPythonModel: def initialize(self, args): self.logger = pb_utils.Logger diff --git a/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt index 17c488f6..377142af 100644 --- a/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt +++ b/Quick_Deploy/Customization/vLLM/model_repository/vllm_model/config.pbtxt @@ -95,4 +95,4 @@ parameters { value: { string_value: "0.5" } -} \ No newline at end of file +} From 3098bd90ae02ebd050a88b311feaed9ebc1b1c6c Mon Sep 17 00:00:00 2001 From: activezhao Date: Fri, 5 Jan 2024 02:07:05 +0800 Subject: [PATCH 7/7] adjust directory structure --- Quick_Deploy/Customization/vLLM/{model_repository => }/model.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename Quick_Deploy/Customization/vLLM/{model_repository => }/model.py (100%) diff --git a/Quick_Deploy/Customization/vLLM/model_repository/model.py b/Quick_Deploy/Customization/vLLM/model.py similarity index 100% rename from Quick_Deploy/Customization/vLLM/model_repository/model.py rename to Quick_Deploy/Customization/vLLM/model.py