Skip to content

Commit 89f0c2f

Browse files
ishant162noopurintelpayalcha
authored
[WIP][Workflow Interface] Resolving "Response code: StatusCode.RESOURCE_EXHAUSTED" (Issue #1565) in FederatedRuntime (#1572)
* grpc resource issue partial fix Signed-off-by: Ishant Thakare <[email protected]> * update director.proto Signed-off-by: Ishant Thakare <[email protected]> * fix rpc get_flow_state Signed-off-by: Ishant Thakare <[email protected]> * code cleanup Signed-off-by: Ishant Thakare <[email protected]> * code cleanup Signed-off-by: Ishant Thakare <[email protected]> * copyright changes Signed-off-by: Ishant Thakare <[email protected]> * update checkpoint rpc Signed-off-by: Ishant Thakare <[email protected]> * incorporated review comments Signed-off-by: Ishant Thakare <[email protected]> --------- Signed-off-by: Ishant Thakare <[email protected]> Co-authored-by: Noopur <[email protected]> Co-authored-by: Payal Chaurasiya <[email protected]>
1 parent d8568e1 commit 89f0c2f

File tree

6 files changed

+45
-32
lines changed

6 files changed

+45
-32
lines changed

openfl/experimental/workflow/protocols/aggregator.proto

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2020-2023 Intel Corporation
1+
// Copyright (C) 2020-2025 Intel Corporation
22
// Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.
33

44
syntax = "proto3";
@@ -9,9 +9,9 @@ import "openfl/protocols/base.proto";
99

1010

1111
service Aggregator {
12-
rpc SendTaskResults(TaskResultsRequest) returns (TaskResultsResponse) {}
13-
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {}
14-
rpc CallCheckpoint(CheckpointRequest) returns (CheckpointResponse) {}
12+
rpc SendTaskResults(stream DataStream) returns (TaskResultsResponse) {}
13+
rpc GetTasks(GetTasksRequest) returns (stream DataStream) {}
14+
rpc CallCheckpoint(stream DataStream) returns (CheckpointResponse) {}
1515
}
1616

1717
message MessageHeader {

openfl/experimental/workflow/protocols/director.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
// Copyright 2020-2024 Intel Corporation
1+
// Copyright 2020-2025 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33

44
syntax = "proto3";
55

66
package openfl.experimental.workflow.director;
77

8+
import "openfl/protocols/base.proto";
89
import "google/protobuf/timestamp.proto";
910
import "google/protobuf/duration.proto";
1011

@@ -18,7 +19,7 @@ service Director {
1819
//Runtime RPCs
1920
rpc SetNewExperiment(stream ExperimentInfo) returns (SetNewExperimentResponse) {}
2021
rpc GetEnvoys(GetEnvoysRequest) returns (GetEnvoysResponse) {}
21-
rpc GetFlowState(GetFlowStateRequest) returns (GetFlowStateResponse) {}
22+
rpc GetFlowState(GetFlowStateRequest) returns (stream DataStream) {}
2223
rpc ConnectRuntime(SendRuntimeRequest) returns (RuntimeRequestResponse) {}
2324
rpc GetExperimentStdout(GetExperimentStdoutRequest) returns (stream GetExperimentStdoutResponse) {}
2425
}

openfl/experimental/workflow/transport/grpc/aggregator_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020-2024 Intel Corporation
1+
# Copyright 2020-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44

@@ -12,6 +12,7 @@
1212

1313
from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc
1414
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
15+
from openfl.protocols.utils import datastream_to_proto, proto_to_datastream
1516

1617

1718
class ConstantBackoff:
@@ -280,7 +281,7 @@ def send_task_results(self, collaborator_name, round_number, next_step, clone_by
280281
execution_environment=clone_bytes,
281282
)
282283

283-
response = self.stub.SendTaskResults(request)
284+
response = self.stub.SendTaskResults(proto_to_datastream(request))
284285
self.validate_response(response, collaborator_name)
285286

286287
return response.header
@@ -291,8 +292,8 @@ def get_tasks(self, collaborator_name):
291292
"""Get tasks from the aggregator."""
292293
self._set_header(collaborator_name)
293294
request = aggregator_pb2.GetTasksRequest(header=self.header)
294-
295-
response = self.stub.GetTasks(request)
295+
response_stream = self.stub.GetTasks(request)
296+
response = datastream_to_proto(aggregator_pb2.GetTasksResponse(), response_stream)
296297
self.validate_response(response, collaborator_name)
297298

298299
return (
@@ -316,7 +317,7 @@ def call_checkpoint(self, collaborator_name, clone_bytes, function, stream_buffe
316317
stream_buffer=stream_buffer,
317318
)
318319

319-
response = self.stub.CallCheckpoint(request)
320+
response = self.stub.CallCheckpoint(proto_to_datastream(request))
320321
self.validate_response(response, collaborator_name)
321322

322323
return response.header

openfl/experimental/workflow/transport/grpc/aggregator_server.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020-2024 Intel Corporation
1+
# Copyright 2020-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44

@@ -14,6 +14,7 @@
1414

1515
from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc
1616
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
17+
from openfl.protocols.utils import datastream_to_proto, proto_to_datastream
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -128,18 +129,20 @@ def check_request(self, request):
128129
)
129130

130131
def SendTaskResults(self, request, context): # NOQA:N802
131-
"""<FIND OUT WHAT COMMENT TO PUT HERE>.
132+
"""Processes a request from a collaborator to retrieve the results of a locally
133+
executed task.
132134
133135
Args:
134136
request: The gRPC message request
135137
context: The gRPC context
136138
"""
137-
self.validate_collaborator(request, context)
138-
self.check_request(request)
139-
collaborator_name = request.header.sender
140-
round_number = (request.round_number,)
141-
next_step = (request.next_step,)
142-
execution_environment = request.execution_environment
139+
proto = datastream_to_proto(aggregator_pb2.TaskResultsRequest(), request)
140+
self.validate_collaborator(proto, context)
141+
self.check_request(proto)
142+
collaborator_name = proto.header.sender
143+
round_number = (proto.round_number,)
144+
next_step = (proto.next_step,)
145+
execution_environment = proto.execution_environment
143146

144147
_ = self.aggregator.send_task_results(
145148
collaborator_name, round_number[0], next_step, execution_environment
@@ -160,14 +163,15 @@ def GetTasks(self, request, context): # NOQA:N802
160163

161164
rn, f, ee, st, q = self.aggregator.get_tasks(request.header.sender)
162165

163-
return aggregator_pb2.GetTasksResponse(
166+
response = aggregator_pb2.GetTasksResponse(
164167
header=self.get_header(collaborator_name),
165168
round_number=rn,
166169
function_name=f,
167170
execution_environment=ee,
168171
sleep_time=st,
169172
quit=q,
170173
)
174+
return proto_to_datastream(response)
171175

172176
def CallCheckpoint(self, request, context): # NOQA:N802
173177
"""Request aggregator to perform a checkpoint for a given function.
@@ -176,12 +180,13 @@ def CallCheckpoint(self, request, context): # NOQA:N802
176180
request: The gRPC message request
177181
context: The gRPC context
178182
"""
179-
self.validate_collaborator(request, context)
180-
self.check_request(request)
181-
collaborator_name = request.header.sender
182-
execution_environment = request.execution_environment
183-
function = request.function
184-
stream_buffer = request.stream_buffer
183+
proto = datastream_to_proto(aggregator_pb2.CheckpointRequest(), request)
184+
self.validate_collaborator(proto, context)
185+
self.check_request(proto)
186+
collaborator_name = proto.header.sender
187+
execution_environment = proto.execution_environment
188+
function = proto.function
189+
stream_buffer = proto.stream_buffer
185190

186191
self.aggregator.call_checkpoint(
187192
collaborator_name, execution_environment, function, stream_buffer

openfl/experimental/workflow/transport/grpc/director_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020-2024 Intel Corporation
1+
# Copyright 2020-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""DirectorClient module."""
@@ -12,6 +12,7 @@
1212

1313
from openfl.experimental.workflow.protocols import director_pb2, director_pb2_grpc
1414
from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError
15+
from openfl.protocols.utils import datastream_to_proto
1516

1617
from .grpc_channel_options import channel_options
1718

@@ -313,8 +314,8 @@ def get_flow_state(self) -> Tuple:
313314
- flspec_obj (object): The FLSpec object containing
314315
details of the updated flow state.
315316
"""
316-
response = self.stub.GetFlowState(director_pb2.GetFlowStateRequest())
317-
317+
response_stream = self.stub.GetFlowState(director_pb2.GetFlowStateRequest())
318+
response = datastream_to_proto(director_pb2.GetFlowStateResponse(), response_stream)
318319
return response.completed, response.flspec_obj
319320

320321
def stream_experiment_stdout(self, experiment_name) -> Iterator[Dict[str, Any]]:

openfl/experimental/workflow/transport/grpc/director_server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020-2024 Intel Corporation
1+
# Copyright 2020-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""DirectorGRPCServer module."""
@@ -15,7 +15,7 @@
1515
from openfl.experimental.workflow.protocols import director_pb2, director_pb2_grpc
1616
from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError
1717
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
18-
from openfl.protocols.utils import get_headers
18+
from openfl.protocols.utils import get_headers, proto_to_datastream
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -326,7 +326,12 @@ async def GetFlowState(self, request, context) -> director_pb2.GetFlowStateRespo
326326
director_pb2.GetFlowStateResponse: The response to the request.
327327
"""
328328
status, flspec_obj = await self.director.get_flow_state()
329-
return director_pb2.GetFlowStateResponse(completed=status, flspec_obj=flspec_obj)
329+
response = director_pb2.GetFlowStateResponse(
330+
completed=status,
331+
flspec_obj=flspec_obj,
332+
)
333+
for chunk in proto_to_datastream(response):
334+
await context.write(chunk)
330335

331336
async def GetExperimentStdout(
332337
self, request, context

0 commit comments

Comments
 (0)