Skip to content

Commit 43b00ca

Browse files
authored
Adding tokenize endpoint for embeddings calculator (#3717)
### 🛠 Summary [CVS-175207](https://jira.devtools.intel.com/browse/CVS-175207) Adding /tokenize endpoint to embeddings calculator ### 🧪 Checklist - [x] Unit tests added. - [x] The documentation updated. - [ ] Change follows security best practices. ``
1 parent 74d4416 commit 43b00ca

File tree

12 files changed

+913
-83
lines changed

12 files changed

+913
-83
lines changed

demos/embeddings/README.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,40 @@ Results will be stored in `results` folder:
503503
"kg_co2_emissions": null
504504
}
505505
```
506-
507506
Compare against local HuggingFace execution for reference:
508507
```console
509508
mteb run -m thenlper/gte-small -t Banking77Classification --output_folder results
509+
```
510+
511+
# Usage of tokenize endpoint (release 2025.4 or weekly)
512+
513+
The `tokenize` endpoint provides a simple API for tokenizing input text using the same tokenizer as the deployed embeddings model. This allows you to see how your text will be split into tokens before feature extraction or inference. The endpoint accepts a string or list of strings and returns the corresponding token IDs and tokenized text.
514+
515+
Example usage:
516+
```console
517+
curl http://localhost:8000/v3/tokenize -H "Content-Type: application/json" -d "{ \"model\": \"BAAI/bge-large-en-v1.5\", \"text\": \"hello world\" }"
518+
```
519+
Response:
520+
```json
521+
{
522+
"tokens": [101,7592,2088,102]
523+
}
510524
```
511525

526+
It's possible to use additional parameters:
527+
- pad_to_max_length - whether to pad the sequence to the maximum length. Default is False.
528+
- max_length - maximum length of the sequence. If None (default), the value will be taken from the IR (where default value from original HF/GGUF model is stored).
529+
- padding_side - side to pad the sequence, can be ‘left’ or ‘right’. Default is None.
530+
- add_special_tokens - whether to add special tokens like BOS, EOS, PAD. Default is True.
531+
532+
Example usage:
533+
```console
534+
curl http://localhost:8000/v3/tokenize -H "Content-Type: application/json" -d "{ \"model\": \"BAAI/bge-large-en-v1.5\", \"text\": \"hello world\", \"max_length\": 10, \"pad_to_max_length\": true, \"padding_side\": \"left\", \"add_special_tokens\": true }"
535+
```
512536

537+
Response:
538+
```json
539+
{
540+
"tokens":[0,0,0,0,0,0,101,7592,2088,102]
541+
}
542+
```

src/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2662,6 +2662,7 @@ cc_test(
26622662
":test_platform_utils",
26632663
"//src/rerank:rerank_api_handler",
26642664
":embeddings_handler_tests",
2665+
":tokenize_parser_tests",
26652666
":inferencerequest_test",
26662667
":libtest_environment",
26672668
":libtest_gpuenvironment",
@@ -3110,6 +3111,21 @@ cc_library(
31103111
linkopts = [],
31113112
)
31123113

3114+
cc_library(
3115+
name = "tokenize_parser_tests",
3116+
linkstatic = 1,
3117+
alwayslink = True,
3118+
srcs = ["test/tokenize_parser_test.cpp"],
3119+
data = [],
3120+
deps = [
3121+
"//src/tokenize:tokenize_parser",
3122+
"@com_google_googletest//:gtest",
3123+
],
3124+
copts = COPTS_TESTS,
3125+
linkopts = [],
3126+
)
3127+
3128+
31133129
cc_library(
31143130
name = "test_llm_output_parser_tests",
31153131
linkstatic = 1,

src/embeddings/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ovms_cc_library(
2222
hdrs = ["embeddings_api.hpp"],
2323
srcs = ["embeddings_api.cpp"],
2424
deps = ["//src:libovmslogging",
25+
"//src/tokenize:tokenize_parser",
2526
"@mediapipe//mediapipe/framework:calculator_framework",
2627
"//third_party:openvino",
2728
"@com_github_tencent_rapidjson//:rapidjson",],

src/embeddings/embeddings_api.cpp

Lines changed: 16 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -45,67 +45,26 @@ using namespace rapidjson;
4545
namespace ovms {
4646

4747
std::variant<EmbeddingsRequest, std::string> EmbeddingsRequest::fromJson(rapidjson::Document* parsedJson) {
48-
enum class InputType {
49-
NONE,
50-
STRING,
51-
INT,
52-
INT_VEC
53-
};
5448
EmbeddingsRequest request;
55-
std::vector<std::string> input_strings;
56-
std::vector<std::vector<int64_t>> input_tokens;
57-
5849
if (!parsedJson->IsObject())
5950
return "Received json is not an object";
6051

61-
auto it = parsedJson->FindMember("input");
62-
if (it != parsedJson->MemberEnd()) {
63-
if (it->value.IsString()) {
64-
input_strings.push_back(it->value.GetString());
65-
} else if (it->value.IsArray()) {
66-
if (it->value.GetArray().Size() == 0) {
67-
return "input array should not be empty";
68-
}
69-
InputType input_type = InputType::NONE;
70-
for (auto& input : it->value.GetArray()) {
71-
if (input.IsArray()) {
72-
if (input_type != InputType::NONE && input_type != InputType::INT_VEC)
73-
return "input must be homogeneous";
74-
input_type = InputType::INT_VEC;
75-
std::vector<int64_t> ints;
76-
ints.reserve(input.GetArray().Size());
77-
for (auto& val : input.GetArray()) {
78-
if (val.IsInt())
79-
ints.push_back(val.GetInt());
80-
else
81-
return "input must be homogeneous";
82-
}
83-
input_tokens.emplace_back(std::move(ints));
84-
} else if (input.IsString()) {
85-
if (input_type != InputType::NONE && input_type != InputType::STRING)
86-
return "input must be homogeneous";
87-
input_type = InputType::STRING;
88-
input_strings.push_back(input.GetString());
89-
} else if (input.IsInt()) {
90-
if (input_type != InputType::NONE && input_type != InputType::INT)
91-
return "input must be homogeneous";
92-
input_type = InputType::INT;
93-
if (input_tokens.size() == 0) {
94-
input_tokens.push_back(std::vector<int64_t>());
95-
}
96-
input_tokens[0].push_back(input.GetInt());
97-
} else {
98-
return "every element in input array should be either string or int";
99-
}
100-
}
52+
auto parsedInput = TokenizeParser::parseInput(*parsedJson, "input");
53+
54+
if (std::holds_alternative<std::string>(parsedInput)) {
55+
return std::get<std::string>(parsedInput);
56+
} else {
57+
auto inputVariant = std::get<EmbeddingsRequest::InputDataType>(parsedInput);
58+
if (std::holds_alternative<std::vector<std::string>>(inputVariant)) {
59+
request.input = std::get<std::vector<std::string>>(inputVariant);
60+
} else if (std::holds_alternative<std::vector<std::vector<int64_t>>>(inputVariant)) {
61+
request.input = std::get<std::vector<std::vector<int64_t>>>(inputVariant);
10162
} else {
102-
return "input should be string, array of strings or array of integers";
63+
return "input must be either array of strings or array of array of integers";
10364
}
104-
} else {
105-
return "input field is required";
10665
}
10766

108-
it = parsedJson->FindMember("encoding_format");
67+
auto it = parsedJson->FindMember("encoding_format");
10968
request.encoding_format = EncodingFormat::FLOAT;
11069
if (it != parsedJson->MemberEnd()) {
11170
if (it->value.IsString()) {
@@ -123,13 +82,6 @@ std::variant<EmbeddingsRequest, std::string> EmbeddingsRequest::fromJson(rapidjs
12382

12483
// TODO: dimensions (optional)
12584
// TODO: user (optional)
126-
if (input_strings.size() > 0) {
127-
request.input = input_strings;
128-
} else if (input_tokens.size() > 0) {
129-
request.input = input_tokens;
130-
} else {
131-
return "input field is required";
132-
}
13385
return request;
13486
}
13587

@@ -149,12 +101,15 @@ absl::Status EmbeddingsHandler::parseRequest() {
149101
return absl::OkStatus();
150102
}
151103

152-
std::variant<std::vector<std::string>, std::vector<std::vector<int64_t>>>& EmbeddingsHandler::getInput() {
104+
TokenizeRequest::InputDataType& EmbeddingsHandler::getInput() {
153105
return request.input;
154106
}
155107
EmbeddingsRequest::EncodingFormat EmbeddingsHandler::getEncodingFormat() const {
156108
return request.encoding_format;
157109
}
110+
ov::AnyMap& EmbeddingsHandler::getParameters() {
111+
return request.parameters;
112+
}
158113

159114
void EmbeddingsHandler::setPromptTokensUsage(int promptTokens) {
160115
this->promptTokens = promptTokens;

src/embeddings/embeddings_api.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@
3434
#include <rapidjson/stringbuffer.h>
3535
#pragma warning(pop)
3636

37+
#include "../tokenize/tokenize_parser.hpp"
38+
3739
namespace ovms {
3840

3941
enum class PoolingMode {
4042
CLS,
4143
LAST
4244
};
4345

44-
struct EmbeddingsRequest {
46+
struct EmbeddingsRequest : TokenizeRequest {
4547
enum class EncodingFormat {
4648
FLOAT,
4749
BASE64
4850
};
49-
std::variant<std::vector<std::string>, std::vector<std::vector<int64_t>>> input;
5051
EncodingFormat encoding_format;
5152

5253
static std::variant<EmbeddingsRequest, std::string> fromJson(rapidjson::Document* request);
@@ -61,8 +62,9 @@ class EmbeddingsHandler {
6162
EmbeddingsHandler(rapidjson::Document& document) :
6263
doc(document) {}
6364

64-
std::variant<std::vector<std::string>, std::vector<std::vector<int64_t>>>& getInput();
65+
TokenizeRequest::InputDataType& getInput();
6566
EmbeddingsRequest::EncodingFormat getEncodingFormat() const;
67+
ov::AnyMap& getParameters();
6668

6769
absl::Status parseRequest();
6870

src/embeddings/embeddings_calculator_ov.cc

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class EmbeddingsServable;
4949
namespace mediapipe {
5050

5151
const std::string EMBEDDINGS_SESSION_SIDE_PACKET_TAG = "EMBEDDINGS_NODE_RESOURCES";
52+
const std::string EMBEDDINGS_TOKENIZE_ENDPOINT_SUFFIX = "tokenize";
5253

5354
using InputDataType = ovms::HttpPayload;
5455
using OutputDataType = std::string;
@@ -62,6 +63,13 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
6263

6364
mediapipe::Timestamp timestamp{0};
6465

66+
absl::Status tokenizeStrings(ov::genai::Tokenizer& tokenizer, const std::vector<std::string>& inputStrings, const ov::AnyMap& parameters, ov::genai::TokenizedInputs& tokens, const size_t& max_context_length) {
67+
tokens = tokenizer.encode(inputStrings, parameters);
68+
RET_CHECK(tokens.input_ids.get_shape().size() == 2);
69+
70+
return absl::OkStatus();
71+
}
72+
6573
protected:
6674
std::shared_ptr<ovms::EmbeddingsServable> embeddings_session{nullptr};
6775

@@ -104,41 +112,77 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
104112
InputDataType payload = cc->Inputs().Tag(INPUT_TAG_NAME).Get<InputDataType>();
105113
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Request body: {}", payload.body);
106114
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Request uri: {}", payload.uri);
107-
ovms::EmbeddingsHandler handler(*payload.parsedJson);
108-
auto parseRequestStartTime = std::chrono::high_resolution_clock::now();
109-
absl::Status status = handler.parseRequest();
110-
if (!status.ok()) {
111-
return status;
112-
}
113-
double time = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - parseRequestStartTime).count();
114-
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Embeddings request deserialization time: {} ms", time / 1000);
115115

116116
ov::Tensor embeddingsTensor;
117117
size_t received_batch_size = 1;
118118
size_t max_context_length = 1024; // default allowed input length. Otherwise, it will be read from model config.json file
119-
ModelMetricReporter unused(nullptr, nullptr, "unused", 1);
120119
ov::genai::TokenizedInputs tokens;
121120
ov::Tensor typeIds;
122121
if (embeddings_session->getMaxModelLength().has_value()) {
123122
max_context_length = embeddings_session->getMaxModelLength().value();
124123
} else {
125124
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "max_position_embeddings nor max_trained_positions included in config.json. Using default value {}", max_context_length);
126125
}
126+
const int endpoint_len = EMBEDDINGS_TOKENIZE_ENDPOINT_SUFFIX.size();
127+
const bool useTokenizeEndpoint = payload.uri.size() >= endpoint_len &&
128+
payload.uri.compare(payload.uri.size() - endpoint_len, endpoint_len, EMBEDDINGS_TOKENIZE_ENDPOINT_SUFFIX) == 0;
129+
if (useTokenizeEndpoint) {
130+
ovms::TokenizeRequest tokenizeRequest;
131+
absl::Status parsingStatus = ovms::TokenizeParser::parseTokenizeRequest(*payload.parsedJson, tokenizeRequest);
132+
if (!parsingStatus.ok()) {
133+
return parsingStatus;
134+
}
135+
auto input = tokenizeRequest.input;
136+
if (auto strings = std::get_if<std::vector<std::string>>(&input)) {
137+
auto tokenizationStatus = this->tokenizeStrings(embeddings_session->getTokenizer(), *strings, tokenizeRequest.parameters, tokens, max_context_length);
138+
if (!tokenizationStatus.ok()) {
139+
return tokenizationStatus;
140+
}
141+
} else {
142+
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Embeddings tokenize input is of not supported type");
143+
return absl::InvalidArgumentError("Input should be string or array of strings");
144+
}
145+
146+
StringBuffer responseBuffer;
147+
auto responseStatus = ovms::TokenizeParser::parseTokenizeResponse(responseBuffer, tokens, tokenizeRequest.parameters);
148+
if (!responseStatus.ok()) {
149+
return responseStatus;
150+
}
151+
cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(new std::string(responseBuffer.GetString()), timestamp);
152+
return absl::OkStatus();
153+
}
154+
ovms::EmbeddingsHandler handler(*payload.parsedJson);
155+
auto parseRequestStartTime = std::chrono::high_resolution_clock::now();
156+
absl::Status status = handler.parseRequest();
157+
158+
if (!status.ok()) {
159+
return status;
160+
}
161+
double time = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - parseRequestStartTime).count();
162+
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Embeddings request deserialization time: {} ms", time / 1000);
163+
164+
ModelMetricReporter unused(nullptr, nullptr, "unused", 1);
165+
127166
try {
128167
auto input = handler.getInput();
129168
if (auto strings = std::get_if<std::vector<std::string>>(&input)) {
169+
ov::AnyMap& params = handler.getParameters();
130170
received_batch_size = strings->size();
131-
ov::AnyMap params = {};
132-
if (cc->Options<EmbeddingsCalculatorOVOptions>().truncate()) {
133-
params = {{"max_length", max_context_length}};
171+
if (cc->Options<EmbeddingsCalculatorOVOptions>().truncate() && params.find("max_length") == params.end()) {
172+
params["max_length"] = max_context_length;
134173
}
135-
tokens = embeddings_session->getTokenizer().encode(*strings, params);
136-
RET_CHECK(tokens.input_ids.get_shape().size() == 2);
174+
175+
absl::Status tokenizationStatus = this->tokenizeStrings(embeddings_session->getTokenizer(), *strings, params, tokens, max_context_length);
176+
if (!tokenizationStatus.ok()) {
177+
return tokenizationStatus;
178+
}
179+
137180
size_t input_ids_size = tokens.input_ids.get_shape()[1];
138181
if (input_ids_size > max_context_length) {
139182
SPDLOG_LOGGER_DEBUG(embeddings_calculator_logger, "Input size {} exceeds max_context_length {}", input_ids_size, max_context_length);
140-
return absl::InvalidArgumentError(absl::StrCat("Input length ", input_ids_size, " longer than allowed ", max_context_length));
183+
return absl::InvalidArgumentError("Input length " + std::to_string(input_ids_size) + " longer than allowed " + std::to_string(max_context_length));
141184
}
185+
142186
if (embeddings_session->getNumberOfModelInputs() == 3) {
143187
typeIds = ov::Tensor{ov::element::i64, tokens.input_ids.get_shape()};
144188
std::fill_n(typeIds.data<int64_t>(), tokens.input_ids.get_size(), 0);

src/test/embeddings_handler_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ TEST(EmbeddingsDeserialization, invalidEncoding) {
252252
{
253253
"model": "embeddings",
254254
"input": ["one", "three"],
255-
"encoding_format": "dummy"
255+
"encoding_format": "dummy"
256256
}
257257
)";
258258
rapidjson::Document d;
@@ -269,7 +269,7 @@ TEST(EmbeddingsDeserialization, invalidEncodingType) {
269269
{
270270
"model": "embeddings",
271271
"input": ["one", "three"],
272-
"encoding_format": 42
272+
"encoding_format": 42
273273
}
274274
)";
275275
rapidjson::Document d;
@@ -340,7 +340,7 @@ TEST(EmbeddingsDeserialization, multipleStringInputFloat) {
340340
{
341341
"model": "embeddings",
342342
"input": ["one", "two", "three"],
343-
"encoding_format": "float"
343+
"encoding_format": "float"
344344
}
345345
)";
346346
rapidjson::Document d;

0 commit comments

Comments
 (0)