Skip to main content

How to Deploy HuggingFace Translation Models on GPU Servers

ยท 13 min read
Markus Hennerbichler

Ever since the release of the HuggingFace๐Ÿค— Transformers library, it has been incredibly simple to train, finetune and run state-of-the-art Transformer-based translation models. This has also accelerated the development of our recently launched Translation feature. However, deploying these models in a production setting on GPU servers is still not straightforward, so I want to share how we at Speechmatics were able to deploy a performant real-time translation service for more than 30 languages and open-sourced part of our solution in the process.

Why Do We Need Inference Servers?โ€‹

There are countless trained translation models available on the HuggingFace Hub, contributed by academia, companies and even individuals. Using them for inference in Python is straightforward. Depending on the specific model used, doing so is a variation of the following three lines:

from transformers import pipeline
en_fr_translator = pipeline("translation_en_to_fr")
en_fr_translator("How old are you?")

This is easy enough for on-device use cases and wherever a Python interpreter is available, but doesn't scale well with traffic when using GPUs. The reason for this is that GPUs only unleash their full potential when they do a lot of work in parallel. For translation one way to build batches could be to consider each sentence in a document as one sample and batch over sentences in the document. In a latency sensitive application, such as real-time translation, we can't afford to wait for enough data to maximise the batch size though. In this scenario we can only form batches by having a server that can batch requests across multiple parallel sessions.

One such inference server is NVIDIAs Triton which was specifically made for CPU and GPU inference. At Speechmatics, we are already relying heavily on Triton for our ASR workloads

Using the Python Backendโ€‹

Triton supports multiple Neural Network inference engines and frameworks as backends and also allows for custom Python backends. Since we are using the transformers Python package, the natural choice is to use the Python backend. A minimal implementation needs a class called TritonPythonModel with an initialize method for loading the model and an execute method for executing a batch. Even though Triton has some support for string input and output types, it's preferred to directly use integer tokens as input, for example to translate into German with NLLB:

from transformers import AutoModelForSeq2SeqLM
import numpy as np
import torch
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
def initialize(self, args):
self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to("cuda")

def execute(self, requests: list):
batch_sizes, input_ids, attention_mask = build_input(requests)
responses = []
translated_tokens = self.model.generate(input_ids=input_ids,
attention_mask=attention_mask,
forced_bos_token_id=256042 # German Language token
).to("cpu")

start = 0
for batch_shape in batch_sizes:
out_tensor = pb_utils.Tensor(
"OUTPUT_IDS", translated_tokens[start : start + batch_shape[0], :].numpy().astype(np.int32)
)
start += batch_shape[0]
responses.append(pb_utils.InferenceResponse(output_tensors=[out_tensor]))

return responses

To make the best possible use of batching, we have to configure dynamic and ragged batching. Ragged batching means we allow Triton to batch together requests with different sequence lengths, and thus we have to manually concatenate and pad the inputs before inference:

def build_input(requests: list):
batch_sizes = [np.shape(pb_utils.get_input_tensor_by_name(request, "INPUT_IDS").as_numpy()) for request in requests]
max_len = np.max([bs[1] for bs in batch_sizes])
input_ids = torch.tensor(np.concatenate([np.pad(
pb_utils.get_input_tensor_by_name(request, "INPUT_IDS").as_numpy(),
((0, 0), (0, max_len - batch_size[1])),
) for batch_size, request in zip(batch_sizes, requests)], axis=0,)
).to("cuda")
attention_mask = torch.tensor(
(
np.arange(max_len).repeat(len(requests)).reshape(max_len, len(requests))
< [bs[1] for bs in batch_sizes]
).T
).to("cuda")
return batch_sizes, input_ids, attention_mask

We need to save this Python script into: models/nllb/1/model.py. Every Triton model requires a configuration, in a Protobuf text file. This is used for specifying inputs and outputs:

backend: "python"
max_batch_size: 128 # can be optimised based on available GPU memory
name: "nllb" # needed for reference in the client
input [
{
name: "INPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
allow_ragged_batch: true
}
]
output [
{
name: "OUTPUT_IDS"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
instance_group [{ kind: KIND_GPU }]
dynamic_batching {
max_queue_delay_microseconds: 5000
}

max_queue_delay_microseconds is a parameter for dynamic batching that can be tuned for trading off latency vs throughput. This config needs to be saved under models/nllb/config.pbtxt. Now we can start a Triton Server docker container using one GPU:

docker run --gpus=1 --rm -p8000:8000 -p8001:8001 -p8002:8002 -v $(pwd)/models:/models nvcr.io/nvidia/tritonserver:23.07-py3 tritonserver --model-repository=/models

Sending Requestsโ€‹

Triton Server provides both a gRPC as well as a HTTP API. We can either directly use those or use any of the offical client libraries. Here is an example on how to make a translation inference request using the Python gRPC client:

import asyncio
import tritonclient.grpc.aio
from tritonclient.utils import np_to_triton_dtype
from grpc import ChannelConnectivity
from transformers import AutoTokenizer
import numpy as np

async def main():
MODEL_NAME = "nllb"
client = tritonclient.grpc.aio.InferenceServerClient("127.0.0.1:8001")
en_text = "Hello World"
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="en")

input_ids = tokenizer(en_text, return_attention_mask=False, return_tensors="np").input_ids.astype(np.int32)
print(f"Tokenised input: {input_ids}")

inputs = [
tritonclient.grpc.aio.InferInput("INPUT_IDS", input_ids.shape, np_to_triton_dtype(input_ids.dtype)),
]
inputs[0].set_data_from_numpy(input_ids)
outputs = [tritonclient.grpc.aio.InferRequestedOutput("OUTPUT_IDS")]

res = await client.infer(model_name=MODEL_NAME, inputs=inputs, outputs=outputs)
out_tokens = res.as_numpy("OUTPUT_IDS")
print(f"Returned tokens: {out_tokens}")
translated_text = tokenizer.batch_decode(out_tokens)
print(translated_text)

if __name__ == "__main__":
asyncio.run(main())

In this example we are using the HuggingFace tokenizer to encode the string into tokens, which we then send as INT32 type and decode the resulting output tokens back into strings. Multiple clients can now make requests in parallel and Triton will create appropriate batches that are passed to our models execute method.

Limitations of the Python Backendโ€‹

Using the Python backend is as simple as wrapping the HuggingFace inference code with Triton specific methods for constructing input and output. This works well for a small number of clients and models, but has limitations when we want to optimise throughput or use many models on a single server. Python's standard implementation CPython guarantees thread-safety in the interpreter with the Global Interpreter Lock (GIL)1. This effectively means that only one CPU can run Python bytecode at a time. This doesn't play well with Triton's thread-based architecture, where each model is running in its own thread. To support multiple Python models in parallel Triton spawns one process per model instance, each allocating a fixed amount of host memory for the interpreter and GPU memory for the CUDA context. This becomes problematic when serving multiple models from the same GPU, because of the fixed host and GPU memory overhead per model.

The one-process-per-model also causes issues with PyTorch, the default underlying ML framework of ๐Ÿค— Transformer. PyTorch relies on a caching memory allocator for speeding up memory allocations on the GPU. Instead of freeing memory after use, it goes back to a reserved memory pool, which unfortunately isn't shared between the different processes. If we have two model instances, both will keep unused memory allocated and thus decrease the total memory that can be used by both of them. To free the memory we need to call torch.cuda.empty_cache() at the end of every execute call, eliminating the speed improvements the caching allocator provides.

Directly Using the PyTorch Modelsโ€‹

To address the limitations in the previous section, we can drop the HuggingFace abstraction and go straight to the underlying PyTorch model with the PyTorch backend. The PyTorch backend utilises libtorch, the standalone C++ library that works without a Python interpreter, thus avoiding the Python backend's GIL-related problems. This independence from Python comes with a trade-off: libtorch can only execute models that are converted into TorchScript. TorchScript is a custom language that supports a subset of Python and is just-in-time compiled by libtorch. There are two ways to convert a PyTorch model to TorchScript:

  • Scripting: a TorchScript program is derived straight from the Python source code, which only works if the model is using only the subset of Python that TorchScript supports.
  • Tracing: given the model and sample data, the TorchScript program is derived from the actual PyTorch functions that are recorded while executing the model with given inputs. All Python language features can be used with tracing, but it comes with the drawback of not supporting any dynamic control-flow, due to only recording what's executed for a given input.

Given these two approaches, each with their own set of drawbacks, it is clear that not every PyTorch model can be converted to TorchScript out-of-the-box and rewriting them to be convertible often comes at the cost of code quality. ๐Ÿค— Transformer use dynamic control flow in the various text decoding algorithm which is generally not written in the allowed Python subset, ruling out both conversion methods, unless we are willing to rewrite most of the inference code.

One last option to convert the models is to dissect the pre-baked model into its encoder and decoder modules and trace those separately. With both models deployed using the PyTorch backend, we need to write our own generate equivalent, that uses the models on Triton instead of the embedded encoder & decoder modules for text generation. In the simplest case, this means passing the inputs through the encoder and then iteratively calling the decoder with the encoder-outputs and previously translated tokens as input. This can be either done on the client, directly calling the models and passing outputs over the network, or using Triton's Business Logic Scripting (BLS), to avoid unnecessary roundtrips. BLS allows writing custom logic for orchestrating other models directly in Triton. This approach can make efficient use of PyTorch for neural network inference, but comes with the disadvantage of having to write a custom generation logic.

Using the CTranslate2 Libraryโ€‹

Both options explored so far use Python and PyTorch either implicitly or explicitly to do inference with the translation models. PyTorch is a very capable general Tensor and Neural Network library that supports anything from training large language models on hundreds of GPUs to real-time inference on a Raspberry Pi. However, when optimising for performance it often makes more sense to use specialised solutions. One such specialised solution for neural machine translation with Transformer models is CTranslate2, developed by the OpenNMT team. It comes with an efficient custom transformer inference engine and a variety of decode algorithms written in C++ and CUDA for GPU acceleration, which makes it in spirit very similar to ggml that is used in the successful llama.cpp project. CTranslate2 claims 2-4x faster inference on both CPU and GPU while using less memory compared to PyTorch, which we also observed on our own setup. The speedups varied by model though and we generally saw bigger improvements for smaller models.

Using any inference library with Triton requires either a Python library for the Python backend or a custom backend. Given the problems that come with Python, we decided implement a backend that wraps CTranslate2 and handles the integration with Triton. We open-sourced our code under a permissive license at speechmatics/ctranslate2_triton_backend .

While the Python as well as the PyTorch backends are maintained officially by NVIDIA, anyone is able to create their own by creating a shared library that implements the backend API. Triton looks up backends based on the backend property in the config, searching for a matching shared library (.so) in the model directory or the special "backends" directory. The Speechmatics ctranslate2 backend thus needs to be built and installed first:

TRITON_VERSION=23.07
git clone https://github.com/speechmatics/ctranslate2_triton_backend
cd ctranslate2_triton_backend
mkdir ctranslate2_triton_backend/build
cd ctranslate2_triton_backend/build
cmake .. \
-DCMAKE_INSTALL_PREFIX:PATH=/opt/tritonserver \
-DTRITON_ENABLE_GPU=ON -DCMAKE_BUILD_TYPE=Release \
-DTRITON_COMMON_REPO_TAG=r$TRITON_VERSION \
-DTRITON_CORE_REPO_TAG=r$TRITON_VERSION \
-DTRITON_BACKEND_REPO_TAG=r$TRITON_VERSION
make -j install

This installs the backend under the default Triton path /opt/tritonserver/backends.

Deploying a CTranslate2 model is similar to deploying the Python backend, meaning we need to create a model repository with a Protobuf configuration and the actual model file. CTranslate2 uses a custom file format for model loading, and it comes with a Python tool to convert models straight from HuggingFace. For example to convert the NLLB model:

pip install ctranslate2 # # only required for model conversion
mkdir $MODEL_DIR/nllb/1 # name must match `model_name` config
ct2-transformers-converter --model facebook/nllb-200-distilled-600M --output_dir $MODEL_DIR/nllb/1/model

The default configuration is the same as for the Python backend, with the only modification being the change of backend: "ctranslate2", except for multi-lingual models such as NLLB we need to add an input for the target langage: config.pbtxt, which can be added in the client with:

bos_ids = np.array(tokenizer.lang_code_to_id["deu_Latn"])
inputs.append(
tritonclient.grpc.aio.InferInput("TARGET_PREFIX", bos_ids.shape, np_to_triton_dtype(bos_ids.dtype))
)
inputs[1].set_data_from_numpy(bos_ids)

Additionally some inference and decoding options of CTranslate2 are exposed via Parameters, details can be found in the README.md.

Comparing the Performance of Python and CTranslate2 Backendsโ€‹

Here, we compare the performance of the Python and CTranslate2 backends. The specific scenario we investigate is when we have multiple (smaller) models on a single GPU. For example this arises when we want to translate a live-stream into multiple languages at once. We simulate this with a multiple concurrent sessions, all sending Triton requests to 4 models at once.

The above graph shows throughput comparison (in translated words per second) for the Python and CTranslate2 implementations at different concurrency levels. As expected, the CTranslate2 backend results in a notable improvement in throughput. While at four streams we got a 2.13x throughput increase with the CTranslate2 backend, we reported even bigger gains on more streams. At 48 streams the increase stands at around 2.87x. This is mostly due to the lower memory use of the CTranslate2 backend and not wasting memory on multiple cache pools that the independent PyTorch processes create. The CTranslate2 library thus allows us to serve more clients on a single GPU and scales better with more traffic.

Conclusionโ€‹

As shown, the simple approach of deploying HuggingFace Transformers onto Triton creates scaling problems due to limitations of Python. Unfortunately using the PyTorch backend directly is also not an option without making significant changes to the HuggingFace source code. In our case we solved the problem by leveraging CTranslate2, a specialised inference library for translation models, which resulted in significant throughput gains on data that is representative of the traffic we see.


  1. This might not be true for much longer. In July 2023 the Python Steering Council agreed to make the GIL optional in future versions.โ†ฉ