Skip to content

Commit 34736ae

Browse files
[ML] Migrate model_version to model_id when parsing persistent elser inference endpoints (elastic#124769) (elastic#124794)
* Handling model_version for prexisting endpoints * Update docs/changelog/124769.yaml (cherry picked from commit bf53f97) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
1 parent 18b4540 commit 34736ae

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

docs/changelog/124769.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pr: 124769
2+
summary: Migrate `model_version` to `model_id` when parsing persistent elser inference
3+
endpoints
4+
area: Machine Learning
5+
type: bug
6+
issues:
7+
- 124675

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.common.logging.DeprecationCategory;
1717
import org.elasticsearch.common.logging.DeprecationLogger;
1818
import org.elasticsearch.core.Nullable;
19+
import org.elasticsearch.core.Strings;
1920
import org.elasticsearch.core.TimeValue;
2021
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
2122
import org.elasticsearch.inference.ChunkingOptions;
@@ -91,6 +92,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
9192
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
9293
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
9394

95+
/**
96+
* Fix for https://github.com/elastic/elasticsearch/issues/124675
97+
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
98+
* service_settings.model_version.
99+
*/
100+
private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";
101+
94102
public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
95103
super(context);
96104
}
@@ -433,14 +441,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
433441
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
434442
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
435443

444+
migrateModelVersionToModelId(serviceSettingsMap);
445+
436446
ChunkingSettings chunkingSettings = null;
437447
if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) {
438448
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
439449
}
440450

441451
String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID);
442452
if (modelId == null) {
443-
throw new IllegalArgumentException("Error parsing request config, model id is missing");
453+
throw new IllegalArgumentException(
454+
Strings.format("Error parsing request config, model id is missing for inference id: %s", inferenceEntityId)
455+
);
444456
}
445457

446458
if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
@@ -472,6 +484,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
472484
}
473485
}
474486

487+
/**
488+
* Fix for https://github.com/elastic/elasticsearch/issues/124675
489+
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
490+
* service_settings.model_version. We need to look for that key and migrate it to model_id.
491+
*/
492+
private void migrateModelVersionToModelId(Map<String, Object> serviceSettingsMap) {
493+
if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) {
494+
String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class);
495+
serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId);
496+
}
497+
}
498+
475499
@Override
476500
public void checkModelConfig(Model model, ActionListener<Model> listener) {
477501
if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,30 @@ private ActionListener<Model> getElserModelVerificationActionListener(
688688

689689
public void testParsePersistedConfig() {
690690

691+
// Parsing a persistent configuration using model_version succeeds
692+
{
693+
var service = createService(mock(Client.class));
694+
var settings = new HashMap<String, Object>();
695+
settings.put(
696+
ModelConfigurations.SERVICE_SETTINGS,
697+
new HashMap<>(
698+
Map.of(
699+
ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
700+
1,
701+
ElasticsearchInternalServiceSettings.NUM_THREADS,
702+
4,
703+
"model_version",
704+
".elser_model_2"
705+
)
706+
)
707+
);
708+
709+
var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings);
710+
assertThat(model, instanceOf(ElserInternalModel.class));
711+
ElserInternalModel elserInternalModel = (ElserInternalModel) model;
712+
assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2"));
713+
}
714+
691715
// Null model variant
692716
{
693717
var service = createService(mock(Client.class));
@@ -706,11 +730,12 @@ public void testParsePersistedConfig() {
706730
)
707731
);
708732

709-
expectThrows(
733+
var exception = expectThrows(
710734
IllegalArgumentException.class,
711735
() -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings)
712736
);
713737

738+
assertThat(exception.getMessage(), containsString(randomInferenceEntityId));
714739
}
715740

716741
// Invalid model variant

0 commit comments

Comments
 (0)