Skip to content

Remove CallContext.copy() #2294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,4 @@ public RealmConfig getRealmConfig() {
public PolarisCallContext getPolarisCallContext() {
return this;
}

@Override
public PolarisCallContext copy() {
// The realm context is a request scoped bean injected by CDI,
// which will be closed after the http request. This copy is currently
// only used by TaskExecutor right before the task is handled, since the
// task is executed outside the active request scope, we need to make a
// copy of the RealmContext to ensure the access during the task executor.
String realmId = this.realmContext.getRealmIdentifier();
RealmContext realmContext = () -> realmId;
return new PolarisCallContext(
realmContext, this.metaStore, this.diagServices, this.configurationStore);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
* underlying nature of the persistence layer may differ between different realms.
*/
public interface CallContext {
/** Copy the {@link CallContext}. */
CallContext copy();

RealmContext getRealmContext();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
import io.smallrye.common.annotation.Identifier;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import org.apache.polaris.core.context.RealmContext;

@ApplicationScoped
@Identifier("default")
Expand All @@ -38,17 +35,8 @@ public DefaultRealmContextResolver(RealmContextConfiguration configuration) {
}

@Override
public CompletionStage<RealmContext> resolveRealmContext(
public String resolveRealmId(
String requestURL, String method, String path, Function<String, String> headers) {
try {
String realm = resolveRealmIdentifier(headers);
return CompletableFuture.completedFuture(() -> realm);
} catch (Exception e) {
return CompletableFuture.failedFuture(e);
}
}

private String resolveRealmIdentifier(Function<String, String> headers) {
String realm = headers.apply(configuration.headerName());
if (realm != null) {
if (!configuration.realms().contains(realm)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.polaris.service.context;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import org.apache.commons.collections.map.CaseInsensitiveMap;
Expand All @@ -43,20 +44,48 @@
public interface RealmContextResolver {

/**
* Resolves the realm context for the given request, and returns a {@link CompletionStage} that
* Resolves the realm context for the given realm ID, and returns a {@link CompletionStage} that
* completes with the resolved realm context.
*
* @return a {@link CompletionStage} that completes with the resolved realm context
*/
CompletionStage<RealmContext> resolveRealmContext(
default CompletionStage<RealmContext> resolveRealmContext(String realmId) {
return CompletableFuture.completedFuture(() -> realmId);
}

/**
* Resolves the realm context for the given HTTP/REST request parameters, and returns a {@link
* CompletionStage} that completes with the resolved realm identifier.
*
* @return a {@link CompletionStage} that completes with the resolved realm identifier
*/
String resolveRealmId(
String requestURL, String method, String path, Function<String, String> headers);

/**
* Resolves the realm context for the given request, and returns a {@link CompletionStage} that
* completes with the resolved realm context.
*
* <p>This is a convenience function combining {@link #resolveRealmContext(String)} with the
* result of {@link #resolveRealmId(String, String, String, Function)}.
*
* @return a {@link CompletionStage} that completes with the resolved realm context or an
* exception.
* @throws RuntimeException propagated directly from {@link #resolveRealmId(String, String,
* String, Function)}
*/
default CompletionStage<RealmContext> resolveRealmContext(
String requestURL, String method, String path, Function<String, String> headers) {
return resolveRealmContext(resolveRealmId(requestURL, method, path, headers));
}

/**
* Resolves the realm context for the given request, and returns a {@link CompletionStage} that
* completes with the resolved realm context.
*
* @return a {@link CompletionStage} that completes with the resolved realm context
*/
@Deprecated(forRemoval = true) // Only used in tests
default CompletionStage<RealmContext> resolveRealmContext(
String requestURL, String method, String path, Map<String, String> headers) {
CaseInsensitiveMap caseInsensitiveMap = new CaseInsensitiveMap(headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
import jakarta.inject.Inject;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import org.apache.polaris.core.context.RealmContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -52,7 +49,7 @@ public TestRealmContextResolver(RealmContextConfiguration configuration) {
}

@Override
public CompletionStage<RealmContext> resolveRealmContext(
public String resolveRealmId(
String requestURL, String method, String path, Function<String, String> headers) {
// Since this default resolver is strictly for use in test/dev environments, we'll consider
// it safe to log all contents. Any "real" resolver used in a prod environment should make
Expand All @@ -73,8 +70,7 @@ public CompletionStage<RealmContext> resolveRealmContext(
configuration.defaultRealm());
parsedProperties.put(REALM_PROPERTY_KEY, configuration.defaultRealm());
}
String realmId = parsedProperties.get(REALM_PROPERTY_KEY);
return CompletableFuture.completedFuture(() -> realmId);
return parsedProperties.get(REALM_PROPERTY_KEY);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.polaris.service.task;

import com.google.common.annotations.VisibleForTesting;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
Expand All @@ -37,12 +38,17 @@
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.polaris.core.PolarisCallContext;
import org.apache.polaris.core.PolarisDiagnostics;
import org.apache.polaris.core.context.CallContext;
import org.apache.polaris.core.context.RealmContext;
import org.apache.polaris.core.entity.PolarisBaseEntity;
import org.apache.polaris.core.entity.PolarisEntityType;
import org.apache.polaris.core.entity.TaskEntity;
import org.apache.polaris.core.persistence.BasePersistence;
import org.apache.polaris.core.persistence.MetaStoreManagerFactory;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;
import org.apache.polaris.service.context.RealmContextResolver;
import org.apache.polaris.service.events.AfterTaskAttemptedEvent;
import org.apache.polaris.service.events.BeforeTaskAttemptedEvent;
import org.apache.polaris.service.events.PolarisEventListener;
Expand All @@ -65,11 +71,12 @@ public class TaskExecutorImpl implements TaskExecutor {
private final TaskFileIOSupplier fileIOSupplier;
private final List<TaskHandler> taskHandlers = new CopyOnWriteArrayList<>();
private final PolarisEventListener polarisEventListener;
private final RealmContextResolver realmContextResolver;
@Nullable private final Tracer tracer;

@SuppressWarnings("unused") // Required by CDI
protected TaskExecutorImpl() {
this(null, null, null, null, null, null);
this(null, null, null, null, null, null, null);
}

@Inject
Expand All @@ -79,12 +86,14 @@ public TaskExecutorImpl(
MetaStoreManagerFactory metaStoreManagerFactory,
TaskFileIOSupplier fileIOSupplier,
PolarisEventListener polarisEventListener,
RealmContextResolver realmContextResolver,
@Nullable Tracer tracer) {
this.executor = executor;
this.clock = clock;
this.metaStoreManagerFactory = metaStoreManagerFactory;
this.fileIOSupplier = fileIOSupplier;
this.polarisEventListener = polarisEventListener;
this.realmContextResolver = realmContextResolver;
this.tracer = tracer;
}

Expand Down Expand Up @@ -122,27 +131,58 @@ public void addTaskHandlerContext(long taskEntityId, CallContext callContext) {
// the task is still running.
// Note: PolarisCallContext has request-scoped beans as well, and must be cloned.
// FIXME replace with context propagation?
CallContext clone = callContext.copy();
tryHandleTask(taskEntityId, clone, null, 1);
tryHandleTask(taskEntityId, new TaskContext(callContext), null, 1);
}

record TaskContext(String realmId, PolarisDiagnostics diagnostics) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we done some testing (at least manually to verify the task is still executing correctly)? as far as I know, we currently do not have a good integration test that actually helps verify the background task. The regtests t_pyspark/test_spark_sql_s3_with_privileges.py contains a tests that helps verify the background purge task, but not running in the CI today.

Can we follow up the readme here https://github.com/apache/polaris/blob/main/regtests/README.md to run the test against aws to verify things are still working?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @gh-yzou 's comment made me think that we probably need to start a new request context for the task via @ActivateRequestContext as in #1817 and also block context propagation on the task thread pool. That should give us proper CDI context isolation. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CDI context propagation does not work across different JVMs for the tasks-proposals.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we done some testing

I mean, sure? There are tests in the code base for this.

The regtests t_pyspark/test_spark_sql_s3_with_privileges.py ... not running in the CI today.

I'd suggest to fix that independently. Mind taking a stab on that, @gh-yzou ? MinIO, Azurite and Google emulator help there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, sure? There are tests in the code base for this.

Our task tests are not very complete right now, so this doesn't give me much confidence.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t_pyspark/test_spark_sql_s3_with_privileges.py can only be run locally with your avaiable S3 account, @jbonofre is working on getting an S3 account for cli testing.

However, I am not sure what would be the extra benefit this approach would give us. The copy() method is used to propagate the whole call Context and RealmContext to the task execution, and the id look up seems also only used by task execution and requires an extra re-construction step. Furthermore, this approach doesn't seem scalable in the future if we are adding more information to callContext or RealmContext (such as user specific information), especially for RealmContext, it may require us to implement a RealmContextManager for lookup.

If the concern is that this function is only used by task executor, i recall @dimas-b was POC something that leverages CDI feature to propagate the CallContext to background thread, which seems a cleaner way to do that. @dimas-b maybe we can resume that work?

TaskContext(PolarisCallContext polarisCallContext) {
this(
polarisCallContext.getRealmContext().getRealmIdentifier(),
polarisCallContext.getDiagServices());
}

TaskContext(CallContext callContext) {
this(callContext.getPolarisCallContext());
}
}

@VisibleForTesting
PolarisCallContext newPolarisCallContext(TaskContext taskContext) {
try {
RealmContext realmContext =
realmContextResolver
.resolveRealmContext(taskContext.realmId())
.toCompletableFuture()
.get();
BasePersistence metaStore = metaStoreManagerFactory.getOrCreateSession(realmContext);
return new PolarisCallContext(realmContext, metaStore, taskContext.diagnostics());
} catch (Exception e) {
LOGGER.error(
"Error while creating PolarisCallContext for task context for realm {}",
taskContext.realmId(),
e);
throw new RuntimeException(e);
}
}

private @Nonnull CompletableFuture<Void> tryHandleTask(
long taskEntityId, CallContext callContext, Throwable e, int attempt) {
long taskEntityId, TaskContext taskContext, Throwable e, int attempt) {
if (attempt > 3) {
return CompletableFuture.failedFuture(e);
}
return CompletableFuture.runAsync(
() -> handleTaskWithTracing(taskEntityId, callContext, attempt), executor)
() -> handleTaskWithTracing(taskEntityId, taskContext, attempt), executor)
.exceptionallyComposeAsync(
(t) -> {
LOGGER.warn("Failed to handle task entity id {}", taskEntityId, t);
return tryHandleTask(taskEntityId, callContext, t, attempt + 1);
return tryHandleTask(taskEntityId, taskContext, t, attempt + 1);
},
CompletableFuture.delayedExecutor(
TASK_RETRY_DELAY * (long) attempt, TimeUnit.MILLISECONDS, executor));
}

protected void handleTask(long taskEntityId, CallContext ctx, int attempt) {
void handleTask(long taskEntityId, TaskContext taskContext, int attempt) {
PolarisCallContext ctx = newPolarisCallContext(taskContext);
polarisEventListener.onBeforeTaskAttempted(
new BeforeTaskAttemptedEvent(taskEntityId, ctx, attempt));

Expand Down Expand Up @@ -192,22 +232,20 @@ protected void handleTask(long taskEntityId, CallContext ctx, int attempt) {
}
}

protected void handleTaskWithTracing(long taskEntityId, CallContext callContext, int attempt) {
void handleTaskWithTracing(long taskEntityId, TaskContext taskContext, int attempt) {
if (tracer == null) {
handleTask(taskEntityId, callContext, attempt);
handleTask(taskEntityId, taskContext, attempt);
} else {
Span span =
tracer
.spanBuilder("polaris.task")
.setParent(Context.current())
.setAttribute(
TracingFilter.REALM_ID_ATTRIBUTE,
callContext.getRealmContext().getRealmIdentifier())
.setAttribute(TracingFilter.REALM_ID_ATTRIBUTE, taskContext.realmId())
.setAttribute("polaris.task.entity.id", taskEntityId)
.setAttribute("polaris.task.attempt", attempt)
.startSpan();
try (Scope ignored = span.makeCurrent()) {
handleTask(taskEntityId, callContext, attempt);
handleTask(taskEntityId, taskContext, attempt);
} finally {
span.end();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.polaris.core.context.RealmContext;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -41,18 +42,22 @@ void setUp() {
when(config.defaultRealm()).thenCallRealMethod();
}

static Function<String, String> realmHeader(String realmId) {
return (key) -> "Polaris-Header".equalsIgnoreCase(key) ? realmId : null;
}

@Test
void headerPresentSuccess() {
DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config);
RealmContext RealmContext1 =
resolver
.resolveRealmContext("requestURL", "method", "path", Map.of("Polaris-Header", "realm1"))
.resolveRealmContext("requestURL", "method", "path", realmHeader("realm1"))
.toCompletableFuture()
.join();
assertThat(RealmContext1.getRealmIdentifier()).isEqualTo("realm1");
RealmContext RealmContext2 =
resolver
.resolveRealmContext("requestURL", "method", "path", Map.of("Polaris-Header", "realm2"))
.resolveRealmContext("requestURL", "method", "path", realmHeader("realm2"))
.toCompletableFuture()
.join();
assertThat(RealmContext2.getRealmIdentifier()).isEqualTo("realm2");
Expand All @@ -64,11 +69,9 @@ void headerPresentFailure() {
assertThatThrownBy(
() ->
resolver
.resolveRealmContext(
"requestURL", "method", "path", Map.of("Polaris-Header", "realm3"))
.resolveRealmContext("requestURL", "method", "path", realmHeader("realm3"))
.toCompletableFuture()
.join())
.rootCause()
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Unknown realm: realm3");
}
Expand All @@ -79,7 +82,7 @@ void headerNotPresentSuccess() {
DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config);
RealmContext RealmContext1 =
resolver
.resolveRealmContext("requestURL", "method", "path", Map.of())
.resolveRealmContext("requestURL", "method", "path", k -> null)
.toCompletableFuture()
.join();
assertThat(RealmContext1.getRealmIdentifier()).isEqualTo("realm1");
Expand All @@ -92,14 +95,15 @@ void headerNotPresentFailure() {
assertThatThrownBy(
() ->
resolver
.resolveRealmContext("requestURL", "method", "path", Map.of())
.resolveRealmContext("requestURL", "method", "path", k -> null)
.toCompletableFuture()
.join())
.rootCause()
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Missing required realm header: Polaris-Header");
}

@SuppressWarnings("removal")
@Deprecated(forRemoval = true)
@Test
void headerCaseInsensitive() {
DefaultRealmContextResolver resolver = new DefaultRealmContextResolver(config);
Expand Down
Loading