diff --git a/benchmarks/README.md b/benchmarks/README.md index af72d16d2ad4b..c5b8f5b9d2321 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -152,11 +152,10 @@ exit Grab the async profiler from https://github.com/jvm-profiling-tools/async-profiler and run `prof async` like so: ``` -gradlew -p benchmarks/ run --args 'LongKeyedBucketOrdsBenchmark.multiBucket -prof "async:libPath=/home/nik9000/Downloads/async-profiler-3.0-29ee888-linux-x64/lib/libasyncProfiler.so;dir=/tmp/prof;output=flamegraph"' +gradlew -p benchmarks/ run --args 'LongKeyedBucketOrdsBenchmark.multiBucket -prof "async:libPath=/home/nik9000/Downloads/async-profiler-4.0-linux-x64/lib/libasyncProfiler.so;dir=/tmp/prof;output=flamegraph"' ``` -Note: As of January 2025 the latest release of async profiler doesn't work - with our JDK but the nightly is fine. +Note: As of July 2025 the 4.0 release of the async profiler works well. If you are on Mac, this'll warn you that you downloaded the shared library from the internet. You'll need to go to settings and allow it to run. diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java index e6f6226111888..94483a136a5d2 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/ValuesSourceReaderBenchmark.java @@ -24,8 +24,10 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NumericUtils; import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -85,6 +87,10 @@ @State(Scope.Thread) @Fork(1) public class ValuesSourceReaderBenchmark { + static { + LogConfigurator.configureESLogging(); + } + private static final String[] SUPPORTED_LAYOUTS = new String[] { "in_order", "shuffled", "shuffled_singles" }; private static final String[] SUPPORTED_NAMES = new String[] { "long", @@ -345,6 +351,7 @@ public FieldNamesFieldMapper.FieldNamesFieldType fieldNames() { public void benchmark() { ValuesSourceReaderOperator op = new ValuesSourceReaderOperator( blockFactory, + ByteSizeValue.ofMb(1).getBytes(), fields(name), List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> { throw new UnsupportedOperationException("can't load _source here"); diff --git a/docs/changelog/130092.yaml b/docs/changelog/130092.yaml new file mode 100644 index 0000000000000..0e54e5f013d23 --- /dev/null +++ b/docs/changelog/130092.yaml @@ -0,0 +1,5 @@ +pr: 130092 +summary: "Added Llama provider support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/130855.yaml b/docs/changelog/130855.yaml new file mode 100644 index 0000000000000..ee95181f033de --- /dev/null +++ b/docs/changelog/130855.yaml @@ -0,0 +1,6 @@ +pr: 130855 +summary: Add checks that optimizers do not modify the layout +area: ES|QL +type: enhancement +issues: + - 125576 diff --git a/docs/changelog/131053.yaml b/docs/changelog/131053.yaml new file mode 100644 index 0000000000000..b30a7c8ee8cc5 --- /dev/null +++ b/docs/changelog/131053.yaml @@ -0,0 +1,5 @@ +pr: 131053 +summary: Split large pages on load sometimes +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/131395.yaml b/docs/changelog/131395.yaml new file mode 100644 index 0000000000000..500b761be1472 --- /dev/null +++ b/docs/changelog/131395.yaml @@ -0,0 +1,5 @@ +pr: 131395 +summary: Enable failure store for newly created OTel data streams +area: Data streams +type: enhancement +issues: [] diff --git a/docs/changelog/131426.yaml b/docs/changelog/131426.yaml new file mode 100644 index 0000000000000..4f79415ba069d --- /dev/null +++ b/docs/changelog/131426.yaml @@ -0,0 +1,6 @@ +pr: 131426 +summary: Disallow remote enrich after lu join +area: ES|QL +type: bug +issues: + - 129372 diff --git a/docs/changelog/131442.yaml b/docs/changelog/131442.yaml new file mode 100644 index 0000000000000..23d00cd7d028d --- /dev/null +++ b/docs/changelog/131442.yaml @@ -0,0 +1,5 @@ +pr: 131442 +summary: Track inference deployments +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/131510.yaml b/docs/changelog/131510.yaml new file mode 100644 index 0000000000000..ccdd727fdc818 --- /dev/null +++ b/docs/changelog/131510.yaml @@ -0,0 +1,5 @@ +pr: 131510 +summary: Upgrade apm-agent to 1.55.0 +area: Infra/Metrics +type: upgrade +issues: [] diff --git a/docs/changelog/131525.yaml b/docs/changelog/131525.yaml new file mode 100644 index 0000000000000..233c4ff643643 --- /dev/null +++ b/docs/changelog/131525.yaml @@ -0,0 +1,6 @@ +pr: 131525 +summary: Fix semantic highlighting bug on flat quantized fields +area: Highlighting +type: bug +issues: + - 131443 diff --git a/docs/changelog/131541.yaml b/docs/changelog/131541.yaml new file mode 100644 index 0000000000000..5c658c194385a --- /dev/null +++ b/docs/changelog/131541.yaml @@ -0,0 +1,5 @@ +pr: 131541 +summary: Added Sample operator `NamedWritable` to plugin +area: ES|QL +type: bug +issues: [] diff --git a/docs/reference/elasticsearch/index-settings/slow-log.md b/docs/reference/elasticsearch/index-settings/slow-log.md index 20b416360a1ca..cb8f6d05e8fbf 100644 --- a/docs/reference/elasticsearch/index-settings/slow-log.md +++ b/docs/reference/elasticsearch/index-settings/slow-log.md @@ -20,6 +20,7 @@ Events that meet the specified threshold are emitted into [{{es}} logging](docs- * If [{{es}} monitoring](docs-content://deploy-manage/monitor/stack-monitoring.md) is enabled, from [Stack Monitoring](docs-content://deploy-manage/monitor/monitoring-data/visualizing-monitoring-data.md). Slow log events have a `logger` value of `index.search.slowlog` or `index.indexing.slowlog`. * From local {{es}} service logs directory. Slow log files have a suffix of `_index_search_slowlog.json` or `_index_indexing_slowlog.json`. +See this [this video](https://www.youtube.com/watch?v=ulUPJshB5bU) for a walkthrough of setting and reviewing slow logs. ## Slow log format [slow-log-format] diff --git a/docs/reference/elasticsearch/jvm-settings.md b/docs/reference/elasticsearch/jvm-settings.md index eb334fd6ee2be..838a3ad312a07 100644 --- a/docs/reference/elasticsearch/jvm-settings.md +++ b/docs/reference/elasticsearch/jvm-settings.md @@ -87,10 +87,18 @@ To override the default heap size, set the minimum and maximum heap size setting The heap size should be based on the available RAM: -* Set `Xms` and `Xmx` to no more than 50% of your total memory. {{es}} requires memory for purposes other than the JVM heap. For example, {{es}} uses off-heap buffers for efficient network communication and relies on the operating system’s filesystem cache for efficient access to files. The JVM itself also requires some memory. It’s normal for {{es}} to use more memory than the limit configured with the `Xmx` setting. +* Set `Xms` and `Xmx` to no more than 50% of the total memory available to each {{es}} node. {{es}} requires memory for purposes other than the JVM heap. For example, {{es}} uses off-heap buffers for efficient network communication and relies on the operating system’s filesystem cache for efficient access to files. The JVM itself also requires some memory. It’s normal for {{es}} to use more memory than the limit configured with the `Xmx` setting. ::::{note} - When running in a container, such as [Docker](docs-content://deploy-manage/deploy/self-managed/install-elasticsearch-with-docker.md), total memory is defined as the amount of memory visible to the container, not the total system memory on the host. + When running in a container, such as [Docker](docs-content://deploy-manage/deploy/self-managed/install-elasticsearch-with-docker.md), the total memory available to {{es}} means the amount of memory available within the container, not the total system memory on the host. + + If you are running multiple {{es}} nodes on the same host, or in the same container, the total of all the nodes' heap sizes should not exceed 50% of the total available memory. + + Account for the memory usage of other processes running on the same host, or in the same container, when computing the total memory available to {{es}}. + + The 50% guideline is intended as a safe upper bound on the heap size. You may find that heap sizes smaller than this maximum offer better performance, for instance by allowing your operating system to use a larger filesystem cache. + + If you set the heap size too large, {{es}} may perform poorly and nodes may be terminated by the operating system. :::: * Set `Xms` and `Xmx` to no more than the threshold for compressed ordinary object pointers (oops). The exact threshold varies but 26GB is safe on most systems and can be as large as 30GB on some systems. To verify you are under the threshold, check the {{es}} log for an entry like this: diff --git a/docs/reference/elasticsearch/mapping-reference/mapping-field-meta.md b/docs/reference/elasticsearch/mapping-reference/mapping-field-meta.md index 33a467e6fbc43..cdcafd43a0877 100644 --- a/docs/reference/elasticsearch/mapping-reference/mapping-field-meta.md +++ b/docs/reference/elasticsearch/mapping-reference/mapping-field-meta.md @@ -24,7 +24,8 @@ PUT my-index-000001 ``` ::::{note} -Field metadata enforces at most 5 entries, that keys have a length that is less than or equal to 20, and that values are strings whose length is less than or equal to 50. +Field metadata enforces at most 5 entries, that keys have a length that is less than or equal to 20, and that values are strings whose length is less than or equal to 500. +The value limit is configurable, with the index setting: `index.mapping.meta.length_limit`. :::: diff --git a/docs/reference/elasticsearch/security-privileges.md b/docs/reference/elasticsearch/security-privileges.md index 6e0fd02b2c3b8..ae3438cec16c6 100644 --- a/docs/reference/elasticsearch/security-privileges.md +++ b/docs/reference/elasticsearch/security-privileges.md @@ -286,22 +286,20 @@ This section lists the privileges that you can assign to a role. `create` : Privilege to index documents. - :::{admonition} Deprecated in 8.0 - Also grants the permission to update the index mapping (but not the data streams mapping), using the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or by relying on [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). In a future major release, this privilege will not grant any mapping update permissions. - ::: - ::::{note} This privilege does not restrict the index operation to the creation of documents but instead restricts API use to the index API. The index API allows a user to overwrite a previously indexed document. See the `create_doc` privilege for an alternative. :::: + :::{important} + Starting from 8.0, this privilege no longer grants the permission to update index mappings. + In earlier versions, it implicitly permitted index mapping updates (excluding data stream mappings) via the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or through [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). + Mapping update capabilities will be fully removed in a future major release. + ::: + `create_doc` : Privilege to index documents. It does not grant the permission to update or overwrite existing documents. - :::{admonition} Deprecated in 8.0 - Also grants the permission to update the index mapping (but not the data streams mapping), using the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or by relying on [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). In a future major release, this privilege will not grant any mapping update permissions. - ::: - ::::{note} This privilege relies on the `op_type` of indexing requests ([Index](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-create) and [Bulk](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-bulk)). When ingesting documents as a user who has the `create_doc` privilege (and no higher privilege such as `index` or `write`), you must ensure that *op_type* is set to *create* through one of the following: @@ -311,6 +309,12 @@ This section lists the privileges that you can assign to a role. :::: + :::{important} + Starting from 8.0, this privilege no longer grants the permission to update index mappings. + In earlier versions, it implicitly permitted index mapping updates (excluding data stream mappings) via the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or through [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). + Mapping update capabilities will be fully removed in a future major release. + ::: + `create_index` : Privilege to create an index or data stream. A create index request may contain aliases to be added to the index once created. In that case the request requires the `manage` privilege as well, on both the index and the aliases names. @@ -340,8 +344,10 @@ This section lists the privileges that you can assign to a role. `index` : Privilege to index and update documents. - :::{admonition} Deprecated in 8.0 - Also grants the permission to update the index mapping (but not the data streams mapping), using the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or by relying on [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). In a future major release, this privilege will not grant any mapping update permissions. + :::{important} + Starting from 8.0, this privilege no longer grants the permission to update index mappings. + In earlier versions, it implicitly permitted index mapping updates (excluding data stream mappings) via the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or through [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). + Mapping update capabilities will be fully removed in a future major release. ::: `maintenance` @@ -389,8 +395,10 @@ This section lists the privileges that you can assign to a role. `write` : Privilege to perform all write operations to documents, which includes the permission to index, update, and delete documents as well as performing bulk operations, while also allowing to dynamically update the index mapping. - :::{admonition} Deprecated in 8.0 - It also grants the permission to update the index mapping (but not the data streams mapping), using the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping). This will be retracted in a future major release. + :::{important} + Starting from 8.0, this privilege no longer grants the permission to update index mappings. + In earlier versions, it implicitly permitted index mapping updates (excluding data stream mappings) via the [updating mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping) or through [dynamic field mapping](docs-content://manage-data/data-store/mapping/dynamic-mapping.md). + Mapping update capabilities will be fully removed in a future major release. ::: ## Run as privilege [_run_as_privilege] diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md index 511ced1094f91..59fdd3c54e1dd 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/score.md @@ -3,5 +3,5 @@ **Parameters** `query` -: (combinations of) full text function(s). +: Boolean expression that contains full text function(s) to be scored. diff --git a/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md b/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md index ed855deb23910..3dd13763aeb2e 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md +++ b/docs/reference/query-languages/esql/_snippets/operators/detailedDescription/rlike.md @@ -18,7 +18,7 @@ ROW message = "foo ( bar" ``` ```{applies_to} -stack: ga 9.1 +stack: ga 9.2 serverless: ga ``` diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/score.json b/docs/reference/query-languages/esql/kibana/definition/functions/score.json index c9b5e22a02e4c..4772093e349d9 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/score.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/score.json @@ -10,7 +10,7 @@ "name" : "query", "type" : "boolean", "optional" : false, - "description" : "(combinations of) full text function(s)." + "description" : "Boolean expression that contains full text function(s) to be scored." } ], "variadic" : false, diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 0eedfc37ca1a9..bb4ae5da279fb 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -71,9 +71,9 @@ - - - + + + diff --git a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java index be13207702627..c3b322db0e3a5 100644 --- a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java +++ b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java @@ -220,6 +220,27 @@ public void declareField(BiConsumer consumer, ContextParser void declareObjectArrayOrNull( + BiConsumer> consumer, + ContextParser objectParser, + ParseField field + ) { + declareField( + consumer, + (p, c) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : parseArray(p, c, objectParser), + field, + ValueType.OBJECT_ARRAY_OR_NULL + ); + } + @Override public void declareNamedObject( BiConsumer consumer, diff --git a/modules/apm/build.gradle b/modules/apm/build.gradle index 86d06258bcbca..37d42e4b3fb0c 100644 --- a/modules/apm/build.gradle +++ b/modules/apm/build.gradle @@ -20,7 +20,7 @@ dependencies { implementation "io.opentelemetry:opentelemetry-api:${otelVersion}" implementation "io.opentelemetry:opentelemetry-context:${otelVersion}" implementation "io.opentelemetry:opentelemetry-semconv:${otelSemconvVersion}" - runtimeOnly "co.elastic.apm:elastic-apm-agent:1.52.2" + runtimeOnly "co.elastic.apm:elastic-apm-agent:1.55.0" javaRestTestImplementation project(':modules:apm') javaRestTestImplementation project(':test:framework') diff --git a/modules/kibana/src/main/java/org/elasticsearch/kibana/KibanaPlugin.java b/modules/kibana/src/main/java/org/elasticsearch/kibana/KibanaPlugin.java index e5877d1f2eb2a..550ac2482a0b2 100644 --- a/modules/kibana/src/main/java/org/elasticsearch/kibana/KibanaPlugin.java +++ b/modules/kibana/src/main/java/org/elasticsearch/kibana/KibanaPlugin.java @@ -38,6 +38,13 @@ public class KibanaPlugin extends Plugin implements SystemIndexPlugin { .setAllowedElasticProductOrigins(KIBANA_PRODUCT_ORIGIN) .build(); + public static final SystemIndexDescriptor ONECHAT_INDEX_DESCRIPTOR = SystemIndexDescriptor.builder() + .setIndexPattern(".chat-*") + .setDescription("Onechat system index") + .setType(Type.EXTERNAL_UNMANAGED) + .setAllowedElasticProductOrigins(KIBANA_PRODUCT_ORIGIN) + .build(); + public static final SystemIndexDescriptor APM_AGENT_CONFIG_INDEX_DESCRIPTOR = SystemIndexDescriptor.builder() .setIndexPattern(".apm-agent-configuration*") .setDescription("system index for APM agent configuration") @@ -57,6 +64,7 @@ public Collection getSystemIndexDescriptors(Settings sett return List.of( KIBANA_INDEX_DESCRIPTOR, REPORTING_INDEX_DESCRIPTOR, + ONECHAT_INDEX_DESCRIPTOR, APM_AGENT_CONFIG_INDEX_DESCRIPTOR, APM_CUSTOM_LINK_INDEX_DESCRIPTOR ); diff --git a/modules/kibana/src/test/java/org/elasticsearch/kibana/KibanaPluginTests.java b/modules/kibana/src/test/java/org/elasticsearch/kibana/KibanaPluginTests.java index aa883c83eecf6..73709b2e48704 100644 --- a/modules/kibana/src/test/java/org/elasticsearch/kibana/KibanaPluginTests.java +++ b/modules/kibana/src/test/java/org/elasticsearch/kibana/KibanaPluginTests.java @@ -20,7 +20,7 @@ public class KibanaPluginTests extends ESTestCase { public void testKibanaIndexNames() { assertThat( new KibanaPlugin().getSystemIndexDescriptors(Settings.EMPTY).stream().map(SystemIndexDescriptor::getIndexPattern).toList(), - contains(".kibana_*", ".reporting-*", ".apm-agent-configuration*", ".apm-custom-link*") + contains(".kibana_*", ".reporting-*", ".chat-*", ".apm-agent-configuration*", ".apm-custom-link*") ); } } diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java index 910789b5e6d83..f7b910bfb2a32 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java @@ -830,7 +830,13 @@ void run(BytesReference expected, BytesReference updated, ActionListenernewForked(l -> ensureOtherUploadsComplete(uploadId, uploadIndex, currentUploads, l)) - // Step 4: Read the current register value. + // Step 4: Read the current register value. Note that getRegister only has read-after-write semantics but that's ok here as: + // - all earlier uploads are now complete, + // - our upload is not completing yet, and + // - later uploads can only be completing if they have already aborted ours. + // Thus if our operation ultimately succeeds then there cannot have been any concurrent writes in flight, so this read + // cannot have observed a stale value, whereas if our operation ultimately fails then it doesn't matter what this read + // observes. .andThen(l -> getRegister(purpose, rawKey, l)) diff --git a/muted-tests.yml b/muted-tests.yml index 6f920d95d4189..aacd33f73abbb 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -382,12 +382,6 @@ tests: - class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests method: testHKmeans issue: https://github.com/elastic/elasticsearch/issues/130497 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testProjectWhere - issue: https://github.com/elastic/elasticsearch/issues/130504 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testTopNPushedToLucene - issue: https://github.com/elastic/elasticsearch/issues/130505 - class: org.elasticsearch.gradle.LoggedExecFuncTest method: failed tasks output logged to console when spooling true issue: https://github.com/elastic/elasticsearch/issues/119509 @@ -406,12 +400,6 @@ tests: - class: org.elasticsearch.search.SearchWithRejectionsIT method: testOpenContextsAfterRejections issue: https://github.com/elastic/elasticsearch/issues/130821 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testExtractFields - issue: https://github.com/elastic/elasticsearch/issues/130501 -- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT - method: testRowStatsProjectGroupByInt - issue: https://github.com/elastic/elasticsearch/issues/131024 - class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT method: test {lookup-join.MvJoinKeyOnFromAfterStats ASYNC} issue: https://github.com/elastic/elasticsearch/issues/131148 @@ -469,6 +457,42 @@ tests: - class: org.elasticsearch.packaging.test.DockerTests method: test072RunEsAsDifferentUserAndGroup issue: https://github.com/elastic/elasticsearch/issues/131412 +- class: org.elasticsearch.xpack.esql.heap_attack.HeapAttackIT + method: testLookupExplosionNoFetch + issue: https://github.com/elastic/elasticsearch/issues/128720 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=vector-tile/20_aggregations/stats agg} + issue: https://github.com/elastic/elasticsearch/issues/131484 +- class: org.elasticsearch.packaging.test.DockerTests + method: test050BasicApiTests + issue: https://github.com/elastic/elasticsearch/issues/120911 +- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT + method: testFromEvalStats + issue: https://github.com/elastic/elasticsearch/issues/131503 +- class: org.elasticsearch.xpack.downsample.DownsampleWithBasicRestIT + method: test {p0=downsample-with-security/10_basic/Downsample index} + issue: https://github.com/elastic/elasticsearch/issues/131513 +- class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT + method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse + issue: https://github.com/elastic/elasticsearch/issues/131248 +- class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT + method: test + issue: https://github.com/elastic/elasticsearch/issues/131508 +- class: org.elasticsearch.action.admin.cluster.node.tasks.CancellableTasksIT + method: testRemoveBanParentsOnDisconnect + issue: https://github.com/elastic/elasticsearch/issues/131562 +- class: org.elasticsearch.xpack.esql.action.CrossClusterQueryWithPartialResultsIT + method: testPartialResults + issue: https://github.com/elastic/elasticsearch/issues/131481 +- class: org.elasticsearch.packaging.test.DockerTests + method: test010Install + issue: https://github.com/elastic/elasticsearch/issues/131376 +- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT + method: test {p0=search/40_indices_boost/Indices boost with alias} + issue: https://github.com/elastic/elasticsearch/issues/131598 +- class: org.elasticsearch.compute.lucene.read.SortedSetOrdinalsBuilderTests + method: testReader + issue: https://github.com/elastic/elasticsearch/issues/131573 # Examples: # diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json b/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json index b1174b89df0bd..fe66541fa9b0b 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/indices.recovery.json @@ -41,6 +41,28 @@ "type":"boolean", "description":"Display only those recoveries that are currently on-going", "default":false + }, + "ignore_unavailable":{ + "type":"boolean", + "description":"Whether specified concrete indices should be ignored when unavailable (missing or closed)", + "default":false + }, + "allow_no_indices":{ + "type":"boolean", + "description":"Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified)", + "default":true + }, + "expand_wildcards":{ + "type":"enum", + "options":[ + "open", + "closed", + "hidden", + "none", + "all" + ], + "default":"open", + "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both." } } } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json b/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json index 359d1e67b07e5..0d8223a71c79d 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/msearch.json @@ -69,6 +69,38 @@ "type":"boolean", "description":"Indicates whether network round-trips should be minimized as part of cross-cluster search requests execution", "default":"true" + }, + "index":{ + "type":"list", + "description":"A comma-separated list of index names to use as default" + }, + "ignore_unavailable":{ + "type":"boolean", + "description":"Whether specified concrete indices should be ignored when unavailable (missing or closed)" + }, + "ignore_throttled":{ + "type":"boolean", + "description":"Whether specified concrete, expanded or aliased indices should be ignored when throttled", + "deprecated":true + }, + "allow_no_indices":{ + "type":"boolean", + "description":"Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified)" + }, + "expand_wildcards":{ + "type":"enum", + "options": ["open", "closed", "hidden", "none", "all"], + "default":"open", + "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both." + }, + "routing":{ + "type":"list", + "description":"A comma-separated list of specific routing values" + }, + "include_named_queries_score":{ + "type":"boolean", + "description":"Indicates whether hit.matched_queries should be rendered as a map that includes the name of the matched query associated with its score (true) or as an array containing the name of the matched queries (false)", + "default": false } }, "body":{ diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml index 0835fe7de6af1..1f5e9d80b7702 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.allocation_explain/10_basic.yml @@ -102,19 +102,6 @@ cluster.allocation_explain: body: { "index": "test", "primary": true } ---- -"cluster shard allocation explanation test with numerical index parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": 0, "shard": 0, "primary": true } - --- "cluster shard allocation explanation test with incorrect index parameter in the body": - do: @@ -142,102 +129,6 @@ cluster.allocation_explain: body: { "index": "test", "shard": 2147483647, "primary": true } ---- -"cluster shard allocation explanation test with long shard value": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 214748364777, "primary": true } - ---- -"cluster shard allocation explanation test with float shard value": - - do: - indices.create: - index: test - body: { "settings": { "index.number_of_shards": 2, "index.number_of_replicas": 0 } } - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": 1.0, "primary": true } - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 1 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with double shard value": - - do: - indices.create: - index: test - body: { "settings": { "index.number_of_shards": 2, "index.number_of_replicas": 0 } } - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": 1.1234567891234567, "primary": true } - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 1 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with an invalid, string shard parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": "wrong", "primary": true } - ---- -"cluster shard allocation explanation test with a valid, string shard parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": "0", "primary": true } - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - - - --- "cluster shard allocation explanation test with three valid body parameters": - do: @@ -260,210 +151,6 @@ - is_true: can_rebalance_to_other_node - is_true: rebalance_explanation ---- -"cluster shard allocation explanation test with numerical primary parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": 0 } - ---- -"cluster shard allocation explanation test with invalid primary parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": truee } - ---- -"cluster shard allocation explanation test with a valid, string primary parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": "true" } - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with an invalid, string primary parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": "truee" } - ---- -"cluster shard allocation explanation test with numerical current node parameter in the body": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /x_content_parse_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true, "current_node": 1 } - ---- -"cluster shard allocation explanation test with invalid include_disk_info parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_disk_info: truee - ---- -"cluster shard allocation explanation test with numerical include_disk_info parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_disk_info: 0 - ---- -"cluster shard allocation explanation test with a valid, string include_disk_info parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_disk_info: "true" - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with an invalid, string include_disk_info parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_disk_info: "truee" - ---- -"cluster shard allocation explanation test with invalid include_yes_decisions parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_yes_decisions: truee - ---- -"cluster shard allocation explanation test with numerical include_yes_decisions parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_yes_decisions: 0 - ---- -"cluster shard allocation explanation test with a valid, string include_yes_decisions parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_yes_decisions: "true" - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with an invalid, string include_yes_decisions parameter": - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - body: { "index": "test", "shard": 0, "primary": true } - include_yes_decisions: "truee" - --- "cluster shard allocation explanation test with 3 body parameters and all query parameters": - do: @@ -803,38 +490,6 @@ index: "test" body: { "shard": 0, "primary": true } ---- -"cluster shard allocation explanation test with numerical index parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: "0" - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - index: 0 - shard: 0 - primary: true - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "0" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - --- "cluster shard allocation explanation test with incorrect index parameter passed in URL": - requires: @@ -857,206 +512,3 @@ index: "test2" shard: 0 primary: true - ---- -"cluster shard allocation explanation test with an invalid, string shard parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: "wrong" - primary: true - ---- -"cluster shard allocation explanation test with a valid, string shard parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - index: "test" - shard: "0" - primary: true - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with float shard parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: 1.0 - primary: true - ---- -"cluster shard allocation explanation test with numerical primary parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: 0 - primary: 0 - ---- -"cluster shard allocation explanation test with invalid primary parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: 0 - primary: truee - ---- -"cluster shard allocation explanation test with a valid, string primary parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - cluster.allocation_explain: - index: "test" - shard: 0 - primary: "true" - - - match: { current_state: "started" } - - is_true: current_node.id - - match: { index: "test" } - - match: { shard: 0 } - - match: { primary: true } - - is_true: can_remain_on_current_node - - is_true: can_rebalance_cluster - - is_true: can_rebalance_to_other_node - - is_true: rebalance_explanation - ---- -"cluster shard allocation explanation test with an invalid, string primary parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: 0 - primary: "truee" - ---- -"cluster shard allocation explanation test with numerical current node parameter passed in URL": - - requires: - capabilities: - - method: GET - path: /_cluster/allocation/explain - capabilities: [ query_parameter_support ] - test_runner_features: [ capabilities ] - reason: "Query parameter support was added in version 9.2.0" - - - do: - indices.create: - index: test - - - match: { acknowledged: true } - - - do: - catch: /illegal_argument_exception/ - cluster.allocation_explain: - index: "test" - shard: 0 - primary: true - current_node: 1 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml index 8ad06910ebe4d..0724f3831aeab 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.recovery/10_basic.yml @@ -17,6 +17,9 @@ indices.recovery: index: [test_1] human: true + ignore_unavailable: false + allow_no_indices: true + expand_wildcards: open - match: { test_1.shards.0.type: "EMPTY_STORE" } - match: { test_1.shards.0.stage: "DONE" } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml index 1052508ca2b88..8ac4ee60f2bbc 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/10_basic.yml @@ -1,5 +1,7 @@ --- setup: + - requires: + test_runner_features: allowed_warnings - do: index: @@ -67,6 +69,12 @@ setup: rest_total_hits_as_int: true max_concurrent_shard_requests: 1 max_concurrent_searches: 1 + ignore_unavailable: false + ignore_throttled: false + allow_no_indices: false + expand_wildcards: open + include_named_queries_score: false + index: index_* body: - index: index_* - query: @@ -83,6 +91,8 @@ setup: - {} - query: match_all: {} + allowed_warnings: + - "[ignore_throttled] parameter is deprecated because frozen indices have been deprecated. Consider cold or frozen tiers in place of frozen indices." - match: { responses.0.hits.total: 2 } - match: { responses.1.hits.total: 1 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml new file mode 100644 index 0000000000000..5b69a4da98418 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/msearch/40_routing.yml @@ -0,0 +1,25 @@ +--- +setup: + - do: + index: + index: index_1 + routing: "1" + id: "1" + body: { foo: bar } + + - do: + indices.refresh: {} + +--- +"Routing": + + - do: + msearch: + rest_total_hits_as_int: true + routing: "1" + body: + - {} + - query: + match_all: {} + + - match: { responses.0.hits.total: 1 } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java index 5f86111d352a9..5d5f2082fb71f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/SpecificMasterNodesIT.java @@ -60,13 +60,15 @@ public void testElectOnlyBetweenMasterNodes() throws Exception { logger.info("--> start master node (1)"); final String masterNodeName = internalCluster().startMasterOnlyNode(); - awaitMasterNode(internalCluster().getNonMasterNodeName(), masterNodeName); - awaitMasterNode(internalCluster().getMasterName(), masterNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, masterNodeName); + } logger.info("--> start master node (2)"); final String nextMasterEligableNodeName = internalCluster().startMasterOnlyNode(); - awaitMasterNode(internalCluster().getNonMasterNodeName(), masterNodeName); - awaitMasterNode(internalCluster().getMasterName(), masterNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, masterNodeName); + } logger.info("--> closing master node (1)"); client().execute( @@ -74,12 +76,14 @@ public void testElectOnlyBetweenMasterNodes() throws Exception { new AddVotingConfigExclusionsRequest(TEST_REQUEST_TIMEOUT, masterNodeName) ).get(); // removing the master from the voting configuration immediately triggers the master to step down - awaitMasterNode(internalCluster().getNonMasterNodeName(), nextMasterEligableNodeName); - awaitMasterNode(internalCluster().getMasterName(), nextMasterEligableNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, nextMasterEligableNodeName); + } internalCluster().stopNode(masterNodeName); - awaitMasterNode(internalCluster().getNonMasterNodeName(), nextMasterEligableNodeName); - awaitMasterNode(internalCluster().getMasterName(), nextMasterEligableNodeName); + for (var nodeName : internalCluster().getNodeNames()) { + awaitMasterNode(nodeName, nextMasterEligableNodeName); + } } public void testAliasFilterValidation() { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java index 25ae21964ba0e..5d389ad5ef11a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java @@ -27,6 +27,7 @@ import org.elasticsearch.cluster.NodeUsageStatsForThreadPools; import org.elasticsearch.cluster.NodeUsageStatsForThreadPoolsCollector; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.routing.RecoverySource; @@ -104,6 +105,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomAsciiLettersOfLength; @@ -355,6 +357,62 @@ public void testNodeWriteLoadsArePresent() { } } + public void testShardWriteLoadsArePresent() { + // Create some indices and some write-load + final int numIndices = randomIntBetween(1, 5); + final String indexPrefix = randomIdentifier(); + IntStream.range(0, numIndices).forEach(i -> { + final String indexName = indexPrefix + "_" + i; + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 3)).build()); + IntStream.range(0, randomIntBetween(1, 500)) + .forEach(j -> prepareIndex(indexName).setSource("foo", randomIdentifier(), "bar", randomIdentifier()).get()); + }); + + final InternalClusterInfoService clusterInfoService = (InternalClusterInfoService) getInstanceFromNode(ClusterInfoService.class); + + // Not collecting stats yet because allocation write load stats collection is disabled by default. + { + ClusterInfoServiceUtils.refresh(clusterInfoService); + final Map shardWriteLoads = clusterInfoService.getClusterInfo().getShardWriteLoads(); + assertNotNull(shardWriteLoads); + assertTrue(shardWriteLoads.isEmpty()); + } + + // Turn on collection of write-load stats. + updateClusterSettings( + Settings.builder() + .put( + WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey(), + WriteLoadConstraintSettings.WriteLoadDeciderStatus.ENABLED + ) + .build() + ); + + try { + // Force a ClusterInfo refresh to run collection of the write-load stats. + ClusterInfoServiceUtils.refresh(clusterInfoService); + final Map shardWriteLoads = clusterInfoService.getClusterInfo().getShardWriteLoads(); + + // Verify that each shard has write-load reported. + final ClusterState state = getInstanceFromNode(ClusterService.class).state(); + assertEquals(state.projectState(ProjectId.DEFAULT).metadata().getTotalNumberOfShards(), shardWriteLoads.size()); + double maximumLoadRecorded = 0; + for (IndexMetadata indexMetadata : state.projectState(ProjectId.DEFAULT).metadata()) { + for (int i = 0; i < indexMetadata.getNumberOfShards(); i++) { + final ShardId shardId = new ShardId(indexMetadata.getIndex(), i); + assertTrue(shardWriteLoads.containsKey(shardId)); + maximumLoadRecorded = Math.max(shardWriteLoads.get(shardId), maximumLoadRecorded); + } + } + // And that at least one is greater than zero + assertThat(maximumLoadRecorded, greaterThan(0.0)); + } finally { + updateClusterSettings( + Settings.builder().putNull(WriteLoadConstraintSettings.WRITE_LOAD_DECIDER_ENABLED_SETTING.getKey()).build() + ); + } + } + public void testIndexCanChangeCustomDataPath() throws Exception { final String index = "test-custom-data-path"; final Path sharedDataPath = getInstanceFromNode(Environment.class).sharedDataDir().resolve(randomAsciiLettersOfLength(10)); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index 02e17e3395760..b14f067992ba0 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalSettingsPlugin; import org.elasticsearch.test.MockLog; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -41,6 +42,7 @@ import static org.hamcrest.Matchers.is; @LuceneTestCase.SuppressCodecs("*") // only use our own codecs +@ESTestCase.WithoutEntitlements // requires entitlement delegation ES-10920 public class DirectIOIT extends ESIntegTestCase { @BeforeClass diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java index 2cd610f204d9e..97da362eebe82 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.sort; +import org.apache.http.util.EntityUtils; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.UnicodeUtil; @@ -20,6 +21,9 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.RestClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Strings; @@ -36,6 +40,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.InternalSettingsPlugin; +import org.elasticsearch.test.junit.annotations.TestIssueLogging; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -84,6 +89,12 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; +@TestIssueLogging( + issueUrl = "https://github.com/elastic/elasticsearch/issues/129445", + value = "org.elasticsearch.action.search.SearchQueryThenFetchAsyncAction:DEBUG," + + "org.elasticsearch.action.search.SearchPhaseController:DEBUG," + + "org.elasticsearch.search:TRACE" +) public class FieldSortIT extends ESIntegTestCase { public static class CustomScriptPlugin extends MockScriptPlugin { @Override @@ -112,6 +123,10 @@ protected Collection> nodePlugins() { return Arrays.asList(InternalSettingsPlugin.class, CustomScriptPlugin.class); } + protected boolean addMockHttpTransport() { + return false; + } + public void testIssue8226() { int numIndices = between(5, 10); final boolean useMapping = randomBoolean(); @@ -2145,7 +2160,7 @@ public void testLongSortOptimizationCorrectResults() { ); } - public void testSortMixedFieldTypes() { + public void testSortMixedFieldTypes() throws IOException { assertAcked( prepareCreate("index_long").setMapping("foo", "type=long"), prepareCreate("index_integer").setMapping("foo", "type=integer"), @@ -2159,6 +2174,16 @@ public void testSortMixedFieldTypes() { prepareIndex("index_keyword").setId("1").setSource("foo", "123").get(); refresh(); + // for debugging, we try to see where the documents are located + try (RestClient restClient = createRestClient()) { + Request checkShardsRequest = new Request( + "GET", + "/_cat/shards/index_long,index_double,index_keyword?format=json&h=index,node,shard,prirep,state,docs,index" + ); + Response response = restClient.performRequest(checkShardsRequest); + logger.info("FieldSortIT#testSortMixedFieldTypes document distribution: " + EntityUtils.toString(response.getEntity())); + } + { // mixing long and integer types is ok, as we convert integer sort to long sort assertNoFailures(prepareSearch("index_long", "index_integer").addSort(new FieldSortBuilder("foo")).setSize(10)); } diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index a6f6f622070f1..90cd3c669a52c 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -484,4 +484,5 @@ exports org.elasticsearch.index.codec.perfield; exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; + exports org.elasticsearch.inference.telemetry; } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 11a0103cd22e0..57da7f348bb47 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -330,6 +330,7 @@ static TransportVersion def(int id) { public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00); public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES_9_1 = def(9_112_0_01); public static final TransportVersion ESQL_FIXED_INDEX_LIKE_9_1 = def(9_112_0_02); + public static final TransportVersion ESQL_SAMPLE_OPERATOR_STATUS_9_1 = def(9_112_0_03); // Below is the first version in 9.2 and NOT in 9.1. public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00); public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00); @@ -342,6 +343,10 @@ static TransportVersion def(int id) { public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00); public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00); public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00); + public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00); + public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00); + public static final TransportVersion SHARD_WRITE_LOAD_IN_CLUSTER_INFO = def(9_126_0_00); + public static final TransportVersion ESQL_SAMPLE_OPERATOR_STATUS = def(9_127_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 7339135fe93dc..2df4c05722908 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -9,6 +9,8 @@ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; @@ -72,6 +74,8 @@ public final class SearchPhaseController { + private static final Logger logger = LogManager.getLogger(SearchPhaseController.class); + private final BiFunction< Supplier, AggregatorFactories.Builder, @@ -153,17 +157,22 @@ static TopDocs mergeTopDocs(List results, int topN, int from) { return topDocs; } final TopDocs mergedTopDocs; - if (topDocs instanceof TopFieldGroups firstTopDocs) { - final Sort sort = validateSameSortTypesAndMaybeRewrite(results, firstTopDocs.fields); - TopFieldGroups[] shardTopDocs = topDocsList.toArray(new TopFieldGroups[0]); - mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false); - } else if (topDocs instanceof TopFieldDocs firstTopDocs) { - TopFieldDocs[] shardTopDocs = topDocsList.toArray(new TopFieldDocs[0]); - final Sort sort = validateSameSortTypesAndMaybeRewrite(results, firstTopDocs.fields); - mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs); - } else { - final TopDocs[] shardTopDocs = topDocsList.toArray(new TopDocs[0]); - mergedTopDocs = TopDocs.merge(from, topN, shardTopDocs); + try { + if (topDocs instanceof TopFieldGroups firstTopDocs) { + final Sort sort = validateSameSortTypesAndMaybeRewrite(results, firstTopDocs.fields); + TopFieldGroups[] shardTopDocs = topDocsList.toArray(new TopFieldGroups[0]); + mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false); + } else if (topDocs instanceof TopFieldDocs firstTopDocs) { + TopFieldDocs[] shardTopDocs = topDocsList.toArray(new TopFieldDocs[0]); + final Sort sort = validateSameSortTypesAndMaybeRewrite(results, firstTopDocs.fields); + mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs); + } else { + final TopDocs[] shardTopDocs = topDocsList.toArray(new TopDocs[0]); + mergedTopDocs = TopDocs.merge(from, topN, shardTopDocs); + } + } catch (IllegalArgumentException e) { + logger.debug("Failed to merge top docs: ", e); + throw e; } return mergedTopDocs; } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index f0be39208c902..8d763698c63c0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -457,6 +457,7 @@ protected void doRun(Map shardIndexMap) { executeAsSingleRequest(routing, request.shards.getFirst()); return; } + String nodeId = routing.nodeId(); final Transport.Connection connection; try { connection = getConnection(routing.clusterAlias(), routing.nodeId()); @@ -508,6 +509,7 @@ public void handleResponse(NodeQueryResponse response) { @Override public void handleException(TransportException e) { Exception cause = (Exception) ExceptionsHelper.unwrapCause(e); + logger.debug("handling node search exception coming from [" + nodeId + "]", cause); if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException) { // two possible special cases here where we do not want to fail the phase: // failure to send out the request -> handle things the same way a shard would fail with unbatched execution diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java index 6d11700500c24..33172e30fb107 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java @@ -41,7 +41,7 @@ /** * ClusterInfo is an object representing a map of nodes to {@link DiskUsage} - * and a map of shard ids to shard sizes, see + * and a map of shard ids to shard sizes and shard write-loads, see * InternalClusterInfoService.shardIdentifierFromRouting(String) * for the key used in the shardSizes map */ @@ -59,9 +59,10 @@ public class ClusterInfo implements ChunkedToXContent, Writeable { final Map reservedSpace; final Map estimatedHeapUsages; final Map nodeUsageStatsForThreadPools; + final Map shardWriteLoads; protected ClusterInfo() { - this(Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); + this(Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); } /** @@ -85,7 +86,8 @@ public ClusterInfo( Map dataPath, Map reservedSpace, Map estimatedHeapUsages, - Map nodeUsageStatsForThreadPools + Map nodeUsageStatsForThreadPools, + Map shardWriteLoads ) { this.leastAvailableSpaceUsage = Map.copyOf(leastAvailableSpaceUsage); this.mostAvailableSpaceUsage = Map.copyOf(mostAvailableSpaceUsage); @@ -95,6 +97,7 @@ public ClusterInfo( this.reservedSpace = Map.copyOf(reservedSpace); this.estimatedHeapUsages = Map.copyOf(estimatedHeapUsages); this.nodeUsageStatsForThreadPools = Map.copyOf(nodeUsageStatsForThreadPools); + this.shardWriteLoads = Map.copyOf(shardWriteLoads); } public ClusterInfo(StreamInput in) throws IOException { @@ -116,6 +119,11 @@ public ClusterInfo(StreamInput in) throws IOException { } else { this.nodeUsageStatsForThreadPools = Map.of(); } + if (in.getTransportVersion().onOrAfter(TransportVersions.SHARD_WRITE_LOAD_IN_CLUSTER_INFO)) { + this.shardWriteLoads = in.readImmutableMap(ShardId::new, StreamInput::readDouble); + } else { + this.shardWriteLoads = Map.of(); + } } @Override @@ -136,6 +144,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO)) { out.writeMap(this.nodeUsageStatsForThreadPools, StreamOutput::writeWriteable); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SHARD_WRITE_LOAD_IN_CLUSTER_INFO)) { + out.writeMap(this.shardWriteLoads, StreamOutput::writeWriteable, StreamOutput::writeDouble); + } } /** @@ -216,7 +227,7 @@ public Iterator toXContentChunked(ToXContent.Params params return builder.endObject(); // NodeAndPath }), endArray() // end "reserved_sizes" - // NOTE: We don't serialize estimatedHeapUsages/nodeUsageStatsForThreadPools at this stage, to avoid + // NOTE: We don't serialize estimatedHeapUsages/nodeUsageStatsForThreadPools/shardWriteLoads at this stage, to avoid // committing to API payloads until the features are settled ); } @@ -255,6 +266,16 @@ public Map getNodeMostAvailableDiskUsages() { return this.mostAvailableSpaceUsage; } + /** + * Returns a map of shard IDs to the write-loads for use in balancing. The write-loads can be interpreted + * as the average number of threads that ingestion to the shard will consume. + * This information may be partial or missing altogether under some circumstances. The absence of a shard + * write load from the map should be interpreted as "unknown". + */ + public Map getShardWriteLoads() { + return shardWriteLoads; + } + /** * Returns the shard size for the given shardId or null if that metric is not available. */ @@ -331,7 +352,9 @@ public boolean equals(Object o) { && shardDataSetSizes.equals(that.shardDataSetSizes) && dataPath.equals(that.dataPath) && reservedSpace.equals(that.reservedSpace) - && nodeUsageStatsForThreadPools.equals(that.nodeUsageStatsForThreadPools); + && estimatedHeapUsages.equals(that.estimatedHeapUsages) + && nodeUsageStatsForThreadPools.equals(that.nodeUsageStatsForThreadPools) + && shardWriteLoads.equals(that.shardWriteLoads); } @Override @@ -343,7 +366,9 @@ public int hashCode() { shardDataSetSizes, dataPath, reservedSpace, - nodeUsageStatsForThreadPools + estimatedHeapUsages, + nodeUsageStatsForThreadPools, + shardWriteLoads ); } @@ -466,6 +491,7 @@ public static class Builder { private Map reservedSpace = Map.of(); private Map estimatedHeapUsages = Map.of(); private Map nodeUsageStatsForThreadPools = Map.of(); + private Map shardWriteLoads = Map.of(); public ClusterInfo build() { return new ClusterInfo( @@ -476,7 +502,8 @@ public ClusterInfo build() { dataPath, reservedSpace, estimatedHeapUsages, - nodeUsageStatsForThreadPools + nodeUsageStatsForThreadPools, + shardWriteLoads ); } @@ -519,5 +546,10 @@ public Builder nodeUsageStatsForThreadPools(Map shardWriteLoads) { + this.shardWriteLoads = shardWriteLoads; + return this; + } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java b/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java index 7e995404191d6..fd9c62daebd29 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterInfoSimulator.java @@ -159,7 +159,8 @@ public ClusterInfo getClusterInfo() { dataPath, Map.of(), estimatedHeapUsages, - nodeThreadPoolUsageStats + nodeThreadPoolUsageStats, + allocation.clusterInfo().getShardWriteLoads() ); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java index 89394c8fa8ba8..d4ecec83ebc8c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java +++ b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java @@ -38,6 +38,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.IndexingStats; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.store.StoreStats; import org.elasticsearch.threadpool.ThreadPool; @@ -215,7 +216,7 @@ void execute() { logger.trace("starting async refresh"); try (var ignoredRefs = fetchRefs) { - maybeFetchIndicesStats(diskThresholdEnabled); + maybeFetchIndicesStats(diskThresholdEnabled || writeLoadConstraintEnabled == WriteLoadDeciderStatus.ENABLED); maybeFetchNodeStats(diskThresholdEnabled || estimatedHeapThresholdEnabled); maybeFetchNodesEstimatedHeapUsage(estimatedHeapThresholdEnabled); maybeFetchNodesUsageStatsForThreadPools(writeLoadConstraintEnabled); @@ -301,7 +302,14 @@ public void onFailure(Exception e) { private void fetchIndicesStats() { final IndicesStatsRequest indicesStatsRequest = new IndicesStatsRequest(); indicesStatsRequest.clear(); - indicesStatsRequest.store(true); + if (diskThresholdEnabled) { + // This returns the shard sizes on disk + indicesStatsRequest.store(true); + } + if (writeLoadConstraintEnabled == WriteLoadDeciderStatus.ENABLED) { + // This returns the shard write-loads + indicesStatsRequest.indexing(true); + } indicesStatsRequest.indicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_CLOSED_HIDDEN); indicesStatsRequest.timeout(fetchTimeout); client.admin() @@ -350,6 +358,7 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { } final ShardStats[] stats = indicesStatsResponse.getShards(); + final Map shardWriteLoadByIdentifierBuilder = new HashMap<>(); final Map shardSizeByIdentifierBuilder = new HashMap<>(); final Map shardDataSetSizeBuilder = new HashMap<>(); final Map dataPath = new HashMap<>(); @@ -357,6 +366,7 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { new HashMap<>(); buildShardLevelInfo( adjustShardStats(stats), + shardWriteLoadByIdentifierBuilder, shardSizeByIdentifierBuilder, shardDataSetSizeBuilder, dataPath, @@ -370,7 +380,8 @@ public void onResponse(IndicesStatsResponse indicesStatsResponse) { Map.copyOf(shardSizeByIdentifierBuilder), Map.copyOf(shardDataSetSizeBuilder), Map.copyOf(dataPath), - Map.copyOf(reservedSpace) + Map.copyOf(reservedSpace), + Map.copyOf(shardWriteLoadByIdentifierBuilder) ); } @@ -527,8 +538,6 @@ public ClusterInfo getClusterInfo() { estimatedHeapUsages.put(nodeId, new EstimatedHeapUsage(nodeId, maxHeapSize.getBytes(), estimatedHeapUsage)); } }); - final Map nodeThreadPoolUsageStats = new HashMap<>(); - nodeThreadPoolUsageStatsPerNode.forEach((nodeId, nodeWriteLoad) -> { nodeThreadPoolUsageStats.put(nodeId, nodeWriteLoad); }); return new ClusterInfo( leastAvailableSpaceUsages, mostAvailableSpaceUsages, @@ -537,7 +546,8 @@ public ClusterInfo getClusterInfo() { indicesStatsSummary.dataPath, indicesStatsSummary.reservedSpace, estimatedHeapUsages, - nodeThreadPoolUsageStats + nodeThreadPoolUsageStatsPerNode, + indicesStatsSummary.shardWriteLoads() ); } @@ -567,6 +577,7 @@ public void addListener(Consumer clusterInfoConsumer) { static void buildShardLevelInfo( ShardStats[] stats, + Map shardWriteLoads, Map shardSizes, Map shardDataSetSizeBuilder, Map dataPathByShard, @@ -577,25 +588,31 @@ static void buildShardLevelInfo( dataPathByShard.put(ClusterInfo.NodeAndShard.from(shardRouting), s.getDataPath()); final StoreStats storeStats = s.getStats().getStore(); - if (storeStats == null) { - continue; - } - final long size = storeStats.sizeInBytes(); - final long dataSetSize = storeStats.totalDataSetSizeInBytes(); - final long reserved = storeStats.reservedSizeInBytes(); - - final String shardIdentifier = ClusterInfo.shardIdentifierFromRouting(shardRouting); - logger.trace("shard: {} size: {} reserved: {}", shardIdentifier, size, reserved); - shardSizes.put(shardIdentifier, size); - if (dataSetSize > shardDataSetSizeBuilder.getOrDefault(shardRouting.shardId(), -1L)) { - shardDataSetSizeBuilder.put(shardRouting.shardId(), dataSetSize); + if (storeStats != null) { + final long size = storeStats.sizeInBytes(); + final long dataSetSize = storeStats.totalDataSetSizeInBytes(); + final long reserved = storeStats.reservedSizeInBytes(); + + final String shardIdentifier = ClusterInfo.shardIdentifierFromRouting(shardRouting); + logger.trace("shard: {} size: {} reserved: {}", shardIdentifier, size, reserved); + shardSizes.put(shardIdentifier, size); + if (dataSetSize > shardDataSetSizeBuilder.getOrDefault(shardRouting.shardId(), -1L)) { + shardDataSetSizeBuilder.put(shardRouting.shardId(), dataSetSize); + } + if (reserved != StoreStats.UNKNOWN_RESERVED_BYTES) { + final ClusterInfo.ReservedSpace.Builder reservedSpaceBuilder = reservedSpaceByShard.computeIfAbsent( + new ClusterInfo.NodeAndPath(shardRouting.currentNodeId(), s.getDataPath()), + t -> new ClusterInfo.ReservedSpace.Builder() + ); + reservedSpaceBuilder.add(shardRouting.shardId(), reserved); + } } - if (reserved != StoreStats.UNKNOWN_RESERVED_BYTES) { - final ClusterInfo.ReservedSpace.Builder reservedSpaceBuilder = reservedSpaceByShard.computeIfAbsent( - new ClusterInfo.NodeAndPath(shardRouting.currentNodeId(), s.getDataPath()), - t -> new ClusterInfo.ReservedSpace.Builder() - ); - reservedSpaceBuilder.add(shardRouting.shardId(), reserved); + final IndexingStats indexingStats = s.getStats().getIndexing(); + if (indexingStats != null) { + final double shardWriteLoad = indexingStats.getTotal().getPeakWriteLoad(); + if (shardWriteLoad > shardWriteLoads.getOrDefault(shardRouting.shardId(), -1.0)) { + shardWriteLoads.put(shardRouting.shardId(), shardWriteLoad); + } } } } @@ -623,9 +640,10 @@ private record IndicesStatsSummary( Map shardSizes, Map shardDataSetSizes, Map dataPath, - Map reservedSpace + Map reservedSpace, + Map shardWriteLoads ) { - static final IndicesStatsSummary EMPTY = new IndicesStatsSummary(Map.of(), Map.of(), Map.of(), Map.of()); + static final IndicesStatsSummary EMPTY = new IndicesStatsSummary(Map.of(), Map.of(), Map.of(), Map.of(), Map.of()); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index 76319091f4911..d4bc58c299435 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -835,10 +835,6 @@ private Iterator toXContentChunkedWithSingleProjectFormat( ); } - private static final DiffableUtils.KeySerializer PROJECT_ID_SERIALIZER = DiffableUtils.getWriteableKeySerializer( - ProjectId.READER - ); - private static class MetadataDiff implements Diff { private final long version; @@ -880,7 +876,7 @@ private static class MetadataDiff implements Diff { multiProject = null; } else { singleProject = null; - multiProject = DiffableUtils.diff(before.projectMetadata, after.projectMetadata, PROJECT_ID_SERIALIZER); + multiProject = DiffableUtils.diff(before.projectMetadata, after.projectMetadata, ProjectId.PROJECT_ID_SERIALIZER); } if (empty) { @@ -1004,7 +1000,7 @@ private MetadataDiff(StreamInput in) throws IOException { singleProject = null; multiProject = DiffableUtils.readJdkMapDiff( in, - PROJECT_ID_SERIALIZER, + ProjectId.PROJECT_ID_SERIALIZER, ProjectMetadata::readFrom, ProjectMetadata.ProjectMetadataDiff::new ); @@ -1059,7 +1055,7 @@ public void writeTo(StreamOutput out) throws IOException { if (multiProject != null) { multiProject.writeTo(out); } else { - DiffableUtils.singleEntryDiff(DEFAULT_PROJECT_ID, singleProject, PROJECT_ID_SERIALIZER).writeTo(out); + DiffableUtils.singleEntryDiff(DEFAULT_PROJECT_ID, singleProject, ProjectId.PROJECT_ID_SERIALIZER).writeTo(out); } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java index 94fa0164b5fbe..88f314ea6cbfe 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ProjectId.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.metadata; +import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -26,6 +27,7 @@ public class ProjectId implements Writeable, ToXContent { private static final String DEFAULT_STRING = "default"; public static final ProjectId DEFAULT = new ProjectId(DEFAULT_STRING); public static final Reader READER = ProjectId::readFrom; + public static final DiffableUtils.KeySerializer PROJECT_ID_SERIALIZER = DiffableUtils.getWriteableKeySerializer(READER); private static final int MAX_LENGTH = 128; private final String id; diff --git a/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java b/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java index 2876ebc13c70c..014ee37724cbc 100644 --- a/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java +++ b/server/src/main/java/org/elasticsearch/cluster/project/ProjectStateRegistry.java @@ -13,15 +13,23 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.AbstractNamedDiffable; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterState.Custom; +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.NamedDiffable; +import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Collections; @@ -30,22 +38,29 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; /** * Represents a registry for managing and retrieving project-specific state in the cluster state. */ -public class ProjectStateRegistry extends AbstractNamedDiffable implements ClusterState.Custom { +public class ProjectStateRegistry extends AbstractNamedDiffable implements Custom, NamedDiffable { public static final String TYPE = "projects_registry"; public static final ProjectStateRegistry EMPTY = new ProjectStateRegistry(Collections.emptyMap(), Collections.emptySet(), 0); + private static final Entry EMPTY_ENTRY = new Entry(Settings.EMPTY); - private final Map projectsSettings; + private final Map projectsEntries; // Projects that have been marked for deletion based on their file-based setting private final Set projectsMarkedForDeletion; // A counter that is incremented each time one or more projects are marked for deletion. private final long projectsMarkedForDeletionGeneration; public ProjectStateRegistry(StreamInput in) throws IOException { - projectsSettings = in.readMap(ProjectId::readFrom, Settings::readSettingsFromStream); + if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + projectsEntries = in.readMap(ProjectId::readFrom, Entry::readFrom); + } else { + Map settingsMap = in.readMap(ProjectId::readFrom, Settings::readSettingsFromStream); + projectsEntries = settingsMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> new Entry(e.getValue()))); + } if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_RECORDS_DELETIONS)) { projectsMarkedForDeletion = in.readCollectionAsImmutableSet(ProjectId::readFrom); projectsMarkedForDeletionGeneration = in.readVLong(); @@ -56,11 +71,11 @@ public ProjectStateRegistry(StreamInput in) throws IOException { } private ProjectStateRegistry( - Map projectsSettings, + Map projectEntries, Set projectsMarkedForDeletion, long projectsMarkedForDeletionGeneration ) { - this.projectsSettings = projectsSettings; + this.projectsEntries = projectEntries; this.projectsMarkedForDeletion = projectsMarkedForDeletion; this.projectsMarkedForDeletionGeneration = projectsMarkedForDeletionGeneration; } @@ -75,7 +90,11 @@ private ProjectStateRegistry( */ public static Settings getProjectSettings(ProjectId projectId, ClusterState clusterState) { ProjectStateRegistry registry = clusterState.custom(TYPE, EMPTY); - return registry.projectsSettings.getOrDefault(projectId, Settings.EMPTY); + return registry.getProjectSettings(projectId); + } + + public Settings getProjectSettings(ProjectId projectId) { + return projectsEntries.getOrDefault(projectId, EMPTY_ENTRY).settings; } public boolean isProjectMarkedForDeletion(ProjectId projectId) { @@ -91,12 +110,10 @@ public Iterator toXContentChunked(ToXContent.Params params return Iterators.concat( Iterators.single((builder, p) -> builder.startArray("projects")), - Iterators.map(projectsSettings.entrySet().iterator(), entry -> (builder, p) -> { + Iterators.map(projectsEntries.entrySet().iterator(), entry -> (builder, p) -> { builder.startObject(); builder.field("id", entry.getKey()); - builder.startObject("settings"); - entry.getValue().toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("flat_settings", "true"))); - builder.endObject(); + entry.getValue().toXContent(builder, params); builder.field("marked_for_deletion", projectsMarkedForDeletion.contains(entry.getKey())); return builder.endObject(); }), @@ -105,8 +122,19 @@ public Iterator toXContentChunked(ToXContent.Params params ); } - public static NamedDiff readDiffFrom(StreamInput in) throws IOException { - return readDiffFrom(ClusterState.Custom.class, TYPE, in); + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + if (in.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + return new ProjectStateRegistryDiff(in); + } + return readDiffFrom(Custom.class, TYPE, in); + } + + @Override + public Diff diff(Custom previousState) { + if (this.equals(previousState)) { + return SimpleDiffable.empty(); + } + return new ProjectStateRegistryDiff((ProjectStateRegistry) previousState, this); } @Override @@ -121,7 +149,14 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap(projectsSettings); + if (out.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY)) { + out.writeMap(projectsEntries); + } else { + Map settingsMap = projectsEntries.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().settings())); + out.writeMap(settingsMap); + } if (out.getTransportVersion().onOrAfter(TransportVersions.PROJECT_STATE_REGISTRY_RECORDS_DELETIONS)) { out.writeCollection(projectsMarkedForDeletion); out.writeVLong(projectsMarkedForDeletionGeneration); @@ -133,7 +168,7 @@ public void writeTo(StreamOutput out) throws IOException { } public int size() { - return projectsSettings.size(); + return projectsEntries.size(); } public long getProjectsMarkedForDeletionGeneration() { @@ -141,15 +176,15 @@ public long getProjectsMarkedForDeletionGeneration() { } // visible for testing - Map getProjectsSettings() { - return Collections.unmodifiableMap(projectsSettings); + Set knownProjects() { + return projectsEntries.keySet(); } @Override public String toString() { return "ProjectStateRegistry[" - + "projectsSettings=" - + projectsSettings + + "entities=" + + projectsEntries + ", projectsMarkedForDeletion=" + projectsMarkedForDeletion + ", projectsMarkedForDeletionGeneration=" @@ -163,13 +198,13 @@ public boolean equals(Object o) { if (o instanceof ProjectStateRegistry == false) return false; ProjectStateRegistry that = (ProjectStateRegistry) o; return projectsMarkedForDeletionGeneration == that.projectsMarkedForDeletionGeneration - && Objects.equals(projectsSettings, that.projectsSettings) + && Objects.equals(projectsEntries, that.projectsEntries) && Objects.equals(projectsMarkedForDeletion, that.projectsMarkedForDeletion); } @Override public int hashCode() { - return Objects.hash(projectsSettings, projectsMarkedForDeletion, projectsMarkedForDeletionGeneration); + return Objects.hash(projectsEntries, projectsMarkedForDeletion, projectsMarkedForDeletionGeneration); } public static Builder builder(ClusterState original) { @@ -185,26 +220,86 @@ public static Builder builder() { return new Builder(); } + static class ProjectStateRegistryDiff implements NamedDiff { + private static final DiffableUtils.DiffableValueReader VALUE_READER = new DiffableUtils.DiffableValueReader<>( + Entry::readFrom, + Entry.EntryDiff::readFrom + ); + + private final DiffableUtils.MapDiff> projectsEntriesDiff; + private final Set projectsMarkedForDeletion; + private final long projectsMarkedForDeletionGeneration; + + ProjectStateRegistryDiff(StreamInput in) throws IOException { + projectsEntriesDiff = DiffableUtils.readJdkMapDiff(in, ProjectId.PROJECT_ID_SERIALIZER, VALUE_READER); + projectsMarkedForDeletion = in.readCollectionAsImmutableSet(ProjectId.READER); + projectsMarkedForDeletionGeneration = in.readVLong(); + } + + ProjectStateRegistryDiff(ProjectStateRegistry previousState, ProjectStateRegistry currentState) { + projectsEntriesDiff = DiffableUtils.diff( + previousState.projectsEntries, + currentState.projectsEntries, + ProjectId.PROJECT_ID_SERIALIZER, + VALUE_READER + ); + projectsMarkedForDeletion = currentState.projectsMarkedForDeletion; + projectsMarkedForDeletionGeneration = currentState.projectsMarkedForDeletionGeneration; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.PROJECT_STATE_REGISTRY_ENTRY; + } + + @Override + public Custom apply(Custom part) { + return new ProjectStateRegistry( + projectsEntriesDiff.apply(((ProjectStateRegistry) part).projectsEntries), + projectsMarkedForDeletion, + projectsMarkedForDeletionGeneration + ); + } + + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + projectsEntriesDiff.writeTo(out); + out.writeCollection(projectsMarkedForDeletion); + out.writeVLong(projectsMarkedForDeletionGeneration); + } + } + public static class Builder { - private final ImmutableOpenMap.Builder projectsSettings; + private final ImmutableOpenMap.Builder projectsEntries; private final Set projectsMarkedForDeletion; private final long projectsMarkedForDeletionGeneration; private boolean newProjectMarkedForDeletion = false; private Builder() { - this.projectsSettings = ImmutableOpenMap.builder(); + this.projectsEntries = ImmutableOpenMap.builder(); projectsMarkedForDeletion = new HashSet<>(); projectsMarkedForDeletionGeneration = 0; } private Builder(ProjectStateRegistry original) { - this.projectsSettings = ImmutableOpenMap.builder(original.projectsSettings); + this.projectsEntries = ImmutableOpenMap.builder(original.projectsEntries); this.projectsMarkedForDeletion = new HashSet<>(original.projectsMarkedForDeletion); this.projectsMarkedForDeletionGeneration = original.projectsMarkedForDeletionGeneration; } public Builder putProjectSettings(ProjectId projectId, Settings settings) { - projectsSettings.put(projectId, settings); + Entry entry = projectsEntries.get(projectId); + if (entry == null) { + entry = new Entry(settings); + } else { + entry = entry.withSettings(settings); + } + projectsEntries.put(projectId, entry); return this; } @@ -216,17 +311,63 @@ public Builder markProjectForDeletion(ProjectId projectId) { } public ProjectStateRegistry build() { - final var unknownButUnderDeletion = Sets.difference(projectsMarkedForDeletion, projectsSettings.keys()); + final var unknownButUnderDeletion = Sets.difference(projectsMarkedForDeletion, projectsEntries.keys()); if (unknownButUnderDeletion.isEmpty() == false) { throw new IllegalArgumentException( "Cannot mark projects for deletion that are not in the registry: " + unknownButUnderDeletion ); } return new ProjectStateRegistry( - projectsSettings.build(), + projectsEntries.build(), projectsMarkedForDeletion, newProjectMarkedForDeletion ? projectsMarkedForDeletionGeneration + 1 : projectsMarkedForDeletionGeneration ); } } + + private record Entry(Settings settings) implements Writeable, Diffable { + + public static Entry readFrom(StreamInput in) throws IOException { + return new Entry(Settings.readSettingsFromStream(in)); + } + + public Entry withSettings(Settings settings) { + return new Entry(settings); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(settings); + } + + public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject("settings"); + settings.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("flat_settings", "true"))); + builder.endObject(); + } + + @Override + public Diff diff(Entry previousState) { + if (this == previousState) { + return SimpleDiffable.empty(); + } + return new EntryDiff(settings.diff(previousState.settings)); + } + + private record EntryDiff(Diff settingsDiff) implements Diff { + public static EntryDiff readFrom(StreamInput in) throws IOException { + return new EntryDiff(Settings.readSettingsDiffFromStream(in)); + } + + @Override + public Entry apply(Entry part) { + return part.withSettings(settingsDiff.apply(part.settings)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(settingsDiff); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java index 3df0d2d65b657..050181802af8d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRouting.java @@ -15,6 +15,8 @@ import org.elasticsearch.action.RoutingMissingException; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexReshardingMetadata; +import org.elasticsearch.cluster.metadata.IndexReshardingState; import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; @@ -73,11 +75,13 @@ public static IndexRouting fromIndexMetadata(IndexMetadata metadata) { protected final String indexName; private final int routingNumShards; private final int routingFactor; + private final IndexReshardingMetadata indexReshardingMetadata; private IndexRouting(IndexMetadata metadata) { this.indexName = metadata.getIndex().getName(); this.routingNumShards = metadata.getRoutingNumShards(); this.routingFactor = metadata.getRoutingFactor(); + this.indexReshardingMetadata = metadata.getReshardingMetadata(); } /** @@ -149,6 +153,23 @@ private static int effectiveRoutingToHash(String effectiveRouting) { */ public void checkIndexSplitAllowed() {} + /** + * If this index is in the process of resharding, and the shard to which this request is being routed, + * is a target shard that is not yet in HANDOFF state, then route it to the source shard. + * @param shardId shardId to which the current document is routed based on hashing + * @return Updated shardId + */ + protected final int rerouteIfResharding(int shardId) { + if (indexReshardingMetadata != null && indexReshardingMetadata.getSplit().isTargetShard(shardId)) { + assert indexReshardingMetadata.isSplit() : "Index resharding state is not a split"; + if (indexReshardingMetadata.getSplit() + .targetStateAtLeast(shardId, IndexReshardingState.Split.TargetShardState.HANDOFF) == false) { + return indexReshardingMetadata.getSplit().sourceShard(shardId); + } + } + return shardId; + } + private abstract static class IdAndRoutingOnly extends IndexRouting { private final boolean routingRequired; private final IndexVersion creationVersion; @@ -195,19 +216,22 @@ public int indexShard(String id, @Nullable String routing, XContentType sourceTy throw new IllegalStateException("id is required and should have been set by process"); } checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override public int updateShard(String id, @Nullable String routing) { checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override public int deleteShard(String id, @Nullable String routing) { checkRoutingRequired(id, routing); - return shardId(id, routing); + int shardId = shardId(id, routing); + return rerouteIfResharding(shardId); } @Override @@ -314,7 +338,8 @@ public int indexShard(String id, @Nullable String routing, XContentType sourceTy assert Transports.assertNotTransportThread("parsing the _source can get slow"); checkNoRouting(routing); hash = hashSource(sourceType, source).buildHash(IndexRouting.ExtractFromSource::defaultOnEmpty); - return hashToShardId(hash); + int shardId = hashToShardId(hash); + return (rerouteIfResharding(shardId)); } public String createId(XContentType sourceType, BytesReference source, byte[] suffix) { @@ -454,13 +479,15 @@ public int updateShard(String id, @Nullable String routing) { @Override public int deleteShard(String id, @Nullable String routing) { checkNoRouting(routing); - return idToHash(id); + int shardId = idToHash(id); + return (rerouteIfResharding(shardId)); } @Override public int getShard(String id, @Nullable String routing) { checkNoRouting(routing); - return idToHash(id); + int shardId = idToHash(id); + return (rerouteIfResharding(shardId)); } private void checkNoRouting(@Nullable String routing) { diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java index c5e9089f93900..340123456435f 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java @@ -304,7 +304,10 @@ default void copyBlob( /** * Atomically sets the value stored at the given key to {@code updated} if the {@code current value == expected}. - * Keys not yet used start at initial value 0. Returns the current value (before it was updated). + * If a key has not yet been used as a register, its initial value is an empty {@link BytesReference}. + *

+ * This operation, together with {@link #compareAndSetRegister}, must have linearizable semantics: a collection of such operations must + * act as if they operate serially, with each operation taking place at some instant in between its invocation and its completion. * * @param purpose The purpose of the operation * @param key key of the value to update @@ -323,9 +326,12 @@ void compareAndExchangeRegister( /** * Atomically sets the value stored at the given key to {@code updated} if the {@code current value == expected}. - * Keys not yet used start at initial value 0. + * If a key has not yet been used as a register, its initial value is an empty {@link BytesReference}. + *

+ * This operation, together with {@link #compareAndExchangeRegister}, must have linearizable semantics: a collection of such operations + * must act as if they operate serially, with each operation taking place at some instant in between its invocation and its completion. * - * @param purpose + * @param purpose The purpose of the operation * @param key key of the value to update * @param expected the expected value * @param updated the new value @@ -350,7 +356,12 @@ default void compareAndSetRegister( /** * Gets the value set by {@link #compareAndSetRegister} or {@link #compareAndExchangeRegister} for a given key. - * If a key has not yet been used, the initial value is an empty {@link BytesReference}. + * If a key has not yet been used as a register, its initial value is an empty {@link BytesReference}. + *

+ * This operation has read-after-write consistency with respect to writes performed using {@link #compareAndExchangeRegister} and + * {@link #compareAndSetRegister}, but does not guarantee full linearizability. In particular, a {@code getRegister} performed during + * one of these write operations may return either the old or the new value, and a caller may therefore observe the old value + * after observing the new value, as long as both such read operations take place before the write operation completes. * * @param purpose The purpose of the operation * @param key key of the value to get diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/support/BlobContainerUtils.java b/server/src/main/java/org/elasticsearch/common/blobstore/support/BlobContainerUtils.java index 5019f41a01a4f..32e6852febf8c 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/support/BlobContainerUtils.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/support/BlobContainerUtils.java @@ -33,9 +33,9 @@ public static void ensureValidRegisterContent(BytesReference bytesReference) { } /** - * Many blob stores have consistent (linearizable/atomic) read semantics and in these casees it is safe to implement {@link - * BlobContainer#getRegister} by simply reading the blob using this utility. - * + * Many blob stores have consistent read-after-write semantics and in these cases it is safe to implement + * {@link BlobContainer#getRegister} by simply reading the blob using this utility. + *

* NB it is not safe for the supplied stream to resume a partial downloads, because the resumed stream may see a different state from * the original. */ diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index 9f4c5b80ccf23..68406bb6730a0 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -175,6 +175,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { MapperService.INDEX_MAPPING_DIMENSION_FIELDS_LIMIT_SETTING, MapperService.INDEX_MAPPING_FIELD_NAME_LENGTH_LIMIT_SETTING, MapperService.INDEX_MAPPER_DYNAMIC_SETTING, + MapperService.INDEX_MAPPING_META_LENGTH_LIMIT_SETTING, BitsetFilterCache.INDEX_LOAD_RANDOM_ACCESS_FILTERS_EAGERLY_SETTING, IndexModule.INDEX_STORE_TYPE_SETTING, IndexModule.INDEX_STORE_PRE_LOAD_SETTING, @@ -206,6 +207,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING, IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING, InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT, + // IndexSettings.INDEX_MAPPING_META_LENGTH_LIMIT_SETTING, // validate that built-in similarities don't get redefined Setting.groupSetting("index.similarity.", (s) -> { diff --git a/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java b/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java index 0a7451702ec66..2eeb8c470b5d8 100644 --- a/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java +++ b/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java @@ -16,12 +16,14 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -133,10 +135,11 @@ protected HealthNode createTask( * Returns the node id from the eligible health nodes */ @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( HealthNodeTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, candidateNodes, DiscoveryNode::canContainData); if (discoveryNode == null) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java index 22b198b10a7ad..f419d87d008fe 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java @@ -98,11 +98,11 @@ protected void writeExtent(BlockLoader.IntBuilder builder, Extent extent) { public BlockLoader.AllReader reader(LeafReaderContext context) throws IOException { return new BlockLoader.AllReader() { @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { var binaryDocValues = context.reader().getBinaryDocValues(fieldName); var reader = new GeometryDocValueReader(); - try (var builder = factory.ints(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (var builder = factory.ints(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(binaryDocValues, docs.get(i), reader, builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 410679fa9cfd5..f95e35a5d0845 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -124,10 +124,10 @@ private static class SingletonLongs extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -173,9 +173,9 @@ private static class Longs extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.LongBuilder builder = factory.longsFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -259,10 +259,10 @@ private static class SingletonInts extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -308,9 +308,9 @@ private static class Ints extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.IntBuilder builder = factory.intsFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -408,10 +408,10 @@ private static class SingletonDoubles extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -461,9 +461,9 @@ private static class Doubles extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.DoubleBuilder builder = factory.doublesFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -544,10 +544,10 @@ private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { // Doubles from doc values ensures that the values are in order - try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count() - offset, dimensions)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < iterator.docID()) { throw new IllegalStateException("docs within same block must be in order"); @@ -645,19 +645,19 @@ private BlockLoader.Block readSingleDoc(BlockFactory factory, int docId) throws if (ordinals.advanceExact(docId)) { BytesRef v = ordinals.lookupOrd(ordinals.ordValue()); // the returned BytesRef can be reused - return factory.constantBytes(BytesRef.deepCopyOf(v)); + return factory.constantBytes(BytesRef.deepCopyOf(v), 1); } else { - return factory.constantNulls(); + return factory.constantNulls(1); } } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - if (docs.count() == 1) { - return readSingleDoc(factory, docs.get(0)); + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + if (docs.count() - offset == 1) { + return readSingleDoc(factory, docs.get(offset)); } - try (var builder = factory.singletonOrdinalsBuilder(ordinals, docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (var builder = factory.singletonOrdinalsBuilder(ordinals, docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < ordinals.docID()) { throw new IllegalStateException("docs within same block must be in order"); @@ -700,12 +700,12 @@ private static class Ordinals extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - if (docs.count() == 1) { - return readSingleDoc(factory, docs.get(0)); + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + if (docs.count() - offset == 1) { + return readSingleDoc(factory, docs.get(offset)); } - try (var builder = factory.sortedSetOrdinalsBuilder(ordinals, docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (var builder = factory.sortedSetOrdinalsBuilder(ordinals, docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < ordinals.docID()) { throw new IllegalStateException("docs within same block must be in order"); @@ -736,12 +736,12 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build private BlockLoader.Block readSingleDoc(BlockFactory factory, int docId) throws IOException { if (ordinals.advanceExact(docId) == false) { - return factory.constantNulls(); + return factory.constantNulls(1); } int count = ordinals.docValueCount(); if (count == 1) { BytesRef v = ordinals.lookupOrd(ordinals.nextOrd()); - return factory.constantBytes(BytesRef.deepCopyOf(v)); + return factory.constantBytes(BytesRef.deepCopyOf(v), 1); } try (var builder = factory.bytesRefsFromDocValues(count)) { builder.beginPositionEntry(); @@ -816,9 +816,9 @@ private static class BytesRefsFromBinary extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -915,9 +915,9 @@ private static class DenseVectorFromBinary extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count() - offset, dimensions)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < docID) { throw new IllegalStateException("docs within same block must be in order"); @@ -999,10 +999,10 @@ private static class SingletonBooleans extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count())) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count() - offset)) { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -1048,9 +1048,9 @@ private static class Booleans extends BlockDocValuesReader { } @Override - public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException { - try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (BlockLoader.BooleanBuilder builder = factory.booleansFromDocValues(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < this.docID) { throw new IllegalStateException("docs within same block must be in order"); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java index fc5d28509b3b6..a4a498e4048db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockLoader.java @@ -43,7 +43,7 @@ interface ColumnAtATimeReader extends Reader { /** * Reads the values of all documents in {@code docs}. */ - BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException; + BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throws IOException; } interface RowStrideReader extends Reader { @@ -149,8 +149,8 @@ public String toString() { */ class ConstantNullsReader implements AllReader { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - return factory.constantNulls(); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + return factory.constantNulls(docs.count() - offset); } @Override @@ -183,8 +183,8 @@ public Builder builder(BlockFactory factory, int expectedCount) { public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) { return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) { - return factory.constantBytes(value); + public Block read(BlockFactory factory, Docs docs, int offset) { + return factory.constantBytes(value, docs.count() - offset); } @Override @@ -261,8 +261,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - return reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + return reader.read(factory, docs, offset); } @Override @@ -408,13 +408,13 @@ interface BlockFactory { /** * Build a block that contains only {@code null}. */ - Block constantNulls(); + Block constantNulls(int count); /** * Build a block that contains {@code value} repeated * {@code size} times. */ - Block constantBytes(BytesRef value); + Block constantBytes(BytesRef value, int count); /** * Build a reader for reading {@link SortedDocValues} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java index a3b10ea901395..3a1a805a25b64 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BooleanScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't emit falses before trues so we conform to the doc values contract and can use booleansFromDocValues - try (BlockLoader.BooleanBuilder builder = factory.booleans(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BooleanBuilder builder = factory.booleans(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java index fb97b0f84c50f..0ec899e19a1cd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DateScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't sort the values sort, so we can't use factory.longsFromDocValues - try (BlockLoader.LongBuilder builder = factory.longs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.LongBuilder builder = factory.longs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java index d762acda9f7e4..f01cc65775e6e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DoubleScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't sort the values sort, so we can't use factory.doublesFromDocValues - try (BlockLoader.DoubleBuilder builder = factory.doubles(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.DoubleBuilder builder = factory.doubles(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index a43575b8f990c..847f4740e21fb 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1298,7 +1298,7 @@ public static Parameter> metaParam() { "meta", true, Map::of, - (n, c, o) -> TypeParsers.parseMeta(n, o), + (n, c, o) -> TypeParsers.parseMeta(n, o, c), m -> m.fieldType().meta(), XContentBuilder::stringStringMap, Objects::toString diff --git a/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java index 48d78129b8781..b232a8e1fc45a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/IpScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use bytesRefsFromDocValues - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java index cfc7045a55513..220bba3d3c079 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordScriptBlockDocValuesReader.java @@ -51,10 +51,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use bytesRefsFromDocValues - try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java index 0a1a8a86154ab..9c947a17de7b6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/LongScriptBlockDocValuesReader.java @@ -49,10 +49,10 @@ public int docId() { } @Override - public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs) throws IOException { + public BlockLoader.Block read(BlockLoader.BlockFactory factory, BlockLoader.Docs docs, int offset) throws IOException { // Note that we don't pre-sort our output so we can't use longsFromDocValues - try (BlockLoader.LongBuilder builder = factory.longs(docs.count())) { - for (int i = 0; i < docs.count(); i++) { + try (BlockLoader.LongBuilder builder = factory.longs(docs.count() - offset)) { + for (int i = offset; i < docs.count(); i++) { read(docs.get(i), builder); } return builder.build(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java index 7958fd8e51525..603e61d9ff4ba 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java @@ -175,6 +175,13 @@ public boolean isAutoUpdate() { Property.IndexScope, Property.IndexSettingDeprecatedInV7AndRemovedInV8 ); + public static final Setting INDEX_MAPPING_META_LENGTH_LIMIT_SETTING = Setting.intSetting( + "index.mapping.meta.length_limit", + 500, + 0, + Property.Dynamic, + Property.IndexScope + ); private final IndexAnalyzers indexAnalyzers; private final MappingParser mappingParser; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/TypeParsers.java b/server/src/main/java/org/elasticsearch/index/mapper/TypeParsers.java index 7be9d658297ca..67bff1f9bd7e4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/TypeParsers.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/TypeParsers.java @@ -24,6 +24,7 @@ import static org.elasticsearch.common.xcontent.support.XContentMapValues.isArray; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeStringValue; +import static org.elasticsearch.index.mapper.MapperService.INDEX_MAPPING_META_LENGTH_LIMIT_SETTING; public class TypeParsers { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(TypeParsers.class); @@ -31,7 +32,7 @@ public class TypeParsers { /** * Parse the {@code meta} key of the mapping. */ - public static Map parseMeta(String name, Object metaObject) { + public static Map parseMeta(String name, Object metaObject, MappingParserContext parserContext) { if (metaObject instanceof Map == false) { throw new MapperParsingException( "[meta] must be an object, got " + metaObject.getClass().getSimpleName() + "[" + metaObject + "] for field [" + name + "]" @@ -52,11 +53,18 @@ public static Map parseMeta(String name, Object metaObject) { ); } } + int metaValueLengthLimit = INDEX_MAPPING_META_LENGTH_LIMIT_SETTING.get(parserContext.getIndexSettings().getSettings()); for (Object value : meta.values()) { if (value instanceof String sValue) { - if (sValue.codePointCount(0, sValue.length()) > 50) { + if (sValue.codePointCount(0, sValue.length()) > metaValueLengthLimit) { throw new MapperParsingException( - "[meta] values can't be longer than 50 chars, but got [" + value + "] for field [" + name + "]" + "[meta] values can't be longer than " + + metaValueLengthLimit + + " chars, but got [" + + value + + "] for field [" + + name + + "]" ); } } else if (value == null) { diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 8fdc53e6b795f..528601f201fee 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -950,6 +950,7 @@ public synchronized void verifyIndexMetadata(IndexMetadata metadata, IndexMetada @Override public void createShard( + final ProjectId projectId, final ShardRouting shardRouting, final PeerRecoveryTargetService recoveryTargetService, final PeerRecoveryTargetService.RecoveryListener recoveryListener, @@ -968,26 +969,29 @@ public void createShard( RecoveryState recoveryState = indexService.createRecoveryState(shardRouting, targetNode, sourceNode); IndexShard indexShard = indexService.createShard(shardRouting, globalCheckpointSyncer, retentionLeaseSyncer); indexShard.addShardFailureCallback(onShardFailure); - indexShard.startRecovery( - recoveryState, - recoveryTargetService, - postRecoveryMerger.maybeMergeAfterRecovery(indexService.getMetadata(), shardRouting, recoveryListener), - repositoriesService, - (mapping, listener) -> { - assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS - : "mapping update consumer only required by local shards recovery"; - AcknowledgedRequest putMappingRequestAcknowledgedRequest = new PutMappingRequest() - // concrete index - no name clash, it uses uuid - .setConcreteIndex(shardRouting.index()) - .source(mapping.source().string(), XContentType.JSON); - client.execute( - TransportAutoPutMappingAction.TYPE, - putMappingRequestAcknowledgedRequest.ackTimeout(TimeValue.MAX_VALUE).masterNodeTimeout(TimeValue.MAX_VALUE), - new RefCountAwareThreadedActionListener<>(threadPool.generic(), listener.map(ignored -> null)) - ); - }, - this, - clusterStateVersion + projectResolver.executeOnProject( + projectId, + () -> indexShard.startRecovery( + recoveryState, + recoveryTargetService, + postRecoveryMerger.maybeMergeAfterRecovery(indexService.getMetadata(), shardRouting, recoveryListener), + repositoriesService, + (mapping, listener) -> { + assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS + : "mapping update consumer only required by local shards recovery"; + AcknowledgedRequest putMappingRequestAcknowledgedRequest = new PutMappingRequest() + // concrete index - no name clash, it uses uuid + .setConcreteIndex(shardRouting.index()) + .source(mapping.source().string(), XContentType.JSON); + client.execute( + TransportAutoPutMappingAction.TYPE, + putMappingRequestAcknowledgedRequest.ackTimeout(TimeValue.MAX_VALUE).masterNodeTimeout(TimeValue.MAX_VALUE), + new RefCountAwareThreadedActionListener<>(threadPool.generic(), listener.map(ignored -> null)) + ); + }, + this, + clusterStateVersion + ) ); } diff --git a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java index 9862ff9d30338..95c462072ae5a 100644 --- a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.ClusterStateApplier; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -781,6 +782,7 @@ private void createShardWhenLockAvailable( try { logger.debug("{} creating shard with primary term [{}], iteration [{}]", shardRouting.shardId(), primaryTerm, iteration); indicesService.createShard( + originalState.metadata().projectFor(shardRouting.index()).id(), shardRouting, recoveryTargetService, new RecoveryListener(shardRouting, primaryTerm), @@ -1330,6 +1332,7 @@ void removeIndex( /** * Creates a shard for the specified shard routing and starts recovery. * + * @param projectId the project for the shard * @param shardRouting the shard routing * @param recoveryTargetService recovery service for the target * @param recoveryListener a callback when recovery changes state (finishes or fails) @@ -1343,6 +1346,7 @@ void removeIndex( * @throws IOException if an I/O exception occurs when creating the shard */ void createShard( + ProjectId projectId, ShardRouting shardRouting, PeerRecoveryTargetService recoveryTargetService, PeerRecoveryTargetService.RecoveryListener recoveryListener, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java index 3274bf571d10a..a6857b82a747f 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import java.util.List; @@ -23,7 +24,13 @@ public interface InferenceServiceExtension { List getInferenceServiceFactories(); - record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {} + record InferenceServiceFactoryContext( + Client client, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings, + InferenceStats inferenceStats + ) {} interface Factory { /** diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index db31aafc8c190..b6f724e69d40f 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -121,7 +121,7 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para * - Key: {@link #MODEL_FIELD}, Value: modelId * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} */ - public static Params withMaxCompletionTokensTokens(String modelId, Params params) { + public static Params withMaxCompletionTokens(String modelId, Params params) { return new DelegatingMapParams( Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)), params diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java similarity index 65% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java rename to server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java index 17c91b81233fb..e73b1ad9c5ff6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; @@ -14,17 +16,17 @@ import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; -import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import java.util.HashMap; import java.util.Map; import java.util.Objects; -public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) { +public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongHistogram deploymentDuration) { public InferenceStats { Objects.requireNonNull(requestCount); Objects.requireNonNull(inferenceDuration); + Objects.requireNonNull(deploymentDuration); } public static InferenceStats create(MeterRegistry meterRegistry) { @@ -38,6 +40,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) { "es.inference.requests.time", "Inference API request counts for a particular service, task type, model ID", "ms" + ), + meterRegistry.registerLongHistogram( + "es.inference.trained_model.deployment.time", + "Inference API time spent waiting for Trained Model Deployments", + "ms" ) ); } @@ -54,8 +61,8 @@ public static Map modelAttributes(Model model) { return modelAttributesMap; } - public static Map routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) { - return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest); + public static Map routingAttributes(boolean hasBeenRerouted, String nodeIdHandlingRequest) { + return Map.of("rerouted", hasBeenRerouted, "node_id", nodeIdHandlingRequest); } public static Map modelAttributes(UnparsedModel model) { @@ -73,4 +80,11 @@ public static Map responseAttributes(@Nullable Throwable throwab return Map.of("error.type", throwable.getClass().getSimpleName()); } + + public static Map modelAndResponseAttributes(Model model, @Nullable Throwable throwable) { + var metricAttributes = new HashMap(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(throwable)); + return metricAttributes; + } } diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java index ceeb1a4e27f1b..f3a25caf79bb6 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java @@ -171,9 +171,9 @@ public ClusterState execute(ClusterState currentState) { assert (projectId == null && taskExecutor.scope() == PersistentTasksExecutor.Scope.CLUSTER) || (projectId != null && taskExecutor.scope() == PersistentTasksExecutor.Scope.PROJECT) : "inconsistent project-id [" + projectId + "] and task scope [" + taskExecutor.scope() + "]"; - taskExecutor.validate(taskParams, currentState); + taskExecutor.validate(taskParams, currentState, projectId); - Assignment assignment = createAssignment(taskName, taskParams, currentState); + Assignment assignment = createAssignment(taskName, taskParams, currentState, projectId); logger.debug("creating {} persistent task [{}] with assignment [{}]", taskTypeString(projectId), taskName, assignment); return builder.addTask(taskId, taskName, taskParams, assignment).buildAndUpdate(currentState, projectId); } @@ -449,7 +449,8 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) private Assignment createAssignment( final String taskName, final Params taskParams, - final ClusterState currentState + final ClusterState currentState, + @Nullable final ProjectId projectId ) { PersistentTasksExecutor persistentTasksExecutor = registry.getPersistentTaskExecutorSafe(taskName); @@ -468,7 +469,7 @@ private Assignment createAssignment( // Task assignment should not rely on node order Randomness.shuffle(candidateNodes); - final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState); + final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState, projectId); assert assignment != null : "getAssignment() should always return an Assignment object, containing a node or a reason why not"; assert (assignment.getExecutorNode() == null || currentState.metadata().nodeShutdowns().contains(assignment.getExecutorNode()) == false) @@ -540,8 +541,8 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) * persistent tasks changed. */ boolean shouldReassignPersistentTasks(final ClusterChangedEvent event) { - final List allTasks = PersistentTasks.getAllTasks(event.state()).map(Tuple::v2).toList(); - if (allTasks.isEmpty()) { + var projectIdToTasksIterator = PersistentTasks.getAllTasks(event.state()).iterator(); + if (projectIdToTasksIterator.hasNext() == false) { return false; } @@ -553,10 +554,16 @@ boolean shouldReassignPersistentTasks(final ClusterChangedEvent event) { || event.metadataChanged() || masterChanged) { - for (PersistentTasks tasks : allTasks) { - for (PersistentTask task : tasks.tasks()) { + while (projectIdToTasksIterator.hasNext()) { + var projectIdToTasks = projectIdToTasksIterator.next(); + for (PersistentTask task : projectIdToTasks.v2().tasks()) { if (needsReassignment(task.getAssignment(), event.state().nodes())) { - Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), event.state()); + Assignment assignment = createAssignment( + task.getTaskName(), + task.getParams(), + event.state(), + projectIdToTasks.v1() + ); if (Objects.equals(assignment, task.getAssignment()) == false) { return true; } @@ -602,7 +609,7 @@ private ClusterState reassignClusterOrSingleProjectTasks(@Nullable final Project // We need to check if removed nodes were running any of the tasks and reassign them for (PersistentTask task : tasks.tasks()) { if (needsReassignment(task.getAssignment(), nodes)) { - Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), clusterState); + Assignment assignment = createAssignment(task.getTaskName(), task.getParams(), clusterState, projectId); if (Objects.equals(assignment, task.getAssignment()) == false) { logger.trace( "reassigning {} task {} from node {} to node {}", diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java index b58ef7523bf99..96c0767fe65f8 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java @@ -10,6 +10,7 @@ package org.elasticsearch.persistent; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; @@ -63,7 +64,31 @@ public Scope scope() { *

* The default implementation returns the least loaded data node from amongst the collection of candidate nodes */ - public Assignment getAssignment(Params params, Collection candidateNodes, ClusterState clusterState) { + public final Assignment getAssignment( + Params params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { + assert (scope() == Scope.PROJECT && projectId != null) || (scope() == Scope.CLUSTER && projectId == null) + : "inconsistent project-id [" + projectId + "] and task scope [" + scope() + "]"; + return doGetAssignment(params, candidateNodes, clusterState, projectId); + } + + /** + * Returns the node id where the params has to be executed, + *

+ * The default implementation returns the least loaded data node from amongst the collection of candidate nodes. + *

+ * If {@link #scope()} returns CLUSTER, then {@link ProjectId} will be null. + * If {@link #scope()} returns PROJECT, then {@link ProjectId} will not be null. + */ + protected Assignment doGetAssignment( + Params params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { DiscoveryNode discoveryNode = selectLeastLoadedNode(clusterState, candidateNodes, DiscoveryNode::canContainData); if (discoveryNode == null) { return NO_NODE_FOUND; @@ -105,7 +130,7 @@ protected DiscoveryNode selectLeastLoadedNode( *

* Throws an exception if the supplied params cannot be executed on the cluster in the current state. */ - public void validate(Params params, ClusterState clusterState) {} + public void validate(Params params, ClusterState clusterState, @Nullable ProjectId projectId) {} /** * Creates a AllocatedPersistentTask for communicating with task manager diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index cab15fffa3fd0..b0408ac3c60cc 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -433,6 +433,16 @@ private void maybeLogSlowMessage(boolean success) { } }); } catch (RuntimeException ex) { + logger.error( + Strings.format( + "unexpected exception calling sendMessage for transport message [%s] of size [%d] on [%s]", + messageDescription.get(), + messageSize, + channel + ), + ex + ); + assert Thread.currentThread().getName().startsWith("TEST-") : ex; channel.setCloseException(ex); Releasables.closeExpectNoException(() -> listener.onFailure(ex), () -> CloseableChannel.closeChannel(channel)); throw ex; diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java index fdb597b47c137..d1816c7fc1687 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java @@ -247,16 +247,16 @@ public Map groupIndices(Set remoteClusterNames, } public Map groupIndices(IndicesOptions indicesOptions, String[] indices, boolean returnLocalAll) { - return groupIndices(getRemoteClusterNames(), indicesOptions, indices, returnLocalAll); + return groupIndices(getRegisteredRemoteClusterNames(), indicesOptions, indices, returnLocalAll); } public Map groupIndices(IndicesOptions indicesOptions, String[] indices) { - return groupIndices(getRemoteClusterNames(), indicesOptions, indices, true); + return groupIndices(getRegisteredRemoteClusterNames(), indicesOptions, indices, true); } @Override public Set getConfiguredClusters() { - return getRemoteClusterNames(); + return getRegisteredRemoteClusterNames(); } /** @@ -270,7 +270,6 @@ boolean isRemoteClusterRegistered(String clusterName) { * Returns the registered remote cluster names. */ public Set getRegisteredRemoteClusterNames() { - // remoteClusters is unmodifiable so its key set will be unmodifiable too return remoteClusters.keySet(); } @@ -355,10 +354,6 @@ public RemoteClusterConnection getRemoteClusterConnection(String cluster) { return connection; } - Set getRemoteClusterNames() { - return this.remoteClusters.keySet(); - } - @Override public void listenForUpdates(ClusterSettings clusterSettings) { super.listenForUpdates(clusterSettings); @@ -648,7 +643,7 @@ public RemoteClusterClient getRemoteClusterClient( "this node does not have the " + DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE.roleName() + " role" ); } - if (transportService.getRemoteClusterService().getRemoteClusterNames().contains(clusterAlias) == false) { + if (transportService.getRemoteClusterService().getRegisteredRemoteClusterNames().contains(clusterAlias) == false) { throw new NoSuchRemoteClusterException(clusterAlias); } return new RemoteClusterAwareClient(transportService, clusterAlias, responseExecutor, switch (disconnectedStrategy) { diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java index 814aa102ce284..e0e749aaa2360 100644 --- a/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/ClusterInfoTests.java @@ -44,10 +44,20 @@ public static ClusterInfo randomClusterInfo() { randomRoutingToDataPath(), randomReservedSpace(), randomNodeHeapUsage(), - randomNodeUsageStatsForThreadPools() + randomNodeUsageStatsForThreadPools(), + randomShardWriteLoad() ); } + private static Map randomShardWriteLoad() { + final int numEntries = randomIntBetween(0, 128); + final Map builder = new HashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + builder.put(randomShardId(), randomDouble()); + } + return builder; + } + private static Map randomNodeHeapUsage() { int numEntries = randomIntBetween(0, 128); Map nodeHeapUsage = new HashMap<>(numEntries); diff --git a/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java b/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java index 80c2395ae9644..3551971f0daa0 100644 --- a/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingHelper; import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.IndexingStats; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.store.StoreStats; @@ -107,6 +108,7 @@ public void testFillShardLevelInfo() { Path test0Path = createTempDir().resolve("indices").resolve(index.getUUID()).resolve("0"); CommonStats commonStats0 = new CommonStats(); commonStats0.store = new StoreStats(100, 101, 0L); + commonStats0.indexing = randomIndexingStats(); ShardRouting test_1 = ShardRouting.newUnassigned( new ShardId(index, 1), false, @@ -119,8 +121,10 @@ public void testFillShardLevelInfo() { Path test1Path = createTempDir().resolve("indices").resolve(index.getUUID()).resolve("1"); CommonStats commonStats1 = new CommonStats(); commonStats1.store = new StoreStats(1000, 1001, 0L); + commonStats1.indexing = randomIndexingStats(); CommonStats commonStats2 = new CommonStats(); commonStats2.store = new StoreStats(1000, 999, 0L); + commonStats2.indexing = randomIndexingStats(); ShardStats[] stats = new ShardStats[] { new ShardStats(test_0, new ShardPath(false, test0Path, test0Path, test_0.shardId()), commonStats0, null, null, null, false, 0), new ShardStats(test_1, new ShardPath(false, test1Path, test1Path, test_1.shardId()), commonStats1, null, null, null, false, 0), @@ -135,9 +139,17 @@ public void testFillShardLevelInfo() { 0 ) }; Map shardSizes = new HashMap<>(); + Map shardWriteLoads = new HashMap<>(); Map shardDataSetSizes = new HashMap<>(); Map routingToPath = new HashMap<>(); - InternalClusterInfoService.buildShardLevelInfo(stats, shardSizes, shardDataSetSizes, routingToPath, new HashMap<>()); + InternalClusterInfoService.buildShardLevelInfo( + stats, + shardWriteLoads, + shardSizes, + shardDataSetSizes, + routingToPath, + new HashMap<>() + ); assertThat( shardSizes, @@ -158,6 +170,41 @@ public void testFillShardLevelInfo() { hasEntry(ClusterInfo.NodeAndShard.from(test_1), test1Path.getParent().getParent().getParent().toAbsolutePath().toString()) ) ); + + assertThat( + shardWriteLoads, + equalTo( + Map.of( + test_0.shardId(), + commonStats0.indexing.getTotal().getPeakWriteLoad(), + test_1.shardId(), + Math.max(commonStats1.indexing.getTotal().getPeakWriteLoad(), commonStats2.indexing.getTotal().getPeakWriteLoad()) + ) + ) + ); + } + + private IndexingStats randomIndexingStats() { + return new IndexingStats( + new IndexingStats.Stats( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomMillisUpToYear9999(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomBoolean(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeLong(), + randomDoubleBetween(0d, 10d, true), + randomDoubleBetween(0d, 10d, true) + ) + ); } public void testLeastAndMostAvailableDiskSpace() { diff --git a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java index 97b8b86ea6a85..7a474f528897c 100644 --- a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistrySerializationTests.java @@ -9,6 +9,8 @@ package org.elasticsearch.cluster.project; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.Diff; @@ -17,10 +19,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.SimpleDiffableWireSerializationTestCase; +import org.elasticsearch.test.TransportVersionUtils; import java.io.IOException; import java.util.stream.IntStream; +import static org.hamcrest.Matchers.equalTo; + public class ProjectStateRegistrySerializationTests extends SimpleDiffableWireSerializationTestCase { @Override @@ -56,7 +61,7 @@ protected ClusterState.Custom mutateInstance(ClusterState.Custom instance) throw private ProjectStateRegistry mutate(ProjectStateRegistry instance) { if (randomBoolean() && instance.size() > 0) { // Remove or mutate a project's settings or deletion flag - var projectId = randomFrom(instance.getProjectsSettings().keySet()); + var projectId = randomFrom(instance.knownProjects()); var builder = ProjectStateRegistry.builder(instance); builder.putProjectSettings(projectId, randomSettings()); if (randomBoolean()) { @@ -86,4 +91,11 @@ public static Settings randomSettings() { IntStream.range(0, randomIntBetween(1, 5)).forEach(i -> builder.put(randomIdentifier(), randomIdentifier())); return builder.build(); } + + public void testProjectStateRegistryBwcSerialization() throws IOException { + ProjectStateRegistry projectStateRegistry = randomProjectStateRegistry(); + TransportVersion oldVersion = TransportVersionUtils.getPreviousVersion(TransportVersions.PROJECT_STATE_REGISTRY_ENTRY); + ClusterState.Custom serialized = copyInstance(projectStateRegistry, oldVersion); + assertThat(serialized, equalTo(projectStateRegistry)); + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java index a4d4dd6f2b154..5fdc73e9a6cf8 100644 --- a/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/project/ProjectStateRegistryTests.java @@ -9,10 +9,13 @@ package org.elasticsearch.cluster.project; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; -import org.hamcrest.Matchers; import static org.elasticsearch.cluster.project.ProjectStateRegistrySerializationTests.randomSettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; public class ProjectStateRegistryTests extends ESTestCase { @@ -26,22 +29,63 @@ public void testBuilder() { ); var projectStateRegistry = builder.build(); var gen1 = projectStateRegistry.getProjectsMarkedForDeletionGeneration(); - assertThat(gen1, Matchers.equalTo(projectsUnderDeletion.isEmpty() ? 0L : 1L)); + assertThat(gen1, equalTo(projectsUnderDeletion.isEmpty() ? 0L : 1L)); projectStateRegistry = ProjectStateRegistry.builder(projectStateRegistry).markProjectForDeletion(randomFrom(projects)).build(); var gen2 = projectStateRegistry.getProjectsMarkedForDeletionGeneration(); - assertThat(gen2, Matchers.equalTo(gen1 + 1)); + assertThat(gen2, equalTo(gen1 + 1)); if (projectsUnderDeletion.isEmpty() == false) { // re-adding the same projectId should not change the generation projectStateRegistry = ProjectStateRegistry.builder(projectStateRegistry) .markProjectForDeletion(randomFrom(projectsUnderDeletion)) .build(); - assertThat(projectStateRegistry.getProjectsMarkedForDeletionGeneration(), Matchers.equalTo(gen2)); + assertThat(projectStateRegistry.getProjectsMarkedForDeletionGeneration(), equalTo(gen2)); } var unknownProjectId = randomUniqueProjectId(); var throwingBuilder = ProjectStateRegistry.builder(projectStateRegistry).markProjectForDeletion(unknownProjectId); assertThrows(IllegalArgumentException.class, throwingBuilder::build); } + + public void testDiff() { + ProjectStateRegistry originalRegistry = ProjectStateRegistry.builder() + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .build(); + + ProjectId newProjectId = randomUniqueProjectId(); + Settings newSettings = randomSettings(); + ProjectId projectToMarkForDeletion = randomFrom(originalRegistry.knownProjects()); + ProjectId projectToModifyId = randomFrom(originalRegistry.knownProjects()); + Settings modifiedSettings = randomSettings(); + + ProjectStateRegistry modifiedRegistry = ProjectStateRegistry.builder(originalRegistry) + .putProjectSettings(newProjectId, newSettings) + .markProjectForDeletion(projectToMarkForDeletion) + .putProjectSettings(projectToModifyId, modifiedSettings) + .build(); + + var diff = modifiedRegistry.diff(originalRegistry); + var appliedRegistry = (ProjectStateRegistry) diff.apply(originalRegistry); + + assertThat(appliedRegistry, equalTo(modifiedRegistry)); + assertThat(appliedRegistry.size(), equalTo(originalRegistry.size() + 1)); + assertTrue(appliedRegistry.knownProjects().contains(newProjectId)); + assertTrue(appliedRegistry.isProjectMarkedForDeletion(projectToMarkForDeletion)); + assertThat(appliedRegistry.getProjectSettings(newProjectId), equalTo(newSettings)); + assertThat(appliedRegistry.getProjectSettings(projectToModifyId), equalTo(modifiedSettings)); + } + + public void testDiffNoChanges() { + ProjectStateRegistry originalRegistry = ProjectStateRegistry.builder() + .putProjectSettings(randomUniqueProjectId(), randomSettings()) + .build(); + + var diff = originalRegistry.diff(originalRegistry); + var appliedRegistry = (ProjectStateRegistry) diff.apply(originalRegistry); + + assertThat(appliedRegistry, sameInstance(originalRegistry)); + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java index 37646d376f8fd..1f8d59a958bfe 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java @@ -621,6 +621,7 @@ public void testUnassignedAllocationPredictsDiskUsage() { ImmutableOpenMap.of(), ImmutableOpenMap.of(), ImmutableOpenMap.of(), + ImmutableOpenMap.of(), ImmutableOpenMap.of() ); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java index f85c2678e04e7..c4ca84e6e977f 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java @@ -1406,7 +1406,17 @@ static class DevNullClusterInfo extends ClusterInfo { Map shardSizes, Map reservedSpace ) { - super(leastAvailableSpaceUsage, mostAvailableSpaceUsage, shardSizes, Map.of(), Map.of(), reservedSpace, Map.of(), Map.of()); + super( + leastAvailableSpaceUsage, + mostAvailableSpaceUsage, + shardSizes, + Map.of(), + Map.of(), + reservedSpace, + Map.of(), + Map.of(), + Map.of() + ); } @Override diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java index 117afe0cec877..debb4343931d7 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java @@ -110,6 +110,7 @@ public void testCanAllocateUsesMaxAvailableSpace() { Map.of(), Map.of(), Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( @@ -183,6 +184,7 @@ private void doTestCannotAllocateDueToLackOfDiskResources(boolean testMaxHeadroo Map.of(), Map.of(), Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( @@ -330,6 +332,7 @@ private void doTestCanRemainUsesLeastAvailableSpace(boolean testMaxHeadroom) { shardRoutingMap, Map.of(), Map.of(), + Map.of(), Map.of() ); RoutingAllocation allocation = new RoutingAllocation( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java index 73d76ad48c955..6a81a93923abc 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapperTests.java @@ -125,7 +125,7 @@ private static void testBoundsBlockLoaderAux( for (int j : array) { expected.add(visitor.apply(geometries.get(j + currentIndex)).get()); } - try (var block = (TestBlock) loader.reader(leaf).read(TestBlock.factory(leafReader.numDocs()), TestBlock.docs(array))) { + try (var block = (TestBlock) loader.reader(leaf).read(TestBlock.factory(), TestBlock.docs(array), 0)) { for (int i = 0; i < block.size(); i++) { intArrayResults.add(block.get(i)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java b/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java index 357ada3ad656d..1fa9c85a5c738 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/BlockSourceReaderTests.java @@ -59,7 +59,7 @@ private void loadBlock(LeafReaderContext ctx, Consumer test) throws I StoredFieldLoader.fromSpec(loader.rowStrideStoredFieldSpec()).getLoader(ctx, null), loader.rowStrideStoredFieldSpec().requiresSource() ? SourceLoader.FROM_STORED_SOURCE.leaf(ctx.reader(), null) : null ); - BlockLoader.Builder builder = loader.builder(TestBlock.factory(ctx.reader().numDocs()), 1); + BlockLoader.Builder builder = loader.builder(TestBlock.factory(), 1); storedFields.advanceTo(0); reader.read(0, storedFields, builder); TestBlock block = (TestBlock) builder.build(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java index ce9a9bc0688f3..54656ab1af3ee 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/BooleanScriptFieldTypeTests.java @@ -446,7 +446,8 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { BooleanScriptFieldType fieldType = build("xor_param", Map.of("param", false), OnScriptError.FAIL); List expected = List.of(false, true); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(expected.subList(1, 2))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(expected)); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java index 3d8ed5ea60262..1eb0ba07d58e2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DateScriptFieldTypeTests.java @@ -493,9 +493,10 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { DateScriptFieldType fieldType = build("add_days", Map.of("days", 1), OnScriptError.FAIL); assertThat( - blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), + blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(1595518581354L, 1595518581355L)) ); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(1595518581355L))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(1595518581354L, 1595518581355L))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java index 140137015d98a..b1cda53876993 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DoubleScriptFieldTypeTests.java @@ -262,7 +262,8 @@ public void testBlockLoader() throws IOException { ); try (DirectoryReader reader = iw.getReader()) { DoubleScriptFieldType fieldType = build("add_param", Map.of("param", 1), OnScriptError.FAIL); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(List.of(2d, 3d))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(2d, 3d))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(3d))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(2d, 3d))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java index 281d2993fa29c..7e9a236f6cc74 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/IpScriptFieldTypeTests.java @@ -273,7 +273,8 @@ public void testBlockLoader() throws IOException { new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.1"))), new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.1.1"))) ); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(expected)); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(expected.subList(1, 2))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(expected)); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java index 57d52991a6442..ccc8ccac4deb4 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/KeywordScriptFieldTypeTests.java @@ -409,9 +409,10 @@ public void testBlockLoader() throws IOException { try (DirectoryReader reader = iw.getReader()) { KeywordScriptFieldType fieldType = build("append_param", Map.of("param", "-Suffix"), OnScriptError.FAIL); assertThat( - blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), + blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(new BytesRef("1-Suffix"), new BytesRef("2-Suffix"))) ); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(new BytesRef("2-Suffix")))); assertThat( blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(new BytesRef("1-Suffix"), new BytesRef("2-Suffix"))) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java index a8cb4d51c5efa..01f96a1a4b1be 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/LongScriptFieldTypeTests.java @@ -295,7 +295,8 @@ public void testBlockLoader() throws IOException { ); try (DirectoryReader reader = iw.getReader()) { LongScriptFieldType fieldType = build("add_param", Map.of("param", 1), OnScriptError.FAIL); - assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType), equalTo(List.of(2L, 3L))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 0), equalTo(List.of(2L, 3L))); + assertThat(blockLoaderReadValuesFromColumnAtATimeReader(reader, fieldType, 1), equalTo(List.of(3L))); assertThat(blockLoaderReadValuesFromRowStrideReader(reader, fieldType), equalTo(List.of(2L, 3L))); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java b/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java index 9a30e7d696b68..25f22e131414c 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java @@ -38,6 +38,7 @@ import static org.elasticsearch.index.analysis.AnalysisRegistry.DEFAULT_ANALYZER_NAME; import static org.elasticsearch.index.analysis.AnalysisRegistry.DEFAULT_SEARCH_ANALYZER_NAME; import static org.elasticsearch.index.analysis.AnalysisRegistry.DEFAULT_SEARCH_QUOTED_ANALYZER_NAME; +import static org.elasticsearch.index.mapper.MapperService.INDEX_MAPPING_META_LENGTH_LIMIT_SETTING; import static org.hamcrest.core.IsEqual.equalTo; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -52,48 +53,32 @@ private static Map defaultAnalyzers() { return analyzers; } - public void testMultiFieldWithinMultiField() throws IOException { - - XContentBuilder mapping = XContentFactory.jsonBuilder() - .startObject() - .field("type", "keyword") - .startObject("fields") - .startObject("sub-field") - .field("type", "keyword") - .startObject("fields") - .startObject("sub-sub-field") - .field("type", "keyword") - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); + private Settings buildSettings() { + return Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .build(); + } + private MappingParserContext createParserContext(Settings settings) { Mapper.TypeParser typeParser = KeywordFieldMapper.PARSER; MapperService mapperService = mock(MapperService.class); IndexAnalyzers indexAnalyzers = IndexAnalyzers.of(defaultAnalyzers()); when(mapperService.getIndexAnalyzers()).thenReturn(indexAnalyzers); - Settings settings = Settings.builder() - .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) - .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) - .build(); IndexMetadata metadata = IndexMetadata.builder("test").settings(settings).build(); IndexSettings indexSettings = new IndexSettings(metadata, Settings.EMPTY); when(mapperService.getIndexSettings()).thenReturn(indexSettings); - // For indices created in 8.0 or later, we should throw an error. - Map fieldNodeCopy = XContentHelper.convertToMap(BytesReference.bytes(mapping), true, mapping.contentType()).v2(); - IndexVersion version = IndexVersionUtils.randomVersionBetween(random(), IndexVersions.V_8_0_0, IndexVersion.current()); TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( random(), TransportVersions.V_8_0_0, TransportVersion.current() ); - MappingParserContext context = new MappingParserContext( + return new MappingParserContext( null, type -> typeParser, type -> null, @@ -108,6 +93,29 @@ public void testMultiFieldWithinMultiField() throws IOException { throw new UnsupportedOperationException(); } ); + } + + public void testMultiFieldWithinMultiField() throws IOException { + + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .field("type", "keyword") + .startObject("fields") + .startObject("sub-field") + .field("type", "keyword") + .startObject("fields") + .startObject("sub-sub-field") + .field("type", "keyword") + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + // For indices created in 8.0 or later, we should throw an error. + Map fieldNodeCopy = XContentHelper.convertToMap(BytesReference.bytes(mapping), true, mapping.contentType()).v2(); + + MappingParserContext context = createParserContext(buildSettings()); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> { TextFieldMapper.PARSER.parse("textField", fieldNodeCopy, context); @@ -122,49 +130,80 @@ public void testMultiFieldWithinMultiField() throws IOException { } public void testParseMeta() { + MappingParserContext parserContext = createParserContext(buildSettings()); + { - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", 3)); + MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", 3, parserContext)); assertEquals("[meta] must be an object, got Integer[3] for field [foo]", e.getMessage()); } { MapperParsingException e = expectThrows( MapperParsingException.class, - () -> TypeParsers.parseMeta("foo", Map.of("veryloooooooooooongkey", 3L)) + () -> TypeParsers.parseMeta("foo", Map.of("veryloooooooooooongkey", 3L), parserContext) ); assertEquals("[meta] keys can't be longer than 20 chars, but got [veryloooooooooooongkey] for field [foo]", e.getMessage()); } { Map mapping = Map.of("foo1", 3L, "foo2", 4L, "foo3", 5L, "foo4", 6L, "foo5", 7L, "foo6", 8L); - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", mapping)); + MapperParsingException e = expectThrows( + MapperParsingException.class, + () -> TypeParsers.parseMeta("foo", mapping, parserContext) + ); assertEquals("[meta] can't have more than 5 entries, but got 6 on field [foo]", e.getMessage()); } { Map mapping = Map.of("foo", Map.of("bar", "baz")); - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", mapping)); + MapperParsingException e = expectThrows( + MapperParsingException.class, + () -> TypeParsers.parseMeta("foo", mapping, parserContext) + ); assertEquals("[meta] values can only be strings, but got Map1[{bar=baz}] for field [foo]", e.getMessage()); } { Map mapping = Map.of("bar", "baz", "foo", 3); - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", mapping)); + MapperParsingException e = expectThrows( + MapperParsingException.class, + () -> TypeParsers.parseMeta("foo", mapping, parserContext) + ); assertEquals("[meta] values can only be strings, but got Integer[3] for field [foo]", e.getMessage()); } { Map meta = new HashMap<>(); meta.put("foo", null); - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", meta)); + MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", meta, parserContext)); assertEquals("[meta] values can't be null (field [foo])", e.getMessage()); } { - String longString = IntStream.range(0, 51).mapToObj(Integer::toString).collect(Collectors.joining()); + String longString = IntStream.range(0, 501).mapToObj(Integer::toString).collect(Collectors.joining()); Map mapping = Map.of("foo", longString); - MapperParsingException e = expectThrows(MapperParsingException.class, () -> TypeParsers.parseMeta("foo", mapping)); - assertThat(e.getMessage(), Matchers.startsWith("[meta] values can't be longer than 50 chars")); + MapperParsingException e = expectThrows( + MapperParsingException.class, + () -> TypeParsers.parseMeta("foo", mapping, parserContext) + ); + assertThat(e.getMessage(), Matchers.startsWith("[meta] values can't be longer than 500 chars")); + } + + { + Settings otherSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(INDEX_MAPPING_META_LENGTH_LIMIT_SETTING.getKey(), 300) + .build(); + MappingParserContext otherParserContext = createParserContext(otherSettings); + String longString = IntStream.range(0, 301).mapToObj(Integer::toString).collect(Collectors.joining()); + Map mapping = Map.of("foo", longString); + MapperParsingException e = expectThrows( + MapperParsingException.class, + () -> TypeParsers.parseMeta("foo", mapping, otherParserContext) + ); + assertThat(e.getMessage(), Matchers.startsWith("[meta] values can't be longer than 300 chars")); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java similarity index 87% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java rename to server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java index f3800f91d9a54..0d71165823e89 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java +++ b/server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java @@ -1,11 +1,13 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.xpack.inference.telemetry; +package org.elasticsearch.inference.telemetry; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.inference.Model; @@ -22,9 +24,9 @@ import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.create; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.create; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.assertArg; @@ -35,9 +37,13 @@ public class InferenceStatsTests extends ESTestCase { + public static InferenceStats mockInferenceStats() { + return new InferenceStats(mock(), mock(), mock()); + } + public void testRecordWithModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -49,7 +55,7 @@ public void testRecordWithModel() { public void testRecordWithoutModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null))); @@ -63,7 +69,7 @@ public void testCreation() { public void testRecordDurationWithoutError() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -88,7 +94,7 @@ public void testRecordDurationWithoutError() { public void testRecordDurationWithElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -116,7 +122,7 @@ public void testRecordDurationWithElasticsearchStatusException() { public void testRecordDurationWithOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -138,7 +144,7 @@ public void testRecordDurationWithOtherException() { public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -163,7 +169,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() public void testRecordDurationWithUnparsedModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -187,7 +193,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() { public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -206,7 +212,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() public void testRecordDurationWithUnknownModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); diff --git a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java index b79f2f6517189..c354f7a7d1991 100644 --- a/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/persistent/PersistentTasksClusterServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.NodesShutdownMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -1087,7 +1088,12 @@ public Scope scope() { } @Override - public Assignment getAssignment(P params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + P params, + Collection candidateNodes, + ClusterState clusterState, + ProjectId projectId + ) { return fn.apply(params, candidateNodes, clusterState); } diff --git a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java index e3189de94b1a6..a6e059444e4da 100644 --- a/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java +++ b/server/src/test/java/org/elasticsearch/persistent/TestPersistentTasksPlugin.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.internal.ElasticsearchClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; @@ -326,12 +327,17 @@ public static void setNonClusterStateCondition(boolean nonClusterStateCondition) } @Override - public Assignment getAssignment(TestParams params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + TestParams params, + Collection candidateNodes, + ClusterState clusterState, + ProjectId projectId + ) { if (nonClusterStateCondition == false) { return new Assignment(null, "non cluster state condition prevents assignment"); } if (params == null || params.getExecutorNodeAttr() == null) { - return super.getAssignment(params, candidateNodes, clusterState); + return super.doGetAssignment(params, candidateNodes, clusterState, projectId); } else { DiscoveryNode executorNode = selectLeastLoadedNode( clusterState, diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java index b8de02293f734..76e280c987ae1 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java @@ -170,7 +170,7 @@ public void testGroupClusterIndices() throws IOException { assertFalse(service.isRemoteClusterRegistered("foo")); { Map> perClusterIndices = service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "cluster_1:bar", "cluster_2:foo:bar", @@ -191,7 +191,7 @@ public void testGroupClusterIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "foo:bar", "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "foo" } ) ); @@ -199,7 +199,7 @@ public void testGroupClusterIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "does_not_exist:*" } ) ); @@ -208,7 +208,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("cluster*:foo*", "foo", "-cluster_1:*", "*:boo")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(2, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -223,7 +226,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("*:*", "-clu*_1:*", "*:boo")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List cluster2 = perClusterIndices.get("cluster_2"); @@ -236,7 +242,10 @@ public void testGroupClusterIndices() throws IOException { new String[0] ); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -246,7 +255,10 @@ public void testGroupClusterIndices() throws IOException { { String[] indices = shuffledList(List.of("cluster*:*", "foo", "-*:*")).toArray(new String[0]); - Map> perClusterIndices = service.groupClusterIndices(service.getRemoteClusterNames(), indices); + Map> perClusterIndices = service.groupClusterIndices( + service.getRegisteredRemoteClusterNames(), + indices + ); assertEquals(1, perClusterIndices.size()); List localIndexes = perClusterIndices.get(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); assertNotNull(localIndexes); @@ -257,7 +269,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), // -cluster_1:foo* is not allowed, only -cluster_1:* new String[] { "cluster_1:bar", "-cluster_2:foo*", "cluster_1:test", "cluster_2:foo*", "foo" } ) @@ -271,7 +283,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), // -cluster_1:* will fail since cluster_1 was never included in order to qualify to be excluded new String[] { "-cluster_1:*", "cluster_2:foo*", "foo" } ) @@ -287,7 +299,7 @@ public void testGroupClusterIndices() throws IOException { { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), new String[] { "-cluster_1:*" }) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), new String[] { "-cluster_1:*" }) ); assertThat( e.getMessage(), @@ -300,7 +312,7 @@ public void testGroupClusterIndices() throws IOException { { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), new String[] { "-*:*" }) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), new String[] { "-*:*" }) ); assertThat( e.getMessage(), @@ -315,7 +327,7 @@ public void testGroupClusterIndices() throws IOException { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> service.groupClusterIndices(service.getRemoteClusterNames(), indices) + () -> service.groupClusterIndices(service.getRegisteredRemoteClusterNames(), indices) ); assertThat( e.getMessage(), @@ -394,7 +406,7 @@ public void testGroupIndices() throws IOException { expectThrows( NoSuchRemoteClusterException.class, () -> service.groupClusterIndices( - service.getRemoteClusterNames(), + service.getRegisteredRemoteClusterNames(), new String[] { "foo:bar", "cluster_1:bar", "cluster_2:foo:bar", "cluster_1:test", "cluster_2:foo*", "foo" } ) ); diff --git a/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java b/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java index 5c3a8b2c50e15..5f85dc8f3bec1 100644 --- a/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java +++ b/test/external-modules/error-query/src/javaRestTest/java/org/elasticsearch/test/esql/EsqlPartialResultsIT.java @@ -214,10 +214,6 @@ public void testFailureFromRemote() throws Exception { } public void testAllShardsFailed() throws Exception { - assumeTrue( - "fail functionality is not enabled", - clusterHasCapability("POST", "/_query", List.of(), List.of("fail_if_all_shards_fail")).orElse(false) - ); setupRemoteClusters(); populateIndices(); try { @@ -236,26 +232,6 @@ public void testAllShardsFailed() throws Exception { } } - public void testAllShardsFailedOldBehavior() throws Exception { - // TODO: drop this once we no longer support the old behavior - assumeFalse( - "fail functionality is enabled", - clusterHasCapability("POST", "/_query", List.of(), List.of("fail_if_all_shards_fail")).orElse(false) - ); - setupRemoteClusters(); - populateIndices(); - try { - Request request = new Request("POST", "/_query"); - request.setJsonEntity("{\"query\": \"FROM " + "*:failing*" + " | LIMIT 100\"}"); - request.addParameter("allow_partial_results", "true"); - Response resp = client().performRequest(request); - Map results = entityAsMap(resp); - assertThat(results.get("is_partial"), equalTo(true)); - } finally { - removeRemoteCluster(); - } - } - private void setupRemoteClusters() throws IOException { String settings = String.format(Locale.ROOT, """ { diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java index 1c237404a78cc..bc5e1f123fe81 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java @@ -22,6 +22,7 @@ static ElasticsearchCluster buildCluster() { .setting("xpack.security.enabled", "false") .setting("xpack.license.self_generated.type", "trial") .setting("esql.query.allow_partial_results", "false") + .setting("logger.org.elasticsearch.compute.lucene.read", "DEBUG") .jvmArg("-Xmx512m"); String javaVersion = JvmInfo.jvmInfo().version(); if (javaVersion.equals("20") || javaVersion.equals("21")) { diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 3912a63ef1514..893acbd22cc23 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -570,7 +570,7 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE } public void testFetchManyBigFields() throws IOException { - initManyBigFieldsIndex(100); + initManyBigFieldsIndex(100, "keyword"); Map response = fetchManyBigFields(100); ListMatcher columns = matchesList(); for (int f = 0; f < 1000; f++) { @@ -580,7 +580,7 @@ public void testFetchManyBigFields() throws IOException { } public void testFetchTooManyBigFields() throws IOException { - initManyBigFieldsIndex(500); + initManyBigFieldsIndex(500, "keyword"); // 500 docs is plenty to circuit break on most nodes assertCircuitBreaks(attempt -> fetchManyBigFields(attempt * 500)); } @@ -594,6 +594,58 @@ private Map fetchManyBigFields(int docs) throws IOException { return responseAsMap(query(query.toString(), "columns")); } + public void testAggManyBigTextFields() throws IOException { + int docs = 100; + int fields = 100; + initManyBigFieldsIndex(docs, "text"); + Map response = aggManyBigFields(fields); + ListMatcher columns = matchesList().item(matchesMap().entry("name", "sum").entry("type", "long")); + assertMap( + response, + matchesMap().entry("columns", columns).entry("values", matchesList().item(matchesList().item(1024 * fields * docs))) + ); + } + + /** + * Aggregates documents containing many fields which are {@code 1kb} each. + */ + private Map aggManyBigFields(int fields) throws IOException { + StringBuilder query = startQuery(); + query.append("FROM manybigfields | STATS sum = SUM("); + query.append("LENGTH(f").append(String.format(Locale.ROOT, "%03d", 0)).append(")"); + for (int f = 1; f < fields; f++) { + query.append(" + LENGTH(f").append(String.format(Locale.ROOT, "%03d", f)).append(")"); + } + query.append(")\"}"); + return responseAsMap(query(query.toString(), "columns,values")); + } + + /** + * Aggregates on the {@code LENGTH} of a giant text field. Without + * splitting pages on load (#131053) this throws a {@link CircuitBreakingException} + * when it tries to load a giant field. With that change it finishes + * after loading many single-row pages. + */ + public void testAggGiantTextField() throws IOException { + int docs = 100; + initGiantTextField(docs); + Map response = aggGiantTextField(); + ListMatcher columns = matchesList().item(matchesMap().entry("name", "sum").entry("type", "long")); + assertMap( + response, + matchesMap().entry("columns", columns).entry("values", matchesList().item(matchesList().item(1024 * 1024 * 5 * docs))) + ); + } + + /** + * Aggregates documents containing a text field that is {@code 1mb} each. + */ + private Map aggGiantTextField() throws IOException { + StringBuilder query = startQuery(); + query.append("FROM bigtext | STATS sum = SUM(LENGTH(f))\"}"); + return responseAsMap(query(query.toString(), "columns,values")); + } + public void testAggMvLongs() throws IOException { int fieldValues = 100; initMvLongsIndex(1, 3, fieldValues); @@ -788,7 +840,7 @@ private void initSingleDocIndex() throws IOException { """); } - private void initManyBigFieldsIndex(int docs) throws IOException { + private void initManyBigFieldsIndex(int docs, String type) throws IOException { logger.info("loading many documents with many big fields"); int docsPerBulk = 5; int fields = 1000; @@ -799,7 +851,7 @@ private void initManyBigFieldsIndex(int docs) throws IOException { config.startObject("settings").field("index.mapping.total_fields.limit", 10000).endObject(); config.startObject("mappings").startObject("properties"); for (int f = 0; f < fields; f++) { - config.startObject("f" + String.format(Locale.ROOT, "%03d", f)).field("type", "keyword").endObject(); + config.startObject("f" + String.format(Locale.ROOT, "%03d", f)).field("type", type).endObject(); } config.endObject().endObject(); request.setJsonEntity(Strings.toString(config.endObject())); @@ -831,6 +883,37 @@ private void initManyBigFieldsIndex(int docs) throws IOException { initIndex("manybigfields", bulk.toString()); } + private void initGiantTextField(int docs) throws IOException { + logger.info("loading many documents with one big text field"); + int docsPerBulk = 3; + int fieldSize = Math.toIntExact(ByteSizeValue.ofMb(5).getBytes()); + + Request request = new Request("PUT", "/bigtext"); + XContentBuilder config = JsonXContent.contentBuilder().startObject(); + config.startObject("mappings").startObject("properties"); + config.startObject("f").field("type", "text").endObject(); + config.endObject().endObject(); + request.setJsonEntity(Strings.toString(config.endObject())); + Response response = client().performRequest(request); + assertThat( + EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), + equalTo("{\"acknowledged\":true,\"shards_acknowledged\":true,\"index\":\"bigtext\"}") + ); + + StringBuilder bulk = new StringBuilder(); + for (int d = 0; d < docs; d++) { + bulk.append("{\"create\":{}}\n"); + bulk.append("{\"f\":\""); + bulk.append(Integer.toString(d % 10).repeat(fieldSize)); + bulk.append("\"}\n"); + if (d % docsPerBulk == docsPerBulk - 1 && d != docs - 1) { + bulk("bigtext", bulk.toString()); + bulk.setLength(0); + } + } + initIndex("bigtext", bulk.toString()); + } + private void initMvLongsIndex(int docs, int fields, int fieldValues) throws IOException { logger.info("loading documents with many multivalued longs"); int docsPerBulk = 100; diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java index 1c785d58f9804..f099aaac463db 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/AbstractScriptFieldTypeTestCase.java @@ -420,13 +420,12 @@ public final void testCacheable() throws IOException { } } - protected final List blockLoaderReadValuesFromColumnAtATimeReader(DirectoryReader reader, MappedFieldType fieldType) + protected final List blockLoaderReadValuesFromColumnAtATimeReader(DirectoryReader reader, MappedFieldType fieldType, int offset) throws IOException { BlockLoader loader = fieldType.blockLoader(blContext()); List all = new ArrayList<>(); for (LeafReaderContext ctx : reader.leaves()) { - TestBlock block = (TestBlock) loader.columnAtATimeReader(ctx) - .read(TestBlock.factory(ctx.reader().numDocs()), TestBlock.docs(ctx)); + TestBlock block = (TestBlock) loader.columnAtATimeReader(ctx).read(TestBlock.factory(), TestBlock.docs(ctx), offset); for (int i = 0; i < block.size(); i++) { all.add(block.get(i)); } @@ -440,7 +439,7 @@ protected final List blockLoaderReadValuesFromRowStrideReader(DirectoryR List all = new ArrayList<>(); for (LeafReaderContext ctx : reader.leaves()) { BlockLoader.RowStrideReader blockReader = loader.rowStrideReader(ctx); - BlockLoader.Builder builder = loader.builder(TestBlock.factory(ctx.reader().numDocs()), ctx.reader().numDocs()); + BlockLoader.Builder builder = loader.builder(TestBlock.factory(), ctx.reader().numDocs()); for (int i = 0; i < ctx.reader().numDocs(); i++) { blockReader.read(i, null, builder); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java index e35a53c0ecca8..eeb1a349d8bbc 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/BlockLoaderTestRunner.java @@ -36,6 +36,8 @@ import static org.apache.lucene.tests.util.LuceneTestCase.newDirectory; import static org.apache.lucene.tests.util.LuceneTestCase.random; import static org.elasticsearch.index.mapper.BlockLoaderTestRunner.PrettyEqual.prettyEqualTo; +import static org.elasticsearch.test.ESTestCase.between; +import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -69,7 +71,11 @@ private Object setupAndInvokeBlockLoader(MapperService mapperService, XContentBu ); LuceneDocument doc = mapperService.documentMapper().parse(source).rootDoc(); - iw.addDocument(doc); + /* + * Add three documents with doc id 0, 1, 2. The real document is 1. + * The other two are empty documents. + */ + iw.addDocuments(List.of(List.of(), doc, List.of())); iw.close(); try (DirectoryReader reader = DirectoryReader.open(directory)) { @@ -83,9 +89,32 @@ private Object load(BlockLoader blockLoader, LeafReaderContext context, MapperSe // `columnAtATimeReader` is tried first, we mimic `ValuesSourceReaderOperator` var columnAtATimeReader = blockLoader.columnAtATimeReader(context); if (columnAtATimeReader != null) { - BlockLoader.Docs docs = TestBlock.docs(0); - var block = (TestBlock) columnAtATimeReader.read(TestBlock.factory(context.reader().numDocs()), docs); - assertThat(block.size(), equalTo(1)); + int[] docArray; + int offset; + if (randomBoolean()) { + // Half the time we load a single document. Nice and simple. + docArray = new int[] { 1 }; + offset = 0; + } else { + /* + * The other half the time we emulate loading a larger page, + * starting part way through the page. + */ + docArray = new int[between(2, 10)]; + offset = between(0, docArray.length - 1); + for (int i = 0; i < docArray.length; i++) { + if (i < offset) { + docArray[i] = 0; + } else if (i == offset) { + docArray[i] = 1; + } else { + docArray[i] = 2; + } + } + } + BlockLoader.Docs docs = TestBlock.docs(docArray); + var block = (TestBlock) columnAtATimeReader.read(TestBlock.factory(), docs, offset); + assertThat(block.size(), equalTo(docArray.length - offset)); return block.get(0); } @@ -102,10 +131,10 @@ private Object load(BlockLoader blockLoader, LeafReaderContext context, MapperSe StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(context, null), leafSourceLoader ); - storedFieldsLoader.advanceTo(0); + storedFieldsLoader.advanceTo(1); - BlockLoader.Builder builder = blockLoader.builder(TestBlock.factory(context.reader().numDocs()), 1); - blockLoader.rowStrideReader(context).read(0, storedFieldsLoader, builder); + BlockLoader.Builder builder = blockLoader.builder(TestBlock.factory(), 1); + blockLoader.rowStrideReader(context).read(1, storedFieldsLoader, builder); var block = (TestBlock) builder.build(); assertThat(block.size(), equalTo(1)); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java index 9b5a66765adbe..cb73dc96f69b2 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestBlock.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.util.BytesRef; +import org.hamcrest.Matcher; import java.io.IOException; import java.io.UncheckedIOException; @@ -20,11 +21,14 @@ import java.util.HashMap; import java.util.List; +import static org.elasticsearch.test.ESTestCase.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; public class TestBlock implements BlockLoader.Block { - public static BlockLoader.BlockFactory factory(int pageSize) { + public static BlockLoader.BlockFactory factory() { return new BlockLoader.BlockFactory() { @Override public BlockLoader.BooleanBuilder booleansFromDocValues(int expectedCount) { @@ -34,6 +38,10 @@ public BlockLoader.BooleanBuilder booleansFromDocValues(int expectedCount) { @Override public BlockLoader.BooleanBuilder booleans(int expectedCount) { class BooleansBuilder extends TestBlock.Builder implements BlockLoader.BooleanBuilder { + private BooleansBuilder() { + super(expectedCount); + } + @Override public BooleansBuilder appendBoolean(boolean value) { add(value); @@ -45,12 +53,41 @@ public BooleansBuilder appendBoolean(boolean value) { @Override public BlockLoader.BytesRefBuilder bytesRefsFromDocValues(int expectedCount) { - return bytesRefs(expectedCount); + class BytesRefsFromDocValuesBuilder extends TestBlock.Builder implements BlockLoader.BytesRefBuilder { + private BytesRefsFromDocValuesBuilder() { + super(1); + } + + @Override + public BytesRefsFromDocValuesBuilder appendBytesRef(BytesRef value) { + add(BytesRef.deepCopyOf(value)); + return this; + } + + @Override + public TestBlock build() { + TestBlock result = super.build(); + List r; + if (result.values.get(0) instanceof List l) { + r = l; + } else { + r = List.of(result.values.get(0)); + } + assertThat(r, hasSize(expectedCount)); + return result; + } + + } + return new BytesRefsFromDocValuesBuilder(); } @Override public BlockLoader.BytesRefBuilder bytesRefs(int expectedCount) { class BytesRefsBuilder extends TestBlock.Builder implements BlockLoader.BytesRefBuilder { + private BytesRefsBuilder() { + super(expectedCount); + } + @Override public BytesRefsBuilder appendBytesRef(BytesRef value) { add(BytesRef.deepCopyOf(value)); @@ -68,6 +105,10 @@ public BlockLoader.DoubleBuilder doublesFromDocValues(int expectedCount) { @Override public BlockLoader.DoubleBuilder doubles(int expectedCount) { class DoublesBuilder extends TestBlock.Builder implements BlockLoader.DoubleBuilder { + private DoublesBuilder() { + super(expectedCount); + } + @Override public DoublesBuilder appendDouble(double value) { add(value); @@ -82,6 +123,10 @@ public BlockLoader.FloatBuilder denseVectors(int expectedCount, int dimensions) class FloatsBuilder extends TestBlock.Builder implements BlockLoader.FloatBuilder { int numElements = 0; + private FloatsBuilder() { + super(expectedCount); + } + @Override public BlockLoader.FloatBuilder appendFloat(float value) { add(value); @@ -118,6 +163,10 @@ public BlockLoader.IntBuilder intsFromDocValues(int expectedCount) { @Override public BlockLoader.IntBuilder ints(int expectedCount) { class IntsBuilder extends TestBlock.Builder implements BlockLoader.IntBuilder { + private IntsBuilder() { + super(expectedCount); + } + @Override public IntsBuilder appendInt(int value) { add(value); @@ -135,6 +184,10 @@ public BlockLoader.LongBuilder longsFromDocValues(int expectedCount) { @Override public BlockLoader.LongBuilder longs(int expectedCount) { class LongsBuilder extends TestBlock.Builder implements BlockLoader.LongBuilder { + private LongsBuilder() { + super(expectedCount); + } + @Override public LongsBuilder appendLong(long value) { add(value); @@ -150,26 +203,30 @@ public BlockLoader.Builder nulls(int expectedCount) { } @Override - public BlockLoader.Block constantNulls() { - BlockLoader.LongBuilder builder = longs(pageSize); - for (int i = 0; i < pageSize; i++) { + public BlockLoader.Block constantNulls(int count) { + BlockLoader.LongBuilder builder = longs(count); + for (int i = 0; i < count; i++) { builder.appendNull(); } return builder.build(); } @Override - public BlockLoader.Block constantBytes(BytesRef value) { - BlockLoader.BytesRefBuilder builder = bytesRefs(pageSize); - for (int i = 0; i < pageSize; i++) { + public BlockLoader.Block constantBytes(BytesRef value, int count) { + BlockLoader.BytesRefBuilder builder = bytesRefs(count); + for (int i = 0; i < count; i++) { builder.appendBytesRef(value); } return builder.build(); } @Override - public BlockLoader.SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocValues ordinals, int count) { + public BlockLoader.SingletonOrdinalsBuilder singletonOrdinalsBuilder(SortedDocValues ordinals, int expectedCount) { class SingletonOrdsBuilder extends TestBlock.Builder implements BlockLoader.SingletonOrdinalsBuilder { + private SingletonOrdsBuilder() { + super(expectedCount); + } + @Override public SingletonOrdsBuilder appendOrd(int value) { try { @@ -184,12 +241,16 @@ public SingletonOrdsBuilder appendOrd(int value) { } @Override - public BlockLoader.SortedSetOrdinalsBuilder sortedSetOrdinalsBuilder(SortedSetDocValues ordinals, int count) { + public BlockLoader.SortedSetOrdinalsBuilder sortedSetOrdinalsBuilder(SortedSetDocValues ordinals, int expectedSize) { class SortedSetOrdinalBuilder extends TestBlock.Builder implements BlockLoader.SortedSetOrdinalsBuilder { + private SortedSetOrdinalBuilder() { + super(expectedSize); + } + @Override public SortedSetOrdinalBuilder appendOrd(int value) { try { - add(ordinals.lookupOrd(value)); + add(BytesRef.deepCopyOf(ordinals.lookupOrd(value))); return this; } catch (IOException e) { throw new UncheckedIOException(e); @@ -199,9 +260,8 @@ public SortedSetOrdinalBuilder appendOrd(int value) { return new SortedSetOrdinalBuilder(); } - @Override - public BlockLoader.AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int count) { - return new AggregateMetricDoubleBlockBuilder(); + public BlockLoader.AggregateMetricDoubleBuilder aggregateMetricDoubleBuilder(int expectedSize) { + return new AggregateMetricDoubleBlockBuilder(expectedSize); } }; } @@ -256,8 +316,14 @@ public void close() { private abstract static class Builder implements BlockLoader.Builder { private final List values = new ArrayList<>(); + private Matcher expectedSize; + private List currentPosition = null; + private Builder(int expectedSize) { + this.expectedSize = equalTo(expectedSize); + } + @Override public Builder appendNull() { assertNull(currentPosition); @@ -286,6 +352,7 @@ protected void add(Object value) { @Override public TestBlock build() { + assertThat(values, hasSize(expectedSize)); return new TestBlock(values); } @@ -300,12 +367,23 @@ public void close() { * The implementation here is fairly close to the production one. */ private static class AggregateMetricDoubleBlockBuilder implements BlockLoader.AggregateMetricDoubleBuilder { - private final DoubleBuilder min = new DoubleBuilder(); - private final DoubleBuilder max = new DoubleBuilder(); - private final DoubleBuilder sum = new DoubleBuilder(); - private final IntBuilder count = new IntBuilder(); + private final DoubleBuilder min; + private final DoubleBuilder max; + private final DoubleBuilder sum; + private final IntBuilder count; + + private AggregateMetricDoubleBlockBuilder(int expectedSize) { + min = new DoubleBuilder(expectedSize); + max = new DoubleBuilder(expectedSize); + sum = new DoubleBuilder(expectedSize); + count = new IntBuilder(expectedSize); + } private static class DoubleBuilder extends TestBlock.Builder implements BlockLoader.DoubleBuilder { + private DoubleBuilder(int expectedSize) { + super(expectedSize); + } + @Override public BlockLoader.DoubleBuilder appendDouble(double value) { add(value); @@ -314,6 +392,10 @@ public BlockLoader.DoubleBuilder appendDouble(double value) { } private static class IntBuilder extends TestBlock.Builder implements BlockLoader.IntBuilder { + private IntBuilder(int expectedSize) { + super(expectedSize); + } + @Override public BlockLoader.IntBuilder appendInt(int value) { add(value); diff --git a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java index 03d6ac6342b42..0c1e381f69c4e 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; @@ -242,6 +243,7 @@ public MockIndexService indexService(Index index) { @Override public void createShard( + final ProjectId projectId, final ShardRouting shardRouting, final PeerRecoveryTargetService recoveryTargetService, final PeerRecoveryTargetService.RecoveryListener recoveryListener, diff --git a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java index 6c4066a447b67..c76a88b0da2f9 100644 --- a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java +++ b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java @@ -961,6 +961,7 @@ private ExtendedClusterInfo(Map extraShardSizes, ClusterInfo info) Map.of(), Map.of(), Map.of(), + Map.of(), Map.of() ); this.delegate = info; diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java index 029ea6dcd6871..717ec4761c87e 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java @@ -32,6 +32,7 @@ import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.MappingMetadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; @@ -43,6 +44,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; @@ -118,7 +120,7 @@ public ShardFollowTasksExecutor(Client client, ThreadPool threadPool, ClusterSer } @Override - public void validate(ShardFollowTask params, ClusterState clusterState) { + public void validate(ShardFollowTask params, ClusterState clusterState, @Nullable ProjectId projectId) { final IndexRoutingTable routingTable = clusterState.getRoutingTable().index(params.getFollowShardId().getIndex()); final ShardRouting primaryShard = routingTable.shard(params.getFollowShardId().id()).primaryShard(); if (primaryShard.active() == false) { @@ -129,10 +131,11 @@ public void validate(ShardFollowTask params, ClusterState clusterState) { private static final Assignment NO_ASSIGNMENT = new Assignment(null, "no nodes found with data and remote cluster client roles"); @Override - public Assignment getAssignment( + protected Assignment doGetAssignment( final ShardFollowTask params, - Collection candidateNodes, - final ClusterState clusterState + final Collection candidateNodes, + final ClusterState clusterState, + @Nullable final ProjectId projectId ) { final DiscoveryNode node = selectLeastLoadedNode( clusterState, diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java index 630aab4c78f43..7cb549df52301 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutorAssignmentTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -93,7 +94,8 @@ private void runAssignmentTest( final Assignment assignment = executor.getAssignment( mock(ShardFollowTask.class), clusterStateBuilder.nodes().getAllNodes(), - clusterStateBuilder.build() + clusterStateBuilder.build(), + ProjectId.DEFAULT ); consumer.accept(theSpecial, assignment); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java index 38b78ad357bf5..f5e3c239dadcd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java @@ -109,7 +109,7 @@ static RoleDescriptor kibanaSystem(String name) { new RoleDescriptor.IndicesPrivileges[] { // System indices defined in KibanaPlugin RoleDescriptor.IndicesPrivileges.builder() - .indices(".kibana*", ".reporting-*") + .indices(".kibana*", ".reporting-*", ".chat-*") .privileges("all") .allowRestrictedIndices(true) .build(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 56dc2a6d0212a..fd5632606867e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -119,7 +119,7 @@ public void testParseAllFields() throws IOException { assertThat(request, is(expected)); assertThat( - Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), + Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), is(XContentHelper.stripWhitespace(requestJson)) ); } diff --git a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java index 0b815fce21b04..6eb3efcdeb735 100644 --- a/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java +++ b/x-pack/plugin/downsample/src/internalClusterTest/java/org/elasticsearch/xpack/downsample/DownsampleIT.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.downsample; +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest; import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; import org.elasticsearch.action.admin.indices.delete.TransportDeleteIndexAction; import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; @@ -17,6 +18,7 @@ import org.elasticsearch.cluster.metadata.DataStreamAction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.xcontent.XContentBuilder; @@ -34,6 +36,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.downsample.DownsampleDataStreamTests.TIMEOUT; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -215,6 +218,12 @@ public void testAggMetricInEsqlTSAfterDownsampling() throws Exception { }; bulkIndex(dataStreamName, nextSourceSupplier, 100); + // check that TS command is available + var response = clusterAdmin().nodesCapabilities( + new NodesCapabilitiesRequest().method(RestRequest.Method.POST).path("/_query").capabilities(METRICS_COMMAND.capabilityName()) + ).actionGet(); + assumeTrue("TS command must be available for this test", response.isSupported().orElse(Boolean.FALSE)); + // Since the downsampled field (cpu) is downsampled in one index and not in the other, we want to confirm // first that the field is unsupported and has 2 original types - double and aggregate_metric_double try (var resp = esqlCommand("TS " + dataStreamName + " | KEEP @timestamp, host, cluster, cpu")) { diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java index 5f91fb18fd58e..76615876c5255 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutor.java @@ -22,6 +22,7 @@ import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; @@ -29,6 +30,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper; import org.elasticsearch.index.shard.ShardId; @@ -116,7 +118,7 @@ protected AllocatedPersistentTask createTask( } @Override - public void validate(DownsampleShardTaskParams params, ClusterState clusterState) { + public void validate(DownsampleShardTaskParams params, ClusterState clusterState, @Nullable ProjectId projectId) { // This is just a pre-check, but doesn't prevent from avoiding from aborting the task when source index disappeared // after initial creation of the persistent task. var indexShardRouting = findShardRoutingTable(params.shardId(), clusterState); @@ -126,10 +128,11 @@ public void validate(DownsampleShardTaskParams params, ClusterState clusterState } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( final DownsampleShardTaskParams params, final Collection candidateNodes, - final ClusterState clusterState + final ClusterState clusterState, + @Nullable final ProjectId projectId ) { // NOTE: downsampling works by running a task per each shard of the source index. // Here we make sure we assign the task to the actual node holding the shard identified by diff --git a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java index 39e92f06ada16..5a4e14dc24015 100644 --- a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java +++ b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleShardPersistentTaskExecutorTests.java @@ -96,7 +96,7 @@ public void testGetAssignment() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(node), clusterState); + var result = executor.getAssignment(params, Set.of(node), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(node.getId())); } @@ -128,7 +128,7 @@ public void testGetAssignmentMissingIndex() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(node), clusterState); + var result = executor.getAssignment(params, Set.of(node), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(node.getId())); assertThat(result.getExplanation(), equalTo("a node to fail and stop this persistent task")); } @@ -165,7 +165,7 @@ public void testGetStatelessAssignment() { Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY ); - var result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState); + var result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState, projectId); assertThat(result.getExecutorNode(), nullValue()); // Assign a copy of the shard to a search node @@ -185,7 +185,7 @@ public void testGetStatelessAssignment() { .build() ) .build(); - result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState); + result = executor.getAssignment(params, Set.of(indexNode, searchNode), clusterState, projectId); assertThat(result.getExecutorNode(), equalTo(searchNode.getId())); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java index 1af98f4b21dc5..6b700f0ee6a7f 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java @@ -134,4 +134,19 @@ public String nodeString() { } protected abstract String label(); + + /** + * Compares the size and datatypes of two lists of attributes for equality. + */ + public static boolean dataTypeEquals(List left, List right) { + if (left.size() != right.size()) { + return false; + } + for (int i = 0; i < left.size(); i++) { + if (left.get(i).dataType() != right.get(i).dataType()) { + return false; + } + } + return true; + } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 042de44d7ebe3..8dc6f594ca47a 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -204,9 +204,9 @@ private TypeSpec type() { for (ClassName groupIdClass : GROUP_IDS_CLASSES) { builder.addMethod(addRawInputLoop(groupIdClass, blockType(aggParam.type()))); builder.addMethod(addRawInputLoop(groupIdClass, vectorType(aggParam.type()))); + builder.addMethod(addIntermediateInput(groupIdClass)); } builder.addMethod(selectedMayContainUnseenGroups()); - builder.addMethod(addIntermediateInput()); builder.addMethod(evaluateIntermediate()); builder.addMethod(evaluateFinal()); builder.addMethod(toStringMethod()); @@ -583,11 +583,12 @@ private MethodSpec selectedMayContainUnseenGroups() { return builder.build(); } - private MethodSpec addIntermediateInput() { + private MethodSpec addIntermediateInput(TypeName groupsType) { + boolean groupsIsBlock = groupsType.toString().endsWith("Block"); MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); builder.addParameter(TypeName.INT, "positionOffset"); - builder.addParameter(INT_VECTOR, "groups"); + builder.addParameter(groupsType, "groups"); builder.addParameter(PAGE, "page"); builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS); @@ -613,7 +614,18 @@ private MethodSpec addIntermediateInput() { } builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); { - builder.addStatement("int groupId = groups.getInt(groupPosition)"); + if (groupsIsBlock) { + builder.beginControlFlow("if (groups.isNull(groupPosition))"); + builder.addStatement("continue"); + builder.endControlFlow(); + builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)"); + builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)"); + builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)"); + builder.addStatement("int groupId = groups.getInt(g)"); + } else { + builder.addStatement("int groupId = groups.getInt(groupPosition)"); + } + if (aggState.declaredType().isPrimitive()) { if (warnExceptions.isEmpty()) { assert intermediateState.size() == 2; @@ -664,6 +676,9 @@ private MethodSpec addIntermediateInput() { declarationType ); } + if (groupsIsBlock) { + builder.endControlFlow(); + } builder.endControlFlow(); } return builder.build(); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java index 5113d70c6e55e..9d04612e02511 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java @@ -139,6 +139,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block fbitUncast = page.getBlock(channels.get(0)); + if (fbitUncast.areAllValuesNull()) { + return; + } + BooleanVector fbit = ((BooleanBlock) fbitUncast).asVector(); + Block tbitUncast = page.getBlock(channels.get(1)); + if (tbitUncast.areAllValuesNull()) { + return; + } + BooleanVector tbit = ((BooleanBlock) tbitUncast).asVector(); + assert fbit.getPositionCount() == tbit.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBooleanAggregator.combineIntermediate(state, groupId, fbit.getBoolean(groupPosition + positionOffset), tbit.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -171,6 +199,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block fbitUncast = page.getBlock(channels.get(0)); + if (fbitUncast.areAllValuesNull()) { + return; + } + BooleanVector fbit = ((BooleanBlock) fbitUncast).asVector(); + Block tbitUncast = page.getBlock(channels.get(1)); + if (tbitUncast.areAllValuesNull()) { + return; + } + BooleanVector tbit = ((BooleanBlock) tbitUncast).asVector(); + assert fbit.getPositionCount() == tbit.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBooleanAggregator.combineIntermediate(state, groupId, fbit.getBoolean(groupPosition + positionOffset), tbit.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -192,11 +248,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +269,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java index f8d893db7e064..e73d20887e29e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBytesRefAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +201,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctBytesRefAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +247,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +263,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java index 3c3952e4a41df..9011e9ea7de07 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctDoubleAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctDoubleAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java index 1000f916d1e5f..6296aac243bcc 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctFloatGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctFloatAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctFloatAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java index 0b89c3867d0cf..8ff5b6636bc57 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctIntAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +198,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctIntAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +258,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java index acdae1a648d78..e6c746887f6f9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctLongAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block hllUncast = page.getBlock(channels.get(0)); + if (hllUncast.areAllValuesNull()) { + return; + } + BytesRefVector hll = ((BytesRefBlock) hllUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CountDistinctLongAggregator.combineIntermediate(state, groupId, hll.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java index b3fc1a93e2d3e..08e11f0ddb3d6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java @@ -149,6 +149,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -184,6 +212,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -234,6 +285,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java index 61b8ba51a6a29..f17f5facc8c85 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java @@ -149,6 +149,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -184,6 +212,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -234,6 +285,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java index f8959ef8c4014..a973f01dcda3a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java @@ -148,6 +148,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -183,6 +211,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -207,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,6 +284,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java index d54cab4d0d853..0d88b3190f1f3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java @@ -147,6 +147,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -182,6 +210,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + FirstOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -206,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -232,6 +283,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java index 4530c5313fa2c..ad935063a95b4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeDoubleGroupingAggregatorFunction.java @@ -149,6 +149,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -184,6 +212,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -234,6 +285,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java index 0e54c799320a9..249b27bf7ee70 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeFloatGroupingAggregatorFunction.java @@ -149,6 +149,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -184,6 +212,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeFloatAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +264,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -234,6 +285,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java index ae5ac000d00d1..fe25154290aac 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeIntGroupingAggregatorFunction.java @@ -148,6 +148,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -183,6 +211,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeIntAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -207,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -233,6 +284,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java index 77cd4a6b3146c..3772f8bf186c1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LastOverTimeLongGroupingAggregatorFunction.java @@ -147,6 +147,34 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -182,6 +210,34 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + assert timestamps.getPositionCount() == values.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + LastOverTimeLongAggregator.combineIntermediate(state, groupId, timestamps, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -206,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -232,6 +283,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java index 266c918e5c080..ab6177f82e6e4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java @@ -139,6 +139,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), max.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -171,6 +201,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), max.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -192,11 +252,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -220,6 +275,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java index 67623ef20da16..588144c23162f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBytesRefGroupingAggregatorFunction.java @@ -144,6 +144,35 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxBytesRefAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +207,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxBytesRefAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +281,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java index b4e080c882b7d..cf06006a24150 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + DoubleVector max = ((DoubleBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), max.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + DoubleVector max = ((DoubleBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), max.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java index f7b957312228b..5d1ac766b590d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxFloatGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + FloatVector max = ((FloatBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), max.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + FloatVector max = ((FloatBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxFloatAggregator.combine(state.getOrDefault(groupId), max.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java index b5364e060f87e..ee501aed26bc2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java @@ -140,6 +140,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + IntVector max = ((IntBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), max.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -172,6 +202,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + IntVector max = ((IntBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), max.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -193,11 +253,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -221,6 +276,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java index 4d01eda88bf14..cfc13a77b2984 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIpGroupingAggregatorFunction.java @@ -144,6 +144,35 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +207,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MaxIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +281,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java index 928caefc40846..36e2101baaae7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + LongVector max = ((LongBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), max.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + LongVector max = ((LongBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), max.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java index 02ec5aaacef70..bdc7ebfeb03f2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java @@ -141,6 +141,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +196,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -215,6 +256,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java index df542b6171bca..b789cae8704a3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationFloatGroupingAggregatorFunction.java @@ -141,6 +141,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +196,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -215,6 +256,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java index a389be2aea830..6cc6271982921 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java @@ -140,6 +140,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -172,6 +195,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -193,11 +239,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -214,6 +255,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java index 00b4316cd6add..cccdec47b3030 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java @@ -141,6 +141,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +196,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MedianAbsoluteDeviationLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +240,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -215,6 +256,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java index fc108b7417f56..52231c0e8975e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java @@ -139,6 +139,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), min.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -171,6 +201,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), min.getBoolean(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -192,11 +252,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -220,6 +275,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java index f07fef1298d18..e7baef1459eb8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBytesRefGroupingAggregatorFunction.java @@ -144,6 +144,35 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BytesRefVector min = ((BytesRefBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinBytesRefAggregator.combineIntermediate(state, groupId, min.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +207,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BytesRefVector min = ((BytesRefBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinBytesRefAggregator.combineIntermediate(state, groupId, min.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +281,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java index 4f85003bbb5e0..ea1ecf6c1f271 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + DoubleVector min = ((DoubleBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), min.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + DoubleVector min = ((DoubleBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), min.getDouble(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java index 14f83dd6a9af5..bf489b7bf6dc9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinFloatGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + FloatVector min = ((FloatBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), min.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + FloatVector min = ((FloatBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinFloatAggregator.combine(state.getOrDefault(groupId), min.getFloat(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java index af555d4eedb63..51102c5dff22a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java @@ -140,6 +140,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + IntVector min = ((IntBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), min.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -172,6 +202,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + IntVector min = ((IntBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), min.getInt(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -193,11 +253,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -221,6 +276,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java index 51bd9e85fdcad..542f744c04a8a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIpGroupingAggregatorFunction.java @@ -144,6 +144,35 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +207,35 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BytesRefVector max = ((BytesRefBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + MinIpAggregator.combineIntermediate(state, groupId, max.getBytesRef(groupPosition + positionOffset, scratch), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +259,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +281,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java index 590ee9c440e43..e5683a154285d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + LongVector min = ((LongBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), min.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + LongVector min = ((LongBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), min.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java index 7260e8fcd0546..4e88aa944f6b5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileDoubleAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java index dbff522fe6d12..04f057ff87cb8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileFloatGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileFloatAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java index d1d16b9938942..402c928970893 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +198,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileIntAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +258,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java index b9221f27db944..8509057d6202f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +199,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block quartUncast = page.getBlock(channels.get(0)); + if (quartUncast.areAllValuesNull()) { + return; + } + BytesRefVector quart = ((BytesRefBlock) quartUncast).asVector(); + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + PercentileLongAggregator.combineIntermediate(state, groupId, quart.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +243,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +259,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java index 0e256a1ed0012..79db8bb3401a1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java @@ -152,6 +152,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -187,6 +225,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateDoubleAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -211,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -247,6 +318,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java index 9c9da99547262..892ed9c2eb25f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateFloatGroupingAggregatorFunction.java @@ -154,6 +154,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateFloatAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -189,6 +227,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateFloatAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -213,11 +289,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -249,6 +320,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java index 60b33edb71b30..bc8445cc9f069 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java @@ -152,6 +152,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateIntAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -187,6 +225,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateIntAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -211,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values, } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -247,6 +318,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java index 69e8d04fe55c4..16d8c8b0b2fa8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java @@ -152,6 +152,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateLongAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -187,6 +225,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block timestampsUncast = page.getBlock(channels.get(0)); + if (timestampsUncast.areAllValuesNull()) { + return; + } + LongBlock timestamps = (LongBlock) timestampsUncast; + Block valuesUncast = page.getBlock(channels.get(1)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + Block sampleCountsUncast = page.getBlock(channels.get(2)); + if (sampleCountsUncast.areAllValuesNull()) { + return; + } + IntVector sampleCounts = ((IntBlock) sampleCountsUncast).asVector(); + Block resetsUncast = page.getBlock(channels.get(3)); + if (resetsUncast.areAllValuesNull()) { + return; + } + DoubleVector resets = ((DoubleBlock) resetsUncast).asVector(); + assert timestamps.getPositionCount() == values.getPositionCount() && timestamps.getPositionCount() == sampleCounts.getPositionCount() && timestamps.getPositionCount() == resets.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + RateLongAggregator.combineIntermediate(state, groupId, timestamps, values, sampleCounts.getInt(groupPosition + positionOffset), resets.getDouble(groupPosition + positionOffset), groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -211,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -247,6 +318,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java index b52c485449f99..5ddd5fc7e4a30 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBooleanAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +198,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBooleanAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +258,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java index 555204535fe4f..0ce4aa997db6e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java @@ -144,6 +144,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBytesRefAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -178,6 +201,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBytesRefAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -201,11 +247,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +263,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java index 610a292fe253e..05e1dc8ba1783 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleDoubleAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +198,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleDoubleAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +258,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java index 8d138067d4d6f..9935a7bfe654d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java @@ -142,6 +142,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleIntAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -174,6 +197,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleIntAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -195,11 +241,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -216,6 +257,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java index 73664865c1af8..7570225ce5f38 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleLongAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +198,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleLongAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +258,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java index fdbec04aaa8b6..4dd4649472948 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevDoubleGroupingAggregatorFunction.java @@ -142,6 +142,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -174,6 +207,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevDoubleAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -195,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -226,6 +287,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java index 2a3bdae362f9d..c78ea039edb63 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevFloatGroupingAggregatorFunction.java @@ -144,6 +144,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +209,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +289,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java index 05eae387ac49a..32839ee533cbf 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevIntGroupingAggregatorFunction.java @@ -143,6 +143,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -175,6 +208,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevIntAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -196,11 +262,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -227,6 +288,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java index 755eea896377b..e06207363bbc6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/StdDevLongGroupingAggregatorFunction.java @@ -142,6 +142,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -174,6 +207,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block meanUncast = page.getBlock(channels.get(0)); + if (meanUncast.areAllValuesNull()) { + return; + } + DoubleVector mean = ((DoubleBlock) meanUncast).asVector(); + Block m2Uncast = page.getBlock(channels.get(1)); + if (m2Uncast.areAllValuesNull()) { + return; + } + DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector(); + Block countUncast = page.getBlock(channels.get(2)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + StdDevLongAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -195,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -226,6 +287,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java index a3ffe05c41256..88d147c3fd451 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java @@ -142,6 +142,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -174,6 +207,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -195,11 +261,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -226,6 +287,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java index 61dc158374e27..d7f0f6185d318 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumFloatGroupingAggregatorFunction.java @@ -144,6 +144,39 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumFloatAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +209,39 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block deltaUncast = page.getBlock(channels.get(1)); + if (deltaUncast.areAllValuesNull()) { + return; + } + DoubleVector delta = ((DoubleBlock) deltaUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == delta.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SumFloatAggregator.combineIntermediate(state, groupId, value.getDouble(groupPosition + positionOffset), delta.getDouble(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +263,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +289,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java index c8b12478255d2..05b29459d1e02 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java @@ -142,6 +142,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -174,6 +204,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -195,11 +255,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -223,6 +278,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index 550f504e61c23..31779335e80c0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -141,6 +141,36 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -173,6 +203,36 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -194,11 +254,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -222,6 +277,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java index 77c61fce3cc37..f6238670a776a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java @@ -145,6 +145,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BooleanBlock top = (BooleanBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBooleanAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -177,6 +199,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BooleanBlock top = (BooleanBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBooleanAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -198,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +257,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java index 02fbb1bda8383..12f1456327264 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java @@ -148,6 +148,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBytesRefAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -182,6 +205,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBytesRefAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -205,11 +251,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -226,6 +267,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java index cf53719168731..11ba0cbea0d6b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleGroupingAggregatorFunction.java @@ -145,6 +145,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopDoubleAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -177,6 +199,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopDoubleAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -198,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +257,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java index 30a8c5242c20a..32dfcaaffcde9 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatGroupingAggregatorFunction.java @@ -145,6 +145,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopFloatAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -177,6 +199,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopFloatAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -198,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +257,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java index 1b7678fb3bd0e..1a0dea4b8d0eb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntGroupingAggregatorFunction.java @@ -144,6 +144,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIntAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -176,6 +198,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIntAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -197,11 +241,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -217,6 +256,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java index 56a8de3d94d04..ad0e75b625e3d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunction.java @@ -148,6 +148,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIpAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -182,6 +205,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopIpAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -205,11 +251,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -226,6 +267,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java index 74b9824739085..71e17e29be5fb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongGroupingAggregatorFunction.java @@ -145,6 +145,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopLongAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -177,6 +199,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopLongAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -198,11 +242,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -218,6 +257,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java index fe0c0632deba7..896d037cf68fb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBooleanGroupingAggregatorFunction.java @@ -138,6 +138,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BooleanBlock values = (BooleanBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBooleanAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -170,6 +192,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVec } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BooleanBlock values = (BooleanBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBooleanAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -191,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, BooleanVector val } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -211,6 +250,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index e240452b3a18d..da8e93f9cf61a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -143,6 +143,29 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BytesRefBlock values = (BytesRefBlock) valuesUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -177,6 +200,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + BytesRefBlock values = (BytesRefBlock) valuesUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -200,11 +246,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -221,6 +262,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java index 8408526adb449..3a35f48fee5f8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesDoubleGroupingAggregatorFunction.java @@ -138,6 +138,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesDoubleAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -170,6 +192,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVect } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + DoubleBlock values = (DoubleBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesDoubleAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -191,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, DoubleVector valu } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -211,6 +250,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java index 0be289943b38f..4917f61a23f8d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesFloatGroupingAggregatorFunction.java @@ -138,6 +138,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector v } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesFloatAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -170,6 +192,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + FloatBlock values = (FloatBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesFloatAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -191,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, FloatVector value } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -211,6 +250,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java index f7a8b45d47c8e..d0e094099af4e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesIntGroupingAggregatorFunction.java @@ -137,6 +137,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesIntAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -169,6 +191,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + IntBlock values = (IntBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesIntAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -190,11 +234,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -210,6 +249,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java index b8793982223dc..287013a1dc136 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesLongGroupingAggregatorFunction.java @@ -138,6 +138,28 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesLongAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -170,6 +192,28 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valuesUncast = page.getBlock(channels.get(0)); + if (valuesUncast.areAllValuesNull()) { + return; + } + LongBlock values = (LongBlock) valuesUncast; + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + ValuesLongAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -191,11 +235,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -211,6 +250,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java index d2f1d88125321..5116ea389510a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointDocValuesGroupingAggregatorFunction.java @@ -148,6 +148,49 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -180,6 +223,49 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -201,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -242,6 +323,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java index 7f1d838cda18c..e0508288abbe3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -153,6 +153,49 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -187,6 +230,49 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -210,11 +296,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -251,6 +332,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java index 5ba0a22f09c96..23936066d214b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointDocValuesGroupingAggregatorFunction.java @@ -148,6 +148,49 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -180,6 +223,49 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointDocValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -201,11 +287,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -242,6 +323,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java index f316218141411..2707f5c78cf62 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialCentroidGeoPointSourceValuesGroupingAggregatorFunction.java @@ -153,6 +153,49 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -187,6 +230,49 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block xValUncast = page.getBlock(channels.get(0)); + if (xValUncast.areAllValuesNull()) { + return; + } + DoubleVector xVal = ((DoubleBlock) xValUncast).asVector(); + Block xDelUncast = page.getBlock(channels.get(1)); + if (xDelUncast.areAllValuesNull()) { + return; + } + DoubleVector xDel = ((DoubleBlock) xDelUncast).asVector(); + Block yValUncast = page.getBlock(channels.get(2)); + if (yValUncast.areAllValuesNull()) { + return; + } + DoubleVector yVal = ((DoubleBlock) yValUncast).asVector(); + Block yDelUncast = page.getBlock(channels.get(3)); + if (yDelUncast.areAllValuesNull()) { + return; + } + DoubleVector yDel = ((DoubleBlock) yDelUncast).asVector(); + Block countUncast = page.getBlock(channels.get(4)); + if (countUncast.areAllValuesNull()) { + return; + } + LongVector count = ((LongBlock) countUncast).asVector(); + assert xVal.getPositionCount() == xDel.getPositionCount() && xVal.getPositionCount() == yVal.getPositionCount() && xVal.getPositionCount() == yDel.getPositionCount() && xVal.getPositionCount() == count.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialCentroidGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, xVal.getDouble(groupPosition + positionOffset), xDel.getDouble(groupPosition + positionOffset), yVal.getDouble(groupPosition + positionOffset), yDel.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -210,11 +296,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -251,6 +332,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java index 59b25a0e675ec..17c887a5e0035 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointDocValuesGroupingAggregatorFunction.java @@ -146,6 +146,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -178,6 +216,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -199,11 +275,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -235,6 +306,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java index c70eee3ae8803..1c4169263e9f0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianPointSourceValuesGroupingAggregatorFunction.java @@ -149,6 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -183,6 +221,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianPointSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -206,11 +282,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -242,6 +313,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java index 386ac52e66288..d9d834d96c2c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesGroupingAggregatorFunction.java @@ -136,6 +136,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -160,6 +198,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeDocValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -180,11 +256,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) // This type does not support vectors because all values are multi-valued } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -216,6 +287,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java index 61988aa3c06d9..c568de2dbd6be 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeSourceValuesGroupingAggregatorFunction.java @@ -149,6 +149,44 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -183,6 +221,44 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minXUncast = page.getBlock(channels.get(0)); + if (minXUncast.areAllValuesNull()) { + return; + } + IntVector minX = ((IntBlock) minXUncast).asVector(); + Block maxXUncast = page.getBlock(channels.get(1)); + if (maxXUncast.areAllValuesNull()) { + return; + } + IntVector maxX = ((IntBlock) maxXUncast).asVector(); + Block maxYUncast = page.getBlock(channels.get(2)); + if (maxYUncast.areAllValuesNull()) { + return; + } + IntVector maxY = ((IntBlock) maxYUncast).asVector(); + Block minYUncast = page.getBlock(channels.get(3)); + if (minYUncast.areAllValuesNull()) { + return; + } + IntVector minY = ((IntBlock) minYUncast).asVector(); + assert minX.getPositionCount() == maxX.getPositionCount() && minX.getPositionCount() == maxY.getPositionCount() && minX.getPositionCount() == minY.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentCartesianShapeSourceValuesAggregator.combineIntermediate(state, groupId, minX.getInt(groupPosition + positionOffset), maxX.getInt(groupPosition + positionOffset), maxY.getInt(groupPosition + positionOffset), minY.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -206,11 +282,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -242,6 +313,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java index 3bf7f0cdf043e..e80e6d4391dc3 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointDocValuesGroupingAggregatorFunction.java @@ -148,6 +148,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector va } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -180,6 +228,54 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -201,11 +297,6 @@ private void addRawInput(int positionOffset, IntVector groups, LongVector values } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -247,6 +338,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java index 16466431b6725..43a2662a229c4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoPointSourceValuesGroupingAggregatorFunction.java @@ -151,6 +151,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -185,6 +233,54 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoPointSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +304,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,6 +345,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java index 8e1da49050ff0..6ad4a92e83c7e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeDocValuesGroupingAggregatorFunction.java @@ -138,6 +138,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector val // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition) || values.isNull(groupPosition + positionOffset)) { @@ -162,6 +210,54 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector // This type does not support vectors because all values are multi-valued } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeDocValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (values.isNull(groupPosition + positionOffset)) { @@ -182,11 +278,6 @@ private void addRawInput(int positionOffset, IntVector groups, IntVector values) // This type does not support vectors because all values are multi-valued } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -228,6 +319,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java index 0529db1776270..7d8f8fefc722b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentGeoShapeSourceValuesGroupingAggregatorFunction.java @@ -151,6 +151,54 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVecto } } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -185,6 +233,54 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVe } } + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntVector top = ((IntBlock) topUncast).asVector(); + Block bottomUncast = page.getBlock(channels.get(1)); + if (bottomUncast.areAllValuesNull()) { + return; + } + IntVector bottom = ((IntBlock) bottomUncast).asVector(); + Block negLeftUncast = page.getBlock(channels.get(2)); + if (negLeftUncast.areAllValuesNull()) { + return; + } + IntVector negLeft = ((IntBlock) negLeftUncast).asVector(); + Block negRightUncast = page.getBlock(channels.get(3)); + if (negRightUncast.areAllValuesNull()) { + return; + } + IntVector negRight = ((IntBlock) negRightUncast).asVector(); + Block posLeftUncast = page.getBlock(channels.get(4)); + if (posLeftUncast.areAllValuesNull()) { + return; + } + IntVector posLeft = ((IntBlock) posLeftUncast).asVector(); + Block posRightUncast = page.getBlock(channels.get(5)); + if (posRightUncast.areAllValuesNull()) { + return; + } + IntVector posRight = ((IntBlock) posRightUncast).asVector(); + assert top.getPositionCount() == bottom.getPositionCount() && top.getPositionCount() == negLeft.getPositionCount() && top.getPositionCount() == negRight.getPositionCount() && top.getPositionCount() == posLeft.getPositionCount() && top.getPositionCount() == posRight.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SpatialExtentGeoShapeSourceValuesAggregator.combineIntermediate(state, groupId, top.getInt(groupPosition + positionOffset), bottom.getInt(groupPosition + positionOffset), negLeft.getInt(groupPosition + positionOffset), negRight.getInt(groupPosition + positionOffset), posLeft.getInt(groupPosition + positionOffset), posRight.getInt(groupPosition + positionOffset)); + } + } + } + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { BytesRef scratch = new BytesRef(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { @@ -208,11 +304,6 @@ private void addRawInput(int positionOffset, IntVector groups, BytesRefVector va } } - @Override - public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { - state.enableGroupIdTracking(seenGroupIds); - } - @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { state.enableGroupIdTracking(new SeenGroupIds.Empty()); @@ -254,6 +345,11 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page } } + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + @Override public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { state.toIntermediate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java index d1de1eaad99a5..f5b7a73a54a1d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java @@ -199,6 +199,48 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + LongVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getLong(groupPosition + positionOffset)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size(); + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + LongVector count = page.getBlock(channels.get(0)).asVector(); + BooleanVector seen = page.getBlock(channels.get(1)).asVector(); + assert count.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.increment(groupId, count.getLong(groupPosition + positionOffset)); + } + } + } + @Override public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { assert channels.size() == intermediateBlockCount(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java index 5011d3d75816b..121d8e213dcbd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FilteredGroupingAggregatorFunction.java @@ -100,6 +100,16 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { // nothing to do - we already put the underlying agg into this state } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + next.addIntermediateInput(positionOffset, groupIdVector, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + next.addIntermediateInput(positionOffset, groupIdVector, page); + } + @Override public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { next.addIntermediateInput(positionOffset, groupIdVector, page); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java index 4bc47839bcb4f..d87ca338c6589 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/FromPartialGroupingAggregatorFunction.java @@ -75,6 +75,18 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { delegate.selectedMayContainUnseenGroups(seenGroupIds); } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(inputChannel); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(inputChannel); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + @Override public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { final CompositeBlock inputBlock = page.getBlock(inputChannel); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java index e21f6205fb690..e84560a39cd4f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java @@ -11,7 +11,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntArrayBlock; import org.elasticsearch.compute.data.IntBigArrayBlock; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -42,19 +41,14 @@ public int evaluateBlockCount() { public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) { if (mode.isInputPartial()) { return new GroupingAggregatorFunction.AddInput() { - @Override - public void add(int positionOffset, IntBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); - } - @Override public void add(int positionOffset, IntArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); + aggregatorFunction.addIntermediateInput(positionOffset, groupIds, page); } @Override public void add(int positionOffset, IntBigArrayBlock groupIds) { - throw new IllegalStateException("Intermediate group id must not have nulls"); + aggregatorFunction.addIntermediateInput(positionOffset, groupIds, page); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java index ff8bdab4dd803..a60bcb1523ffc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java @@ -123,6 +123,16 @@ default void add(int positionOffset, IntBlock groupIds) { */ void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds); + /** + * Add data produced by {@link #evaluateIntermediate}. + */ + void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page); + + /** + * Add data produced by {@link #evaluateIntermediate}. + */ + void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page); + /** * Add data produced by {@link #evaluateIntermediate}. */ diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java index e8748ad33c5c2..5aa489f6e2fd9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ToPartialGroupingAggregatorFunction.java @@ -10,6 +10,8 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.CompositeBlock; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; @@ -64,6 +66,18 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { delegate.selectedMayContainUnseenGroups(seenGroupIds); } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(channels.get(0)); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIdVector, Page page) { + final CompositeBlock inputBlock = page.getBlock(channels.get(0)); + delegate.addIntermediateInput(positionOffset, groupIdVector, inputBlock.asPage()); + } + @Override public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) { final CompositeBlock inputBlock = page.getBlock(channels.get(0)); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java index 20ca4ed70e3f8..ccd0f82343401 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/DocVector.java @@ -138,7 +138,6 @@ private boolean checkIfSingleSegmentNonDecreasing() { prev = v; } return true; - } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java index 089846f9939ae..ba6da814542e4 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java @@ -183,16 +183,20 @@ void readDocsForNextPage() throws IOException { for (LeafIterator leaf : oneTsidQueue) { leaf.reinitializeIfNeeded(executingThread); } - do { - PriorityQueue sub = subQueueForNextTsid(); - if (sub.size() == 0) { - break; - } - tsHashesBuilder.appendNewTsid(sub.top().timeSeriesHash); - if (readValuesForOneTsid(sub)) { - break; - } - } while (mainQueue.size() > 0); + if (mainQueue.size() + oneTsidQueue.size() == 1) { + readValuesFromSingleRemainingLeaf(); + } else { + do { + PriorityQueue sub = subQueueForNextTsid(); + if (sub.size() == 0) { + break; + } + tsHashesBuilder.appendNewTsid(sub.top().timeSeriesHash); + if (readValuesForOneTsid(sub)) { + break; + } + } while (mainQueue.size() > 0); + } } private boolean readValuesForOneTsid(PriorityQueue sub) throws IOException { @@ -236,6 +240,38 @@ private PriorityQueue subQueueForNextTsid() { return oneTsidQueue; } + private void readValuesFromSingleRemainingLeaf() throws IOException { + if (oneTsidQueue.size() == 0) { + oneTsidQueue.add(getMainQueue().pop()); + tsidsLoaded++; + } + final LeafIterator sub = oneTsidQueue.top(); + int lastTsid = -1; + do { + currentPagePos++; + remainingDocs--; + docCollector.collect(sub.segmentOrd, sub.docID); + if (lastTsid != sub.lastTsidOrd) { + tsHashesBuilder.appendNewTsid(sub.timeSeriesHash); + lastTsid = sub.lastTsidOrd; + } + tsHashesBuilder.appendOrdinal(); + timestampsBuilder.appendLong(sub.timestamp); + if (sub.nextDoc() == false) { + if (sub.docID == DocIdSetIterator.NO_MORE_DOCS) { + oneTsidQueue.clear(); + return; + } else { + ++tsidsLoaded; + } + } + } while (remainingDocs > 0 && currentPagePos < maxPageSize); + } + + private PriorityQueue getMainQueue() { + return mainQueue; + } + boolean completed() { return mainQueue.size() == 0 && oneTsidQueue.size() == 0; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java index f7f5f541c747f..20e7ffc4ca2cb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ComputeBlockLoaderFactory.java @@ -14,18 +14,16 @@ import org.elasticsearch.core.Releasable; class ComputeBlockLoaderFactory extends DelegatingBlockLoaderFactory implements Releasable { - private final int pageSize; private Block nullBlock; - ComputeBlockLoaderFactory(BlockFactory factory, int pageSize) { + ComputeBlockLoaderFactory(BlockFactory factory) { super(factory); - this.pageSize = pageSize; } @Override - public Block constantNulls() { + public Block constantNulls(int count) { if (nullBlock == null) { - nullBlock = factory.newConstantNullBlock(pageSize); + nullBlock = factory.newConstantNullBlock(count); } nullBlock.incRef(); return nullBlock; @@ -39,7 +37,7 @@ public void close() { } @Override - public BytesRefBlock constantBytes(BytesRef value) { - return factory.newConstantBytesRefBlockWith(value, pageSize); + public BytesRefBlock constantBytes(BytesRef value, int count) { + return factory.newConstantBytesRefBlockWith(value, count); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java index 9ec5802b43f98..e197861e9b701 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/TimeSeriesExtractFieldOperator.java @@ -198,12 +198,12 @@ static class BlockLoaderFactory extends DelegatingBlockLoaderFactory { } @Override - public BlockLoader.Block constantNulls() { + public BlockLoader.Block constantNulls(int count) { throw new UnsupportedOperationException("must not be used by column readers"); } @Override - public BlockLoader.Block constantBytes(BytesRef value) { + public BlockLoader.Block constantBytes(BytesRef value, int count) { throw new UnsupportedOperationException("must not be used by column readers"); } @@ -254,7 +254,8 @@ static final class ShardLevelFieldsReader implements Releasable { this.storedFieldsSpec = storedFieldsSpec; this.dimensions = new boolean[fields.size()]; for (int i = 0; i < fields.size(); i++) { - dimensions[i] = shardContext.fieldType(fields.get(i).name()).isDimension(); + final var mappedFieldType = shardContext.fieldType(fields.get(i).name()); + dimensions[i] = mappedFieldType != null && mappedFieldType.isDimension(); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java index 7ff6e7211b7f2..6f00e97a1f9f2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromManyReader.java @@ -16,6 +16,8 @@ import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.BlockLoaderStoredFieldsFromLeafLoader; import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; @@ -24,6 +26,8 @@ * Loads values from a many leaves. Much less efficient than {@link ValuesFromSingleReader}. */ class ValuesFromManyReader extends ValuesReader { + private static final Logger log = LogManager.getLogger(ValuesFromManyReader.class); + private final int[] forwards; private final int[] backwards; private final BlockLoader.RowStrideReader[] rowStride; @@ -35,6 +39,7 @@ class ValuesFromManyReader extends ValuesReader { forwards = docs.shardSegmentDocMapForwards(); backwards = docs.shardSegmentDocMapBackwards(); rowStride = new BlockLoader.RowStrideReader[operator.fields.length]; + log.debug("initializing {} positions", docs.getPositionCount()); } @Override @@ -70,9 +75,7 @@ void run(int offset) throws IOException { builders[f] = new Block.Builder[operator.shardContexts.size()]; converters[f] = new BlockLoader[operator.shardContexts.size()]; } - try ( - ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory, docs.getPositionCount()) - ) { + try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory)) { int p = forwards[offset]; int shard = docs.shards().getInt(p); int segment = docs.segments().getInt(p); @@ -84,7 +87,9 @@ void run(int offset) throws IOException { read(firstDoc, shard); int i = offset + 1; - while (i < forwards.length) { + long estimated = estimatedRamBytesUsed(); + long dangerZoneBytes = Long.MAX_VALUE; // TODO danger_zone if ascending + while (i < forwards.length && estimated < dangerZoneBytes) { p = forwards[i]; shard = docs.shards().getInt(p); segment = docs.segments().getInt(p); @@ -96,8 +101,17 @@ void run(int offset) throws IOException { verifyBuilders(loaderBlockFactory, shard); read(docs.docs().getInt(p), shard); i++; + estimated = estimatedRamBytesUsed(); + log.trace("{}: bytes loaded {}/{}", p, estimated, dangerZoneBytes); } buildBlocks(); + if (log.isDebugEnabled()) { + long actual = 0; + for (Block b : target) { + actual += b.ramBytesUsed(); + } + log.debug("loaded {} positions total estimated/actual {}/{} bytes", p, estimated, actual); + } } } @@ -115,6 +129,9 @@ private void buildBlocks() { } operator.sanityCheckBlock(rowStride[f], backwards.length, target[f], f); } + if (target[0].getPositionCount() != docs.getPositionCount()) { + throw new IllegalStateException("partial pages not yet supported"); + } } private void verifyBuilders(ComputeBlockLoaderFactory loaderBlockFactory, int shard) { @@ -141,6 +158,18 @@ public void close() { Releasables.closeExpectNoException(builders[f]); } } + + private long estimatedRamBytesUsed() { + long estimated = 0; + for (Block.Builder[] builders : this.builders) { + for (Block.Builder builder : builders) { + if (builder != null) { + estimated += builder.estimatedBytes(); + } + } + } + return estimated; + } } private void fieldsMoved(LeafReaderContext ctx, int shard) throws IOException { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java index 1bee68160e024..d47a015c24578 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesFromSingleReader.java @@ -16,6 +16,8 @@ import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.BlockLoaderStoredFieldsFromLeafLoader; import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; @@ -26,6 +28,8 @@ * Loads values from a single leaf. Much more efficient than {@link ValuesFromManyReader}. */ class ValuesFromSingleReader extends ValuesReader { + private static final Logger log = LogManager.getLogger(ValuesFromSingleReader.class); + /** * Minimum number of documents for which it is more efficient to use a * sequential stored field reader when reading stored fields. @@ -45,39 +49,27 @@ class ValuesFromSingleReader extends ValuesReader { super(operator, docs); this.shard = docs.shards().getInt(0); this.segment = docs.segments().getInt(0); + log.debug("initialized {} positions", docs.getPositionCount()); } @Override protected void load(Block[] target, int offset) throws IOException { - assert offset == 0; // TODO allow non-0 offset to support splitting pages if (docs.singleSegmentNonDecreasing()) { - loadFromSingleLeaf(target, new BlockLoader.Docs() { - @Override - public int count() { - return docs.getPositionCount(); - } - - @Override - public int get(int i) { - return docs.docs().getInt(i); - } - }); + loadFromSingleLeaf(operator.jumboBytes, target, new ValuesReaderDocs(docs), offset); return; } + if (offset != 0) { + throw new IllegalStateException("can only load partial pages with single-segment non-decreasing pages"); + } int[] forwards = docs.shardSegmentDocMapForwards(); Block[] unshuffled = new Block[target.length]; try { - loadFromSingleLeaf(unshuffled, new BlockLoader.Docs() { - @Override - public int count() { - return docs.getPositionCount(); - } - - @Override - public int get(int i) { - return docs.docs().getInt(forwards[i]); - } - }); + loadFromSingleLeaf( + Long.MAX_VALUE, // Effectively disable splitting pages when we're not loading in order + unshuffled, + new ValuesReaderDocs(docs).mapped(forwards), + 0 + ); final int[] backwards = docs.shardSegmentDocMapBackwards(); for (int i = 0; i < unshuffled.length; i++) { target[i] = unshuffled[i].filter(backwards); @@ -89,24 +81,25 @@ public int get(int i) { } } - private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IOException { - int firstDoc = docs.get(0); + private void loadFromSingleLeaf(long jumboBytes, Block[] target, ValuesReaderDocs docs, int offset) throws IOException { + int firstDoc = docs.get(offset); operator.positionFieldWork(shard, segment, firstDoc); StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS; - List rowStrideReaders = new ArrayList<>(operator.fields.length); LeafReaderContext ctx = operator.ctx(shard, segment); - try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory, docs.count())) { + + List columnAtATimeReaders = new ArrayList<>(operator.fields.length); + List rowStrideReaders = new ArrayList<>(operator.fields.length); + try (ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(operator.blockFactory)) { for (int f = 0; f < operator.fields.length; f++) { ValuesSourceReaderOperator.FieldWork field = operator.fields[f]; BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime(ctx); if (columnAtATime != null) { - target[f] = (Block) columnAtATime.read(loaderBlockFactory, docs); - operator.sanityCheckBlock(columnAtATime, docs.count(), target[f], f); + columnAtATimeReaders.add(new ColumnAtATimeWork(columnAtATime, f)); } else { rowStrideReaders.add( new RowStrideReaderWork( field.rowStride(ctx), - (Block.Builder) field.loader.builder(loaderBlockFactory, docs.count()), + (Block.Builder) field.loader.builder(loaderBlockFactory, docs.count() - offset), field.loader, f ) @@ -116,7 +109,18 @@ private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IO } if (rowStrideReaders.isEmpty() == false) { - loadFromRowStrideReaders(target, storedFieldsSpec, rowStrideReaders, ctx, docs); + loadFromRowStrideReaders(jumboBytes, target, storedFieldsSpec, rowStrideReaders, ctx, docs, offset); + } + for (ColumnAtATimeWork r : columnAtATimeReaders) { + target[r.idx] = (Block) r.reader.read(loaderBlockFactory, docs, offset); + operator.sanityCheckBlock(r.reader, docs.count() - offset, target[r.idx], r.idx); + } + if (log.isDebugEnabled()) { + long total = 0; + for (Block b : target) { + total += b.ramBytesUsed(); + } + log.debug("loaded {} positions total ({} bytes)", target[0].getPositionCount(), total); } } finally { Releasables.close(rowStrideReaders); @@ -124,11 +128,13 @@ private void loadFromSingleLeaf(Block[] target, BlockLoader.Docs docs) throws IO } private void loadFromRowStrideReaders( + long jumboBytes, Block[] target, StoredFieldsSpec storedFieldsSpec, List rowStrideReaders, LeafReaderContext ctx, - BlockLoader.Docs docs + ValuesReaderDocs docs, + int offset ) throws IOException { SourceLoader sourceLoader = null; ValuesSourceReaderOperator.ShardContext shardContext = operator.shardContexts.get(shard); @@ -153,18 +159,29 @@ private void loadFromRowStrideReaders( storedFieldLoader.getLoader(ctx, null), sourceLoader != null ? sourceLoader.leaf(ctx.reader(), null) : null ); - int p = 0; - while (p < docs.count()) { + int p = offset; + long estimated = 0; + while (p < docs.count() && estimated < jumboBytes) { int doc = docs.get(p++); storedFields.advanceTo(doc); for (RowStrideReaderWork work : rowStrideReaders) { work.read(doc, storedFields); } + estimated = estimatedRamBytesUsed(rowStrideReaders); + log.trace("{}: bytes loaded {}/{}", p, estimated, jumboBytes); } for (RowStrideReaderWork work : rowStrideReaders) { - target[work.offset] = work.build(); - operator.sanityCheckBlock(work.reader, p, target[work.offset], work.offset); + target[work.idx] = work.build(); + operator.sanityCheckBlock(work.reader, p - offset, target[work.idx], work.idx); } + if (log.isDebugEnabled()) { + long actual = 0; + for (RowStrideReaderWork work : rowStrideReaders) { + actual += target[work.idx].ramBytesUsed(); + } + log.debug("loaded {} positions row stride estimated/actual {}/{} bytes", p - offset, estimated, actual); + } + docs.setCount(p); } /** @@ -180,7 +197,21 @@ private boolean useSequentialStoredFieldsReader(BlockLoader.Docs docs, double st return range * storedFieldsSequentialProportion <= count; } - private record RowStrideReaderWork(BlockLoader.RowStrideReader reader, Block.Builder builder, BlockLoader loader, int offset) + /** + * Work for building a column-at-a-time. + * @param reader reads the values + * @param idx destination in array of {@linkplain Block}s we build + */ + private record ColumnAtATimeWork(BlockLoader.ColumnAtATimeReader reader, int idx) {} + + /** + * Work for + * @param reader + * @param builder + * @param loader + * @param idx + */ + private record RowStrideReaderWork(BlockLoader.RowStrideReader reader, Block.Builder builder, BlockLoader loader, int idx) implements Releasable { void read(int doc, BlockLoaderStoredFieldsFromLeafLoader storedFields) throws IOException { @@ -196,4 +227,12 @@ public void close() { builder.close(); } } + + private long estimatedRamBytesUsed(List rowStrideReaders) { + long estimated = 0; + for (RowStrideReaderWork r : rowStrideReaders) { + estimated += r.builder.estimatedBytes(); + } + return estimated; + } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java index ebfac0cb24f7f..d3b8b0edcec3d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReader.java @@ -36,9 +36,6 @@ public Block[] next() { boolean success = false; try { load(target, offset); - if (target[0].getPositionCount() != docs.getPositionCount()) { - throw new IllegalStateException("partial pages not yet supported"); - } success = true; for (Block b : target) { operator.valuesLoaded += b.getTotalValueCount(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java new file mode 100644 index 0000000000000..2e138dc2d0446 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesReaderDocs.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.lucene.read; + +import org.elasticsearch.compute.data.DocVector; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.BlockLoader; + +/** + * Implementation of {@link BlockLoader.Docs} for ESQL. It's important that + * only this implementation, and the implementation returned by {@link #mapped} + * exist. This allows the jvm to inline the {@code invokevirtual}s to call + * the interface in hot, hot code. + *

+ * We've investigated moving the {@code offset} parameter from the + * {@link BlockLoader.ColumnAtATimeReader#read} into this. That's more + * readable, but a clock cycle slower. + *

+ *

+ * When we tried having a {@link Nullable} map member instead of a subclass + * that was also slower. + *

+ */ +class ValuesReaderDocs implements BlockLoader.Docs { + private final DocVector docs; + private int count; + + ValuesReaderDocs(DocVector docs) { + this.docs = docs; + this.count = docs.getPositionCount(); + } + + final Mapped mapped(int[] forwards) { + return new Mapped(docs, forwards); + } + + public final void setCount(int count) { + this.count = count; + } + + @Override + public final int count() { + return count; + } + + @Override + public int get(int i) { + return docs.docs().getInt(i); + } + + private class Mapped extends ValuesReaderDocs { + private final int[] forwards; + + private Mapped(DocVector docs, int[] forwards) { + super(docs); + this.forwards = forwards; + } + + @Override + public int get(int i) { + return super.get(forwards[i]); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java index 2fd4784224087..6d0ebb9c312d0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperator.java @@ -9,6 +9,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocBlock; @@ -42,7 +43,9 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingToIteratorOpe * @param shardContexts per-shard loading information * @param docChannel the channel containing the shard, leaf/segment and doc id */ - public record Factory(List fields, List shardContexts, int docChannel) implements OperatorFactory { + public record Factory(ByteSizeValue jumboSize, List fields, List shardContexts, int docChannel) + implements + OperatorFactory { public Factory { if (fields.isEmpty()) { throw new IllegalStateException("ValuesSourceReaderOperator doesn't support empty fields"); @@ -51,7 +54,7 @@ public record Factory(List fields, List shardContexts, @Override public Operator get(DriverContext driverContext) { - return new ValuesSourceReaderOperator(driverContext.blockFactory(), fields, shardContexts, docChannel); + return new ValuesSourceReaderOperator(driverContext.blockFactory(), jumboSize.getBytes(), fields, shardContexts, docChannel); } @Override @@ -85,10 +88,21 @@ public record FieldInfo(String name, ElementType type, IntFunction public record ShardContext(IndexReader reader, Supplier newSourceLoader, double storedFieldsSequentialProportion) {} + final BlockFactory blockFactory; + /** + * When the loaded fields {@link Block}s' estimated size grows larger than this, + * we finish loading the {@linkplain Page} and return it, even if + * the {@linkplain Page} is shorter than the incoming {@linkplain Page}. + *

+ * NOTE: This only applies when loading single segment non-descending + * row stride bytes. This is the most common way to get giant fields, + * but it isn't all the ways. + *

+ */ + final long jumboBytes; final FieldWork[] fields; final List shardContexts; private final int docChannel; - final BlockFactory blockFactory; private final Map readersBuilt = new TreeMap<>(); long valuesLoaded; @@ -101,14 +115,21 @@ public record ShardContext(IndexReader reader, Supplier newSourceL * @param fields fields to load * @param docChannel the channel containing the shard, leaf/segment and doc id */ - public ValuesSourceReaderOperator(BlockFactory blockFactory, List fields, List shardContexts, int docChannel) { + public ValuesSourceReaderOperator( + BlockFactory blockFactory, + long jumboBytes, + List fields, + List shardContexts, + int docChannel + ) { if (fields.isEmpty()) { throw new IllegalStateException("ValuesSourceReaderOperator doesn't support empty fields"); } + this.blockFactory = blockFactory; + this.jumboBytes = jumboBytes; this.fields = fields.stream().map(FieldWork::new).toArray(FieldWork[]::new); this.shardContexts = shardContexts; this.docChannel = docChannel; - this.blockFactory = blockFactory; } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java index 56ba95f66f5fa..1395a0d0ad73c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java @@ -66,11 +66,11 @@ public String describe() { private final RandomSamplingQuery.RandomSamplingIterator randomSamplingIterator; private boolean finished; - private int pagesProcessed = 0; - private int rowsReceived = 0; - private int rowsEmitted = 0; private long collectNanos; private long emitNanos; + private int pagesProcessed = 0; + private long rowsReceived = 0; + private long rowsEmitted = 0; private SampleOperator(double probability, int seed) { finished = false; @@ -109,7 +109,7 @@ private void createOutputPage(Page page) { final int[] sampledPositions = new int[page.getPositionCount()]; int sampledIdx = 0; for (int i = randomSamplingIterator.docID(); i - rowsReceived < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) { - sampledPositions[sampledIdx++] = i - rowsReceived; + sampledPositions[sampledIdx++] = Math.toIntExact(i - rowsReceived); } if (sampledIdx > 0) { outputPages.add(page.filter(Arrays.copyOf(sampledPositions, sampledIdx))); @@ -167,7 +167,7 @@ public Operator.Status status() { return new Status(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted); } - private record Status(long collectNanos, long emitNanos, int pagesProcessed, int rowsReceived, int rowsEmitted) + public record Status(long collectNanos, long emitNanos, int pagesProcessed, long rowsReceived, long rowsEmitted) implements Operator.Status { @@ -178,7 +178,13 @@ private record Status(long collectNanos, long emitNanos, int pagesProcessed, int ); Status(StreamInput streamInput) throws IOException { - this(streamInput.readVLong(), streamInput.readVLong(), streamInput.readVInt(), streamInput.readVInt(), streamInput.readVInt()); + this( + streamInput.readVLong(), + streamInput.readVLong(), + streamInput.readVInt(), + streamInput.readVLong(), + streamInput.readVLong() + ); } @Override @@ -186,8 +192,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(collectNanos); out.writeVLong(emitNanos); out.writeVInt(pagesProcessed); - out.writeVInt(rowsReceived); - out.writeVInt(rowsEmitted); + out.writeVLong(rowsReceived); + out.writeVLong(rowsEmitted); } @Override @@ -236,7 +242,14 @@ public String toString() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ZERO; + assert false : "must not be called when overriding supportsVersion"; + throw new UnsupportedOperationException("must not be called when overriding supportsVersion"); + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return version.onOrAfter(TransportVersions.ESQL_SAMPLE_OPERATOR_STATUS) + || version.isPatchFrom(TransportVersions.ESQL_SAMPLE_OPERATOR_STATUS_9_1); } } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index db4febdf8ddca..8185b045029b3 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -172,6 +172,7 @@ public void testPushRoundToToQuery() throws IOException { LuceneOperator.NO_LIMIT ); ValuesSourceReaderOperator.Factory load = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("v", ElementType.LONG, f -> new BlockDocValuesReader.LongsBlockLoader("v")) ), @@ -198,7 +199,6 @@ public void testPushRoundToToQuery() throws IOException { boolean sawSecondMax = false; boolean sawThirdMax = false; for (Page page : pages) { - logger.error("ADFA {}", page); LongVector group = page.getBlock(1).asVector(); LongVector value = page.getBlock(2).asVector(); for (int p = 0; p < page.getPositionCount(); p++) { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index f040b86850133..dda9671b3b242 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -759,16 +759,52 @@ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { delegate.selectedMayContainUnseenGroups(seenGroupIds); } + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds, page); + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds, page); + } + @Override public void addIntermediateInput(int positionOffset, IntVector groupIds, Page page) { + addIntermediateInputInternal(positionOffset, groupIds.asBlock(), page); + } + + public void addIntermediateInputInternal(int positionOffset, IntBlock groupIds, Page page) { + BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance(); int[] chunk = new int[emitChunkSize]; - for (int offset = 0; offset < groupIds.getPositionCount(); offset += emitChunkSize) { - int count = 0; - for (int i = offset; i < Math.min(groupIds.getPositionCount(), offset + emitChunkSize); i++) { - chunk[count++] = groupIds.getInt(i); + int chunkPosition = 0; + int offset = 0; + for (int position = 0; position < groupIds.getPositionCount(); position++) { + if (groupIds.isNull(position)) { + continue; } - BlockFactory blockFactory = TestBlockFactory.getNonBreakingInstance(); // TODO: just for compile - delegate.addIntermediateInput(positionOffset + offset, blockFactory.newIntArrayVector(chunk, count), page); + int firstValueIndex = groupIds.getFirstValueIndex(position); + int valueCount = groupIds.getValueCount(position); + assert valueCount == 1; // Multi-values make chunking more complex, and it's not a real case yet + + int groupId = groupIds.getInt(firstValueIndex); + chunk[chunkPosition++] = groupId; + if (chunkPosition == emitChunkSize) { + delegate.addIntermediateInput( + positionOffset + offset, + blockFactory.newIntArrayVector(chunk, chunkPosition), + page + ); + chunkPosition = 0; + offset = position + 1; + } + } + if (chunkPosition > 0) { + delegate.addIntermediateInput( + positionOffset + offset, + blockFactory.newIntArrayVector(chunk, chunkPosition), + page + ); } } @@ -846,9 +882,7 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { blockHash.add(page, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { - IntBlock newGroupIds = aggregatorMode.isInputPartial() - ? groupIds - : BlockTypeRandomizer.randomizeBlockType(groupIds); + IntBlock newGroupIds = BlockTypeRandomizer.randomizeBlockType(groupIds); addInput.add(positionOffset, newGroupIds); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java index 655f7b54c61c0..2ef64623daa74 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.OperatorTests; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; @@ -201,6 +202,7 @@ private List runQuery(Set values, Query query, boolean shuffleDocs operators.add( new ValuesSourceReaderOperator( blockFactory, + ByteSizeValue.ofGb(1).getBytes(), List.of( new ValuesSourceReaderOperator.FieldInfo( FIELD, diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java index 2bd5cc95dd804..5a1f2ee7cc949 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValueSourceReaderTypeConversionTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.Block; @@ -241,12 +242,17 @@ private static Operator.OperatorFactory factory( ElementType elementType, BlockLoader loader ) { - return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { - if (shardIdx < 0 || shardIdx >= INDICES.size()) { - fail("unexpected shardIdx [" + shardIdx + "]"); - } - return loader; - })), shardContexts, 0); + return new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), + List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { + if (shardIdx < 0 || shardIdx >= INDICES.size()) { + fail("unexpected shardIdx [" + shardIdx + "]"); + } + return loader; + })), + shardContexts, + 0 + ); } protected SourceOperator simpleInput(DriverContext context, int size) { @@ -493,6 +499,7 @@ public void testManySingleDocPages() { // TODO: Add index2 operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(testCase.info, fieldInfo(mapperService(indexKey).fieldType("key"), ElementType.INT)), shardContexts, 0 @@ -600,6 +607,7 @@ private void loadSimpleAndAssert( List operators = new ArrayList<>(); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( fieldInfo(mapperService("index1").fieldType("key"), ElementType.INT), fieldInfo(mapperService("index1").fieldType("indexKey"), ElementType.BYTES_REF) @@ -614,7 +622,9 @@ private void loadSimpleAndAssert( cases.removeAll(b); tests.addAll(b); operators.add( - new ValuesSourceReaderOperator.Factory(b.stream().map(i -> i.info).toList(), shardContexts, 0).get(driverContext) + new ValuesSourceReaderOperator.Factory(ByteSizeValue.ofGb(1), b.stream().map(i -> i.info).toList(), shardContexts, 0).get( + driverContext + ) ); } List results = drive(operators, input.iterator(), driverContext); @@ -718,7 +728,7 @@ private void testLoadAllStatus(boolean allInOnePage) { Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING ); List operators = cases.stream() - .map(i -> new ValuesSourceReaderOperator.Factory(List.of(i.info), shardContexts, 0).get(driverContext)) + .map(i -> new ValuesSourceReaderOperator.Factory(ByteSizeValue.ofGb(1), List.of(i.info), shardContexts, 0).get(driverContext)) .toList(); if (allInOnePage) { input = List.of(CannedSourceOperator.mergePages(input)); @@ -1390,6 +1400,7 @@ public void testNullsShared() { simpleInput(driverContext, 10), List.of( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS), new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS) @@ -1424,6 +1435,7 @@ public void testDescriptionOfMany() throws IOException { List cases = infoAndChecksForEachType(ordering, ordering); ValuesSourceReaderOperator.Factory factory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), cases.stream().map(c -> c.info).toList(), List.of(new ValuesSourceReaderOperator.ShardContext(reader(indexKey), () -> SourceLoader.FROM_STORED_SOURCE, 0.2)), 0 @@ -1469,6 +1481,7 @@ public void testManyShards() throws IOException { // TODO add index2 MappedFieldType ft = mapperService(indexKey).fieldType("key"); var readerFactory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> { seenShards.add(shardIdx); return ft.blockLoader(blContext()); @@ -1676,8 +1689,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - Block block = reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + Block block = reader.read(factory, docs, offset); Page page = new Page((org.elasticsearch.compute.data.Block) block); return convertEvaluator.eval(page); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java index 0c227b5411e25..19a645c146242 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/read/ValuesSourceReaderOperatorTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -37,6 +38,7 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.DocBlock; +import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; @@ -99,6 +101,7 @@ import static org.elasticsearch.test.MapMatcher.matchesMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -150,12 +153,14 @@ public static Operator.OperatorFactory factory(IndexReader reader, MappedFieldTy } static Operator.OperatorFactory factory(IndexReader reader, String name, ElementType elementType, BlockLoader loader) { - return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { - if (shardIdx != 0) { - fail("unexpected shardIdx [" + shardIdx + "]"); - } - return loader; - })), + return new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), + List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> { + if (shardIdx != 0) { + fail("unexpected shardIdx [" + shardIdx + "]"); + } + return loader; + })), List.of( new ValuesSourceReaderOperator.ShardContext( reader, @@ -401,7 +406,7 @@ private IndexReader initIndexLongField(Directory directory, int size, int commit for (int d = 0; d < size; d++) { XContentBuilder source = JsonXContent.contentBuilder(); source.startObject(); - source.field("long_source_text", Integer.toString(d).repeat(100 * 1024)); + source.field("long_source_text", d + "#" + "a".repeat(100 * 1024)); source.endObject(); ParsedDocument doc = mapperService.documentParser() .parseDocument( @@ -489,6 +494,7 @@ public void testManySingleDocPages() { ); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(testCase.info, fieldInfo(mapperService.fieldType("key"), ElementType.INT)), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -608,6 +614,7 @@ private void loadSimpleAndAssert( List operators = new ArrayList<>(); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(fieldInfo(mapperService.fieldType("key"), ElementType.INT)), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -626,6 +633,7 @@ private void loadSimpleAndAssert( tests.addAll(b); operators.add( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), b.stream().map(i -> i.info).toList(), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -724,6 +732,7 @@ private void testLoadAllStatus(boolean allInOnePage) { List operators = cases.stream() .map( i -> new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(i.info), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -928,7 +937,6 @@ public void testLoadLongShuffledManySegments() throws IOException { private void testLoadLong(boolean shuffle, boolean manySegments) throws IOException { int numDocs = between(10, 500); initMapping(); - keyToTags.clear(); reader = initIndexLongField(directory, numDocs, manySegments ? commitEvery(numDocs) : numDocs, manySegments == false); DriverContext driverContext = driverContext(); @@ -941,6 +949,7 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept if (shuffle) { input = input.stream().map(this::shuffle).toList(); } + boolean willSplit = loadLongWillSplit(input); Checks checks = new Checks(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING, Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING); @@ -956,6 +965,7 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept List operators = cases.stream() .map( i -> new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(i.info), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -968,12 +978,55 @@ private void testLoadLong(boolean shuffle, boolean manySegments) throws IOExcept ).get(driverContext) ) .toList(); - drive(operators, input.iterator(), driverContext); + List result = drive(operators, input.iterator(), driverContext); + + boolean[] found = new boolean[numDocs]; + for (Page page : result) { + BytesRefVector bytes = page.getBlock(1).asVector(); + BytesRef scratch = new BytesRef(); + for (int p = 0; p < bytes.getPositionCount(); p++) { + BytesRef v = bytes.getBytesRef(p, scratch); + int d = Integer.valueOf(v.utf8ToString().split("#")[0]); + assertFalse("found a duplicate " + d, found[d]); + found[d] = true; + } + } + List missing = new ArrayList<>(); + for (int d = 0; d < numDocs; d++) { + if (found[d] == false) { + missing.add(d); + } + } + assertThat(missing, hasSize(0)); + assertThat(result, hasSize(willSplit ? greaterThanOrEqualTo(input.size()) : equalTo(input.size()))); + for (int i = 0; i < cases.size(); i++) { ValuesSourceReaderOperatorStatus status = (ValuesSourceReaderOperatorStatus) operators.get(i).status(); assertThat(status.pagesReceived(), equalTo(input.size())); - assertThat(status.pagesEmitted(), equalTo(input.size())); + assertThat(status.pagesEmitted(), willSplit ? greaterThanOrEqualTo(input.size()) : equalTo(input.size())); + } + } + + private boolean loadLongWillSplit(List input) { + int nextDoc = -1; + for (Page page : input) { + DocVector doc = page.getBlock(0).asVector(); + for (int p = 0; p < doc.getPositionCount(); p++) { + if (doc.shards().getInt(p) != 0) { + return false; + } + if (doc.segments().getInt(p) != 0) { + return false; + } + if (nextDoc == -1) { + nextDoc = doc.docs().getInt(p); + } else if (doc.docs().getInt(p) != nextDoc) { + return false; + } + nextDoc++; + } } + return true; } record Checks(Block.MvOrdering booleanAndNumericalDocValuesMvOrdering, Block.MvOrdering bytesRefDocValuesMvOrdering) { @@ -1565,6 +1618,7 @@ public void testNullsShared() { simpleInput(driverContext.blockFactory(), 10), List.of( new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS), new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS) @@ -1616,6 +1670,7 @@ private void testSequentialStoredFields(boolean sequential, int docCount) throws assertThat(source, hasSize(1)); // We want one page for simpler assertions, and we want them all in one segment assertTrue(source.get(0).getBlock(0).asVector().singleSegmentNonDecreasing()); Operator op = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of( fieldInfo(mapperService.fieldType("key"), ElementType.INT), fieldInfo(storedTextField("stored_text"), ElementType.BYTES_REF) @@ -1653,6 +1708,7 @@ public void testDescriptionOfMany() throws IOException { List cases = infoAndChecksForEachType(ordering, ordering); ValuesSourceReaderOperator.Factory factory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), cases.stream().map(c -> c.info).toList(), List.of( new ValuesSourceReaderOperator.ShardContext( @@ -1706,6 +1762,7 @@ public void testManyShards() throws IOException { ); MappedFieldType ft = mapperService.fieldType("key"); var readerFactory = new ValuesSourceReaderOperator.Factory( + ByteSizeValue.ofGb(1), List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> { seenShards.add(shardIdx); return ft.blockLoader(blContext()); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index 0e9c0e33d22cd..a4072754fae10 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -28,6 +28,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.function.Function; import java.util.stream.LongStream; import static java.util.stream.IntStream.range; @@ -254,4 +255,96 @@ public void testTopNNullsFirst() { outputPage.releaseBlocks(); } } + + /** + * When in intermediate/final mode, it will receive intermediate outputs that may have to be discarded + * (TopN in the datanode but not acceptable in the coordinator). + *

+ * This test ensures that such discarding works correctly. + *

+ */ + public void testTopNNullsIntermediateDiscards() { + boolean ascOrder = randomBoolean(); + var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L }; + if (ascOrder) { + Arrays.sort(groups, Comparator.reverseOrder()); + } + var groupChannel = 0; + + // Supplier of operators to ensure that they're identical, simulating a datanode/coordinator connection + Function makeAggWithMode = (mode) -> { + var sumAggregatorChannels = mode.isInputPartial() ? List.of(1, 2) : List.of(1); + var maxAggregatorChannels = mode.isInputPartial() ? List.of(3, 4) : List.of(1); + + return new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))), + mode, + List.of( + new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, sumAggregatorChannels), + new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, maxAggregatorChannels) + ), + randomPageSize(), + null + ).get(driverContext()); + }; + + // The operator that will collect all the results + try (var collectingOperator = makeAggWithMode.apply(AggregatorMode.FINAL)) { + // First datanode, sending a suitable TopN set of data + try (var datanodeOperator = makeAggWithMode.apply(AggregatorMode.INITIAL)) { + var page = new Page( + BlockUtils.fromList(blockFactory(), List.of(List.of(groups[4], 1L), List.of(groups[3], 2L), List.of(groups[2], 4L))) + ); + datanodeOperator.addInput(page); + datanodeOperator.finish(); + + var outputPage = datanodeOperator.getOutput(); + collectingOperator.addInput(outputPage); + } + + // Second datanode, sending an outdated TopN, as the coordinator has better top values already + try (var datanodeOperator = makeAggWithMode.apply(AggregatorMode.INITIAL)) { + var page = new Page( + BlockUtils.fromList( + blockFactory(), + List.of( + List.of(groups[5], 8L), + List.of(groups[3], 16L), + List.of(groups[1], 32L) // This group is worse than the worst group in the coordinator + ) + ) + ); + datanodeOperator.addInput(page); + datanodeOperator.finish(); + + var outputPage = datanodeOperator.getOutput(); + collectingOperator.addInput(outputPage); + } + + collectingOperator.finish(); + + var outputPage = collectingOperator.getOutput(); + + var groupsBlock = (LongBlock) outputPage.getBlock(0); + var sumBlock = (LongBlock) outputPage.getBlock(1); + var maxBlock = (LongBlock) outputPage.getBlock(2); + + assertThat(groupsBlock.getPositionCount(), equalTo(3)); + assertThat(sumBlock.getPositionCount(), equalTo(3)); + assertThat(maxBlock.getPositionCount(), equalTo(3)); + + assertThat(groupsBlock.getTotalValueCount(), equalTo(3)); + assertThat(sumBlock.getTotalValueCount(), equalTo(3)); + assertThat(maxBlock.getTotalValueCount(), equalTo(3)); + + assertThat( + BlockTestUtils.valuesAtPositions(groupsBlock, 0, 3), + equalTo(Arrays.asList(List.of(groups[4]), List.of(groups[3]), List.of(groups[5]))) + ); + assertThat(BlockTestUtils.valuesAtPositions(sumBlock, 0, 3), equalTo(List.of(List.of(1L), List.of(18L), List.of(8L)))); + assertThat(BlockTestUtils.valuesAtPositions(maxBlock, 0, 3), equalTo(List.of(List.of(1L), List.of(16L), List.of(8L)))); + + outputPage.releaseBlocks(); + } + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorStatusTests.java new file mode 100644 index 0000000000000..50f3f456f3745 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorStatusTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class SampleOperatorStatusTests extends AbstractWireSerializingTestCase { + public static SampleOperator.Status simple() { + return new SampleOperator.Status(500012, 200012, 123, 111, 222); + } + + public static String simpleToJson() { + return """ + { + "collect_nanos" : 500012, + "collect_time" : "500micros", + "emit_nanos" : 200012, + "emit_time" : "200micros", + "pages_processed" : 123, + "rows_received" : 111, + "rows_emitted" : 222 + }"""; + } + + public void testToXContent() { + assertThat(Strings.toString(simple(), true, true), equalTo(simpleToJson())); + } + + @Override + protected Writeable.Reader instanceReader() { + return SampleOperator.Status::new; + } + + @Override + public SampleOperator.Status createTestInstance() { + return new SampleOperator.Status( + randomNonNegativeLong(), + randomNonNegativeLong(), + randomNonNegativeInt(), + randomNonNegativeLong(), + randomNonNegativeLong() + ); + } + + @Override + protected SampleOperator.Status mutateInstance(SampleOperator.Status instance) { + long collectNanos = instance.collectNanos(); + long emitNanos = instance.emitNanos(); + int pagesProcessed = instance.pagesProcessed(); + long rowsReceived = instance.rowsReceived(); + long rowsEmitted = instance.rowsEmitted(); + switch (between(0, 4)) { + case 0 -> collectNanos = randomValueOtherThan(collectNanos, ESTestCase::randomNonNegativeLong); + case 1 -> emitNanos = randomValueOtherThan(emitNanos, ESTestCase::randomNonNegativeLong); + case 2 -> pagesProcessed = randomValueOtherThan(pagesProcessed, ESTestCase::randomNonNegativeInt); + case 3 -> rowsReceived = randomValueOtherThan(rowsReceived, ESTestCase::randomNonNegativeLong); + case 4 -> rowsEmitted = randomValueOtherThan(rowsEmitted, ESTestCase::randomNonNegativeLong); + default -> throw new UnsupportedOperationException(); + } + return new SampleOperator.Status(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted); + } +} diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java index 705bfca2e903e..a943e917e0335 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java @@ -126,7 +126,10 @@ public MultiClusterSpecIT( "NullifiedJoinKeyToPurgeTheJoin", "SortBeforeAndAfterJoin", "SortEvalBeforeLookup", - "SortBeforeAndAfterMultipleJoinAndMvExpand" + "SortBeforeAndAfterMultipleJoinAndMvExpand", + "LookupJoinAfterTopNAndRemoteEnrich", + // Lookup join after LIMIT is not supported in CCS yet + "LookupJoinAfterLimitAndRemoteEnrich" ); @Override diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java index d01e1c9fb7f56..3484f19afa451 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/EsqlSpecIT.java @@ -9,13 +9,21 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.TestClustersThreadFilter; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.esql.CsvSpecReader.CsvTestCase; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.plugin.ComputeService; import org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase; +import org.junit.Before; import org.junit.ClassRule; +import java.io.IOException; + @ThreadLeakFilters(filters = TestClustersThreadFilter.class) public class EsqlSpecIT extends EsqlSpecTestCase { @ClassRule @@ -50,4 +58,14 @@ protected boolean enableRoundingDoubleValuesOnAsserting() { protected boolean supportsSourceFieldMapping() { return cluster.getNumNodes() == 1; } + + @Before + public void configureChunks() throws IOException { + boolean smallChunks = randomBoolean(); + Request request = new Request("PUT", "/_cluster/settings"); + XContentBuilder builder = JsonXContent.contentBuilder().startObject().startObject("persistent"); + builder.field(PhysicalSettings.VALUES_LOADING_JUMBO_SIZE.getKey(), smallChunks ? "1kb" : null); + request.setJsonEntity(Strings.toString(builder.endObject().endObject())); + assertOK(client().performRequest(request)); + } } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java index ef02d4a1f8c98..9073b8bb81333 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java @@ -76,6 +76,46 @@ public static CommandGenerator randomPipeCommandGenerator() { return randomFrom(PIPE_COMMANDS); } + public interface Executor { + void run(CommandGenerator generator, CommandGenerator.CommandDescription current); + + List previousCommands(); + + boolean continueExecuting(); + + List currentSchema(); + + } + + public static void generatePipeline( + final int depth, + CommandGenerator commandGenerator, + final CommandGenerator.QuerySchema schema, + Executor executor + ) { + CommandGenerator.CommandDescription desc = commandGenerator.generate(List.of(), List.of(), schema); + executor.run(commandGenerator, desc); + if (executor.continueExecuting() == false) { + return; + } + + for (int j = 0; j < depth; j++) { + if (executor.currentSchema().isEmpty()) { + break; + } + commandGenerator = EsqlQueryGenerator.randomPipeCommandGenerator(); + desc = commandGenerator.generate(executor.previousCommands(), executor.currentSchema(), schema); + if (desc == CommandGenerator.EMPTY_DESCRIPTION) { + continue; + } + + executor.run(commandGenerator, desc); + if (executor.continueExecuting() == false) { + break; + } + } + } + public static String booleanExpression(List previousOutput) { // TODO LIKE, RLIKE, functions etc. return switch (randomIntBetween(0, 3)) { diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java index 41b8658302bd7..baa42ab58e62b 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java @@ -48,9 +48,9 @@ public abstract class GenerativeRestTest extends ESRestTestCase { // Awaiting fixes for query failure "Unknown column \\[\\]", // https://github.com/elastic/elasticsearch/issues/121741, "Plan \\[ProjectExec\\[\\[.* optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/125866 - "optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/116781 "The incoming YAML document exceeds the limit:", // still to investigate, but it seems to be specific to the test framework "Data too large", // Circuit breaker exceptions eg. https://github.com/elastic/elasticsearch/issues/130072 + "optimized incorrectly due to missing references", // https://github.com/elastic/elasticsearch/issues/131509 // Awaiting fixes for correctness "Expecting at most \\[.*\\] columns, got \\[.*\\]" // https://github.com/elastic/elasticsearch/issues/129561 @@ -87,43 +87,53 @@ public void test() throws IOException { List lookupIndices = lookupIndices(); List policies = availableEnrichPolicies(); CommandGenerator.QuerySchema mappingInfo = new CommandGenerator.QuerySchema(indices, lookupIndices, policies); - EsqlQueryGenerator.QueryExecuted previousResult = null; + for (int i = 0; i < ITERATIONS; i++) { - List previousCommands = new ArrayList<>(); - CommandGenerator commandGenerator = EsqlQueryGenerator.sourceCommand(); - CommandGenerator.CommandDescription desc = commandGenerator.generate(List.of(), List.of(), mappingInfo); - String command = desc.commandString(); - EsqlQueryGenerator.QueryExecuted result = execute(command, 0); - if (result.exception() != null) { - checkException(result); - continue; - } - if (checkResults(List.of(), commandGenerator, desc, null, result).success() == false) { - continue; - } - previousResult = result; - previousCommands.add(desc); - for (int j = 0; j < MAX_DEPTH; j++) { - if (result.outputSchema().isEmpty()) { - break; + var exec = new EsqlQueryGenerator.Executor() { + @Override + public void run(CommandGenerator generator, CommandGenerator.CommandDescription current) { + previousCommands.add(current); + final String command = current.commandString(); + + final EsqlQueryGenerator.QueryExecuted result = previousResult == null + ? execute(command, 0) + : execute(previousResult.query() + command, previousResult.depth()); + previousResult = result; + + final boolean hasException = result.exception() != null; + if (hasException || checkResults(List.of(), generator, current, previousResult, result).success() == false) { + if (hasException) { + checkException(result); + } + continueExecuting = false; + currentSchema = List.of(); + } else { + continueExecuting = true; + currentSchema = result.outputSchema(); + } } - commandGenerator = EsqlQueryGenerator.randomPipeCommandGenerator(); - desc = commandGenerator.generate(previousCommands, result.outputSchema(), mappingInfo); - if (desc == CommandGenerator.EMPTY_DESCRIPTION) { - continue; + + @Override + public List previousCommands() { + return previousCommands; } - command = desc.commandString(); - result = execute(result.query() + command, result.depth() + 1); - if (result.exception() != null) { - checkException(result); - break; + + @Override + public boolean continueExecuting() { + return continueExecuting; } - if (checkResults(previousCommands, commandGenerator, desc, previousResult, result).success() == false) { - break; + + @Override + public List currentSchema() { + return currentSchema; } - previousCommands.add(desc); - previousResult = result; - } + + boolean continueExecuting; + List currentSchema; + final List previousCommands = new ArrayList<>(); + EsqlQueryGenerator.QueryExecuted previousResult; + }; + EsqlQueryGenerator.generatePipeline(MAX_DEPTH, EsqlQueryGenerator.sourceCommand(), mappingInfo, exec); } } @@ -163,7 +173,7 @@ private void checkException(EsqlQueryGenerator.QueryExecuted query) { } @SuppressWarnings("unchecked") - private EsqlQueryGenerator.QueryExecuted execute(String command, int depth) { + public static EsqlQueryGenerator.QueryExecuted execute(String command, int depth) { try { Map a = RestEsqlTestCase.runEsql( new RestEsqlTestCase.RequestObjectBuilder().query(command).build(), @@ -183,7 +193,7 @@ private EsqlQueryGenerator.QueryExecuted execute(String command, int depth) { } @SuppressWarnings("unchecked") - private List outputSchema(Map a) { + private static List outputSchema(Map a) { List> cols = (List>) a.get("columns"); if (cols == null) { return null; diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/ForkGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/ForkGenerator.java index bb82a7e4f7e22..84bf417c1a732 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/ForkGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/ForkGenerator.java @@ -8,10 +8,13 @@ package org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe; import org.elasticsearch.xpack.esql.qa.rest.generative.EsqlQueryGenerator; +import org.elasticsearch.xpack.esql.qa.rest.generative.GenerativeRestTest; import org.elasticsearch.xpack.esql.qa.rest.generative.command.CommandGenerator; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.test.ESTestCase.randomIntBetween; @@ -28,17 +31,105 @@ public CommandDescription generate( ) { // FORK can only be allowed once - so we skip adding another FORK if we already have one // otherwise, most generated queries would only result in a validation error + StringBuilder completeCommand = new StringBuilder(); for (CommandDescription command : previousCommands) { if (command.commandName().equals(FORK)) { - return new CommandDescription(FORK, this, " ", Map.of()); + return EMPTY_DESCRIPTION; } + + completeCommand.append(command.commandString()); } - int n = randomIntBetween(2, 3); + final int branchCount = randomIntBetween(2, 3); + final int branchToRetain = randomIntBetween(1, branchCount); + + StringBuilder forkCmd = new StringBuilder(" | FORK "); + for (int i = 0; i < branchCount; i++) { + var expr = WhereGenerator.randomExpression(randomIntBetween(1, 2), previousOutput); + if (expr == null) { + expr = "true"; + } + forkCmd.append(" (").append("where ").append(expr); + + var exec = new EsqlQueryGenerator.Executor() { + @Override + public void run(CommandGenerator generator, CommandDescription current) { + final String command = current.commandString(); + + // Try appending new command to parent of Fork. If we successfully execute (without exception) AND still retain the same + // schema, we append the command. Enforcing the same schema is stricter than the Fork needs (it only needs types to be + // the same on columns which are present), but given we currently generate independent sub-pipelines, this way we can + // generate more valid Fork queries. + final EsqlQueryGenerator.QueryExecuted result = previousResult == null + ? GenerativeRestTest.execute(command, 0) + : GenerativeRestTest.execute(previousResult.query() + command, previousResult.depth()); + previousResult = result; + + continueExecuting = result.exception() == null && result.outputSchema().equals(previousOutput); + if (continueExecuting) { + previousCommands.add(current); + } + } + + @Override + public List previousCommands() { + return previousCommands; + } + + @Override + public boolean continueExecuting() { + return continueExecuting; + } - String cmd = " | FORK " + "( WHERE true ) ".repeat(n) + " | WHERE _fork == \"fork" + randomIntBetween(1, n) + "\" | DROP _fork"; + @Override + public List currentSchema() { + return previousOutput; + } + + final List previousCommands = new ArrayList<>(); + boolean continueExecuting; + EsqlQueryGenerator.QueryExecuted previousResult; + }; + + var gen = new CommandGenerator() { + @Override + public CommandDescription generate( + List previousCommands, + List previousOutput, + QuerySchema schema + ) { + return new CommandDescription(FORK, this, completeCommand.toString(), Map.of()); + } + + @Override + public ValidationResult validateOutput( + List previousCommands, + CommandDescription command, + List previousColumns, + List> previousOutput, + List columns, + List> output + ) { + return VALIDATION_OK; + } + }; + + EsqlQueryGenerator.generatePipeline(3, gen, schema, exec); + if (exec.previousCommands().size() > 1) { + String previousCmd = exec.previousCommands() + .stream() + .skip(1) + .map(CommandDescription::commandString) + .collect(Collectors.joining(" ")); + forkCmd.append(previousCmd); + } + + forkCmd.append(")"); + } + forkCmd.append(" | WHERE _fork == \"fork").append(branchToRetain).append("\" | DROP _fork"); - return new CommandDescription(FORK, this, cmd, Map.of()); + // System.out.println("Generated fork command: " + forkCmd); + return new CommandDescription(FORK, this, forkCmd.toString(), Map.of()); } @Override diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/WhereGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/WhereGenerator.java index 9bba468de0412..28d9f563896d9 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/WhereGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/WhereGenerator.java @@ -21,20 +21,15 @@ public class WhereGenerator implements CommandGenerator { public static final String WHERE = "where"; public static final CommandGenerator INSTANCE = new WhereGenerator(); - @Override - public CommandDescription generate( - List previousCommands, - List previousOutput, - QuerySchema schema - ) { + public static String randomExpression(final int nConditions, List previousOutput) { // TODO more complex conditions - StringBuilder result = new StringBuilder(" | where "); - int nConditions = randomIntBetween(1, 5); + var result = new StringBuilder(); + for (int i = 0; i < nConditions; i++) { String exp = EsqlQueryGenerator.booleanExpression(previousOutput); if (exp == null) { - // cannot generate expressions, just skip - return EMPTY_DESCRIPTION; + // Cannot generate expressions, just skip. + return null; } if (i > 0) { result.append(randomBoolean() ? " AND " : " OR "); @@ -45,8 +40,20 @@ public CommandDescription generate( result.append(exp); } - String cmd = result.toString(); - return new CommandDescription(WHERE, this, cmd, Map.of()); + return result.toString(); + } + + @Override + public CommandDescription generate( + List previousCommands, + List previousOutput, + QuerySchema schema + ) { + String expression = randomExpression(randomIntBetween(1, 5), previousOutput); + if (expression == null) { + return EMPTY_DESCRIPTION; + } + return new CommandDescription(WHERE, this, " | where " + expression, Map.of()); } @Override diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec index 9bfb08eb82b45..2aa6189a957ec 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec @@ -661,3 +661,104 @@ from * author.keyword:keyword|book_no:keyword|scalerank:integer|street:keyword|bytes_in:ul|@timestamp:unsupported|abbrev:keyword|city_location:geo_point|distance:double|description:unsupported|birth_date:date|language_code:integer|intersects:boolean|client_ip:unsupported|event_duration:long|version:version|language_name:keyword Fyodor Dostoevsky |1211 |null |null |null |null |null |null |null |null |null |null |null |null |null |null |null ; + + +statsAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1", "Connected to 10.1.0.2") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| STATS messages = count_distinct(message) BY language_name +; + +messages:long | language_name:keyword +2 | English +; + + +enrichAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +coordinatorEnrichAfterRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _coordinator:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +doubleRemoteEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _remote:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +enrichAfterCoordinatorEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _coordinator:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; + + +doubleCoordinatorEnrich +required_capability: enrich_load + +FROM sample_data +| KEEP message +| WHERE message IN ("Connected to 10.1.0.1") +| EVAL language_code = "1" +| ENRICH _coordinator:languages_policy ON language_code +| RENAME language_name AS first_language_name +| ENRICH _coordinator:languages_policy ON language_code +; + +message:keyword | language_code:keyword | first_language_name:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | English | English +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec index bdf0413a03d02..c71bf34cafd1a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec @@ -4773,3 +4773,101 @@ FROM sample_data_ts_nanos 2023-10-23T12:27:28.948123456Z | 172.21.2.113 | 2764889 | Connected to 10.1.0.2 2023-10-23T12:15:03.360123456Z | 172.21.2.162 | 3450233 | Connected to 10.1.0.3 ; + +############################################### +# LOOKUP JOIN and ENRICH +############################################### + +enrichAfterLookupJoin +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| LOOKUP JOIN message_types_lookup ON message +| ENRICH languages_policy ON language_code +; + +message:keyword | language_code:keyword | type:keyword | language_name:keyword +Connected to 10.1.0.1 | 1 | Success | English +; + + +lookupJoinAfterEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| ENRICH languages_policy ON language_code +| LOOKUP JOIN message_types_lookup ON message +; + +message:keyword | language_code:keyword | language_name:keyword | type:keyword +Connected to 10.1.0.1 | 1 | English | Success +; + + +lookupJoinAfterRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| ENRICH _remote:languages_policy ON language_code +| LOOKUP JOIN message_types_lookup ON message +; + +message:keyword | language_code:keyword | language_name:keyword | type:keyword +Connected to 10.1.0.1 | 1 | English | Success +; + + +lookupJoinAfterLimitAndRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| LIMIT 1 +| ENRICH _remote:languages_policy ON language_code +| EVAL enrich_language_name = language_name, language_code = language_code::integer +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| KEEP message, enrich_language_name, language_name, country.keyword +| SORT language_name, country.keyword +; + +message:keyword | enrich_language_name:keyword | language_name:keyword | country.keyword:keyword +Connected to 10.1.0.1 | English | English | Canada +Connected to 10.1.0.1 | English | English | United States of America +Connected to 10.1.0.1 | English | English | null +Connected to 10.1.0.1 | English | null | United Kingdom +; + + +lookupJoinAfterTopNAndRemoteEnrich +required_capability: join_lookup_v12 + +FROM sample_data +| KEEP message +| WHERE message == "Connected to 10.1.0.1" +| EVAL language_code = "1" +| SORT message +| LIMIT 1 +| ENRICH _remote:languages_policy ON language_code +| EVAL enrich_language_name = language_name, language_code = language_code::integer +| LOOKUP JOIN languages_lookup_non_unique_key ON language_code +| KEEP message, enrich_language_name, language_name, country.keyword +| SORT language_name, country.keyword +; + +message:keyword | enrich_language_name:keyword | language_name:keyword | country.keyword:keyword +Connected to 10.1.0.1 | English | English | Canada +Connected to 10.1.0.1 | English | English | United States of America +Connected to 10.1.0.1 | English | English | null +Connected to 10.1.0.1 | English | null | United Kingdom +; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java index 1d63a2bcf5373..e25cb82f29851 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java @@ -60,6 +60,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexOperator; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; @@ -198,6 +199,7 @@ private void runLookup(DataType keyType, PopulateIndices populateIndices) throws false // no scoring ); ValuesSourceReaderOperator.Factory reader = new ValuesSourceReaderOperator.Factory( + PhysicalSettings.VALUES_LOADING_JUMBO_SIZE.getDefault(Settings.EMPTY), List.of( new ValuesSourceReaderOperator.FieldInfo( "key", diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java index 4a54861df3d3a..f7833b917b746 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.Build; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator; import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator; @@ -56,7 +57,7 @@ public void testEmpty() { run("TS empty_index | LIMIT 1").close(); } - record Doc(String host, String cluster, long timestamp, int requestCount, double cpu) {} + record Doc(String host, String cluster, long timestamp, int requestCount, double cpu, ByteSizeValue memory) {} final List docs = new ArrayList<>(); @@ -84,7 +85,6 @@ static Double computeRate(List values) { @Before public void populateIndex() { - // this can be expensive, do one Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); client().admin() .indices() @@ -99,6 +99,8 @@ public void populateIndex() { "type=keyword,time_series_dimension=true", "cpu", "type=double,time_series_metric=gauge", + "memory", + "type=long,time_series_metric=gauge", "request_count", "type=integer,time_series_metric=counter" ) @@ -123,7 +125,8 @@ public void populateIndex() { } }); int cpu = randomIntBetween(0, 100); - docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, cpu)); + ByteSizeValue memory = ByteSizeValue.ofBytes(randomIntBetween(1024, 1024 * 1024)); + docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, cpu, memory)); } } Randomness.shuffle(docs); @@ -138,6 +141,8 @@ public void populateIndex() { doc.cluster, "cpu", doc.cpu, + "memory", + doc.memory.getBytes(), "request_count", doc.requestCount ) @@ -417,6 +422,63 @@ public void testIndexMode() { assertThat(failure.getMessage(), containsString("Unknown index [hosts-old]")); } + public void testFieldDoesNotExist() { + // the old-hosts index doesn't have the cpu field + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); + client().admin() + .indices() + .prepareCreate("old-hosts") + .setSettings(settings) + .setMapping( + "@timestamp", + "type=date", + "host", + "type=keyword,time_series_dimension=true", + "cluster", + "type=keyword,time_series_dimension=true", + "memory", + "type=long,time_series_metric=gauge", + "request_count", + "type=integer,time_series_metric=counter" + ) + .get(); + Randomness.shuffle(docs); + for (Doc doc : docs) { + client().prepareIndex("old-hosts") + .setSource( + "@timestamp", + doc.timestamp, + "host", + doc.host, + "cluster", + doc.cluster, + "memory", + doc.memory.getBytes(), + "request_count", + doc.requestCount + ) + .get(); + } + client().admin().indices().prepareRefresh("old-hosts").get(); + try (var resp1 = run(""" + TS hosts,old-hosts + | STATS sum(rate(request_count)), max(last_over_time(cpu)), max(last_over_time(memory)) BY cluster, host + | SORT cluster, host + | DROP `sum(rate(request_count))` + """)) { + try (var resp2 = run(""" + TS hosts + | STATS sum(rate(request_count)), max(last_over_time(cpu)), max(last_over_time(memory)) BY cluster, host + | SORT cluster, host + | DROP `sum(rate(request_count))` + """)) { + List> values1 = EsqlTestUtils.getValuesList(resp1); + List> values2 = EsqlTestUtils.getValuesList(resp2); + assertThat(values1, equalTo(values2)); + } + } + } + public void testProfile() { EsqlQueryRequest request = new EsqlQueryRequest(); request.profile(true); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 5c97879dd6a6d..733f0cabb2e22 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1276,11 +1276,6 @@ public enum Cap { */ NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES, - /** - * Fail if all shards fail - */ - FAIL_IF_ALL_SHARDS_FAIL(Build.current().isSnapshot()), - /** * Cosine vector similarity function */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java index 6d5630b0e6581..dd305f09c12dc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/AbstractLookupService.java @@ -450,6 +450,7 @@ private static Operator extractFieldsOperator( } return new ValuesSourceReaderOperator( driverContext.blockFactory(), + Long.MAX_VALUE, fields, List.of( new ValuesSourceReaderOperator.ShardContext( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 620d86650abf4..414e1f372ea3f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -21,6 +21,8 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; import org.elasticsearch.xpack.esql.plugin.TransportActionServices; import org.elasticsearch.xpack.esql.querylog.EsqlQueryLog; @@ -85,6 +87,7 @@ public void esql( indexResolver, enrichPolicyResolver, preAnalyzer, + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)), mapper, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java index b5ade85db06e0..e03af271b30d7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/regex/RLike.java @@ -44,7 +44,7 @@ Matching special characters (eg. `.`, `*`, `(`...) will require escaping. <> ```{applies_to} - stack: ga 9.1 + stack: ga 9.2 serverless: ga ``` diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index 39f37f952ae02..3749aef7488ad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -88,7 +88,7 @@ private static Batch localOperators() { public LogicalPlan localOptimize(LogicalPlan plan) { LogicalPlan optimized = execute(plan); - Failures failures = verifier.verify(optimized, true); + Failures failures = verifier.verify(optimized, true, plan.output()); if (failures.hasFailures()) { throw new VerificationException(failures); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 836eab9bb9590..af36963ac54a3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.EnableSpatialDistancePushdown; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.InsertFieldExtraction; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.ParallelizeTimeSeriesSource; @@ -42,15 +43,15 @@ public LocalPhysicalPlanOptimizer(LocalPhysicalOptimizerContext context) { } public PhysicalPlan localOptimize(PhysicalPlan plan) { - return verify(execute(plan)); + return verify(execute(plan), plan.output()); } - PhysicalPlan verify(PhysicalPlan plan) { - Failures failures = verifier.verify(plan, true); + PhysicalPlan verify(PhysicalPlan optimizedPlan, List expectedOutputAttributes) { + Failures failures = verifier.verify(optimizedPlan, true, expectedOutputAttributes); if (failures.hasFailures()) { throw new VerificationException(failures); } - return plan; + return optimizedPlan; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index ca117bfff34d6..dac533f872022 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -113,7 +113,7 @@ public LogicalPlanOptimizer(LogicalOptimizerContext optimizerContext) { public LogicalPlan optimize(LogicalPlan verified) { var optimized = execute(verified); - Failures failures = verifier.verify(optimized, false); + Failures failures = verifier.verify(optimized, false, verified.output()); if (failures.hasFailures()) { throw new VerificationException(failures); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java new file mode 100644 index 0000000000000..fdd8e1318f636 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +/** + * The class is responsible for invoking any steps that need to be applied to the logical plan, + * before this is being optimized. + *

+ * This is useful, especially if you need to execute some async tasks before the plan is optimized. + *

+ */ +public class LogicalPlanPreOptimizer { + + private final LogicalPreOptimizerContext preOptimizerContext; + + public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { + this.preOptimizerContext = preOptimizerContext; + } + + /** + * Pre-optimize a logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization is complete + */ + public void preOptimize(LogicalPlan plan, ActionListener listener) { + if (plan.analyzed() == false) { + listener.onFailure(new IllegalStateException("Expected analyzed plan")); + return; + } + + doPreOptimize(plan, listener.delegateFailureAndWrap((l, preOptimized) -> { + preOptimized.setPreOptimized(); + listener.onResponse(preOptimized); + })); + } + + private void doPreOptimize(LogicalPlan plan, ActionListener listener) { + // this is where we will be executing async tasks + listener.onResponse(plan); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java new file mode 100644 index 0000000000000..d082bd56fc46d --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.xpack.esql.core.expression.FoldContext; + +import java.util.Objects; + +public class LogicalPreOptimizerContext { + + private final FoldContext foldCtx; + + public LogicalPreOptimizerContext(FoldContext foldCtx) { + this.foldCtx = foldCtx; + } + + public FoldContext foldCtx() { + return foldCtx; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (LogicalPreOptimizerContext) obj; + return this.foldCtx.equals(that.foldCtx); + } + + @Override + public int hashCode() { + return Objects.hash(foldCtx); + } + + @Override + public String toString() { + return "LogicalPreOptimizerContext[foldCtx=" + foldCtx + ']'; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 6751ae4cd2d80..4a04b46be295a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -13,27 +13,28 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -public final class LogicalVerifier { +public final class LogicalVerifier extends PostOptimizationPhasePlanVerifier { public static final LogicalVerifier INSTANCE = new LogicalVerifier(); private LogicalVerifier() {} - /** Verifies the optimized logical plan. */ - public Failures verify(LogicalPlan plan, boolean skipRemoteEnrichVerification) { - Failures failures = new Failures(); - Failures dependencyFailures = new Failures(); - + @Override + boolean skipVerification(LogicalPlan optimizedPlan, boolean skipRemoteEnrichVerification) { if (skipRemoteEnrichVerification) { // AwaitsFix https://github.com/elastic/elasticsearch/issues/118531 - var enriches = plan.collectFirstChildren(Enrich.class::isInstance); + var enriches = optimizedPlan.collectFirstChildren(Enrich.class::isInstance); if (enriches.isEmpty() == false && ((Enrich) enriches.get(0)).mode() == Enrich.Mode.REMOTE) { - return failures; + return true; } } + return false; + } - plan.forEachUp(p -> { - PlanConsistencyChecker.checkPlan(p, dependencyFailures); + @Override + void checkPlanConsistency(LogicalPlan optimizedPlan, Failures failures, Failures depFailures) { + optimizedPlan.forEachUp(p -> { + PlanConsistencyChecker.checkPlan(p, depFailures); if (failures.hasFailures() == false) { if (p instanceof PostOptimizationVerificationAware pova) { @@ -46,11 +47,5 @@ public Failures verify(LogicalPlan plan, boolean skipRemoteEnrichVerification) { }); } }); - - if (dependencyFailures.hasFailures()) { - throw new IllegalStateException(dependencyFailures.toString()); - } - - return failures; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java index ab6bea5ffddac..6d60c547f47d6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; import org.elasticsearch.xpack.esql.plan.physical.FragmentExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -34,15 +35,15 @@ public PhysicalPlanOptimizer(PhysicalOptimizerContext context) { } public PhysicalPlan optimize(PhysicalPlan plan) { - return verify(execute(plan)); + return verify(execute(plan), plan.output()); } - PhysicalPlan verify(PhysicalPlan plan) { - Failures failures = verifier.verify(plan, false); + PhysicalPlan verify(PhysicalPlan optimizedPlan, List expectedOutputAttributes) { + Failures failures = verifier.verify(optimizedPlan, false, expectedOutputAttributes); if (failures.hasFailures()) { throw new VerificationException(failures); } - return plan; + return optimizedPlan; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java index 607aa11575bcb..781a8f5263c1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java @@ -20,26 +20,27 @@ import static org.elasticsearch.xpack.esql.common.Failure.fail; /** Physical plan verifier. */ -public final class PhysicalVerifier { +public final class PhysicalVerifier extends PostOptimizationPhasePlanVerifier { public static final PhysicalVerifier INSTANCE = new PhysicalVerifier(); private PhysicalVerifier() {} - /** Verifies the physical plan. */ - public Failures verify(PhysicalPlan plan, boolean skipRemoteEnrichVerification) { - Failures failures = new Failures(); - Failures depFailures = new Failures(); - + @Override + boolean skipVerification(PhysicalPlan optimizedPlan, boolean skipRemoteEnrichVerification) { if (skipRemoteEnrichVerification) { // AwaitsFix https://github.com/elastic/elasticsearch/issues/118531 - var enriches = plan.collectFirstChildren(EnrichExec.class::isInstance); + var enriches = optimizedPlan.collectFirstChildren(EnrichExec.class::isInstance); if (enriches.isEmpty() == false && ((EnrichExec) enriches.get(0)).mode() == Enrich.Mode.REMOTE) { - return failures; + return true; } } + return false; + } - plan.forEachDown(p -> { + @Override + void checkPlanConsistency(PhysicalPlan optimizedPlan, Failures failures, Failures depFailures) { + optimizedPlan.forEachDown(p -> { if (p instanceof FieldExtractExec fieldExtractExec) { Attribute sourceAttribute = fieldExtractExec.sourceAttribute(); if (sourceAttribute == null) { @@ -66,11 +67,5 @@ public Failures verify(PhysicalPlan plan, boolean skipRemoteEnrichVerification) }); } }); - - if (depFailures.hasFailures()) { - throw new IllegalStateException(depFailures.toString()); - } - - return failures; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java new file mode 100644 index 0000000000000..647dafe649984 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PostOptimizationPhasePlanVerifier.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; +import org.elasticsearch.xpack.esql.plan.QueryPlan; +import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; + +import java.util.List; + +import static org.elasticsearch.index.IndexMode.LOOKUP; +import static org.elasticsearch.xpack.esql.common.Failure.fail; +import static org.elasticsearch.xpack.esql.core.expression.Attribute.dataTypeEquals; + +/** + * Verifies the plan after optimization. + * This is invoked immediately after a Plan Optimizer completes its work. + * Currently, it is called after LogicalPlanOptimizer, PhysicalPlanOptimizer, + * LocalLogicalPlanOptimizer, and LocalPhysicalPlanOptimizer. + * Note: Logical and Physical optimizers may override methods in this class to perform different checks. + */ +public abstract class PostOptimizationPhasePlanVerifier

> { + + /** Verifies the optimized plan */ + public Failures verify(P optimizedPlan, boolean skipRemoteEnrichVerification, List expectedOutputAttributes) { + Failures failures = new Failures(); + Failures depFailures = new Failures(); + if (skipVerification(optimizedPlan, skipRemoteEnrichVerification)) { + return failures; + } + + checkPlanConsistency(optimizedPlan, failures, depFailures); + + verifyOutputNotChanged(optimizedPlan, expectedOutputAttributes, failures); + + if (depFailures.hasFailures()) { + throw new IllegalStateException(depFailures.toString()); + } + + return failures; + } + + abstract boolean skipVerification(P optimizedPlan, boolean skipRemoteEnrichVerification); + + abstract void checkPlanConsistency(P optimizedPlan, Failures failures, Failures depFailures); + + private static void verifyOutputNotChanged(QueryPlan optimizedPlan, List expectedOutputAttributes, Failures failures) { + if (dataTypeEquals(expectedOutputAttributes, optimizedPlan.output()) == false) { + // If the output level is empty we add a column called ProjectAwayColumns.ALL_FIELDS_PROJECTED + // We will ignore such cases for output verification + // TODO: this special casing is required due to https://github.com/elastic/elasticsearch/issues/121741, remove when fixed. + boolean hasProjectAwayColumns = optimizedPlan.output() + .stream() + .anyMatch(x -> x.name().equals(ProjectAwayColumns.ALL_FIELDS_PROJECTED)); + // LookupJoinExec represents the lookup index with EsSourceExec and this is turned into EsQueryExec by + // ReplaceSourceAttributes. Because InsertFieldExtractions doesn't apply to lookup indices, the + // right hand side will only have the EsQueryExec providing the _doc attribute and nothing else. + // We perform an optimizer run on every fragment. LookupJoinExec also contains such a fragment, + // and currently it only contains an EsQueryExec after optimization. + boolean hasLookupJoinExec = optimizedPlan instanceof EsQueryExec esQueryExec && esQueryExec.indexMode() == LOOKUP; + boolean ignoreError = hasProjectAwayColumns || hasLookupJoinExec; + if (ignoreError == false) { + failures.add( + fail( + optimizedPlan, + "Output has changed from [{}] to [{}]. ", + expectedOutputAttributes.toString(), + optimizedPlan.output().toString() + ) + ); + } + } + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java index 887fb039a14cb..189fc5e4c7415 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java @@ -36,6 +36,7 @@ * extraction. */ public class ProjectAwayColumns extends Rule { + public static String ALL_FIELDS_PROJECTED = ""; @Override public PhysicalPlan apply(PhysicalPlan plan) { @@ -94,7 +95,7 @@ public PhysicalPlan apply(PhysicalPlan plan) { // add a synthetic field (so it doesn't clash with the user defined one) to return a constant // to avoid the block from being trimmed if (output.isEmpty()) { - var alias = new Alias(logicalFragment.source(), "", Literal.NULL, null, true); + var alias = new Alias(logicalFragment.source(), ALL_FIELDS_PROJECTED, Literal.NULL, null, true); List fields = singletonList(alias); logicalFragment = new Eval(logicalFragment.source(), logicalFragment, fields); output = Expressions.asAttributes(fields); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java index 11e9a57064e5b..7307fd8efad39 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.GeneratingPlan; +import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; import java.io.IOException; import java.util.ArrayList; @@ -295,23 +296,43 @@ public BiConsumer postAnalysisPlanVerification() { * retaining the originating cluster and restructing pages for routing, which might be complicated. */ private static void checkRemoteEnrich(LogicalPlan plan, Failures failures) { - boolean[] agg = { false }; - boolean[] enrichCoord = { false }; + // First look for remote ENRICH, and then look at its children. Going over the whole plan once is trickier as remote ENRICHs can be + // in separate FORK branches which are valid by themselves. + plan.forEachUp(Enrich.class, enrich -> checkForPlansForbiddenBeforeRemoteEnrich(enrich, failures)); + } + + /** + * For a given remote {@link Enrich}, check if there are any forbidden plans upstream. + */ + private static void checkForPlansForbiddenBeforeRemoteEnrich(Enrich enrich, Failures failures) { + if (enrich.mode != Mode.REMOTE) { + return; + } + + // TODO: shouldn't we also include FORK? Everything downstream from FORK should be coordinator-only. + // https://github.com/elastic/elasticsearch/issues/131445 + boolean[] aggregate = { false }; + boolean[] coordinatorOnlyEnrich = { false }; + boolean[] lookupJoin = { false }; - plan.forEachUp(UnaryPlan.class, u -> { + enrich.forEachUp(LogicalPlan.class, u -> { if (u instanceof Aggregate) { - agg[0] = true; - } else if (u instanceof Enrich enrich && enrich.mode() == Enrich.Mode.COORDINATOR) { - enrichCoord[0] = true; - } - if (u instanceof Enrich enrich && enrich.mode() == Enrich.Mode.REMOTE) { - if (agg[0]) { - failures.add(fail(enrich, "ENRICH with remote policy can't be executed after STATS")); - } - if (enrichCoord[0]) { - failures.add(fail(enrich, "ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); - } + aggregate[0] = true; + } else if (u instanceof Enrich upstreamEnrich && upstreamEnrich.mode() == Enrich.Mode.COORDINATOR) { + coordinatorOnlyEnrich[0] = true; + } else if (u instanceof LookupJoin) { + lookupJoin[0] = true; } }); + + if (aggregate[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after STATS")); + } + if (coordinatorOnlyEnrich[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); + } + if (lookupJoin[0]) { + failures.add(fail(enrich, "ENRICH with remote policy can't be executed after LOOKUP JOIN")); + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java index ac4baea8bc853..762b22389ae24 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java @@ -25,6 +25,7 @@ public enum Stage { PARSED, PRE_ANALYZED, ANALYZED, + PRE_OPTIMIZED, OPTIMIZED; } @@ -52,6 +53,14 @@ public void setAnalyzed() { stage = Stage.ANALYZED; } + public boolean preOptimized() { + return stage.ordinal() >= Stage.PRE_OPTIMIZED.ordinal(); + } + + public void setPreOptimized() { + stage = Stage.PRE_OPTIMIZED; + } + public boolean optimized() { return stage.ordinal() >= Stage.OPTIMIZED.ordinal(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 7abe84d99e5f2..e0b570267899b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -22,7 +22,6 @@ import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneCountOperator; import org.elasticsearch.compute.lucene.LuceneOperator; import org.elasticsearch.compute.lucene.LuceneSliceQueue; @@ -140,17 +139,17 @@ public boolean hasReferences() { } private final List shardContexts; - private final DataPartitioning defaultDataPartitioning; + private final PhysicalSettings physicalSettings; public EsPhysicalOperationProviders( FoldContext foldContext, List shardContexts, AnalysisRegistry analysisRegistry, - DataPartitioning defaultDataPartitioning + PhysicalSettings physicalSettings ) { super(foldContext, analysisRegistry); this.shardContexts = shardContexts; - this.defaultDataPartitioning = defaultDataPartitioning; + this.physicalSettings = physicalSettings; } @Override @@ -175,7 +174,10 @@ public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fi // TODO: consolidate with ValuesSourceReaderOperator return source.with(new TimeSeriesExtractFieldOperator.Factory(fields, shardContexts), layout.build()); } else { - return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build()); + return source.with( + new ValuesSourceReaderOperator.Factory(physicalSettings.valuesLoadingJumboSize(), fields, readers, docChannel), + layout.build() + ); } } @@ -278,7 +280,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, luceneFactory = new LuceneTopNSourceOperator.Factory( shardContexts, querySupplier(esQueryExec.query()), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), context.pageSize(rowEstimatedSize), limit, @@ -289,7 +291,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, luceneFactory = new LuceneSourceOperator.Factory( shardContexts, querySupplier(esQueryExec.query()), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), context.pageSize(rowEstimatedSize), limit, @@ -341,7 +343,7 @@ public LuceneCountOperator.Factory countSource(LocalExecutionPlannerContext cont return new LuceneCountOperator.Factory( shardContexts, querySupplier(queryBuilder), - context.queryPragmas().dataPartitioning(defaultDataPartitioning), + context.queryPragmas().dataPartitioning(physicalSettings.defaultDataPartitioning()), context.queryPragmas().taskConcurrency(), limit == null ? NO_LIMIT : (Integer) limit.fold(context.foldCtx()) ); @@ -530,8 +532,8 @@ public ColumnAtATimeReader columnAtATimeReader(LeafReaderContext context) throws } return new ColumnAtATimeReader() { @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - Block block = reader.read(factory, docs); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + Block block = reader.read(factory, docs, offset); return typeConverter.convert((org.elasticsearch.compute.data.Block) block); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java new file mode 100644 index 0000000000000..4276eeaf39f9b --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PhysicalSettings.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.planner; + +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.MemorySizeValue; +import org.elasticsearch.compute.lucene.DataPartitioning; +import org.elasticsearch.monitor.jvm.JvmInfo; + +/** + * Values for cluster level settings used in physical planning. + */ +public class PhysicalSettings { + public static final Setting DEFAULT_DATA_PARTITIONING = Setting.enumSetting( + DataPartitioning.class, + "esql.default_data_partitioning", + DataPartitioning.AUTO, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting VALUES_LOADING_JUMBO_SIZE = new Setting<>("esql.values_loading_jumbo_size", settings -> { + long proportional = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() / 1024; + return ByteSizeValue.ofBytes(Math.max(proportional, ByteSizeValue.ofMb(1).getBytes())).getStringRep(); + }, + s -> MemorySizeValue.parseBytesSizeValueOrHeapRatio(s, "esql.values_loading_jumbo_size"), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private volatile DataPartitioning defaultDataPartitioning; + private volatile ByteSizeValue valuesLoadingJumboSize; + + /** + * Ctor for prod that listens for updates from the {@link ClusterService}. + */ + public PhysicalSettings(ClusterService clusterService) { + clusterService.getClusterSettings().initializeAndWatch(DEFAULT_DATA_PARTITIONING, v -> this.defaultDataPartitioning = v); + clusterService.getClusterSettings().initializeAndWatch(VALUES_LOADING_JUMBO_SIZE, v -> this.valuesLoadingJumboSize = v); + } + + /** + * Ctor for testing. + */ + public PhysicalSettings(DataPartitioning defaultDataPartitioning, ByteSizeValue valuesLoadingJumboSize) { + this.defaultDataPartitioning = defaultDataPartitioning; + this.valuesLoadingJumboSize = valuesLoadingJumboSize; + } + + public DataPartitioning defaultDataPartitioning() { + return defaultDataPartitioning; + } + + public ByteSizeValue valuesLoadingJumboSize() { + return valuesLoadingJumboSize; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java index 4d1d65d63932d..bf6f0b89efbec 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java @@ -87,7 +87,7 @@ private PhysicalPlan mapUnary(UnaryPlan unary) { PhysicalPlan mappedChild = map(unary.child()); // - // TODO - this is hard to follow and needs reworking + // TODO - this is hard to follow, causes bugs and needs reworking // https://github.com/elastic/elasticsearch/issues/115897 // if (unary instanceof Enrich enrich && enrich.mode() == Enrich.Mode.REMOTE) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 1ae6a4634644d..d12799ab8b170 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.operator.DriverCompletionInfo; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.FailureCollector; @@ -47,7 +46,6 @@ import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.action.EsqlQueryAction; import org.elasticsearch.xpack.esql.core.expression.Attribute; @@ -61,6 +59,7 @@ import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.session.EsqlCCSUtils; @@ -139,8 +138,7 @@ public class ComputeService { private final DataNodeComputeHandler dataNodeComputeHandler; private final ClusterComputeHandler clusterComputeHandler; private final ExchangeService exchangeService; - - private volatile DataPartitioning defaultDataPartitioning; + private final PhysicalSettings physicalSettings; @SuppressWarnings("this-escape") public ComputeService( @@ -179,7 +177,7 @@ public ComputeService( esqlExecutor, dataNodeComputeHandler ); - clusterService.getClusterSettings().initializeAndWatch(EsqlPlugin.DEFAULT_DATA_PARTITIONING, v -> this.defaultDataPartitioning = v); + this.physicalSettings = new PhysicalSettings(clusterService); } public void execute( @@ -549,9 +547,6 @@ private static void updateExecutionInfoAfterCoordinatorOnlyQuery(EsqlExecutionIn * which doesn't consider the failures from the remote clusters when skip_unavailable is true. */ static void failIfAllShardsFailed(EsqlExecutionInfo execInfo, List finalResults) { - if (EsqlCapabilities.Cap.FAIL_IF_ALL_SHARDS_FAIL.isEnabled() == false) { - return; - } // do not fail if any final result has results if (finalResults.stream().anyMatch(p -> p.getPositionCount() > 0)) { return; @@ -612,7 +607,7 @@ public SourceProvider createSourceProvider() { context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis(), - defaultDataPartitioning + physicalSettings ); try { LocalExecutionPlanner planner = new LocalExecutionPlanner( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java index 7cba5eeb56278..776874fbf90f6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java @@ -21,7 +21,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockFactoryProvider; -import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneOperator; import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator; import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperatorStatus; @@ -33,6 +32,7 @@ import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.LimitOperator; import org.elasticsearch.compute.operator.MvExpandOperator; +import org.elasticsearch.compute.operator.SampleOperator; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator; import org.elasticsearch.compute.operator.exchange.ExchangeSourceOperator; @@ -75,6 +75,7 @@ import org.elasticsearch.xpack.esql.io.stream.ExpressionQueryBuilder; import org.elasticsearch.xpack.esql.io.stream.PlanStreamWrapperQueryBuilder; import org.elasticsearch.xpack.esql.plan.PlanWritables; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.querylog.EsqlQueryLog; import org.elasticsearch.xpack.esql.session.IndexResolver; @@ -160,14 +161,6 @@ public class EsqlPlugin extends Plugin implements ActionPlugin, ExtensiblePlugin Setting.Property.Dynamic ); - public static final Setting DEFAULT_DATA_PARTITIONING = Setting.enumSetting( - DataPartitioning.class, - "esql.default_data_partitioning", - DataPartitioning.AUTO, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - /** * Tuning parameter for deciding when to use the "merge" stored field loader. * Think of it as "how similar to a sequential block of documents do I have to @@ -263,7 +256,8 @@ public List> getSettings() { ESQL_QUERYLOG_THRESHOLD_INFO_SETTING, ESQL_QUERYLOG_THRESHOLD_WARN_SETTING, ESQL_QUERYLOG_INCLUDE_USER_SETTING, - DEFAULT_DATA_PARTITIONING, + PhysicalSettings.DEFAULT_DATA_PARTITIONING, + PhysicalSettings.VALUES_LOADING_JUMBO_SIZE, STORED_FIELDS_SEQUENTIAL_PROPORTION, EsqlFlags.ESQL_STRING_LIKE_ON_INDEX ); @@ -328,6 +322,7 @@ public List getNamedWriteables() { entries.add(AsyncOperator.Status.ENTRY); entries.add(EnrichLookupOperator.Status.ENTRY); entries.add(LookupFromIndexOperator.Status.ENTRY); + entries.add(SampleOperator.Status.ENTRY); entries.add(ExpressionQueryBuilder.ENTRY); entries.add(PlanStreamWrapperQueryBuilder.ENTRY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index 345bf3b8767ef..bdd0e382c3fd3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import java.io.IOException; import java.util.Locale; @@ -45,7 +46,7 @@ public final class QueryPragmas implements Writeable { * the enum {@link DataPartitioning} which has more documentation. Not an * {@link Setting#enumSetting} because those can't have {@code null} defaults. * {@code null} here means "use the default from the cluster setting - * named {@link EsqlPlugin#DEFAULT_DATA_PARTITIONING}." + * named {@link PhysicalSettings#DEFAULT_DATA_PARTITIONING}." */ public static final Setting DATA_PARTITIONING = Setting.simpleString("data_partitioning"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java index 7df5a029d724e..e4a5423d35f8e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/rule/RuleExecutor.java @@ -175,7 +175,7 @@ protected final ExecutionInfo executeWithInfo(TreeType plan) { if (tf.hasChanged()) { hasChanged = true; if (log.isTraceEnabled()) { - log.trace("Rule {} applied\n{}", rule, NodeUtils.diffString(tf.before, tf.after)); + log.trace("Rule {} applied with change\n{}", rule, NodeUtils.diffString(tf.before, tf.after)); } } else { if (log.isTraceEnabled()) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index df18051bcf721..4590cb8c6d3bd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -71,6 +71,7 @@ import org.elasticsearch.xpack.esql.inference.InferenceResolution; import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; import org.elasticsearch.xpack.esql.parser.EsqlParser; @@ -147,6 +148,7 @@ public interface PlanRunner { private final PreAnalyzer preAnalyzer; private final Verifier verifier; private final EsqlFunctionRegistry functionRegistry; + private final LogicalPlanPreOptimizer logicalPlanPreOptimizer; private final LogicalPlanOptimizer logicalPlanOptimizer; private final PreMapper preMapper; @@ -168,6 +170,7 @@ public EsqlSession( IndexResolver indexResolver, EnrichPolicyResolver enrichPolicyResolver, PreAnalyzer preAnalyzer, + LogicalPlanPreOptimizer logicalPlanPreOptimizer, EsqlFunctionRegistry functionRegistry, LogicalPlanOptimizer logicalPlanOptimizer, Mapper mapper, @@ -181,6 +184,7 @@ public EsqlSession( this.indexResolver = indexResolver; this.enrichPolicyResolver = enrichPolicyResolver; this.preAnalyzer = preAnalyzer; + this.logicalPlanPreOptimizer = logicalPlanPreOptimizer; this.verifier = verifier; this.functionRegistry = functionRegistry; this.mapper = mapper; @@ -212,11 +216,10 @@ public void execute(EsqlQueryRequest request, EsqlExecutionInfo executionInfo, P analyzedPlan(parsed, executionInfo, request.filter(), new EsqlCCSUtils.CssPartialErrorsActionListener(executionInfo, listener) { @Override public void onResponse(LogicalPlan analyzedPlan) { - LogicalPlan optimizedPlan = optimizedPlan(analyzedPlan); - preMapper.preMapper( - optimizedPlan, - listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, p, l)) - ); + SubscribableListener.newForked(l -> preOptimizedPlan(analyzedPlan, l)) + .andThen((l, p) -> preMapper.preMapper(optimizedPlan(p), l)) + .andThen((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, p, l)) + .addListener(listener); } }); } @@ -1043,14 +1046,18 @@ private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQu } public LogicalPlan optimizedPlan(LogicalPlan logicalPlan) { - if (logicalPlan.analyzed() == false) { - throw new IllegalStateException("Expected analyzed plan"); + if (logicalPlan.preOptimized() == false) { + throw new IllegalStateException("Expected pre-optimized plan"); } var plan = logicalPlanOptimizer.optimize(logicalPlan); LOGGER.debug("Optimized logicalPlan plan:\n{}", plan); return plan; } + public void preOptimizedPlan(LogicalPlan logicalPlan, ActionListener listener) { + logicalPlanPreOptimizer.preOptimize(logicalPlan, listener); + } + public PhysicalPlan physicalPlan(LogicalPlan optimizedPlan) { if (optimizedPlan.optimized() == false) { throw new IllegalStateException("Expected optimized plan"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 62280a38ba608..b38b3089823d5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -73,6 +73,8 @@ import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.TestLocalPhysicalPlanOptimizer; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -582,6 +584,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { null, null, null, + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)), mapper, @@ -594,24 +597,27 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { PlainActionFuture listener = new PlainActionFuture<>(); - session.executeOptimizedPlan( - new EsqlQueryRequest(), - new EsqlExecutionInfo(randomBoolean()), - planRunner(bigArrays, foldCtx, physicalOperationProviders), - session.optimizedPlan(analyzed), - listener.delegateFailureAndWrap( - // Wrap so we can capture the warnings in the calling thread - (next, result) -> next.onResponse( - new ActualResults( - result.schema().stream().map(Attribute::name).toList(), - result.schema().stream().map(a -> Type.asType(a.dataType().nameUpper())).toList(), - result.schema().stream().map(Attribute::dataType).toList(), - result.pages(), - threadPool.getThreadContext().getResponseHeaders() + session.preOptimizedPlan(analyzed, listener.delegateFailureAndWrap((l, preOptimized) -> { + session.executeOptimizedPlan( + new EsqlQueryRequest(), + new EsqlExecutionInfo(randomBoolean()), + planRunner(bigArrays, foldCtx, physicalOperationProviders), + session.optimizedPlan(preOptimized), + listener.delegateFailureAndWrap( + // Wrap so we can capture the warnings in the calling thread + (next, result) -> next.onResponse( + new ActualResults( + result.schema().stream().map(Attribute::name).toList(), + result.schema().stream().map(a -> Type.asType(a.dataType().nameUpper())).toList(), + result.schema().stream().map(Attribute::dataType).toList(), + result.pages(), + threadPool.getThreadContext().getResponseHeaders() + ) ) ) - ) - ); + ); + })); + return listener.get(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index cbb825ca9581b..fbfa18dccc477 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -36,9 +36,9 @@ import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.RANGE_TYPE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution; public final class AnalyzerTestUtils { @@ -61,27 +61,36 @@ public static Analyzer analyzer(IndexResolution indexResolution, Map lookupResolution, Verifier verifier) { + return analyzer(indexResolution, lookupResolution, defaultEnrichResolution(), verifier); + } + + public static Analyzer analyzer( + IndexResolution indexResolution, + Map lookupResolution, + EnrichResolution enrichResolution, + Verifier verifier + ) { + return analyzer(indexResolution, lookupResolution, enrichResolution, verifier, TEST_CFG); + } + + public static Analyzer analyzer( + IndexResolution indexResolution, + Map lookupResolution, + EnrichResolution enrichResolution, + Verifier verifier, + Configuration config + ) { return new Analyzer( new AnalyzerContext( - EsqlTestUtils.TEST_CFG, + config, new EsqlFunctionRegistry(), indexResolution, lookupResolution, - defaultEnrichResolution(), + enrichResolution, defaultInferenceResolution() ), verifier @@ -89,17 +98,7 @@ public static Analyzer analyzer(IndexResolution indexResolution, Map query("FROM test,remote:test | EVAL language_code = languages | LOOKUP JOIN languages_lookup ON language_code") ); assertThat(e.getMessage(), containsString("remote clusters are not supported with LOOKUP JOIN")); + } + public void testRemoteEnrichAfterLookupJoin() { + EnrichResolution enrichResolution = new EnrichResolution(); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.REMOTE, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + var analyzer = AnalyzerTestUtils.analyzer( + loadMapping("mapping-default.json", "test"), + defaultLookupResolution(), + enrichResolution, + TEST_VERIFIER + ); + + String lookupCommand = randomBoolean() ? "LOOKUP JOIN test_lookup ON languages" : "LOOKUP JOIN languages_lookup ON language_code"; + + query(Strings.format(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | %s + """, lookupCommand), analyzer); + + String err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | ENRICH _remote:languages ON language_code + """, lookupCommand), analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + + err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | ENRICH _remote:languages ON language_code + | %s + """, lookupCommand, lookupCommand), analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + + err = error(Strings.format(""" + FROM test + | EVAL language_code = languages + | %s + | EVAL x = 1 + | MV_EXPAND language_code + | ENRICH _remote:languages ON language_code + """, lookupCommand), analyzer); + assertThat(err, containsString("6:3: ENRICH with remote policy can't be executed after LOOKUP JOIN")); + } + + public void testRemoteEnrichAfterCoordinatorOnlyPlans() { + EnrichResolution enrichResolution = new EnrichResolution(); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.REMOTE, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + loadEnrichPolicyResolution( + enrichResolution, + Enrich.Mode.COORDINATOR, + MATCH_TYPE, + "languages", + "language_code", + "languages_idx", + "mapping-languages.json" + ); + var analyzer = AnalyzerTestUtils.analyzer( + loadMapping("mapping-default.json", "test"), + defaultLookupResolution(), + enrichResolution, + TEST_VERIFIER + ); + + query(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | STATS count(*) BY language_name + """, analyzer); + + String err = error(""" + FROM test + | EVAL language_code = languages + | STATS count(*) BY language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after STATS")); + + err = error(""" + FROM test + | EVAL language_code = languages + | STATS count(*) BY language_code + | EVAL x = 1 + | MV_EXPAND language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("6:3: ENRICH with remote policy can't be executed after STATS")); + + query(""" + FROM test + | EVAL language_code = languages + | ENRICH _remote:languages ON language_code + | ENRICH _coordinator:languages ON language_code + """, analyzer); + + err = error(""" + FROM test + | EVAL language_code = languages + | ENRICH _coordinator:languages ON language_code + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("4:3: ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); + + err = error(""" + FROM test + | EVAL language_code = languages + | ENRICH _coordinator:languages ON language_code + | EVAL x = 1 + | MV_EXPAND language_name + | DISSECT language_name "%{foo}" + | ENRICH _remote:languages ON language_code + """, analyzer); + assertThat(err, containsString("7:3: ENRICH with remote policy can't be executed after another ENRICH with coordinator policy")); } private void checkFullTextFunctionsInStats(String functionInvocation) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java index 346b1cafa02f4..74b2dffe2e4c4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/ScoreTests.java @@ -10,17 +10,17 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.FunctionName; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.junit.BeforeClass; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; -import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.hamcrest.Matchers.equalTo; @@ -28,6 +28,11 @@ @FunctionName("score") public class ScoreTests extends AbstractMatchFullTextFunctionTests { + @BeforeClass + public static void init() { + assumeTrue("can run this only when score() function is enabled", EsqlCapabilities.Cap.SCORE_FUNCTION.isEnabled()); + } + public ScoreTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -55,18 +60,4 @@ protected Expression build(Source source, List args) { return new Score(source, args.getFirst()); } - /** - * Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match. - */ - @Override - protected Expression serializeDeserializeExpression(Expression expression) { - Expression newExpression = serializeDeserialize( - expression, - PlanStreamOutput::writeNamedWriteable, - in -> in.readNamedWriteable(Expression.class), - testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work - ); - // Fields use synthetic sources, which can't be serialized. So we use the originals instead. - return newExpression.replaceChildren(expression.children()); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index f89c79079c876..0cf86378a0f70 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.core.expression.Alias; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; @@ -43,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.InferIsNotNull; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -59,10 +62,12 @@ import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.hamcrest.Matchers; import org.junit.BeforeClass; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -88,6 +93,9 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -780,7 +788,7 @@ public void testGroupingByMissingFields() { as(eval.child(), EsRelation.class); } - public void testPlanSanityCheck() throws Exception { + public void testVerifierOnMissingReferences() throws Exception { var plan = localPlan(""" from test | stats a = min(salary) by emp_no @@ -806,6 +814,103 @@ public void testPlanSanityCheck() throws Exception { assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [salary")); } + private LocalLogicalPlanOptimizer getCustomRulesLocalLogicalPlanOptimizer(List> batches) { + LocalLogicalOptimizerContext context = new LocalLogicalOptimizerContext( + EsqlTestUtils.TEST_CFG, + FoldContext.small(), + TEST_SEARCH_STATS + ); + LocalLogicalPlanOptimizer customOptimizer = new LocalLogicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return customOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + var plan = localPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(UP) { + + @Override + protected LogicalPlan rule(Aggregate plan, LocalLogicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new Eval(plan.source(), plan, List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral))); + } + return plan; + } + + } + ); + LocalLogicalPlanOptimizer customRulesLocalLogicalPlanOptimizer = getCustomRulesLocalLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalLogicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() { + var plan = localPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(DOWN) { + @Override + protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Limit limit = as(plan, Limit.class); + Limit newLimit = new Limit(plan.source(), limit.limit(), limit.child()) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + + } + ); + LocalLogicalPlanOptimizer customRulesLocalLogicalPlanOptimizer = getCustomRulesLocalLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalLogicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + private IsNotNull isNotNull(Expression field) { return new IsNotNull(EMPTY, field); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index a604e1d26d313..cd6371e4d4d5e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -34,12 +34,14 @@ import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.EsqlTestUtils.TestSearchStats; +import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.Verifier; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; @@ -100,6 +102,7 @@ import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.rule.Rule; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchContextStats; import org.elasticsearch.xpack.esql.stats.SearchStats; @@ -2389,20 +2392,119 @@ public void testVerifierOnMissingReferences() throws Exception { // We want to verify that the localOptimize detects the missing attribute. // However, it also throws an error in one of the rules before we get to the verifier. // So we use an implementation of LocalPhysicalPlanOptimizer that does not have any rules. + LocalPhysicalPlanOptimizer optimizerWithNoRules = getCustomRulesLocalPhysicalPlanOptimizer(List.of()); + Exception e = expectThrows(IllegalStateException.class, () -> optimizerWithNoRules.localOptimize(topNExec)); + assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [missing attr")); + } + + private LocalPhysicalPlanOptimizer getCustomRulesLocalPhysicalPlanOptimizer(List> batches) { LocalPhysicalOptimizerContext context = new LocalPhysicalOptimizerContext( new EsqlFlags(true), config, FoldContext.small(), SearchStats.EMPTY ); - LocalPhysicalPlanOptimizer optimizerWithNoopExecute = new LocalPhysicalPlanOptimizer(context) { + LocalPhysicalPlanOptimizer localPhysicalPlanOptimizer = new LocalPhysicalPlanOptimizer(context) { @Override protected List> batches() { - return List.of(); + return batches; } }; - Exception e = expectThrows(IllegalStateException.class, () -> optimizerWithNoopExecute.localOptimize(topNExec)); - assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references [missing attr")); + return localPhysicalPlanOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + + PhysicalPlan plan = plannerOptimizer.plan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new EvalExec( + plan.source(), + plan, + List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral)) + ); + } + return plan; + } + } + ); + LocalPhysicalPlanOptimizer customRulesLocalPhysicalPlanOptimizer = getCustomRulesLocalPhysicalPlanOptimizer( + List.of(customRuleBatch) + ); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalPhysicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() throws Exception { + + PhysicalPlan plan = plannerOptimizer.plan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + LimitExec limit = as(plan, LimitExec.class); + LimitExec newLimit = new LimitExec( + plan.source(), + limit.child(), + new Literal(Source.EMPTY, 1000, INTEGER), + randomEstimatedRowSize() + ) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + } + ); + LocalPhysicalPlanOptimizer customRulesLocalPhysicalPlanOptimizer = getCustomRulesLocalPhysicalPlanOptimizer( + List.of(customRuleBatch) + ); + Exception e = expectThrows(VerificationException.class, () -> customRulesLocalPhysicalPlanOptimizer.localOptimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); } private boolean isMultiTypeEsField(Expression e) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index e301c1610bd7b..a0dd67105097d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -132,6 +132,7 @@ import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import java.time.Duration; import java.util.ArrayList; @@ -179,6 +180,8 @@ import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; @@ -2902,7 +2905,7 @@ public void testPruneRedundantSortClausesUsingAlias() { public void testInsist_fieldDoesNotExist_createsUnmappedFieldInRelation() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - LogicalPlan plan = optimizedPlan("FROM test | INSIST_🐔 foo"); + LogicalPlan plan = optimizedPlan("FROM test | INSIST_\uD83D\uDC14 foo"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); @@ -2913,7 +2916,7 @@ public void testInsist_fieldDoesNotExist_createsUnmappedFieldInRelation() { public void testInsist_multiIndexFieldPartiallyExistsAndIsKeyword_castsAreNotSupported() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - var plan = planMultiIndex("FROM multi_index | INSIST_🐔 partial_type_keyword"); + var plan = planMultiIndex("FROM multi_index | INSIST_\uD83D\uDC14 partial_type_keyword"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); var relation = as(limit.child(), EsRelation.class); @@ -2924,7 +2927,7 @@ public void testInsist_multiIndexFieldPartiallyExistsAndIsKeyword_castsAreNotSup public void testInsist_multipleInsistClauses_insistsAreFolded() { assumeTrue("Requires UNMAPPED FIELDS", EsqlCapabilities.Cap.UNMAPPED_FIELDS.isEnabled()); - var plan = planMultiIndex("FROM multi_index | INSIST_🐔 partial_type_keyword | INSIST_🐔 foo"); + var plan = planMultiIndex("FROM multi_index | INSIST_\uD83D\uDC14 partial_type_keyword | INSIST_\uD83D\uDC14 foo"); var project = as(plan, Project.class); var limit = as(project.child(), Limit.class); var relation = as(limit.child(), EsRelation.class); @@ -5561,7 +5564,7 @@ public void testPushShadowingGeneratingPlanPastProject() { List initialGeneratedExprs = ((GeneratingPlan) initialPlan).generatedAttributes(); LogicalPlan optimizedPlan = testCase.rule.apply(initialPlan); - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -5612,7 +5615,7 @@ public void testPushShadowingGeneratingPlanPastRenamingProject() { List initialGeneratedExprs = ((GeneratingPlan) initialPlan).generatedAttributes(); LogicalPlan optimizedPlan = testCase.rule.apply(initialPlan); - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -5668,7 +5671,7 @@ public void testPushShadowingGeneratingPlanPastRenamingProjectWithResolution() { // This ensures that our generating plan doesn't use invalid references, resp. that any rename from the Project has // been propagated into the generating plan. - Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false); + Failures inconsistencies = LogicalVerifier.INSTANCE.verify(optimizedPlan, false, initialPlan.output()); assertFalse(inconsistencies.hasFailures()); Project project = as(optimizedPlan, Project.class); @@ -8026,4 +8029,97 @@ public void testMultipleKnnQueriesInPrefilters() { assertThat(secondKnnFilters.size(), equalTo(1)); assertTrue(secondKnnFilters.contains(firstOr.right())); } + + private LogicalPlanOptimizer getCustomRulesLogicalPlanOptimizer(List> batches) { + LogicalOptimizerContext context = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small()); + LogicalPlanOptimizer customOptimizer = new LogicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return customOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + var plan = optimizedPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(UP) { + @Override + protected LogicalPlan rule(Aggregate plan, LogicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new Eval(plan.source(), plan, List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral))); + } + return plan; + } + + } + ); + LogicalPlanOptimizer customRulesLogicalPlanOptimizer = getCustomRulesLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLogicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() { + var plan = optimizedPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new OptimizerRules.ParameterizedOptimizerRule(DOWN) { + @Override + protected LogicalPlan rule(LogicalPlan plan, LogicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Limit limit = as(plan, Limit.class); + Limit newLimit = new Limit(plan.source(), limit.limit(), limit.child()) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + + } + ); + LogicalPlanOptimizer customRulesLogicalPlanOptimizer = getCustomRulesLogicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesLogicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java new file mode 100644 index 0000000000000..8e573dd1cf3c9 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; + +import java.util.List; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; + +public class LogicalPlanPreOptimizerTests extends ESTestCase { + + public void testPlanIsMarkedAsPreOptimized() throws Exception { + for (int round = 0; round < 100; round++) { + // We want to make sure that the pre-optimizer woks for a wide range of plans + preOptimizedPlan(randomPlan()); + } + } + + public void testPreOptimizeFailsIfPlanIsNotAnalyzed() throws Exception { + LogicalPlan plan = EsqlTestUtils.relation(); + SetOnce exceptionHolder = new SetOnce<>(); + + preOptimizer().preOptimize(plan, ActionListener.wrap(r -> fail("Should have failed"), exceptionHolder::set)); + assertBusy(() -> { + assertThat(exceptionHolder.get(), notNullValue()); + IllegalStateException e = as(exceptionHolder.get(), IllegalStateException.class); + assertThat(e.getMessage(), equalTo("Expected analyzed plan")); + }); + } + + public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { + // set plan as analyzed + plan.setPreOptimized(); + + SetOnce resultHolder = new SetOnce<>(); + SetOnce exceptionHolder = new SetOnce<>(); + + preOptimizer().preOptimize(plan, ActionListener.wrap(resultHolder::set, exceptionHolder::set)); + + if (exceptionHolder.get() != null) { + throw exceptionHolder.get(); + } + + assertThat(resultHolder.get(), notNullValue()); + assertThat(resultHolder.get().preOptimized(), equalTo(true)); + + return resultHolder.get(); + } + + private LogicalPlanPreOptimizer preOptimizer() { + LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); + return new LogicalPlanPreOptimizer(preOptimizerContext); + } + + private LogicalPlan randomPlan() { + LogicalPlan plan = EsqlTestUtils.relation(); + int numCommands = between(0, 100); + + for (int i = 0; i < numCommands; i++) { + plan = switch (randomInt(3)) { + case 0 -> new Eval(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), randomExpression()))); + case 1 -> new Limit(Source.EMPTY, of(randomInt()), plan); + case 2 -> new Filter(Source.EMPTY, plan, randomCondition()); + default -> new Project(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), fieldAttribute()))); + }; + } + return plan; + } + + private Expression randomExpression() { + return switch (randomInt(3)) { + case 0 -> of(randomInt()); + case 1 -> of(randomIdentifier()); + case 2 -> new Add(Source.EMPTY, of(randomInt()), of(randomDouble())); + default -> new Concat(Source.EMPTY, of(randomIdentifier()), randomList(1, 10, () -> of(randomIdentifier()))); + }; + } + + private Expression randomCondition() { + if (randomBoolean()) { + return EsqlTestUtils.equalsOf(randomExpression(), randomExpression()); + } + + return EsqlTestUtils.greaterThanOf(randomExpression(), randomExpression()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index a609a1e494e54..6850e052eda9e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.compute.aggregation.AggregatorMode; @@ -64,11 +65,13 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent; @@ -135,6 +138,7 @@ import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner; +import org.elasticsearch.xpack.esql.planner.PhysicalSettings; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; import org.elasticsearch.xpack.esql.plugin.EsqlFlags; @@ -142,6 +146,7 @@ import org.elasticsearch.xpack.esql.querydsl.query.EqualsSyntheticSourceDelegate; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.querydsl.query.SpatialRelatesQuery; +import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.junit.Before; @@ -187,9 +192,11 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_SHAPE; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.util.TestUtils.stripThrough; import static org.elasticsearch.xpack.esql.parser.ExpressionBuilder.MAX_EXPRESSION_DEPTH; import static org.elasticsearch.xpack.esql.parser.LogicalPlanBuilder.MAX_QUERY_DEPTH; +import static org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests.randomEstimatedRowSize; import static org.elasticsearch.xpack.esql.planner.mapper.MapperUtils.hasScoreAttribute; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; @@ -2873,7 +2880,7 @@ public void testFieldExtractWithoutSourceAttributes() { ) ); - var e = expectThrows(VerificationException.class, () -> physicalPlanOptimizer.verify(badPlan)); + var e = expectThrows(VerificationException.class, () -> physicalPlanOptimizer.verify(badPlan, verifiedPlan.output())); assertThat( e.getMessage(), containsString( @@ -2888,7 +2895,7 @@ public void testVerifierOnMissingReferences() { | stats s = sum(salary) by emp_no | where emp_no > 10 """); - + final var planBeforeModification = plan; plan = plan.transformUp( AggregateExec.class, a -> new AggregateExec( @@ -2902,7 +2909,7 @@ public void testVerifierOnMissingReferences() { ) ); final var finalPlan = plan; - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan, planBeforeModification.output())); assertThat(e.getMessage(), containsString(" > 10[INTEGER]]] optimized incorrectly due to missing references [emp_no{f}#")); } @@ -2920,7 +2927,7 @@ public void testVerifierOnMissingReferencesWithBinaryPlans() throws Exception { var planWithInvalidJoinLeftSide = plan.transformUp(LookupJoinExec.class, join -> join.replaceChildren(join.right(), join.right())); - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinLeftSide)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinLeftSide, plan.output())); assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references from left hand side [languages")); var planWithInvalidJoinRightSide = plan.transformUp( @@ -2937,7 +2944,7 @@ public void testVerifierOnMissingReferencesWithBinaryPlans() throws Exception { ) ); - e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinRightSide)); + e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(planWithInvalidJoinRightSide, plan.output())); assertThat(e.getMessage(), containsString(" optimized incorrectly due to missing references from right hand side [language_code")); } @@ -2947,7 +2954,7 @@ public void testVerifierOnDuplicateOutputAttributes() { | stats s = sum(salary) by emp_no | where emp_no > 10 """); - + final var planBeforeModification = plan; plan = plan.transformUp(AggregateExec.class, a -> { List intermediates = new ArrayList<>(a.intermediateAttributes()); intermediates.add(intermediates.get(0)); @@ -2962,7 +2969,7 @@ public void testVerifierOnDuplicateOutputAttributes() { ); }); final var finalPlan = plan; - var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan)); + var e = expectThrows(IllegalStateException.class, () -> physicalPlanOptimizer.verify(finalPlan, planBeforeModification.output())); assertThat( e.getMessage(), containsString("Plan [LimitExec[1000[INTEGER],null]] optimized incorrectly due to duplicate output attribute emp_no{f}#") @@ -7888,7 +7895,12 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP null, null, null, - new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null, DataPartitioning.AUTO), + new EsPhysicalOperationProviders( + FoldContext.small(), + List.of(), + null, + new PhysicalSettings(DataPartitioning.AUTO, ByteSizeValue.ofMb(1)) + ), List.of() ); @@ -8308,6 +8320,107 @@ private QueryBuilder sv(QueryBuilder builder, String fieldName) { return sv.next(); } + private PhysicalPlanOptimizer getCustomRulesPhysicalPlanOptimizer(List> batches) { + PhysicalOptimizerContext context = new PhysicalOptimizerContext(config); + PhysicalPlanOptimizer PhysicalPlanOptimizer = new PhysicalPlanOptimizer(context) { + @Override + protected List> batches() { + return batches; + } + }; + return PhysicalPlanOptimizer; + } + + public void testVerifierOnAdditionalAttributeAdded() throws Exception { + + PhysicalPlan plan = physicalPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that adds another output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, PhysicalOptimizerContext context) { + // This rule adds a missing attribute to the plan output + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + Literal additionalLiteral = new Literal(Source.EMPTY, "additional literal", INTEGER); + return new EvalExec( + plan.source(), + plan, + List.of(new Alias(Source.EMPTY, "additionalAttribute", additionalLiteral)) + ); + } + return plan; + } + } + ); + PhysicalPlanOptimizer customRulesPhysicalPlanOptimizer = getCustomRulesPhysicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesPhysicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + assertThat(e.getMessage(), containsString("additionalAttribute")); + } + + public void testVerifierOnAttributeDatatypeChanged() throws Exception { + + PhysicalPlan plan = physicalPlan(""" + from test + | stats a = min(salary) by emp_no + """); + + var limit = as(plan, LimitExec.class); + var aggregate = as(limit.child(), AggregateExec.class); + var min = as(Alias.unwrap(aggregate.aggregates().get(0)), Min.class); + var salary = as(min.field(), NamedExpression.class); + assertThat(salary.name(), is("salary")); + Holder appliedCount = new Holder<>(0); + // use a custom rule that changes the datatype of an output attribute + var customRuleBatch = new RuleExecutor.Batch<>( + "CustomRuleBatch", + RuleExecutor.Limiter.ONCE, + new PhysicalOptimizerRules.ParameterizedOptimizerRule() { + @Override + public PhysicalPlan rule(PhysicalPlan plan, PhysicalOptimizerContext context) { + // We only want to apply it once, so we use a static counter + if (appliedCount.get() == 0) { + appliedCount.set(appliedCount.get() + 1); + LimitExec limit = as(plan, LimitExec.class); + LimitExec newLimit = new LimitExec( + plan.source(), + limit.child(), + new Literal(Source.EMPTY, 1000, INTEGER), + randomEstimatedRowSize() + ) { + @Override + public List output() { + List oldOutput = super.output(); + List newOutput = new ArrayList<>(oldOutput); + newOutput.set(0, oldOutput.get(0).withDataType(DataType.DATETIME)); + return newOutput; + } + }; + return newLimit; + } + return plan; + } + } + ); + PhysicalPlanOptimizer customRulesPhysicalPlanOptimizer = getCustomRulesPhysicalPlanOptimizer(List.of(customRuleBatch)); + Exception e = expectThrows(VerificationException.class, () -> customRulesPhysicalPlanOptimizer.optimize(plan)); + assertThat(e.getMessage(), containsString("Output has changed from")); + } + @Override protected List filteredWarnings() { return withDefaultLimitWarning(super.filteredWarnings()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 6749f03bedde7..b56f4a3a4898b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -19,6 +19,7 @@ import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.lucene.LuceneSourceOperator; @@ -340,7 +341,12 @@ private Configuration config() { } private EsPhysicalOperationProviders esPhysicalOperationProviders(List shardContexts) { - return new EsPhysicalOperationProviders(FoldContext.small(), shardContexts, null, DataPartitioning.AUTO); + return new EsPhysicalOperationProviders( + FoldContext.small(), + shardContexts, + null, + new PhysicalSettings(DataPartitioning.AUTO, ByteSizeValue.ofMb(1)) + ); } private List createShardContexts() throws IOException { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 8e40bba8b32f7..1eb530ac1bb9e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() { createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(gatewayUrl), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index e56782bd00ef5..22aebee72df0c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -129,7 +130,8 @@ public void testGetModel() throws Exception { mock(Client.class), mock(ThreadPool.class), mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + InferenceStatsTests.mockInferenceStats() ) ); ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index fe31ae71ba8c1..00f40e903d1ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -42,6 +42,7 @@ public class InferenceFeatures implements FeatureSpecification { ); private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter"); private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2"); + public static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTING_FLAT = new NodeFeature("semantic_text.highlighter.flat_index_options"); @Override public Set getTestFeatures() { @@ -72,7 +73,8 @@ public Set getTestFeatures() { SEMANTIC_TEXT_INDEX_OPTIONS, COHERE_V2_API, SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS, - SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX + SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX, + SEMANTIC_TEXT_HIGHLIGHTING_FLAT ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 5028cc6873cbb..6fd07cd4c2831 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -106,6 +106,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; @@ -175,6 +177,7 @@ public static List getNamedWriteables() { addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); addCustomNamedWriteables(namedWriteables); + addLlamaNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -274,8 +277,25 @@ private static void addMistralNamedWriteables(List MistralChatCompletionServiceSettings::new ) ); + // no task settings for Mistral + } - // note - no task settings for Mistral embeddings... + private static void addLlamaNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaEmbeddingsServiceSettings.NAME, + LlamaEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaChatCompletionServiceSettings.NAME, + LlamaChatCompletionServiceSettings::new + ) + ); + // no task settings for Llama } private static void addAzureAiStudioNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..bbb1bd1a2fec2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -30,6 +30,7 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; @@ -132,6 +133,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; @@ -140,7 +142,6 @@ import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; import java.util.Collection; @@ -311,7 +312,8 @@ public Collection createComponents(PluginServices services) { serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), - authorizationHandler + authorizationHandler, + context ), context -> new SageMakerService( new SageMakerModelBuilder(sageMakerSchemas), @@ -321,16 +323,22 @@ public Collection createComponents(PluginServices services) { ), sageMakerSchemas, services.threadPool(), - sageMakerConfigurations::getOrCompute + sageMakerConfigurations::getOrCompute, + context ) ) ); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var inferenceStats = InferenceStats.create(meterRegistry); + var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext( services.client(), services.threadPool(), services.clusterService(), - settings + settings, + inferenceStats ); // This must be done after the HttpRequestSenderFactory is created so that the services can get the @@ -342,10 +350,6 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); - var meterRegistry = services.telemetryProvider().getMeterRegistry(); - var inferenceStats = InferenceStats.create(meterRegistry); - var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats); - var actionFilter = new ShardBulkInferenceActionFilter( services.clusterService(), serviceRegistry, @@ -383,24 +387,25 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), - context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new CohereService(httpFactory.get(), serviceComponents.get()), - context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), - context -> new MistralService(httpFactory.get(), serviceComponents.get()), - context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), - context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), - context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), - context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), - context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), - context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), - context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context), + context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), + context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), + context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, - context -> new CustomService(httpFactory.get(), serviceComponents.get()) + context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index dec6d0d928b97..269e0f27fd461 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -42,7 +43,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; @@ -57,10 +57,11 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; +import static org.elasticsearch.inference.telemetry.InferenceStats.routingAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes; /** * Base class for transport actions that handle inference requests. @@ -274,15 +275,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException { } private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAttributes(model)); + metricAttributes.putAll(responseAttributes(unwrapCause(t))); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnServiceWithMetrics( @@ -369,7 +366,7 @@ protected Flow.Publisher streamErrorHandler(Flow.Publisher upstream) { private void recordRequestCountMetrics(Model model, Request request, String localNodeId) { Map requestCountAttributes = new HashMap<>(); requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(routingAttributes(request, localNodeId)); + requestCountAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); inferenceStats.requestCount().incrementBy(1, requestCountAttributes); } @@ -381,16 +378,11 @@ private void recordRequestDurationMetrics( String localNodeId, @Nullable Throwable t ) { - try { - Map metricAttributes = new HashMap<>(); - metricAttributes.putAll(modelAttributes(model)); - metricAttributes.putAll(routingAttributes(request, localNodeId)); - metricAttributes.putAll(responseAttributes(unwrapCause(t))); - - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + Map metricAttributes = new HashMap<>(); + metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t))); + metricAttributes.putAll(routingAttributes(request.hasBeenRerouted(), localNodeId)); + + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes); } private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7d24b7766baa3..f14d679ba7d26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.threadpool.ThreadPool; @@ -24,7 +25,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; public class TransportInferenceAction extends BaseTransportInferenceAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index bfa8141d312cf..d0eef677ca1d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -29,7 +30,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.concurrent.Flow; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3127361de6d11..ecf73ed004194 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -46,6 +46,7 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -63,7 +64,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.io.IOException; import java.util.ArrayList; @@ -76,11 +76,10 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -459,8 +458,7 @@ public void onFailure(Exception exc) { private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) { Map requestCountAttributes = new HashMap<>(); - requestCountAttributes.putAll(modelAttributes(model)); - requestCountAttributes.putAll(responseAttributes(throwable)); + requestCountAttributes.putAll(modelAndResponseAttributes(model, throwable)); requestCountAttributes.put("inference_source", "semantic_text_bulk"); inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes); } @@ -637,7 +635,9 @@ private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure index addInferenceResponseFailure( itemIndex, new InferenceException( - "Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]", + "Unable to insert inference results into document [" + + indexRequest.getIndexRequest().id() + + "] due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes.", e ) ); @@ -749,7 +749,9 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons item.abort( item.index(), new InferenceException( - "Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]", + "Unable to insert inference results into document [" + + indexRequest.id() + + "] due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes.", e ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java index 92333a10c4d08..8e55cc9c222b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java @@ -32,6 +32,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; +import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.SparseVectorQueryWrapper; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.Text; @@ -273,6 +274,8 @@ public void visitLeaf(Query query) { queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null)); } else if (query instanceof MatchAllDocsQuery) { queries.add(new MatchAllDocsQuery()); + } else if (query instanceof DenseVectorQuery.Floats floatsQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(floatsQuery.getQuery()), null)); } } }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..5074749c1cd9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -42,11 +43,13 @@ public abstract class SenderService implements InferenceService { protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; + private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(clusterService); } public Sender getSender() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index adbec49328804..f5f1074bfbb86 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -304,6 +305,12 @@ public static String invalidSettingError(String settingName, String scope) { return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); } + public static URI extractUri(Map map, String fieldName, ValidationException validationException) { + String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + + return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + } + public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { return createOptionalUri(url); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..da608779fee0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -85,8 +87,20 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..c2b0ae8e69c37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -93,9 +95,19 @@ public class AmazonBedrockService extends SenderService { public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, - ServiceComponents serviceComponents + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents); + this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService()); + } + + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(httpSenderFactory, serviceComponents, clusterService); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 791518ccc9168..8cf5446f8b6d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AnthropicService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 17c7cbd6bdf0e..4a5a8be8b6633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..3d9a3dd516a2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +71,16 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureOpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..fb6c630bd60c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CohereService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4e81d37ead3ad..5f5078affa9d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,8 +76,16 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CustomService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..8a77efbd604d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -10,12 +10,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public DeepSeekService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 36712ed922e95..58e964bb5c25f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -22,6 +23,7 @@ import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -139,9 +141,28 @@ public ElasticInferenceService( ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents); + this( + factory, + serviceComponents, + elasticInferenceServiceSettings, + modelRegistry, + authorizationRequestHandler, + context.clusterService() + ); + } + + public ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java index b45d4449251f4..007dc820c629f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -28,7 +28,7 @@ public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInpu @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params)); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index ee4221157388e..4aaf3c2db2e61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; @@ -23,6 +22,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -38,13 +38,16 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.function.Consumer; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -55,6 +58,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi protected final ExecutorService inferenceExecutor; protected final Consumer> preferredModelVariantFn; private final ClusterService clusterService; + private final InferenceStats inferenceStats; public enum PreferredModelVariant { LINUX_X86_OPTIMIZED, @@ -69,10 +73,11 @@ public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServi this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } // For testing. - // platformArchFn enables similating different architectures + // platformArchFn enables simulating different architectures // without extensive mocking on the client to simulate the nodes info response. // TODO make package private once the elser service is moved to the Elasticsearch // service package. @@ -85,6 +90,7 @@ public BaseElasticsearchInternalService( this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); this.preferredModelVariantFn = preferredModelVariantFn; this.clusterService = context.clusterService(); + this.inferenceStats = context.inferenceStats(); } @Override @@ -103,6 +109,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL return; } + var timer = InferenceTimer.start(); // instead of a subscribably listener, use some wait to wait for the first one. var subscribableListener = SubscribableListener.newForked( forkedListener -> { isBuiltinModelPut(model, forkedListener); } @@ -118,21 +125,25 @@ public void start(Model model, TimeValue timeout, ActionListener finalL client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }); subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor); - subscribableListener.addListener(finalListener.delegateResponse((l, e) -> { + subscribableListener.addListener(ActionListener.wrap(started -> { + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, null)); + finalListener.onResponse(started); + }, e -> { if (e instanceof ElasticsearchTimeoutException) { - l.onFailure( - new ModelDeploymentTimeoutException( - format( - "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " - + "The inference endpoint can not be used to perform inference until the deployment has started. " - + "Use the trained model stats API to track the state of the deployment.", - timeout, - model.getInferenceEntityId() - ) + var timeoutException = new ModelDeploymentTimeoutException( + format( + "Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment.", + timeout, + model.getInferenceEntityId() ) ); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, timeoutException)); + finalListener.onFailure(timeoutException); } else { - l.onFailure(e); + inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, unwrapCause(e))); + finalListener.onFailure(e); } })); @@ -323,7 +334,7 @@ protected void maybeStartDeployment( InferModelAction.Request request, ActionListener listener ) { - if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + if (isDefaultId(model.getInferenceEntityId()) && unwrapCause(e) instanceof ResourceNotFoundException) { this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..4c8997f35555b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -82,8 +84,16 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..2c2c667cd6eee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,8 +99,16 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleVertexAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index b0d40b41914d5..325f88c8904a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,8 +46,16 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceBaseService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..bc64e832d182a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -11,10 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -71,8 +73,16 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index 7429153835ee3..91735d39f3973 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -31,11 +31,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings { public static final String NAME = "hugging_face_service_settings"; @@ -70,12 +69,6 @@ public static HuggingFaceServiceSettings fromMap(Map map, Config return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings); } - public static URI extractUri(Map map, String fieldName, ValidationException validationException) { - String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - - return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - } - private final URI uri; private final SimilarityMeasure similarity; private final Integer dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index cdc2529428bed..64da6e32bc1f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -31,7 +31,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Settings for the Hugging Face chat completion service. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..5f9288bb99c24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -57,8 +59,16 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceElserService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java index b1d3297fc6328..ad771e72b6b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceElserServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java index b0b21b26395af..57c103bbbf3b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -27,7 +27,7 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceRerankServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 9bc63be1f9e7e..9617bff0d3f3d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +85,16 @@ public class IbmWatsonxService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public IbmWatsonxService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..00e1aede95a2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -76,8 +78,16 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public JinaAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java new file mode 100644 index 0000000000000..3e24d058d8540 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract class representing a Llama model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Llama models. + */ +public abstract class LlamaModel extends RateLimitGroupingModel { + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + /** + * Constructor for creating a LlamaModel with specified configurations and secrets. + * + * @param configurations the model configurations + * @param secrets the secret settings for the model + */ + protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + /** + * Constructor for creating a LlamaModel with specified model, service settings, and secret settings. + * @param model the model configurations + * @param serviceSettings the settings for the inference service + */ + protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public URI uri() { + return this.uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().modelId(), uri, getSecretSettings()); + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + /** + * Retrieves the secret settings from the provided map of secrets. + * If the map is null or empty, it returns an instance of EmptySecretSettings. + * Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication. + * + * @param secrets the map containing secret settings + * @return an instance of SecretSettings + */ + protected static SecretSettings retrieveSecretSettings(Map secrets) { + return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets); + } + + protected abstract ExecutableAction accept(LlamaActionVisitor creator); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java new file mode 100644 index 0000000000000..bd6b3c91fc9e9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -0,0 +1,423 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks. + * It extends SenderService to handle HTTP requests and responses for Llama models. + */ +public class LlamaService extends SenderService { + public static final String NAME = "llama"; + private static final String SERVICE_NAME = "Llama"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Llama use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler( + "llama chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + /** + * Constructor for creating a LlamaService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public LlamaService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + if (model instanceof LlamaModel llamaModel) { + llamaModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + /** + * Creates a LlamaModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param failureMessage the message to use in case of failure + * @param context the context for parsing configuration settings + * @return a new instance of LlamaModel based on the provided parameters + */ + protected LlamaModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING: + return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION: + return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof LlamaEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new LlamaEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings() + ); + + return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof LlamaEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaModel = (LlamaEmbeddingsModel) model; + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + llamaModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = llamaModel.accept(actionCreator); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof LlamaChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaChatCompletionModel = (LlamaChatCompletionModel) model; + var overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new LlamaChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = LlamaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + LlamaModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + private LlamaModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public boolean hideFromConfigurationApi() { + // The Llama service is very configurable so we're going to hide it from being exposed in the service API. + return true; + } + + /** + * Configuration class for the Llama inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Llama models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java new file mode 100644 index 0000000000000..52e284ba7ccca --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates actions for Llama inference requests, handling both embeddings and completions. + * This class implements the {@link LlamaActionVisitor} interface to provide specific action creation methods. + */ +public class LlamaActionCreator implements LlamaActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Llama %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "Llama completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler( + "llama text embedding", + HuggingFaceEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler( + "llama completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new LlamaActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(LlamaEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new LlamaEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(LlamaChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java new file mode 100644 index 0000000000000..1521b83b668c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; + +/** + * Visitor interface for creating executable actions for Llama inference models. + * This interface defines methods to create actions for both embeddings and chat completion models. + */ +public interface LlamaActionVisitor { + /** + * Creates an executable action for the given Llama embeddings model. + * + * @param model the Llama embeddings model + * @return an executable action for the embeddings model + */ + ExecutableAction create(LlamaEmbeddingsModel model); + + /** + * Creates an executable action for the given Llama chat completion model. + * + * @param model the Llama chat completion model + * @return an executable action for the chat completion model + */ + ExecutableAction create(LlamaChatCompletionModel model); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java new file mode 100644 index 0000000000000..a1a38f1eae326 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama chat completion model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for chat completion tasks. + */ +public class LlamaChatCompletionModel extends LlamaModel { + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaChatCompletionServiceSettings.fromMap(serviceSettings, context), + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Factory method to create a LlamaChatCompletionModel with overridden model settings based on the request. + * If the request does not specify a model, the original model is returned. + * + * @param model the original LlamaChatCompletionModel + * @param request the UnifiedCompletionRequest containing potential overrides + * @return a new LlamaChatCompletionModel with overridden settings or the original model if no overrides are specified + */ + public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model id is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new LlamaChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new LlamaChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + private void setPropertiesFromServiceSettings(LlamaChatCompletionServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Returns the service settings specific to Llama chat completion. + * + * @return the LlamaChatCompletionServiceSettings associated with this model + */ + @Override + public LlamaChatCompletionServiceSettings getServiceSettings() { + return (LlamaChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor that creates an executable action for this Llama chat completion model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing this model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..85d60308d77d3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; +import java.util.Optional; + +import static org.elasticsearch.core.Strings.format; + +/** + * Handles streaming chat completion responses and error parsing for Llama inference endpoints. + * This handler is designed to work with the unified Llama chat completion API. + */ +public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String LLAMA_ERROR = "llama_error"; + private static final String STREAM_ERROR = "stream_error"; + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type, + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the response + * @param errorResponse the error response parsed from the HTTP result + * @return an exception representing the error, specific to Llama chat completion + */ + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof LlamaErrorResponse + ? new UnifiedChatCompletionException(restStatus, errorMessage, LLAMA_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } + + /** + * Builds an exception for mid-stream errors encountered during Llama chat completion requests. + * + * @param request the request that caused the error + * @param message the error message + * @param e the exception that occurred, if any + * @return a UnifiedChatCompletionException representing the error + */ + @Override + protected Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message); + if (errorResponse instanceof StreamingLlamaErrorResponseEntity) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + LLAMA_ERROR, + STREAM_ERROR + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + } + + /** + * StreamingLlamaErrorResponseEntity allows creation of {@link ErrorResponse} from a JSON string. + * This entity is used to parse error responses from streaming Llama requests. + * For non-streaming requests {@link LlamaErrorResponse} should be used. + * Example error response for Bad Request error would look like: + *


+     *  {
+     *      "error": {
+     *          "message": "400: Invalid value: Model 'llama3.12:3b' not found"
+     *      }
+     *  }
+     * 
+ */ + private static class StreamingLlamaErrorResponseEntity extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0]) + ); + private static final ConstructingObjectParser< + LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity, + Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity( + args[0] != null ? (String) args[0] : "unknown" + ) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + /** + * Parses a streaming Llama error response from a JSON string. + * + * @param response the raw JSON string representing an error + * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails + */ + private static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + /** + * Constructs a StreamingLlamaErrorResponseEntity with the specified error message. + * + * @param errorMessage the error message to include in the response entity + */ + StreamingLlamaErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..7917a8cba5b48 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for a Llama chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the Llama chat completion service. + */ +public class LlamaChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_completion_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static LlamaChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + LlamaService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaChatCompletionServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public LlamaChatCompletionServiceSettings(String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + */ + public LlamaChatCompletionServiceSettings(String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the Llama chat completion service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the Llama chat completion service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaChatCompletionServiceSettings that = (LlamaChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java new file mode 100644 index 0000000000000..8e3b5b10df900 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for Llama models, extending the OpenAI completion response handler. + * This class is specifically designed to handle Llama's error response format. + */ +public class LlamaCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a LlamaCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "llama completions"). + * @param parseFunction The function to parse the response. + */ + public LlamaCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java new file mode 100644 index 0000000000000..ebf0b7e8132c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama embeddings model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for embeddings tasks. + */ +public class LlamaEmbeddingsModel extends LlamaModel { + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param model the base LlamaEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public LlamaEmbeddingsModel(LlamaEmbeddingsModel model, LlamaEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Sets properties from the provided LlamaEmbeddingsServiceSettings. + * + * @param serviceSettings the service settings to extract properties from + */ + private void setPropertiesFromServiceSettings(LlamaEmbeddingsServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + @Override + public LlamaEmbeddingsServiceSettings getServiceSettings() { + return (LlamaEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Llama embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Llama embeddings model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..240ccf46c7482 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for Llama embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for Llama embeddings. + */ +public class LlamaEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new LlamaEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..a14146070247a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Settings for the Llama embeddings service. + * This class encapsulates the configuration settings required to use Llama for generating embeddings. + */ +public class LlamaEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_embeddings_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static LlamaEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, LlamaService.NAME, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaEmbeddingsServiceSettings(model, uri, dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.uri = uri; + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + public URI uri() { + return this.uri; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaEmbeddingsServiceSettings that = (LlamaEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java new file mode 100644 index 0000000000000..3bb01f215087e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Chat Completion Request + * This class is responsible for creating a request to the Llama chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaChatCompletionRequest implements Request { + + private final LlamaChatCompletionModel model; + private final UnifiedChatInput chatInput; + + /** + * Constructs a new LlamaChatCompletionRequest with the specified chat input and model. + * + * @param chatInput the chat input containing the messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequest(UnifiedChatInput chatInput, LlamaChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + /** + * Returns the chat input for this request. + * + * @return the chat input containing the messages and parameters + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Llama chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Llama chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..fc80dab09f6f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +/** + * LlamaChatCompletionRequestEntity is responsible for creating the request entity for Llama chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class LlamaChatCompletionRequestEntity implements ToXContentObject { + + private final LlamaChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + /** + * Constructs a LlamaChatCompletionRequestEntity with the specified unified chat input and model. + * + * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, LlamaChatCompletionModel model) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java new file mode 100644 index 0000000000000..5883880dbb812 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Embeddings Request + * This class is responsible for creating a request to the Llama embeddings model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaEmbeddingsRequest implements Request { + private final URI uri; + private final LlamaEmbeddingsModel model; + private final String inferenceEntityId; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new LlamaEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the Llama embeddings model to be used for the request + */ + public LlamaEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, LlamaEmbeddingsModel model) { + this.uri = model.uri(); + this.model = model; + this.inferenceEntityId = model.getInferenceEntityId(); + this.truncator = truncator; + this.truncationResult = input; + } + + /** + * Returns the URI for this request. + * + * @return the URI of the Llama embeddings model + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaEmbeddingsRequestEntity(model.getServiceSettings().modelId(), truncationResult.input())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new LlamaEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..3f734bacec87d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * LlamaEmbeddingsRequestEntity is responsible for creating the request entity for Llama embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record LlamaEmbeddingsRequestEntity(String modelId, List contents) implements ToXContentObject { + + public static final String CONTENTS_FIELD = "contents"; + public static final String MODEL_ID_FIELD = "model_id"; + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and contents. + * + * @param modelId the ID of the model to use for embeddings + * @param contents the list of contents to generate embeddings for + */ + public LlamaEmbeddingsRequestEntity { + Objects.requireNonNull(modelId); + Objects.requireNonNull(contents); + } + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and a single content string. + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_ID_FIELD, modelId); + builder.field(CONTENTS_FIELD, contents); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java new file mode 100644 index 0000000000000..727231209fdf1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.nio.charset.StandardCharsets; + +/** + * LlamaErrorResponse is responsible for handling error responses from Llama inference services. + * It extends ErrorResponse to provide specific functionality for Llama errors. + * An example error response for Not Found error would look like: + *

+ *  {
+ *      "detail": "Not Found"
+ *  }
+ * 
+ * An example error response for Bad Request error would look like: + *

+ *  {
+ *     "error": {
+ *         "detail": {
+ *             "errors": [
+ *                 {
+ *                     "loc": [
+ *                         "body",
+ *                         "model"
+ *                     ],
+ *                     "msg": "Field required",
+ *                     "type": "missing"
+ *                 }
+ *             ]
+ *         }
+ *     }
+ *  }
+ * 
+ */ +public class LlamaErrorResponse extends ErrorResponse { + + public LlamaErrorResponse(String message) { + super(message); + } + + public static ErrorResponse fromResponse(HttpResult response) { + try { + String errorMessage = new String(response.body(), StandardCharsets.UTF_8); + return new LlamaErrorResponse(errorMessage); + } catch (Exception e) { + // swallow the error + } + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java index 57219a03b3bdb..55a5b4fe71047 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java @@ -10,19 +10,21 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.net.URISyntaxException; +import java.util.Objects; /** * Represents a Mistral model that can be used for inference tasks. * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. */ public abstract class MistralModel extends RateLimitGroupingModel { - protected String model; protected URI uri; protected RateLimitSettings rateLimitSettings; @@ -34,10 +36,6 @@ protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSett super(model, serviceSettings); } - public String model() { - return this.model; - } - public URI uri() { return this.uri; } @@ -49,7 +47,7 @@ public RateLimitSettings rateLimitSettings() { @Override public int rateLimitGroupingHash() { - return 0; + return Objects.hash(getServiceSettings().modelId(), getSecretSettings().apiKey()); } // Needed for testing only @@ -65,4 +63,6 @@ public void setURI(String newUri) { public DefaultSecretSettings getSecretSettings() { return (DefaultSecretSettings) super.getSecretSettings(); } + + public abstract ExecutableAction accept(MistralActionVisitor creator); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..c1eee5eb27338 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public MistralService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override @@ -98,16 +108,10 @@ protected void doInfer( ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); - switch (model) { - case MistralEmbeddingsModel mistralEmbeddingsModel: - mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); - break; - case MistralChatCompletionModel mistralChatCompletionModel: - mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener); - break; - default: - listener.onFailure(createInvalidModelException(model)); - break; + if (model instanceof MistralModel mistralModel) { + mistralModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); } } @@ -162,7 +166,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); + var action = mistralEmbeddingsModel.accept(actionCreator); action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { @@ -207,7 +211,6 @@ public void parseRequestConfig( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), @@ -232,7 +235,7 @@ public MistralModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; @@ -244,7 +247,6 @@ public MistralModel parsePersistedConfigWithSecrets( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(modelId, NAME) @@ -254,7 +256,7 @@ public MistralModel parsePersistedConfigWithSecrets( @Override public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -265,7 +267,6 @@ public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map< modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, null, parsePersistedConfigErrorMsg(modelId, NAME) @@ -286,7 +287,6 @@ private static MistralModel createModel( String modelId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, @@ -294,16 +294,7 @@ private static MistralModel createModel( ) { switch (taskType) { case TEXT_EMBEDDING: - return new MistralEmbeddingsModel( - modelId, - taskType, - NAME, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - context - ); + return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); case CHAT_COMPLETION, COMPLETION: return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: @@ -315,7 +306,6 @@ private MistralModel createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage @@ -324,7 +314,6 @@ private MistralModel createModelFromPersistent( inferenceEntityId, taskType, serviceSettings, - taskSettings, chunkingSettings, secretSettings, failureMessage, @@ -359,10 +348,10 @@ public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int e */ public static class Configuration { public static InferenceServiceConfiguration get() { - return configuration.getOrCompute(); + return CONFIGURATION.getOrCompute(); } - private static final LazyInitializable configuration = new LazyInitializable<>( + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( () -> { var configurationMap = new HashMap(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java index fbf842f4fb789..ba7377c3209e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java @@ -24,7 +24,6 @@ import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; -import java.util.Map; import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -51,7 +50,7 @@ public MistralActionCreator(Sender sender, ServiceComponents serviceComponents) } @Override - public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings) { + public ExecutableAction create(MistralEmbeddingsModel embeddingsModel) { var requestManager = new MistralEmbeddingsRequestManager( embeddingsModel, serviceComponents.truncator(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java index 5f494e4d65477..e1c4b12883c56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; -import java.util.Map; - /** * Interface for creating {@link ExecutableAction} instances for Mistral models. *

@@ -25,10 +23,9 @@ public interface MistralActionVisitor { * Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}. * * @param embeddingsModel The model to create the action for. - * @param taskSettings The task settings to use. * @return An {@link ExecutableAction} for the given model. */ - ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings); + ExecutableAction create(MistralEmbeddingsModel embeddingsModel); /** * Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java index 03fe502a82807..876c46edcb70d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java @@ -22,7 +22,6 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH; @@ -95,23 +94,17 @@ public MistralChatCompletionModel( DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); } private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } - @Override - public int rateLimitGroupingHash() { - return Objects.hash(model, getSecretSettings().apiKey()); - } - private void setEndpointUrl() { try { this.uri = new URI(API_COMPLETIONS_PATH); @@ -131,6 +124,7 @@ public MistralChatCompletionServiceSettings getServiceSettings() { * @param creator The visitor that creates the executable action. * @return An ExecutableAction that can be executed. */ + @Override public ExecutableAction accept(MistralActionVisitor creator) { return creator.create(this); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index 48d2fecc5ce13..8ac186ac9d642 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -37,7 +36,6 @@ public MistralEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context @@ -47,7 +45,6 @@ public MistralEmbeddingsModel( taskType, service, MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), - EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); @@ -59,7 +56,6 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer } private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } @@ -77,12 +73,11 @@ public MistralEmbeddingsModel( TaskType taskType, String service, MistralEmbeddingsServiceSettings serviceSettings, - TaskSettings taskSettings, ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); @@ -93,7 +88,8 @@ public MistralEmbeddingsServiceSettings getServiceSettings() { return (MistralEmbeddingsServiceSettings) super.getServiceSettings(); } - public ExecutableAction accept(MistralActionVisitor creator, Map taskSettings) { - return creator.create(this, taskSettings); + @Override + public ExecutableAction accept(MistralActionVisitor creator) { + return creator.create(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index 6b1c7d36a9fe6..4cf1fef3c92c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -178,12 +178,13 @@ public boolean equals(Object o) { return Objects.equals(model, that.model) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override public int hashCode() { - return Objects.hash(model, dimensions, maxInputTokens, similarity); + return Objects.hash(model, dimensions, maxInputTokens, similarity, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java index 8b772d4b8f2ed..b7d3866bcebfd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java @@ -42,7 +42,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(this.uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.model(), truncationResult.input())) + Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.getServiceSettings().modelId(), truncationResult.input())) .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..b9e9e34c44736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,8 +93,16 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public OpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 3120f1ff92e48..957203b5ee802 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -198,7 +198,7 @@ private static class DeltaParser { PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); - PARSER.declareObjectArray( + PARSER.declareObjectArrayOrNull( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), new ParseField(TOOL_CALLS_FIELD) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java index 928ed3ff444e6..2ae70cb52b565 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java @@ -34,7 +34,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent( builder, - UnifiedCompletionRequest.withMaxCompletionTokensTokens(model.getServiceSettings().modelId(), params) + UnifiedCompletionRequest.withMaxCompletionTokens(model.getServiceSettings().modelId(), params) ); if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index aafd6c46857fc..653c4288263f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -37,6 +39,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -55,13 +58,26 @@ public class SageMakerService implements InferenceService { private final SageMakerSchemas schemas; private final ThreadPool threadPool; private final LazyInitializable configuration; + private final ClusterService clusterService; public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, - CheckedSupplier, RuntimeException> configurationMap + CheckedSupplier, RuntimeException> configurationMap, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(modelBuilder, client, schemas, threadPool, configurationMap, context.clusterService()); + } + + public SageMakerService( + SageMakerModelBuilder modelBuilder, + SageMakerClient client, + SageMakerSchemas schemas, + ThreadPool threadPool, + CheckedSupplier, RuntimeException> configurationMap, + ClusterService clusterService ) { this.modelBuilder = modelBuilder; this.client = client; @@ -74,6 +90,7 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); + this.clusterService = Objects.requireNonNull(clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..9698ee4c0d4bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,8 +98,16 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public VoyageAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 70499c7987965..812cd1e3c6d7f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -21,6 +21,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -33,7 +35,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -88,7 +89,7 @@ public void setUp() throws Exception { licenseState = mock(); modelRegistry = mock(); serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); + inferenceStats = InferenceStatsTests.mockInferenceStats(); streamingTaskManager = mock(); action = createAction( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 4d986cf0a837f..547078d93acc4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportException; @@ -22,7 +23,6 @@ import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.common.RateLimitAssignment; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index f26d0675487a5..9e6f4a6260936 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; @@ -20,7 +21,6 @@ import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.Optional; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 5b4925d8fb0a3..e96fda569aa12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -49,6 +49,8 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -66,7 +68,6 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -148,7 +149,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(), @@ -181,7 +182,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -227,7 +228,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -275,7 +276,7 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -364,7 +365,7 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -440,7 +441,7 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -495,7 +496,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -559,7 +560,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -677,7 +678,7 @@ public void testIndexingPressure() throws Exception { @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() ); @@ -709,7 +710,10 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); assertThat( doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to update source on document [doc_1]") + containsString( + "Unable to insert inference results into document [doc_1]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -762,7 +766,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (length(doc1Source) + 1) + "b").build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -791,7 +795,10 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); assertThat( doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_1]") + containsString( + "Unable to insert inference results into document [doc_1]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -875,7 +882,7 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), @@ -902,7 +909,10 @@ public void testIndexingPressurePartialFailure() throws Exception { BulkItemResponse.Failure doc2Failure = doc2Response.getFailure(); assertThat( doc2Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_2]") + containsString( + "Unable to insert inference results into document [doc_2]" + + " due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes." + ) ); assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index aeb09af03ebab..4a4c59f091abf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.custom.CustomModel; import org.junit.After; import org.junit.Assume; import org.junit.Before; @@ -141,7 +140,7 @@ public boolean isEnabled() { return true; } - protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); } private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { @@ -151,7 +150,7 @@ public boolean isEnabled() { } @Override - protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { throw new UnsupportedOperationException("Update model tests are disabled"); } }; @@ -351,11 +350,17 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr assertThat( exception.getMessage(), - containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType)) + containsString( + Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType) + ) ); } } + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "service does not support task type [%s]"; + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var parseConfigTestConfig = testConfiguration.commonConfig; @@ -374,7 +379,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -396,7 +401,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServ persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -413,7 +418,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -430,7 +435,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -468,7 +473,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) ); - assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:")); + assertThat(exception.getMessage(), containsString("Can't update embedding details for model")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..7457859a64603 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; @@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -64,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); @@ -84,7 +86,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); @@ -102,8 +104,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } private static final class TestSenderService extends SenderService { - TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 8fbbd33d569e4..f0258e9f66ed5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -91,7 +91,13 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -116,7 +122,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -143,7 +155,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -169,7 +187,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -190,7 +214,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -210,7 +240,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -235,7 +271,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -262,7 +304,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -279,7 +321,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -316,7 +358,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -360,7 +402,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -404,7 +446,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -452,7 +494,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -482,7 +524,13 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { String content = XContentHelper.stripWhitespace( """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index a014f27e7f0cc..c3b1cab4b4e0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -959,7 +959,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1007,7 +1014,12 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); @@ -1042,7 +1054,14 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var results = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1088,7 +1107,14 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); requestSender.enqueue(mockResults); @@ -1132,7 +1158,14 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = AmazonBedrockChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1166,7 +1199,14 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var embeddingSize = randomNonNegativeInt(); var provider = randomFrom(AmazonBedrockProvider.values()); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -1205,7 +1245,12 @@ public void testInfer_UnauthorizedResponse() throws IOException { ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { requestSender.enqueue( @@ -1240,7 +1285,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { } public void testSupportsStreaming() throws IOException { - try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1284,7 +1329,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { var mockResults1 = new TextEmbeddingFloatResults( @@ -1345,7 +1397,12 @@ private AmazonBedrockService createAmazonBedrockService() { ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + return new AmazonBedrockService( + mock(HttpRequestSender.Factory.class), + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index a3f0b01901009..9111866d29c88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -453,7 +453,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -486,7 +486,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", @@ -579,7 +579,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AnthropicChatCompletionModelTests.createChatCompletionModel( getUrl(webServer), "secret", @@ -679,13 +679,13 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AnthropicService createServiceWithMockSender() { - return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index fee2adcf664ec..3383762a9f332 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1073,7 +1073,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws I public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1098,7 +1098,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1152,7 +1152,7 @@ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginal private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1185,7 +1185,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1223,7 +1223,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1293,7 +1293,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1379,7 +1379,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep public void testInfer_WithChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); var model = AzureAiStudioChatCompletionModelTests.createModel( @@ -1416,7 +1416,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { public void testInfer_WithRerankModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); var model = AzureAiStudioRerankModelTests.createModel( @@ -1457,7 +1457,7 @@ public void testInfer_WithRerankModel() throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1534,7 +1534,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( "id", getUrl(webServer), @@ -1666,7 +1666,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1675,7 +1675,11 @@ public void testSupportsStreaming() throws IOException { // ---------------------------------------------------------------- private AzureAiStudioService createService() { - return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index de2e9ae9a21b8..f3d65c5589169 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -752,7 +752,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -785,7 +785,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep public void testInfer_SendsRequest() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -844,7 +844,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues(); assertThrows( ElasticsearchStatusException.class, @@ -864,7 +864,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureOpenAiEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -891,7 +891,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -952,7 +952,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createCompletionModel( "resource", "deployment", @@ -1209,14 +1209,18 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AzureOpenAiService createAzureOpenAiService() { - return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureOpenAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 52e4f904a4de0..8f189baa33b20 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -779,7 +779,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new CohereService(factory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -812,7 +812,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -886,7 +886,7 @@ public void testInfer_SendsRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -906,7 +906,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(CohereEmbeddingType.values()); var model = CohereEmbeddingsModelTests.createModel( @@ -933,7 +933,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -975,7 +975,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1051,7 +1051,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1125,7 +1125,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1200,7 +1200,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1297,7 +1297,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1387,7 +1387,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1507,7 +1507,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1591,7 +1591,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1632,7 +1632,7 @@ private Map getRequestConfigMap(Map serviceSetti } private CohereService createCohereService() { - return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index cc1bb4471c0a9..a707030a34189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -53,6 +53,7 @@ import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -148,7 +149,7 @@ private static void assertCompletionModel(Model model) { public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + return new CustomService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static Map createServiceSettingsMap(TaskType taskType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index af38ee38e1eff..908451b8e681f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -360,7 +360,8 @@ public void testDoChunkedInferAlwaysFails() throws IOException { private DeepSeekService createService() { return new DeepSeekService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6ce484954d3ce..94d1e064648ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1427,7 +1427,8 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1456,7 +1457,8 @@ private ElasticInferenceService createService( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1469,7 +1471,8 @@ private ElasticInferenceService createServiceWithAuthHandler( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d2c22cdcf6f57..6c01145701d92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.support.ActionTestUtils; @@ -37,6 +38,8 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.telemetry.InferenceStats; +import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -57,10 +60,12 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; @@ -98,6 +103,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -113,6 +119,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -124,12 +131,16 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { - String randomInferenceEntityId = randomAlphaOfLength(10); + private String randomInferenceEntityId; + private InferenceStats inferenceStats; private static ThreadPool threadPool; @Before - public void setUpThreadPool() { + public void setUp() throws Exception { + super.setUp(); + randomInferenceEntityId = randomAlphaOfLength(10); + inferenceStats = InferenceStatsTests.mockInferenceStats(); threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); } @@ -1813,7 +1824,8 @@ public void testUpdateWithoutMlEnabled() throws IOException, InterruptedExceptio mock(), threadPool, cs, - Settings.builder().put("xpack.ml.enabled", false).build() + Settings.builder().put("xpack.ml.enabled", false).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { var models = List.of(mock(Model.class)); @@ -1855,7 +1867,8 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { client, threadPool, cs, - Settings.builder().put("xpack.ml.enabled", true).build() + Settings.builder().put("xpack.ml.enabled", true).build(), + inferenceStats ); try (var service = new ElasticsearchInternalService(context)) { List models = List.of(model); @@ -1869,7 +1882,82 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { } public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { - var model = new ElserInternalModel( + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat(exception.getMessage(), is("failed")); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("504")); + })); + } + } + + public void testStart_OnFailure_WhenDeploymentTimeoutOccurs() throws IOException { + var model = mockModel(); + + var client = mockClientForStart( + listener -> listener.onFailure(new ElasticsearchTimeoutException("failed", RestStatus.GATEWAY_TIMEOUT)) + ); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ModelDeploymentTimeoutException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat( + exception.getMessage(), + is( + "Timed out after [30s] waiting for trained model deployment for inference endpoint [inference_id] to start. " + + "The inference endpoint can not be used to perform inference until the deployment has started. " + + "Use the trained model stats API to track the state of the deployment." + ) + ); + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertThat(attributes.get("error.type"), is("408")); + })); + } + } + + public void testStart() throws IOException { + var model = mockModel(); + + var client = mockClientForStart(listener -> { + var response = mock(CreateTrainedModelAssignmentAction.Response.class); + when(response.getTrainedModelAssignment()).thenReturn(TrainedModelAssignmentTests.randomInstance()); + listener.onResponse(response); + }); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + assertTrue(actionListener.actionGet(TimeValue.timeValueSeconds(30))); + + verify(inferenceStats.deploymentDuration()).record(anyLong(), assertArg(attributes -> { + assertNotNull(attributes); + assertNull(attributes.get("error.type")); + assertThat(attributes.get("status_code"), is(200)); + })); + } + } + + private ElserInternalModel mockModel() { + return new ElserInternalModel( "inference_id", TaskType.SPARSE_EMBEDDING, "elasticsearch", @@ -1879,7 +1967,9 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { new ElserMlNodeTaskSettings(), null ); + } + private Client mockClientForStart(Consumer> startModelListener) { var client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -1895,27 +1985,18 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(2); - listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); + startModelListener.accept(listener); return Void.TYPE; }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); - try (var service = createService(client)) { - var actionListener = new PlainActionFuture(); - service.start(model, TimeValue.timeValueSeconds(30), actionListener); - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) - ); - - assertThat(exception.getMessage(), is("failed")); - } + return client; } private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); when(cs.getClusterSettings()).thenReturn(cSettings); - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY, inferenceStats); return new ElasticsearchInternalService(context); } @@ -1924,7 +2005,8 @@ private ElasticsearchInternalService createService(Client client, BaseElasticsea client, threadPool, mock(ClusterService.class), - Settings.EMPTY + Settings.EMPTY, + inferenceStats ); return new ElasticsearchInternalService(context, l -> l.onResponse(modelVariant)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 41175581df1cf..435ea9de5911b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -658,7 +658,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -696,7 +696,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD var model = GoogleAiStudioEmbeddingsModelTests.createModel("model", getUrl(webServer), "secret"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -730,7 +730,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "candidates": [ @@ -818,7 +818,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -897,7 +897,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -998,7 +998,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1033,7 +1033,7 @@ public void testInfer_ResourceNotFound() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1052,7 +1052,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = GoogleAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1171,6 +1171,10 @@ private Map getRequestConfigMap( } private GoogleAiStudioService createGoogleAiStudioService() { - return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new GoogleAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 99a09b983787d..26fd076e72462 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -1043,7 +1043,7 @@ public void testGetConfiguration() throws Exception { private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 3be4b72c1237f..2cdf3f5263751 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -29,6 +29,7 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; import static org.mockito.Mockito.mock; @@ -92,7 +93,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep private static final class TestService extends HuggingFaceService { TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 814d533129439..93156d4331263 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -81,7 +81,7 @@ public void shutdown() throws IOException { public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -137,7 +137,8 @@ public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ) ) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index e2850910ac64a..c770672c5d5f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -258,7 +258,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -328,7 +328,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -357,7 +357,7 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -486,7 +486,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -548,7 +548,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -621,7 +621,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1009,7 +1009,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1060,7 +1060,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1087,7 +1087,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { public void testInfer_SendsElserRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1139,7 +1139,7 @@ public void testInfer_SendsElserRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1158,7 +1158,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = HuggingFaceEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1179,7 +1179,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1233,7 +1233,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1340,7 +1340,11 @@ public void testGetConfiguration() throws Exception { } private HuggingFaceService createHuggingFaceService() { - return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new HuggingFaceService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 3295ecfd4ece5..ddc62b5a412b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -597,7 +597,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -635,7 +635,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1018,12 +1018,12 @@ private Map getRequestConfigMap( } private IbmWatsonxService createIbmWatsonxService() { - return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index eca76bc1a702a..d36c574e0aa99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -778,7 +778,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -819,7 +819,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( @@ -846,7 +846,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -889,7 +889,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -923,7 +923,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -994,7 +994,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_clustering() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, @@ -1120,7 +1120,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1210,7 +1210,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1295,7 +1295,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1392,7 +1392,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1475,7 +1475,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1540,7 +1540,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1637,7 +1637,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1800,7 +1800,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1841,7 +1841,7 @@ private Map getRequestConfigMap(Map serviceSetti } private JinaAIService createJinaAIService() { - return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java new file mode 100644 index 0000000000000..dd68c43f5e62d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -0,0 +1,840 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class LlamaServiceTests extends AbstractInferenceServiceTests { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public LlamaServiceTests() { + super(createTestConfiguration()); + } + + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return LlamaServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return LlamaServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return LlamaServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType) { + LlamaServiceTests.assertModel(model, taskType); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION -> assertCompletionModel(model); + case CHAT_COMPLETION -> assertChatCompletionModel(model); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model) { + var llamaModel = assertCommonModelFields(model); + + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); + } + + private static LlamaModel assertCommonModelFields(Model model) { + assertThat(model, instanceOf(LlamaModel.class)); + + var llamaModel = (LlamaModel) model; + assertThat(llamaModel.getServiceSettings().modelId(), is("model_id")); + assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com")); + assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); + assertThat( + ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), + Matchers.is(new SecureString("secret".toCharArray())) + ); + + return llamaModel; + } + + private static void assertCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); + } + + private static void assertChatCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + Map settingsMap = new HashMap<>( + Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id") + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ); + } + + return settingsMap; + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of("api_key", "secret")); + } + + private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { + var inferenceId = "inference_id"; + + return new LlamaEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + LlamaService.NAME, + new LlamaEmbeddingsServiceSettings( + "model_id", + "http://www.abc.com", + 1536, + similarityMeasure, + 512, + new RateLimitSettings(10_000) + ), + ChunkingSettingsTests.createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } + + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "Failed to parse stored model [id] for [llama] service, please delete and add the service again"; + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException { + var model = "model"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", + "choices": [{ + "delta": { + "content": "Deep", + "role": "assistant" + }, + "index": 0 + } + ], + "model": "llama3.2:3b", + "object": "chat.completion.chunk" + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [id] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "llama_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "stream_error", + "message": "Received an error response for request from inference entity id [id].\ + Error message: [400: Invalid value: Model 'llama3.12:3b' not found]", + "type": "llama_error" + } + } + """)); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "detail": "Not Found" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); + assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """ + Resource not found at [%s] for request from inference entity id [id] status [404]. Error message: [{ + "detail": "Not Found" + }]""", getUrl(webServer)))); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new LlamaService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModelWithChunkingSettings("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "embeddings": [ + [ + 0.010060793, + -0.0017529363 + ], + [ + 0.110060793, + -0.1017529363 + ] + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.010060793f, -0.0017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.110060793f, -0.1017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("contents"), Matchers.is(List.of("abc", "def"))); + assertThat(requestMap.get("model_id"), Matchers.is("id")); + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "llama", + "name": "Llama", + "task_types": ["text_embedding", "completion", "chat_completion"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "model_id": { + "description": "Refer to the Llama models documentation for the list of available models.", + "label": "Model", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "url": { + "description": "The URL endpoint to use for the requests.", + "label": "URL", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private LlamaService createService() { + return new LlamaService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private static Map getEmbeddingsServiceSettingsMap() { + return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java new file mode 100644 index 0000000000000..366e0926f0daa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class LlamaActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + [ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + ] + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello there, how may I assist you today?", + "refusal": null, + "role": "assistant", + "annotations": null, + "audio": null, + "function_call": null, + "tool_calls": null + } + } + ], + "created": 1750157476, + "model": "llama3.2:3b", + "object": "chat.completion", + "service_tier": null, + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 30, + "total_tokens": 40, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertCompletionRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]") + ); + + assertCompletionRequest(); + } + } + + private PlainActionFuture createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + return listener; + } + + private PlainActionFuture createCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertCompletionRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + + @SuppressWarnings("unchecked") + private void assertEmbeddingsRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("contents"), instanceOf(List.class)); + var inputList = (List) requestMap.get("contents"); + assertThat(inputList, contains("abc")); + } + + private void assertCommonRequestProperties() { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java new file mode 100644 index 0000000000000..844d17addac6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionModelTests extends ESTestCase { + + public static LlamaChatCompletionModel createCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModelNoAuth(String modelId, String url) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + EmptySecretSettings.INSTANCE + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "model_name", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..c9b6069d383ed --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LlamaChatCompletionResponseHandlerTests extends ESTestCase { + private final LlamaChatCompletionResponseHandler responseHandler = new LlamaChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "llama_error" + } + }"""))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "error": { + "detail": { + "errors": [{ + "loc": [ + "body", + "messages" + ], + "msg": "Field required", + "type": "missing" + } + ] + } + } + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a bad request status code for request from inference entity id [id] status [400].\ + Error message: [{\\"error\\":{\\"detail\\":{\\"errors\\":[{\\"loc\\":[\\"body\\",\\"messages\\"],\\"msg\\":\\"Field\ + required\\",\\"type\\":\\"missing\\"}]}}}]", + "type": "llama_error" + } + } + """))); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson, 500); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \ + [what? this isn't a json\\n]", + "type": "llama_error" + } + } + """))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions")); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..21b42453d9c39 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT)))); + } + + public void testFromMap_MissingModelId_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaChatCompletionServiceSettings::new; + } + + @Override + protected LlamaChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstance(LlamaChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaChatCompletionServiceSettingsTests::createRandom); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstanceForVersion( + LlamaChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static LlamaChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + return new LlamaChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model, String url) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + map.put(ServiceFields.URL, url); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java new file mode 100644 index 0000000000000..4e75cab196a6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; + +public class LlamaEmbeddingsModelTests extends ESTestCase { + public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelWithChunkingSettings(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, String url) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + EmptySecretSettings.INSTANCE + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..5fd3ce704540c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java @@ -0,0 +1,479 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL_ID = "some model"; + private static final String CORRECT_URL = "https://www.elastic.co"; + private static final int DIMENSIONS = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS = 128; + private static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_NoModelId_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + null, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_NoUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + null, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_EmptyUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;") + ); + } + + public void testFromMap_InvalidUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "^^^", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. " + + "Error: unable to parse url [^^^]. Reason: Illegal character in path;" + ) + ); + } + + public void testFromMap_NoSimilarity_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + null, + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + null, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + "by_size", + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. " + + "[similarity] must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testFromMap_NoDimensions_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + null, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + null, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + 0, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NegativeDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + -10, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NoInputTokens_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + null, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + null, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + 0, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NegativeInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + -10, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NoRateLimit_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3000) + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "dimensions": 384, + "similarity": "dot_product", + "max_input_tokens": 128, + "rate_limit": { + "requests_per_minute": 3 + } + } + """))); + } + + public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { + var outputBuffer = new BytesStreamOutput(); + var settings = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + settings.writeTo(outputBuffer); + + var outputBufferRef = outputBuffer.bytes(); + var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); + + var settingsFromBuffer = new LlamaEmbeddingsServiceSettings(inputBuffer); + + assertEquals(settings, settingsFromBuffer); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaEmbeddingsServiceSettings::new; + } + + @Override + protected LlamaEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaEmbeddingsServiceSettings mutateInstance(LlamaEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaEmbeddingsServiceSettingsTests::createRandom); + } + + private static LlamaEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + var similarityMeasure = randomFrom(SimilarityMeasure.values()); + var dimensions = randomIntBetween(32, 256); + var maxInputTokens = randomIntBetween(128, 256); + return new LlamaEmbeddingsServiceSettings( + modelId, + url, + dimensions, + similarityMeasure, + maxInputTokens, + RateLimitSettingsTests.createRandom() + ); + } + + public static HashMap buildServiceSettingsMap( + @Nullable String modelId, + @Nullable String url, + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable HashMap rateLimitSettings + ) { + HashMap result = new HashMap<>(); + if (modelId != null) { + result.put(ServiceFields.MODEL_ID, modelId); + } + if (url != null) { + result.put(ServiceFields.URL, url); + } + if (similarity != null) { + result.put(ServiceFields.SIMILARITY, similarity); + } + if (dimensions != null) { + result.put(ServiceFields.DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (rateLimitSettings != null) { + result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); + } + return result; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..dd8b3d7dfa38c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.ArrayList; + +public class LlamaChatCompletionRequestEntityTests extends ESTestCase { + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key"); + + LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String expectedJson = """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java new file mode 100644 index 0000000000000..6f0701a810fb1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequest("model", "url", "secret", input, true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(true)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + } + + public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequestWithNoAuth("model", "url", input, false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(false)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertNull(requestMap.get("stream_options")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_DoesNotReduceInputTextSize() { + String input = randomAlphaOfLength(5); + var request = createRequest("model", "url", "secret", input, true); + assertThat(request.truncate(), is(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true); + assertNull(request.getTruncationInfo()); + } + + public static LlamaChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModel(modelId, url, apiKey); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + + public static LlamaChatCompletionRequest createRequestWithNoAuth(String modelId, String url, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModelNoAuth(modelId, url); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..a055a0870e30d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class LlamaEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_Success() throws IOException { + var entity = new LlamaEmbeddingsRequestEntity("llama-embed", List.of("ABDC")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "model_id": "llama-embed", + "contents": ["ABDC"] + } + """))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..ab24fa9a0bc56 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsRequestTests extends ESTestCase { + + public void testCreateRequest_WithAuth_Success() throws IOException { + var request = createRequest(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + } + + public void testCreateRequest_NoAuth_Success() throws IOException { + var request = createRequestNoAuth(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest(); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("AB"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest(); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + return httpPost; + } + + private static LlamaEmbeddingsRequest createRequest() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + + private static LlamaEmbeddingsRequest createRequestNoAuth() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModelNoAuth("llama-embed", "url"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java new file mode 100644 index 0000000000000..aa3c6f6c20b6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class LlamaErrorResponseTests extends ESTestCase { + + public static final String ERROR_RESPONSE_JSON = """ + { + "error": "A valid user token is required" + } + """; + + public void testFromResponse() { + var errorResponse = LlamaErrorResponse.fromResponse( + new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 4ba9b8aa24394..8e170b25393e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -249,7 +249,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -308,7 +308,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -353,7 +353,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -421,7 +421,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -459,7 +459,7 @@ public void testInfer_StreamRequest_ErrorResponse() { } public void testSupportsStreaming() throws IOException { - try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -942,7 +942,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = new Model(ModelConfigurationsTests.createRandomInstance()); assertThrows( @@ -962,7 +962,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = MistralEmbeddingModelTests.createModel( randomAlphaOfLength(10), @@ -990,7 +990,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1028,7 +1028,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1086,7 +1086,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1173,7 +1173,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1276,7 +1276,7 @@ public void testGetConfiguration() throws Exception { // ---------------------------------------------------------------- private MistralService createService() { - return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java index 6f8b40fd7f19c..9aa076e224efe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -37,7 +36,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -57,7 +55,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java index 4a70861932d28..2c8fb4fd48698 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java @@ -49,7 +49,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(4)); - // We do not truncate for Hugging Face chat completions + // We do not truncate for Mistral chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c19eb664e88ac..83455861198d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -847,7 +847,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -885,7 +885,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -924,7 +924,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -965,7 +965,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1003,7 +1003,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1099,7 +1099,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1132,7 +1132,7 @@ public void testUnifiedCompletionError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -1189,7 +1189,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1267,7 +1267,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1344,7 +1344,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1400,7 +1400,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1485,7 +1485,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // response with 2 embeddings String responseJson = """ @@ -1656,6 +1656,6 @@ public void testGetConfiguration() throws Exception { } private OpenAiService createOpenAiService() { - return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7d9473f18084..bf883a6345398 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -47,6 +47,7 @@ import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -84,7 +85,7 @@ public void init() { ThreadPool threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty()); } public void testSupportedTaskTypes() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8602621e9eb78..72a3b530ab647 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -718,7 +718,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -763,7 +763,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept "voyage-3-large" ); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -806,7 +806,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = VoyageAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -831,7 +831,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -873,7 +873,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -907,7 +907,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -989,7 +989,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1071,7 +1071,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1163,7 +1163,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1251,7 +1251,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1345,7 +1345,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1423,7 +1423,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1490,7 +1490,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1599,7 +1599,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1745,7 +1745,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1786,7 +1786,7 @@ private Map getRequestConfigMap(Map serviceSetti } private VoyageAIService createVoyageAIService() { - return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml index 021dfe320d78e..60dea800ca624 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml @@ -35,6 +35,23 @@ setup: } } + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-compatible-with-bbq + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 64, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-sparse-index @@ -70,7 +87,7 @@ setup: id: doc_1 body: title: "Elasticsearch" - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - do: @@ -89,14 +106,14 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: - another_body: {} + another_body: { } - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } - - not_exists: hits.hits.0.highlight.another_body + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - not_exists: hits.hits.0.highlight.another_body --- "Highlighting using a sparse embedding model": @@ -114,10 +131,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -133,11 +150,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -154,10 +171,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -196,10 +213,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -215,11 +232,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -236,10 +253,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -256,17 +273,17 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "Default highlighter for fields": - requires: - cluster_features: "semantic_text.highlighter.default" - reason: semantic text field defaults to the semantic highlighter + cluster_features: "semantic_text.highlighter.default" + reason: semantic text field defaults to the semantic highlighter - do: search: @@ -281,11 +298,11 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "semantic highlighter ignores non-inference fields": @@ -306,8 +323,8 @@ setup: type: semantic number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - not_exists: hits.hits.0.highlight.title --- @@ -333,7 +350,7 @@ setup: index: test-multi-chunk-index id: doc_1 body: - semantic_text_field: ["some test data", " ", "now with chunks"] + semantic_text_field: [ "some test data", " ", "now with chunks" ] refresh: true - do: @@ -367,25 +384,25 @@ setup: index: test-sparse-index body: query: - match_all: {} + match_all: { } highlight: fields: body: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: body: @@ -432,18 +449,18 @@ setup: index: test-index-sparse body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.semantic_text_field: 2 } - - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } + - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } - do: indices.create: @@ -473,7 +490,7 @@ setup: index: test-index-dense body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: @@ -485,3 +502,172 @@ setup: - length: { hits.hits.0.highlight.semantic_text_field: 2 } - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + +--- +"Highlighting with flat quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-flat + body: + settings: + index.mapping.semantic_text.use_legacy_format: false + mappings: + properties: + flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: flat + int4_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_flat + int8_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_flat + bbq_flat_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_flat + + + - do: + index: + index: test-dense-index-flat + id: doc_1 + body: + flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-flat + body: + query: + match_all: { } + highlight: + fields: + flat_field: + type: "semantic" + number_of_fragments: 1 + int4_flat_field: + type: "semantic" + number_of_fragments: 1 + int8_flat_field: + type: "semantic" + number_of_fragments: 1 + bbq_flat_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.flat_field: 1 } + - match: { hits.hits.0.highlight.flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_flat_field: 1 } + - match: { hits.hits.0.highlight.int4_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_flat_field: 1 } + - match: { hits.hits.0.highlight.int8_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_flat_field: 1 } + - match: { hits.hits.0.highlight.bbq_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + +--- +"Highlighting with HNSW quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-hnsw + body: + settings: + index.mapping.semantic_text.use_legacy_format: false + mappings: + properties: + hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: hnsw + int4_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_hnsw + int8_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_hnsw + bbq_hnsw_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_hnsw + + + - do: + index: + index: test-dense-index-hnsw + id: doc_1 + body: + hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-hnsw + body: + query: + match_all: { } + highlight: + fields: + hnsw_field: + type: "semantic" + number_of_fragments: 1 + int4_hnsw_field: + type: "semantic" + number_of_fragments: 1 + int8_hnsw_field: + type: "semantic" + number_of_fragments: 1 + bbq_hnsw_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.hnsw_field: 1 } + - match: { hits.hits.0.highlight.hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int4_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int8_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_hnsw_field: 1 } + - match: { hits.hits.0.highlight.bbq_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml index 1e874d60a016c..4675977842973 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter_bwc.yml @@ -35,6 +35,23 @@ setup: } } + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-compatible-with-bbq + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 64, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-sparse-index @@ -65,12 +82,12 @@ setup: --- "Highlighting empty field": - do: - index: - index: test-dense-index - id: doc_1 - body: - body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] - refresh: true + index: + index: test-dense-index + id: doc_1 + body: + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true - match: { result: created } @@ -79,14 +96,14 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: - another_body: {} + another_body: { } - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } - - not_exists: hits.hits.0.highlight.another_body + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - not_exists: hits.hits.0.highlight.another_body --- "Highlighting using a sparse embedding model": @@ -95,7 +112,7 @@ setup: index: test-sparse-index id: doc_1 body: - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - match: { result: created } @@ -114,10 +131,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -133,11 +150,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -154,10 +171,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - do: search: @@ -187,7 +204,7 @@ setup: index: test-dense-index id: doc_1 body: - body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + body: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] refresh: true - match: { result: created } @@ -206,10 +223,10 @@ setup: type: "semantic" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -225,11 +242,11 @@ setup: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: search: @@ -246,10 +263,10 @@ setup: order: "score" number_of_fragments: 1 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 1 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - do: search: @@ -266,11 +283,11 @@ setup: order: "score" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } - - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } --- "Highlighting and multi chunks with empty input": @@ -295,7 +312,7 @@ setup: index: test-multi-chunk-index id: doc_1 body: - semantic_text_field: ["some test data", " ", "now with chunks"] + semantic_text_field: [ "some test data", " ", "now with chunks" ] refresh: true - do: @@ -337,18 +354,18 @@ setup: index: test-sparse-index body: query: - match_all: {} + match_all: { } highlight: fields: body: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.body: 2 } - - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } - do: index: @@ -363,7 +380,7 @@ setup: index: test-dense-index body: query: - match_all: {} + match_all: { } highlight: fields: body: @@ -410,18 +427,18 @@ setup: index: test-index-sparse body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: type: "semantic" number_of_fragments: 2 - - match: { hits.total.value: 1 } - - match: { hits.hits.0._id: "doc_1" } + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.semantic_text_field: 2 } - - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } + - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } + - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } - do: indices.create: @@ -451,7 +468,7 @@ setup: index: test-index-dense body: query: - match_all: {} + match_all: { } highlight: fields: semantic_text_field: @@ -464,3 +481,173 @@ setup: - match: { hits.hits.0.highlight.semantic_text_field.0: "some test data" } - match: { hits.hits.0.highlight.semantic_text_field.1: "now with chunks" } +--- +"Highlighting with flat quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-flat + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: flat + int4_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_flat + int8_flat_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_flat + bbq_flat_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_flat + + + - do: + index: + index: test-dense-index-flat + id: doc_1 + body: + flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_flat_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-flat + body: + query: + match_all: { } + highlight: + fields: + flat_field: + type: "semantic" + number_of_fragments: 1 + int4_flat_field: + type: "semantic" + number_of_fragments: 1 + int8_flat_field: + type: "semantic" + number_of_fragments: 1 + bbq_flat_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.flat_field: 1 } + - match: { hits.hits.0.highlight.flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_flat_field: 1 } + - match: { hits.hits.0.highlight.int4_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_flat_field: 1 } + - match: { hits.hits.0.highlight.int8_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_flat_field: 1 } + - match: { hits.hits.0.highlight.bbq_flat_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + +--- +"Highlighting with HNSW quantization index options": + - requires: + cluster_features: "semantic_text.highlighter.flat_index_options" + reason: semantic highlighter fix for flat index options + + - do: + indices.create: + index: test-dense-index-hnsw + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: hnsw + int4_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int4_hnsw + int8_hnsw_field: + type: semantic_text + inference_id: dense-inference-id + index_options: + dense_vector: + type: int8_hnsw + bbq_hnsw_field: + type: semantic_text + inference_id: dense-inference-id-compatible-with-bbq + index_options: + dense_vector: + type: bbq_hnsw + + + - do: + index: + index: test-dense-index-hnsw + id: doc_1 + body: + hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int4_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + int8_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + bbq_hnsw_field: [ "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!" ] + refresh: true + + - do: + search: + index: test-dense-index-hnsw + body: + query: + match_all: { } + highlight: + fields: + hnsw_field: + type: "semantic" + number_of_fragments: 1 + int4_hnsw_field: + type: "semantic" + number_of_fragments: 1 + int8_hnsw_field: + type: "semantic" + number_of_fragments: 1 + bbq_hnsw_field: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight: 4 } + - length: { hits.hits.0.highlight.hnsw_field: 1 } + - match: { hits.hits.0.highlight.hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int4_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int4_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.int8_hnsw_field: 1 } + - match: { hits.hits.0.highlight.int8_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.bbq_hnsw_field: 1 } + - match: { hits.hits.0.highlight.bbq_hnsw_field.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + + diff --git a/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java b/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java index 5eafb858eacbe..3658313642700 100644 --- a/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java +++ b/x-pack/plugin/mapper-aggregate-metric/src/main/java/org/elasticsearch/xpack/aggregatemetric/mapper/AggregateMetricDoubleFieldMapper.java @@ -571,20 +571,24 @@ public String toString() { } @Override - public Block read(BlockFactory factory, Docs docs) throws IOException { - try (var builder = factory.aggregateMetricDoubleBuilder(docs.count())) { - copyDoubleValuesToBuilder(docs, builder.min(), minValues); - copyDoubleValuesToBuilder(docs, builder.max(), maxValues); - copyDoubleValuesToBuilder(docs, builder.sum(), sumValues); - copyIntValuesToBuilder(docs, builder.count(), valueCountValues); + public Block read(BlockFactory factory, Docs docs, int offset) throws IOException { + try (var builder = factory.aggregateMetricDoubleBuilder(docs.count() - offset)) { + copyDoubleValuesToBuilder(docs, offset, builder.min(), minValues); + copyDoubleValuesToBuilder(docs, offset, builder.max(), maxValues); + copyDoubleValuesToBuilder(docs, offset, builder.sum(), sumValues); + copyIntValuesToBuilder(docs, offset, builder.count(), valueCountValues); return builder.build(); } } - private void copyDoubleValuesToBuilder(Docs docs, BlockLoader.DoubleBuilder builder, NumericDocValues values) - throws IOException { + private void copyDoubleValuesToBuilder( + Docs docs, + int offset, + BlockLoader.DoubleBuilder builder, + NumericDocValues values + ) throws IOException { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); @@ -600,10 +604,10 @@ private void copyDoubleValuesToBuilder(Docs docs, BlockLoader.DoubleBuilder buil } } - private void copyIntValuesToBuilder(Docs docs, BlockLoader.IntBuilder builder, NumericDocValues values) + private void copyIntValuesToBuilder(Docs docs, int offset, BlockLoader.IntBuilder builder, NumericDocValues values) throws IOException { int lastDoc = -1; - for (int i = 0; i < docs.count(); i++) { + for (int i = offset; i < docs.count(); i++) { int doc = docs.get(i); if (doc < lastDoc) { throw new IllegalStateException("docs within same block must be in order"); diff --git a/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java b/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java index c0c2db53b97e9..1a94ca1b8d40a 100644 --- a/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java +++ b/x-pack/plugin/mapper-constant-keyword/src/test/java/org/elasticsearch/xpack/constantkeyword/mapper/ConstantKeywordFieldMapperTests.java @@ -276,7 +276,7 @@ public FieldNamesFieldMapper.FieldNamesFieldType fieldNames() { iw.close(); try (DirectoryReader reader = DirectoryReader.open(directory)) { TestBlock block = (TestBlock) loader.columnAtATimeReader(reader.leaves().get(0)) - .read(TestBlock.factory(reader.numDocs()), new BlockLoader.Docs() { + .read(TestBlock.factory(), new BlockLoader.Docs() { @Override public int count() { return 1; @@ -286,7 +286,7 @@ public int count() { public int get(int i) { return 0; } - }); + }, 0); assertThat(block.get(0), nullValue()); } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java index e15a1d36bdb9f..dad16d3cfa83b 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/system_indices/task/SystemIndexMigrationExecutor.java @@ -9,10 +9,12 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.indices.SystemIndices; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTaskParams; @@ -86,10 +88,11 @@ protected AllocatedPersistentTask createTask( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( SystemIndexMigrationTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { // This should select from master-eligible nodes because we already require all master-eligible nodes to have all plugins installed. // However, due to a misunderstanding, this code as-written needs to run on the master node in particular. This is not a fundamental diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index a15a733cac6c7..be905caeacba0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -30,6 +31,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.injection.guice.Inject; @@ -690,10 +692,11 @@ protected AllocatedPersistentTask createTask( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TaskParams params, Collection candidateNodes, - @SuppressWarnings("HiddenField") ClusterState clusterState + @SuppressWarnings("HiddenField") ClusterState clusterState, + @Nullable ProjectId projectId ) { boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); Optional optionalAssignment = getPotentialAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java index f45c92d3466c6..7a636e18017e1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -494,10 +495,11 @@ public StartDatafeedPersistentTasksExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( StartDatafeedAction.DatafeedParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { return new DatafeedNodeSelector( clusterState, @@ -510,7 +512,7 @@ public PersistentTasksCustomMetadata.Assignment getAssignment( } @Override - public void validate(StartDatafeedAction.DatafeedParams params, ClusterState clusterState) { + public void validate(StartDatafeedAction.DatafeedParams params, ClusterState clusterState, @Nullable ProjectId projectId) { new DatafeedNodeSelector( clusterState, resolver, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java index 42f722e330a19..00370dde3e089 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java @@ -15,9 +15,11 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTaskState; @@ -88,10 +90,11 @@ public SnapshotUpgradeTaskExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( SnapshotUpgradeTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); Optional optionalAssignment = getPotentialAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java index 0e517b63f6f60..5621da489da7d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java @@ -17,9 +17,11 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.license.XPackLicenseState; @@ -121,7 +123,12 @@ public OpenJobPersistentTasksExecutor( } @Override - public Assignment getAssignment(OpenJobAction.JobParams params, Collection candidateNodes, ClusterState clusterState) { + protected Assignment doGetAssignment( + OpenJobAction.JobParams params, + Collection candidateNodes, + ClusterState clusterState, + @Nullable ProjectId projectId + ) { Job job = params.getJob(); // If the task parameters do not have a job field then the job // was first opened on a pre v6.6 node and has not been migrated @@ -210,13 +217,13 @@ static void validateJobAndId(String jobId, Job job) { } @Override - public void validate(OpenJobAction.JobParams params, ClusterState clusterState) { + public void validate(OpenJobAction.JobParams params, ClusterState clusterState, @Nullable ProjectId projectId) { final Job job = params.getJob(); final String jobId = params.getJobId(); validateJobAndId(jobId, job); // If we already know that we can't find an ml node because all ml nodes are running at capacity or // simply because there are no ml nodes in the cluster then we fail quickly here: - PersistentTasksCustomMetadata.Assignment assignment = getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + var assignment = getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, projectId); if (assignment.equals(AWAITING_UPGRADE)) { throw makeCurrentlyBeingUpgradedException(logger, params.getJobId()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java index 33fae40f80db6..550352954bfbc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -62,7 +63,7 @@ public void testGetAssignment_UpgradeModeIsEnabled() { .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build())) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat(assignment.getExplanation(), is(equalTo("persistent task cannot be assigned while upgrade mode is enabled."))); } @@ -75,7 +76,7 @@ public void testGetAssignment_NoNodes() { .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat(assignment.getExplanation(), is(emptyString())); } @@ -94,7 +95,7 @@ public void testGetAssignment_NoMlNodes() { ) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(nullValue())); assertThat( assignment.getExplanation(), @@ -116,7 +117,7 @@ public void testGetAssignment_MlNodeIsNewerThanTheMlJobButTheAssignmentSuceeds() .nodes(DiscoveryNodes.builder().add(createNode(0, true, Version.V_7_10_0, MlConfigVersion.V_7_10_0))) .build(); - Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState); + Assignment assignment = executor.getAssignment(params, clusterState.nodes().getAllNodes(), clusterState, ProjectId.DEFAULT); assertThat(assignment.getExecutorNode(), is(equalTo("_node_id0"))); assertThat(assignment.getExplanation(), is(emptyString())); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java index d88e1235241d8..4b1ed557ef287 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.OperationRouting; @@ -173,7 +174,7 @@ public void testGetAssignment_GivenUnavailableIndicesWithLazyNode() { assertEquals( "Not opening [unavailable_index_with_lazy_node], " + "because not all primary shards are active for the following indices [.ml-state]", - executor.getAssignment(params, csBuilder.nodes().getAllNodes(), csBuilder.build()).getExplanation() + executor.getAssignment(params, csBuilder.nodes().getAllNodes(), csBuilder.build(), ProjectId.DEFAULT).getExplanation() ); } @@ -195,7 +196,8 @@ public void testGetAssignment_GivenLazyJobAndNoGlobalLazyNodes() { PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment( params, csBuilder.nodes().getAllNodes(), - csBuilder.build() + csBuilder.build(), + ProjectId.DEFAULT ); assertNotNull(assignment); assertNull(assignment.getExecutorNode()); @@ -216,7 +218,8 @@ public void testGetAssignment_GivenResetInProgress() { PersistentTasksCustomMetadata.Assignment assignment = executor.getAssignment( params, csBuilder.nodes().getAllNodes(), - csBuilder.build() + csBuilder.build(), + ProjectId.DEFAULT ); assertNotNull(assignment); assertNull(assignment.getExecutorNode()); diff --git a/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml b/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml new file mode 100644 index 0000000000000..0d3b1b25a7ea9 --- /dev/null +++ b/x-pack/plugin/otel-data/src/main/resources/component-templates/otel@settings.yaml @@ -0,0 +1,8 @@ +version: ${xpack.oteldata.template.version} +_meta: + description: Default settings for all OpenTelemetry data streams + managed: true +template: + data_stream_options: + failure_store: + enabled: true diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml index 6772ec5bc65d4..929d26e1c30af 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/logs-otel@template.yaml @@ -11,6 +11,7 @@ composed_of: - logs@mappings - logs@settings - otel@mappings + - otel@settings - logs-otel@mappings - semconv-resource-to-ecs@mappings - logs@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml index f8489605ad1bf..a042fc77e6fa3 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml index f5033135120bc..60739559cc9eb 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.10m@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml index 9168062f30bfb..9464936f5e1e5 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.1m@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml index 47c2d7d014322..888a2145073fd 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_destination.60m@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml index c9438e8c27402..36be8cb78d851 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml index b29caa3fe34a7..20d1e3ca65e88 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml index 4cab3e41a1dfa..9bb62ae9edd3b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_summary.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml index 037f3546205d6..ff4780744e216 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml index 303ac2c406fd0..b1037535754f3 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml index ea42079ced4dd..15088a2198abc 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-service_transaction.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml index 81e70cc3361fc..2f6f7e28ffc22 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.10m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml index c54b90bf8b683..5cc1828d3285b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.1m.otel@template.yaml @@ -10,6 +10,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml index 8afe8b87951c0..906e535e2c05b 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/metrics-transaction.60m.otel@template.yaml @@ -11,6 +11,7 @@ _meta: composed_of: - metrics@tsdb-settings - otel@mappings + - otel@settings - metrics-otel@mappings - semconv-resource-to-ecs@mappings - metrics@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml b/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml index 370b9351c16f5..c2e9a68bc72ad 100644 --- a/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/index-templates/traces-otel@template.yaml @@ -11,6 +11,7 @@ composed_of: - traces@mappings - traces@settings - otel@mappings + - otel@settings - traces-otel@mappings - semconv-resource-to-ecs@mappings - traces@custom diff --git a/x-pack/plugin/otel-data/src/main/resources/resources.yaml b/x-pack/plugin/otel-data/src/main/resources/resources.yaml index 6aadfde1683dc..608dc369c34eb 100644 --- a/x-pack/plugin/otel-data/src/main/resources/resources.yaml +++ b/x-pack/plugin/otel-data/src/main/resources/resources.yaml @@ -1,10 +1,11 @@ # "version" holds the version of the templates and ingest pipelines installed # by xpack-plugin otel-data. This must be increased whenever an existing template is # changed, in order for it to be updated on Elasticsearch upgrade. -version: 9 +version: 10 component-templates: - otel@mappings + - otel@settings - logs-otel@mappings - semconv-resource-to-ecs@mappings - metrics-otel@mappings diff --git a/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml b/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml new file mode 100644 index 0000000000000..dfc6d0fc050b0 --- /dev/null +++ b/x-pack/plugin/otel-data/src/yamlRestTest/resources/rest-api-spec/test/20_logs_failure_store_test.yml @@ -0,0 +1,73 @@ +--- +setup: + - do: + cluster.health: + wait_for_events: languid +--- +teardown: + - do: + indices.delete_data_stream: + name: logs-generic.otel-default + ignore: 404 +--- +"Test logs-*.otel-* data streams have failure store enabled by default": + # Index a valid document (string message). + - do: + index: + index: logs-generic.otel-default + refresh: true + body: + '@timestamp': '2023-01-01T12:00:00Z' + severity_text: "INFO" + text: "Application started successfully" + - match: { result: created } + + # Assert empty failure store. + - do: + indices.get_data_stream: + name: logs-generic.otel-default + - match: { data_streams.0.name: logs-generic.otel-default } + - length: { data_streams.0.indices: 1 } + - match: { data_streams.0.failure_store.enabled: true } + - length: { data_streams.0.failure_store.indices: 0 } + + # Index a document with naming alias, causing an error. + - do: + index: + index: logs-generic.otel-default + refresh: true + body: + '@timestamp': '2023-01-01T12:01:00Z' + severity_text: "ERROR" + message: "Application started successfully" + - match: { result: 'created' } + - match: { failure_store: used} + + # Assert failure store containing 1 item. + - do: + indices.get_data_stream: + name: logs-generic.otel-default + - length: { data_streams.0.failure_store.indices: 1 } + + # Assert valid document. + - do: + search: + index: logs-generic.otel-default::data + body: + query: + match_all: {} + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.severity_text: "INFO" } + - match: { hits.hits.0._source.text: "Application started successfully" } + + # Assert invalid document. + - do: + search: + index: logs-generic.otel-default::failures + body: + query: + match_all: {} + - length: { hits.hits: 1 } + - match: { hits.hits.0._source.document.source.severity_text: "ERROR" } + - match: { hits.hits.0._source.document.source.message: "Application started successfully" } + - match: { hits.hits.0._source.error.type: "document_parsing_exception" } diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java index 784f1c1fbe23e..10c1a4321f1e5 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; @@ -166,13 +167,14 @@ protected TaskExecutor(Client client, ClusterService clusterService, ThreadPool } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TestTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + ProjectId projectId ) { candidates.set(candidateNodes); - return super.getAssignment(params, candidateNodes, clusterState); + return super.doGetAssignment(params, candidateNodes, clusterState, projectId); } @Override diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/ContendedRegisterAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/ContendedRegisterAnalyzeAction.java index 13ffea8943f3b..fc03be7853cfa 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/ContendedRegisterAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/ContendedRegisterAnalyzeAction.java @@ -156,6 +156,8 @@ public void onFailure(Exception e) { }; if (request.getInitialRead() > request.getRequestCount()) { + // This is just the initial read, so we can use getRegister() despite its weaker read-after-write semantics: all subsequent + // operations of this action use compareAndExchangeRegister() and do not rely on this value being accurate. blobContainer.getRegister(OperationPurpose.REPOSITORY_ANALYSIS, registerName, initialValueListener.delegateFailure((l, r) -> { if (r.isPresent()) { l.onResponse(r); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java index 35e11ae40d51a..1623546b1c494 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalyzeAction.java @@ -651,6 +651,7 @@ private Runnable finalRegisterValueVerifier(String registerName, int expectedFin case 0 -> new CheckedConsumer, Exception>() { @Override public void accept(ActionListener listener) { + // All register operations have completed by this point so getRegister is safe getBlobContainer().getRegister(OperationPurpose.REPOSITORY_ANALYSIS, registerName, listener); } diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java index b7bd434194b80..495a0db966343 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutor.java @@ -21,6 +21,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; @@ -113,10 +114,11 @@ public TransformPersistentTasksExecutor( } @Override - public PersistentTasksCustomMetadata.Assignment getAssignment( + protected PersistentTasksCustomMetadata.Assignment doGetAssignment( TransformTaskParams params, Collection candidateNodes, - ClusterState clusterState + ClusterState clusterState, + @Nullable ProjectId projectId ) { /* Note: * diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java index fa509143f9ba9..ec4122b3da7f2 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java @@ -12,10 +12,14 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.project.TestProjectResolvers; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RecoverySource; @@ -83,6 +87,7 @@ public class TransformPersistentTasksExecutorTests extends ESTestCase { private static ThreadPool threadPool; private TransformConfigAutoMigration autoMigration; + private ProjectId projectId; @BeforeClass public static void setUpThreadPool() { @@ -106,13 +111,15 @@ public static void tearDownThreadPool() { } @Before - public void initMocks() { + public void setUp() throws Exception { + super.setUp(); autoMigration = mock(); doAnswer(ans -> { ActionListener listener = ans.getArgument(1); listener.onResponse(ans.getArgument(0)); return null; }).when(autoMigration).migrateAndSave(any(), any()); + projectId = randomUniqueProjectId(); } public void testNodeVersionAssignment() { @@ -124,7 +131,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("current-data-node-with-1-tasks") ); @@ -132,7 +140,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("current-data-node-with-0-tasks-transform-remote-disabled") ); @@ -140,7 +149,8 @@ public void testNodeVersionAssignment() { executor.getAssignment( new TransformTaskParams("new-old-task-id", TransformConfigVersion.V_7_7_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ).getExecutorNode(), equalTo("past-data-node-1") ); @@ -154,7 +164,8 @@ public void testNodeAssignmentProblems() { Assignment assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), List.of(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -173,7 +184,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), List.of(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -189,7 +201,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -205,7 +218,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("dedicated-transform-node")); @@ -218,7 +232,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_8_0_0, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -235,7 +250,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1")); @@ -248,7 +264,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -264,7 +281,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.CURRENT, null, false), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("current-data-node-with-0-tasks-transform-remote-disabled")); @@ -277,7 +295,8 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNull(assignment.getExecutorNode()); assertThat( @@ -299,29 +318,27 @@ public void testNodeAssignmentProblems() { assignment = executor.getAssignment( new TransformTaskParams("new-task-id", TransformConfigVersion.V_7_5_0, null, true), cs.nodes().getAllNodes(), - cs + cs, + projectId ); assertNotNull(assignment.getExecutorNode()); assertThat(assignment.getExecutorNode(), equalTo("past-data-node-1")); } public void testVerifyIndicesPrimaryShardsAreActive() { - Metadata.Builder metadata = Metadata.builder(); + Metadata.Builder metadata = metadataWithProject(); RoutingTable.Builder routingTable = RoutingTable.builder(); addIndices(metadata, routingTable); ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); ClusterState cs = csBuilder.build(); - assertEquals( - 0, - TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(cs, TestIndexNameExpressionResolver.newInstance()).size() - ); + assertEquals(0, TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(cs, indexNameExpressionResolver()).size()); metadata = Metadata.builder(cs.metadata()); - routingTable = new RoutingTable.Builder(cs.routingTable()); + routingTable = new RoutingTable.Builder(cs.routingTable(projectId)); String indexToRemove = TransformInternalIndexConstants.LATEST_INDEX_NAME; if (randomBoolean()) { routingTable.remove(indexToRemove); @@ -342,11 +359,11 @@ public void testVerifyIndicesPrimaryShardsAreActive() { } csBuilder = ClusterState.builder(cs); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); List result = TransformPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive( csBuilder.build(), - TestIndexNameExpressionResolver.newInstance() + indexNameExpressionResolver() ); assertEquals(1, result.size()); assertEquals(indexToRemove, result.get(0)); @@ -441,7 +458,7 @@ private void addIndices(Metadata.Builder metadata, RoutingTable.Builder routingT for (String indexName : indices) { IndexMetadata.Builder indexMetadata = IndexMetadata.builder(indexName); indexMetadata.settings(indexSettings(IndexVersion.current(), 1, 0).put(IndexMetadata.SETTING_INDEX_UUID, "_uuid")); - metadata.put(indexMetadata); + metadata.getProject(projectId).put(indexMetadata); Index index = new Index(indexName, "_uuid"); ShardId shardId = new ShardId(index, 0); ShardRouting shardRouting = ShardRouting.newUnassigned( @@ -556,7 +573,7 @@ private DiscoveryNodes.Builder buildNodes( } private ClusterState buildClusterState(DiscoveryNodes.Builder nodes) { - Metadata.Builder metadata = Metadata.builder().clusterUUID("cluster-uuid"); + Metadata.Builder metadata = metadataWithProject().clusterUUID("cluster-uuid"); RoutingTable.Builder routingTable = RoutingTable.builder(); addIndices(metadata, routingTable); PersistentTasksCustomMetadata.Builder pTasksBuilder = PersistentTasksCustomMetadata.builder() @@ -580,15 +597,19 @@ private ClusterState buildClusterState(DiscoveryNodes.Builder nodes) { ); PersistentTasksCustomMetadata pTasks = pTasksBuilder.build(); - metadata.putCustom(PersistentTasksCustomMetadata.TYPE, pTasks); + metadata.getProject(projectId).putCustom(PersistentTasksCustomMetadata.TYPE, pTasks); ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")).nodes(nodes); - csBuilder.routingTable(routingTable.build()); + csBuilder.putRoutingTable(projectId, routingTable.build()); csBuilder.metadata(metadata); return csBuilder.build(); } + private Metadata.Builder metadataWithProject() { + return Metadata.builder().put(ProjectMetadata.builder(projectId)); + } + private TransformPersistentTasksExecutor buildTaskExecutor() { var transformServices = transformServices( new InMemoryTransformConfigManager(), @@ -622,11 +643,15 @@ private TransformPersistentTasksExecutor buildTaskExecutor(TransformServices tra clusterService(), Settings.EMPTY, new DefaultTransformExtension(), - TestIndexNameExpressionResolver.newInstance(), + indexNameExpressionResolver(), autoMigration ); } + private IndexNameExpressionResolver indexNameExpressionResolver() { + return TestIndexNameExpressionResolver.newInstance(TestProjectResolvers.singleProjectOnly(projectId)); + } + private ClusterService clusterService() { var clusterService = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(Transform.NUM_FAILURE_RETRIES_SETTING)); diff --git a/x-pack/qa/smoke-test-plugins-ssl/build.gradle b/x-pack/qa/smoke-test-plugins-ssl/build.gradle index cbd837fc2ccf6..03e67bdf0dd4b 100644 --- a/x-pack/qa/smoke-test-plugins-ssl/build.gradle +++ b/x-pack/qa/smoke-test-plugins-ssl/build.gradle @@ -83,6 +83,8 @@ testClusters.matching { it.name == "yamlRestTest" }.configureEach { user username: "test_user", password: "x-pack-test-password" user username: "monitoring_agent", password: "x-pack-test-password", role: "remote_monitoring_agent" + systemProperty 'es.queryable_built_in_roles_enabled', 'false' + pluginPaths.each { pluginPath -> plugin pluginPath }