diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 98f596141eed..cc9501583641 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -88,8 +88,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.FailoverChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; @@ -114,6 +116,8 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.auth.MoreCallCredentials; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; @@ -381,7 +385,8 @@ private StreamingWorkerHarnessFactoryOutput createFanOutStreamingEngineWorkerHar MemoryMonitor memoryMonitor, GrpcDispatcherClient dispatcherClient) { WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); - ChannelCache channelCache = createChannelCache(options, checkNotNull(configFetcher)); + ChannelCache channelCache = + createChannelCache(options, checkNotNull(configFetcher), dispatcherClient); @SuppressWarnings("methodref.receiver.bound") FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = FanOutStreamingEngineWorkerHarness.create( @@ -804,20 +809,37 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) } private static ChannelCache createChannelCache( - DataflowWorkerHarnessOptions workerOptions, ComputationConfig.Fetcher configFetcher) { + DataflowWorkerHarnessOptions workerOptions, + ComputationConfig.Fetcher configFetcher, + GrpcDispatcherClient dispatcherClient) { ChannelCache channelCache = ChannelCache.create( (currentFlowControlSettings, serviceAddress) -> { - // IsolationChannel will create and manage separate RPC channels to the same - // serviceAddress. - return IsolationChannel.create( - () -> - remoteChannel( - serviceAddress, - workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), - currentFlowControlSettings), - currentFlowControlSettings.getOnReadyThresholdBytes()); + ManagedChannel primaryChannel = + IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + // Create an isolated fallback channel from dispatcher endpoints. + // This ensures both primary and fallback use separate isolated channels. + ManagedChannel fallbackChannel = + IsolationChannel.create( + () -> + remoteChannel( + dispatcherClient.getDispatcherEndpoints().iterator().next(), + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), + currentFlowControlSettings), + currentFlowControlSettings.getOnReadyThresholdBytes()); + return FailoverChannel.create( + primaryChannel, + fallbackChannel, + MoreCallCredentials.from( + new VendoredCredentialsAdapter(workerOptions.getGcpCredential()))); }); + configFetcher .getGlobalConfigHandle() .registerConfigObserver( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 75c2b91af603..63ab5379bd49 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -410,15 +410,18 @@ private GlobalDataStreamSender getOrCreateGlobalDataSteam( } private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { + GetWorkRequest.Builder getWorkRequestBuilder = + GetWorkRequest.newBuilder() + .setClientId(jobHeader.getClientId()) + .setJobId(jobHeader.getJobId()) + .setProjectId(jobHeader.getProjectId()) + .setWorkerId(jobHeader.getWorkerId()); + endpoint.workerToken().ifPresent(getWorkRequestBuilder::setBackendWorkerToken); + WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( WindmillConnection.from(endpoint, this::createWindmillStub), - GetWorkRequest.newBuilder() - .setClientId(jobHeader.getClientId()) - .setJobId(jobHeader.getJobId()) - .setProjectId(jobHeader.getProjectId()) - .setWorkerId(jobHeader.getWorkerId()) - .build(), + getWorkRequestBuilder.build(), GetWorkBudget.noBudget(), streamFactory, workItemScheduler, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index 82e66c4b0d74..0d8f75dd816a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -128,7 +128,7 @@ public CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { : randomlySelectNextStub(windmillServiceStubs)); } - ImmutableSet getDispatcherEndpoints() { + public ImmutableSet getDispatcherEndpoints() { return dispatcherStubs.get().dispatcherEndpoints(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java new file mode 100644 index 000000000000..ea550f361f3e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannel.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link ManagedChannel} that wraps a primary and a fallback channel. + * + *

Routes requests to either primary or fallback channel based on two independent failover modes: + * + *

    + *
  • Connection Status Failover: If the primary channel is not ready for 10+ seconds + * (e.g., during network issues), routes to fallback channel. Switches back as soon as the + * primary channel becomes READY again. + *
  • RPC Failover: If primary channel RPC fails with transient errors ({@link + * Status.Code#UNAVAILABLE}, {@link Status.Code#DEADLINE_EXCEEDED}, or {@link + * Status.Code#UNKNOWN}), switches to fallback channel and waits for a 1-hour cooling period + * before retrying primary. + *
+ */ +@Internal +public final class FailoverChannel extends ManagedChannel { + private static final Logger LOG = LoggerFactory.getLogger(FailoverChannel.class); + // Time to wait before retrying the primary channel after an RPC-based fallback. + private static final long FALLBACK_COOLING_PERIOD_NANOS = TimeUnit.HOURS.toNanos(1); + private static final long PRIMARY_NOT_READY_WAIT_NANOS = TimeUnit.SECONDS.toNanos(10); + private final ManagedChannel primary; + @Nullable private final ManagedChannel fallback; + @Nullable private final CallCredentials fallbackCallCredentials; + // Set when primary's connection state has been unavailable for too long. + private final AtomicBoolean useFallbackDueToState = new AtomicBoolean(false); + // Set when an RPC on primary fails with a transient error. + private final AtomicBoolean useFallbackDueToRPC = new AtomicBoolean(false); + private final AtomicLong lastRPCFallbackTimeNanos = new AtomicLong(0); + private final AtomicLong primaryNotReadySinceNanos = new AtomicLong(-1); + private final LongSupplier nanoClock; + private final AtomicBoolean stateChangeListenerRegistered = new AtomicBoolean(false); + + private FailoverChannel( + ManagedChannel primary, + @Nullable ManagedChannel fallback, + @Nullable CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + this.primary = primary; + this.fallback = fallback; + this.fallbackCallCredentials = fallbackCallCredentials; + this.nanoClock = nanoClock; + // Register callback to monitor primary channel state changes + registerPrimaryStateChangeListener(); + } + + // Test-only. + public static FailoverChannel create(ManagedChannel primary, ManagedChannel fallback) { + return new FailoverChannel(primary, fallback, null, System::nanoTime); + } + + public static FailoverChannel create( + ManagedChannel primary, ManagedChannel fallback, CallCredentials fallbackCallCredentials) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, System::nanoTime); + } + + static FailoverChannel forTest( + ManagedChannel primary, + ManagedChannel fallback, + CallCredentials fallbackCallCredentials, + LongSupplier nanoClock) { + return new FailoverChannel(primary, fallback, fallbackCallCredentials, nanoClock); + } + + @Override + public String authority() { + return primary.authority(); + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + // Check if the RPC-based cooling period has elapsed. + if (useFallbackDueToRPC.get()) { + long timeSinceLastFallback = nanoClock.getAsLong() - lastRPCFallbackTimeNanos.get(); + if (timeSinceLastFallback >= FALLBACK_COOLING_PERIOD_NANOS) { + if (useFallbackDueToRPC.compareAndSet(true, false)) { + LOG.info("Primary channel cooling period elapsed; switching back from fallback."); + } + } + } + + if (fallback != null && (useFallbackDueToRPC.get() || useFallbackDueToState.get())) { + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + true, + methodDescriptor.getFullMethodName()); + } + + // If primary has not become ready for a sustained period, fail over to fallback. + if (fallback != null && shouldFallBackDueToPrimaryState()) { + if (useFallbackDueToState.compareAndSet(false, true)) { + LOG.warn("Primary connection unavailable. Switching to secondary connection."); + } + return new FailoverClientCall<>( + fallback.newCall(methodDescriptor, applyFallbackCredentials(callOptions)), + true, + methodDescriptor.getFullMethodName()); + } + + return new FailoverClientCall<>( + primary.newCall(methodDescriptor, callOptions), + false, + methodDescriptor.getFullMethodName()); + } + + @Override + public ManagedChannel shutdown() { + primary.shutdown(); + if (fallback != null) { + fallback.shutdown(); + } + return this; + } + + @Override + public ManagedChannel shutdownNow() { + primary.shutdownNow(); + if (fallback != null) { + fallback.shutdownNow(); + } + return this; + } + + @Override + public boolean isShutdown() { + return primary.isShutdown() && (fallback == null || fallback.isShutdown()); + } + + @Override + public boolean isTerminated() { + return primary.isTerminated() && (fallback == null || fallback.isTerminated()); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long endTimeNanos = nanoClock.getAsLong() + unit.toNanos(timeout); + boolean primaryTerminated = primary.awaitTermination(timeout, unit); + if (fallback != null) { + long remainingNanos = Math.max(0, endTimeNanos - nanoClock.getAsLong()); + return primaryTerminated && fallback.awaitTermination(remainingNanos, TimeUnit.NANOSECONDS); + } + return primaryTerminated; + } + + private boolean shouldFallbackBasedOnRPCStatus(Status status) { + switch (status.getCode()) { + case UNAVAILABLE: + case DEADLINE_EXCEEDED: + case UNKNOWN: + return true; + default: + return false; + } + } + + private boolean hasFallbackChannel() { + return fallback != null; + } + + private CallOptions applyFallbackCredentials(CallOptions callOptions) { + if (fallbackCallCredentials != null && callOptions.getCredentials() == null) { + return callOptions.withCallCredentials(fallbackCallCredentials); + } + return callOptions; + } + + private boolean shouldFallBackDueToPrimaryState() { + ConnectivityState connectivityState = primary.getState(true); + if (connectivityState == ConnectivityState.READY) { + primaryNotReadySinceNanos.set(-1); + return false; + } + long currentTimeNanos = nanoClock.getAsLong(); + if (primaryNotReadySinceNanos.get() < 0) { + primaryNotReadySinceNanos.set(currentTimeNanos); + } + return currentTimeNanos - primaryNotReadySinceNanos.get() > PRIMARY_NOT_READY_WAIT_NANOS; + } + + private void notifyFailure(Status status, boolean isFallback, String methodName) { + if (!status.isOk() + && !isFallback + && hasFallbackChannel() + && shouldFallbackBasedOnRPCStatus(status)) { + if (useFallbackDueToRPC.compareAndSet(false, true)) { + lastRPCFallbackTimeNanos.set(nanoClock.getAsLong()); + LOG.warn( + "Primary connection failed for method: {}. Switching to secondary connection. Status: {}", + methodName, + status.getCode()); + } + } else if (isFallback && !status.isOk()) { + LOG.warn( + "Secondary connection failed for method: {}. Status: {}", methodName, status.getCode()); + } + } + + private final class FailoverClientCall + extends SimpleForwardingClientCall { + private final boolean isFallback; + private final String methodName; + + /** + * @param delegate the underlying ClientCall (either primary or fallback) + * @param isFallback true if {@code delegate} is a fallback channel call, false if it is a + * primary channel call. This flag is inspected by {@link #notifyFailure} to determine + * whether a failure should trigger switching to the fallback channel (only primary failures + * do). + * @param methodName full gRPC method name (for logging) + */ + FailoverClientCall(ClientCall delegate, boolean isFallback, String methodName) { + super(delegate); + this.isFallback = isFallback; + this.methodName = methodName; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + super.start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + notifyFailure(status, isFallback, methodName); + super.onClose(status, trailers); + } + }, + headers); + } + } + + /** Registers callback for primary channel state changes. */ + private void registerPrimaryStateChangeListener() { + if (!stateChangeListenerRegistered.getAndSet(true)) { + try { + ConnectivityState currentState = primary.getState(false); + primary.notifyWhenStateChanged(currentState, this::onPrimaryStateChanged); + } catch (Exception e) { + LOG.warn( + "Failed to register channel state monitor. Continuing with fallback detection.", e); + stateChangeListenerRegistered.set(false); + } + } + } + + /** Callback invoked when primary channel connectivity state changes. */ + private void onPrimaryStateChanged() { + if (isShutdown() || isTerminated()) { + return; + } + + // If primary is READY, clear state-based fallback immediately. + if (primary.getState(false) == ConnectivityState.READY) { + if (useFallbackDueToState.compareAndSet(true, false)) { + LOG.info("Primary channel recovered; switching back from fallback."); + } + } + + // Always re-register for next state change (unless shutdown) + if (!isShutdown() && !isTerminated()) { + stateChangeListenerRegistered.set(false); + registerPrimaryStateChangeListener(); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 94c8f4b75957..b5f244f77eb6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -132,7 +132,7 @@ private static WorkItemScheduler noOpProcessWorkItemFn() { getWorkStreamLatencies) -> {}; } - private static GetWorkRequest getWorkRequest(long items, long bytes) { + private static GetWorkRequest getWorkRequest(long items, long bytes, String backendWorkerToken) { return GetWorkRequest.newBuilder() .setJobId(JOB_ID) .setProjectId(PROJECT_ID) @@ -140,6 +140,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) + .setBackendWorkerToken(backendWorkerToken) .build(); } @@ -239,9 +240,22 @@ public void testStreamsStartCorrectly() throws InterruptedException { .distributeBudget( any(), eq(GetWorkBudget.builder().setItems(items).setBytes(bytes).build())); - verify(streamFactory, times(2)) + verify(streamFactory, times(1)) .createDirectGetWorkStream( - any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); + any(), + eq(getWorkRequest(0, 0, workerToken)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); + verify(streamFactory, times(1)) + .createDirectGetWorkStream( + any(), + eq(getWorkRequest(0, 0, workerToken2)), + any(), + any(), + any(), + eq(noOpProcessWorkItemFn())); verify(streamFactory, times(2)).createDirectGetDataStream(any()); verify(streamFactory, times(2)).createDirectCommitWorkStream(any()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java new file mode 100644 index 000000000000..c98004943613 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/FailoverChannelTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallCredentials; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall.Listener; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ConnectivityState; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +@RunWith(JUnit4.class) +public class FailoverChannelTest { + + private MethodDescriptor methodDescriptor = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName(MethodDescriptor.generateFullMethodName("test", "test")) + .setRequestMarshaller(new IsolationChannelTest.NoopMarshaller()) + .setResponseMarshaller(new IsolationChannelTest.NoopMarshaller()) + .build(); + + @Test + public void testRPCFailureTriggersFallback() throws Exception { + // RPC failure with UNAVAILABLE should switch to fallback channel. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + ClientCall fallbackCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + + FailoverChannel failoverChannel = FailoverChannel.create(mockChannel, mockFallbackChannel); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + Metadata metadata1 = new Metadata(); + call1.start(new NoopClientCall.NoopClientCallListener<>(), metadata1); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), same(metadata1)); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + ClientCall call2 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call2.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + } + + @Test + public void testRPCFailureRecoveryAfterCoolingPeriod() throws Exception { + // After RPC failure, channel stays on fallback during cooling period, then returns to primary. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + ClientCall fallbackCall = mock(ClientCall.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall, mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(fallbackCall); + when(mockChannel.getState(true)).thenReturn(ConnectivityState.READY); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Trigger RPC failure fallback + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), any()); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + // Within cooling period: still on fallback + time.addAndGet(TimeUnit.MINUTES.toNanos(30)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel, atLeastOnce()).newCall(any(), any()); + + // After cooling period: recovers to primary + time.addAndGet(TimeUnit.MINUTES.toNanos(40)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } + + @Test + public void testFallbackWithCredentials() throws Exception { + // Fallback channel should receive custom credentials when provided. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + ClientCall underlyingCall = mock(ClientCall.class); + CallCredentials mockCredentials = mock(CallCredentials.class); + when(mockChannel.newCall(any(), any())).thenReturn(underlyingCall); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + + FailoverChannel failoverChannel = + FailoverChannel.create(mockChannel, mockFallbackChannel, mockCredentials); + + ClientCall call1 = + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + call1.start(new NoopClientCall.NoopClientCallListener<>(), new Metadata()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(ClientCall.Listener.class); + verify(underlyingCall).start(captor.capture(), any()); + captor.getValue().onClose(Status.UNAVAILABLE, new Metadata()); + + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + + ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + verify(mockFallbackChannel).newCall(same(methodDescriptor), optionsCaptor.capture()); + assertEquals(mockCredentials, optionsCaptor.getValue().getCredentials()); + } + + @Test + public void testStateFallbackAfterPrimaryNotReady() { + // If primary connection is not ready for 10+ seconds, routes to fallback. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockChannel.getState(true)).thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // Within 10 seconds: still routes to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: routes to fallback + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + } + + @Test + public void testStateBasedFallbackRecoveryViaCallback() { + // After state-based fallback, recovery to primary is immediate when callback fires with READY. + ManagedChannel mockChannel = mock(ManagedChannel.class); + ManagedChannel mockFallbackChannel = mock(ManagedChannel.class); + when(mockChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + when(mockFallbackChannel.newCall(any(), any())).thenReturn(mock(ClientCall.class)); + // getState(true): IDLE starts timer, IDLE exceeds timer, READY on recovery check + when(mockChannel.getState(true)) + .thenReturn(ConnectivityState.IDLE, ConnectivityState.IDLE, ConnectivityState.READY); + // getState(false): IDLE for constructor registration, READY when callback fires + when(mockChannel.getState(false)) + .thenReturn(ConnectivityState.IDLE, ConnectivityState.READY, ConnectivityState.READY); + + AtomicReference stateChangeCallback = new AtomicReference<>(); + doAnswer( + invocation -> { + stateChangeCallback.set(invocation.getArgument(1)); + return null; + }) + .when(mockChannel) + .notifyWhenStateChanged(any(), any()); + + AtomicLong time = new AtomicLong(0); + FailoverChannel failoverChannel = + FailoverChannel.forTest(mockChannel, mockFallbackChannel, null, time::get); + + // First call - primary not yet timed out, routes to primary + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel).newCall(any(), any()); + + // After 10 seconds: state-based fallback kicks in + time.addAndGet(TimeUnit.SECONDS.toNanos(11)); + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockFallbackChannel).newCall(any(), any()); + + // Callback fires with primary now READY: clears state flag immediately + stateChangeCallback.get().run(); + + // Next call recovers to primary with no waiting + failoverChannel.newCall(methodDescriptor, CallOptions.DEFAULT); + verify(mockChannel, atLeast(2)).newCall(any(), any()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java index 20321bbd66c3..580bf873d916 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/IsolationChannelTest.java @@ -49,42 +49,12 @@ import org.mockito.ArgumentCaptor; import org.mockito.InOrder; -/** - * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in - * places where a scriptable call is necessary. By default, all methods are noops, and designed to - * be overridden. - */ -class NoopClientCall extends ClientCall { - - /** - * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It - * is designed to be used in places where a scriptable call listener is necessary. By default, all - * methods are noops, and designed to be overridden. - */ - public static class NoopClientCallListener extends ClientCall.Listener {} - - @Override - public void start(ClientCall.Listener listener, Metadata headers) {} - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(String message, Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} -} - @RunWith(JUnit4.class) public class IsolationChannelTest { private Supplier channelSupplier = mock(Supplier.class); - private static class NoopMarshaller implements Marshaller { + public static class NoopMarshaller implements Marshaller { @Override public InputStream stream(Object o) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java new file mode 100644 index 000000000000..93a421e0d618 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/NoopClientCall.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; + +/** + * {@link NoopClientCall} is a class that is designed for use in tests. It is designed to be used in + * places where a scriptable call is necessary. By default, all methods are noops, and designed to + * be overridden. + */ +public class NoopClientCall extends ClientCall { + + /** + * {@link NoopClientCall.NoopClientCallListener} is a class that is designed for use in tests. It + * is designed to be used in places where a scriptable call listener is necessary. By default, all + * methods are noops, and designed to be overridden. + */ + public static class NoopClientCallListener extends ClientCall.Listener {} + + @Override + public void start(ClientCall.Listener listener, Metadata headers) {} + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} +} diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index a4b3df906dd9..6286b2d67110 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -470,6 +470,8 @@ message GetWorkRequest { optional string project_id = 7; optional int64 max_items = 2 [default = 0xffffffff]; optional int64 max_bytes = 3 [default = 0x7fffffffffffffff]; + repeated string computation_id_filter = 8; + optional string backend_worker_token = 9; reserved 6; }