diff --git a/polaris-core/src/main/java/org/apache/polaris/core/PolarisCallContext.java b/polaris-core/src/main/java/org/apache/polaris/core/PolarisCallContext.java index 8878b9ef6a..187a594d44 100644 --- a/polaris-core/src/main/java/org/apache/polaris/core/PolarisCallContext.java +++ b/polaris-core/src/main/java/org/apache/polaris/core/PolarisCallContext.java @@ -80,17 +80,4 @@ public RealmConfig getRealmConfig() { public PolarisCallContext getPolarisCallContext() { return this; } - - @Override - public PolarisCallContext copy() { - // The realm context is a request scoped bean injected by CDI, - // which will be closed after the http request. This copy is currently - // only used by TaskExecutor right before the task is handled, since the - // task is executed outside the active request scope, we need to make a - // copy of the RealmContext to ensure the access during the task executor. - String realmId = this.realmContext.getRealmIdentifier(); - RealmContext realmContext = () -> realmId; - return new PolarisCallContext( - realmContext, this.metaStore, this.diagServices, this.configurationStore); - } } diff --git a/polaris-core/src/main/java/org/apache/polaris/core/context/CallContext.java b/polaris-core/src/main/java/org/apache/polaris/core/context/CallContext.java index 28ec97ba6c..693ea6b04b 100644 --- a/polaris-core/src/main/java/org/apache/polaris/core/context/CallContext.java +++ b/polaris-core/src/main/java/org/apache/polaris/core/context/CallContext.java @@ -30,9 +30,6 @@ * underlying nature of the persistence layer may differ between different realms. */ public interface CallContext { - /** Copy the {@link CallContext}. */ - CallContext copy(); - RealmContext getRealmContext(); /** diff --git a/runtime/service/src/main/java/org/apache/polaris/service/context/DefaultRealmContextResolver.java b/runtime/service/src/main/java/org/apache/polaris/service/context/DefaultRealmContextResolver.java index 7d779ef85e..977f8917f1 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/context/DefaultRealmContextResolver.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/context/DefaultRealmContextResolver.java @@ -21,10 +21,7 @@ import io.smallrye.common.annotation.Identifier; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; import java.util.function.Function; -import org.apache.polaris.core.context.RealmContext; @ApplicationScoped @Identifier("default") @@ -38,17 +35,8 @@ public DefaultRealmContextResolver(RealmContextConfiguration configuration) { } @Override - public CompletionStage resolveRealmContext( + public String resolveRealmId( String requestURL, String method, String path, Function headers) { - try { - String realm = resolveRealmIdentifier(headers); - return CompletableFuture.completedFuture(() -> realm); - } catch (Exception e) { - return CompletableFuture.failedFuture(e); - } - } - - private String resolveRealmIdentifier(Function headers) { String realm = headers.apply(configuration.headerName()); if (realm != null) { if (!configuration.realms().contains(realm)) { diff --git a/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextResolver.java b/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextResolver.java index 0aea6e9bd1..da42b85909 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextResolver.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/context/RealmContextResolver.java @@ -19,6 +19,7 @@ package org.apache.polaris.service.context; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.Function; import org.apache.commons.collections.map.CaseInsensitiveMap; @@ -43,20 +44,48 @@ public interface RealmContextResolver { /** - * Resolves the realm context for the given request, and returns a {@link CompletionStage} that + * Resolves the realm context for the given realm ID, and returns a {@link CompletionStage} that * completes with the resolved realm context. * * @return a {@link CompletionStage} that completes with the resolved realm context */ - CompletionStage resolveRealmContext( + default CompletionStage resolveRealmContext(String realmId) { + return CompletableFuture.completedFuture(() -> realmId); + } + + /** + * Resolves the realm context for the given HTTP/REST request parameters, and returns a {@link + * CompletionStage} that completes with the resolved realm identifier. + * + * @return a {@link CompletionStage} that completes with the resolved realm identifier + */ + String resolveRealmId( String requestURL, String method, String path, Function headers); + /** + * Resolves the realm context for the given request, and returns a {@link CompletionStage} that + * completes with the resolved realm context. + * + *

This is a convenience function combining {@link #resolveRealmContext(String)} with the + * result of {@link #resolveRealmId(String, String, String, Function)}. + * + * @return a {@link CompletionStage} that completes with the resolved realm context or an + * exception. + * @throws RuntimeException propagated directly from {@link #resolveRealmId(String, String, + * String, Function)} + */ + default CompletionStage resolveRealmContext( + String requestURL, String method, String path, Function headers) { + return resolveRealmContext(resolveRealmId(requestURL, method, path, headers)); + } + /** * Resolves the realm context for the given request, and returns a {@link CompletionStage} that * completes with the resolved realm context. * * @return a {@link CompletionStage} that completes with the resolved realm context */ + @Deprecated(forRemoval = true) // Only used in tests default CompletionStage resolveRealmContext( String requestURL, String method, String path, Map headers) { CaseInsensitiveMap caseInsensitiveMap = new CaseInsensitiveMap(headers); diff --git a/runtime/service/src/main/java/org/apache/polaris/service/context/TestRealmContextResolver.java b/runtime/service/src/main/java/org/apache/polaris/service/context/TestRealmContextResolver.java index 79a9149546..e80c86c6bd 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/context/TestRealmContextResolver.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/context/TestRealmContextResolver.java @@ -24,10 +24,7 @@ import jakarta.inject.Inject; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; import java.util.function.Function; -import org.apache.polaris.core.context.RealmContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,7 +49,7 @@ public TestRealmContextResolver(RealmContextConfiguration configuration) { } @Override - public CompletionStage resolveRealmContext( + public String resolveRealmId( String requestURL, String method, String path, Function headers) { // Since this default resolver is strictly for use in test/dev environments, we'll consider // it safe to log all contents. Any "real" resolver used in a prod environment should make @@ -73,8 +70,7 @@ public CompletionStage resolveRealmContext( configuration.defaultRealm()); parsedProperties.put(REALM_PROPERTY_KEY, configuration.defaultRealm()); } - String realmId = parsedProperties.get(REALM_PROPERTY_KEY); - return CompletableFuture.completedFuture(() -> realmId); + return parsedProperties.get(REALM_PROPERTY_KEY); } /** diff --git a/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java b/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java index 6ee681ead7..2c2e1a352a 100644 --- a/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java +++ b/runtime/service/src/main/java/org/apache/polaris/service/task/TaskExecutorImpl.java @@ -18,6 +18,7 @@ */ package org.apache.polaris.service.task; +import com.google.common.annotations.VisibleForTesting; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; @@ -37,12 +38,17 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import org.apache.polaris.core.PolarisCallContext; +import org.apache.polaris.core.PolarisDiagnostics; import org.apache.polaris.core.context.CallContext; +import org.apache.polaris.core.context.RealmContext; import org.apache.polaris.core.entity.PolarisBaseEntity; import org.apache.polaris.core.entity.PolarisEntityType; import org.apache.polaris.core.entity.TaskEntity; +import org.apache.polaris.core.persistence.BasePersistence; import org.apache.polaris.core.persistence.MetaStoreManagerFactory; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; +import org.apache.polaris.service.context.RealmContextResolver; import org.apache.polaris.service.events.AfterTaskAttemptedEvent; import org.apache.polaris.service.events.BeforeTaskAttemptedEvent; import org.apache.polaris.service.events.PolarisEventListener; @@ -65,11 +71,12 @@ public class TaskExecutorImpl implements TaskExecutor { private final TaskFileIOSupplier fileIOSupplier; private final List taskHandlers = new CopyOnWriteArrayList<>(); private final PolarisEventListener polarisEventListener; + private final RealmContextResolver realmContextResolver; @Nullable private final Tracer tracer; @SuppressWarnings("unused") // Required by CDI protected TaskExecutorImpl() { - this(null, null, null, null, null, null); + this(null, null, null, null, null, null, null); } @Inject @@ -79,12 +86,14 @@ public TaskExecutorImpl( MetaStoreManagerFactory metaStoreManagerFactory, TaskFileIOSupplier fileIOSupplier, PolarisEventListener polarisEventListener, + RealmContextResolver realmContextResolver, @Nullable Tracer tracer) { this.executor = executor; this.clock = clock; this.metaStoreManagerFactory = metaStoreManagerFactory; this.fileIOSupplier = fileIOSupplier; this.polarisEventListener = polarisEventListener; + this.realmContextResolver = realmContextResolver; this.tracer = tracer; } @@ -122,27 +131,58 @@ public void addTaskHandlerContext(long taskEntityId, CallContext callContext) { // the task is still running. // Note: PolarisCallContext has request-scoped beans as well, and must be cloned. // FIXME replace with context propagation? - CallContext clone = callContext.copy(); - tryHandleTask(taskEntityId, clone, null, 1); + tryHandleTask(taskEntityId, new TaskContext(callContext), null, 1); + } + + record TaskContext(String realmId, PolarisDiagnostics diagnostics) { + TaskContext(PolarisCallContext polarisCallContext) { + this( + polarisCallContext.getRealmContext().getRealmIdentifier(), + polarisCallContext.getDiagServices()); + } + + TaskContext(CallContext callContext) { + this(callContext.getPolarisCallContext()); + } + } + + @VisibleForTesting + PolarisCallContext newPolarisCallContext(TaskContext taskContext) { + try { + RealmContext realmContext = + realmContextResolver + .resolveRealmContext(taskContext.realmId()) + .toCompletableFuture() + .get(); + BasePersistence metaStore = metaStoreManagerFactory.getOrCreateSession(realmContext); + return new PolarisCallContext(realmContext, metaStore, taskContext.diagnostics()); + } catch (Exception e) { + LOGGER.error( + "Error while creating PolarisCallContext for task context for realm {}", + taskContext.realmId(), + e); + throw new RuntimeException(e); + } } private @Nonnull CompletableFuture tryHandleTask( - long taskEntityId, CallContext callContext, Throwable e, int attempt) { + long taskEntityId, TaskContext taskContext, Throwable e, int attempt) { if (attempt > 3) { return CompletableFuture.failedFuture(e); } return CompletableFuture.runAsync( - () -> handleTaskWithTracing(taskEntityId, callContext, attempt), executor) + () -> handleTaskWithTracing(taskEntityId, taskContext, attempt), executor) .exceptionallyComposeAsync( (t) -> { LOGGER.warn("Failed to handle task entity id {}", taskEntityId, t); - return tryHandleTask(taskEntityId, callContext, t, attempt + 1); + return tryHandleTask(taskEntityId, taskContext, t, attempt + 1); }, CompletableFuture.delayedExecutor( TASK_RETRY_DELAY * (long) attempt, TimeUnit.MILLISECONDS, executor)); } - protected void handleTask(long taskEntityId, CallContext ctx, int attempt) { + void handleTask(long taskEntityId, TaskContext taskContext, int attempt) { + PolarisCallContext ctx = newPolarisCallContext(taskContext); polarisEventListener.onBeforeTaskAttempted( new BeforeTaskAttemptedEvent(taskEntityId, ctx, attempt)); @@ -192,22 +232,20 @@ protected void handleTask(long taskEntityId, CallContext ctx, int attempt) { } } - protected void handleTaskWithTracing(long taskEntityId, CallContext callContext, int attempt) { + void handleTaskWithTracing(long taskEntityId, TaskContext taskContext, int attempt) { if (tracer == null) { - handleTask(taskEntityId, callContext, attempt); + handleTask(taskEntityId, taskContext, attempt); } else { Span span = tracer .spanBuilder("polaris.task") .setParent(Context.current()) - .setAttribute( - TracingFilter.REALM_ID_ATTRIBUTE, - callContext.getRealmContext().getRealmIdentifier()) + .setAttribute(TracingFilter.REALM_ID_ATTRIBUTE, taskContext.realmId()) .setAttribute("polaris.task.entity.id", taskEntityId) .setAttribute("polaris.task.attempt", attempt) .startSpan(); try (Scope ignored = span.makeCurrent()) { - handleTask(taskEntityId, callContext, attempt); + handleTask(taskEntityId, taskContext, attempt); } finally { span.end(); } diff --git a/runtime/service/src/test/java/org/apache/polaris/service/context/DefaultRealmIdResolverTest.java b/runtime/service/src/test/java/org/apache/polaris/service/context/DefaultRealmIdResolverTest.java index e7d1ab70db..e68a54bc98 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/context/DefaultRealmIdResolverTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/context/DefaultRealmIdResolverTest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Function; import org.apache.polaris.core.context.RealmContext; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,18 +42,22 @@ void setUp() { when(config.defaultRealm()).thenCallRealMethod(); } + static Function realmHeader(String realmId) { + return (key) -> "Polaris-Header".equalsIgnoreCase(key) ? realmId : null; + } + @Test void headerPresentSuccess() { DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config); RealmContext RealmContext1 = resolver - .resolveRealmContext("requestURL", "method", "path", Map.of("Polaris-Header", "realm1")) + .resolveRealmContext("requestURL", "method", "path", realmHeader("realm1")) .toCompletableFuture() .join(); assertThat(RealmContext1.getRealmIdentifier()).isEqualTo("realm1"); RealmContext RealmContext2 = resolver - .resolveRealmContext("requestURL", "method", "path", Map.of("Polaris-Header", "realm2")) + .resolveRealmContext("requestURL", "method", "path", realmHeader("realm2")) .toCompletableFuture() .join(); assertThat(RealmContext2.getRealmIdentifier()).isEqualTo("realm2"); @@ -64,11 +69,9 @@ void headerPresentFailure() { assertThatThrownBy( () -> resolver - .resolveRealmContext( - "requestURL", "method", "path", Map.of("Polaris-Header", "realm3")) + .resolveRealmContext("requestURL", "method", "path", realmHeader("realm3")) .toCompletableFuture() .join()) - .rootCause() .isInstanceOf(IllegalArgumentException.class) .hasMessage("Unknown realm: realm3"); } @@ -79,7 +82,7 @@ void headerNotPresentSuccess() { DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config); RealmContext RealmContext1 = resolver - .resolveRealmContext("requestURL", "method", "path", Map.of()) + .resolveRealmContext("requestURL", "method", "path", k -> null) .toCompletableFuture() .join(); assertThat(RealmContext1.getRealmIdentifier()).isEqualTo("realm1"); @@ -92,14 +95,15 @@ void headerNotPresentFailure() { assertThatThrownBy( () -> resolver - .resolveRealmContext("requestURL", "method", "path", Map.of()) + .resolveRealmContext("requestURL", "method", "path", k -> null) .toCompletableFuture() .join()) - .rootCause() .isInstanceOf(IllegalArgumentException.class) .hasMessage("Missing required realm header: Polaris-Header"); } + @SuppressWarnings("removal") + @Deprecated(forRemoval = true) @Test void headerCaseInsensitive() { DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config); diff --git a/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java index 03f9c88a3d..c3432a6695 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/task/TaskExecutorImplTest.java @@ -18,6 +18,9 @@ */ package org.apache.polaris.service.task; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicReference; import org.apache.polaris.core.PolarisCallContext; import org.apache.polaris.core.context.CallContext; import org.apache.polaris.core.context.RealmContext; @@ -61,6 +64,7 @@ void testEventsAreEmitted() { int attempt = 1; + AtomicReference newCallContext = new AtomicReference<>(); TaskExecutorImpl executor = new TaskExecutorImpl( Runnable::run, @@ -68,7 +72,17 @@ void testEventsAreEmitted() { testServices.metaStoreManagerFactory(), new TaskFileIOSupplier(testServices.fileIOFactory()), testServices.polarisEventListener(), - null); + (requestURL, method, path, headers) -> { + throw new UnsupportedOperationException("mustn't be called in this test"); + }, + null) { + @Override + PolarisCallContext newPolarisCallContext(TaskContext taskContext) { + var ctx = super.newPolarisCallContext(taskContext); + newCallContext.set(ctx); + return ctx; + } + }; executor.addTaskHandler( new TaskHandler() { @@ -88,11 +102,12 @@ public boolean handleTask(TaskEntity task, CallContext callContext) { } }); - executor.handleTask(taskEntity.getId(), polarisCallCtx, attempt); + executor.handleTask( + taskEntity.getId(), new TaskExecutorImpl.TaskContext(polarisCallCtx), attempt); var afterAttemptTaskEvent = testPolarisEventListener.getLatest(AfterTaskAttemptedEvent.class); Assertions.assertEquals(taskEntity.getId(), afterAttemptTaskEvent.taskEntityId()); - Assertions.assertEquals(polarisCallCtx, afterAttemptTaskEvent.callContext()); + assertThat(newCallContext.get()).isNotNull().isSameAs(afterAttemptTaskEvent.callContext()); Assertions.assertEquals(attempt, afterAttemptTaskEvent.attempt()); Assertions.assertTrue(afterAttemptTaskEvent.success()); } diff --git a/runtime/service/src/test/java/org/apache/polaris/service/test/PolarisIntegrationTestFixture.java b/runtime/service/src/test/java/org/apache/polaris/service/test/PolarisIntegrationTestFixture.java index eb5dd948dc..c317b49e42 100644 --- a/runtime/service/src/test/java/org/apache/polaris/service/test/PolarisIntegrationTestFixture.java +++ b/runtime/service/src/test/java/org/apache/polaris/service/test/PolarisIntegrationTestFixture.java @@ -102,7 +102,11 @@ private PolarisPrincipalSecrets fetchAdminSecrets() { RealmContext realmContext = helper .realmContextResolver - .resolveRealmContext(baseUri.toString(), "GET", "/", Map.of(REALM_PROPERTY_KEY, realm)) + .resolveRealmContext( + baseUri.toString(), + "GET", + "/", + k -> REALM_PROPERTY_KEY.equalsIgnoreCase(k) ? realm : null) .toCompletableFuture() .join();