Skip to content

Commit 8e17f2b

Browse files
authored
Implement job limiter for GCP Batch (#5158)
It implements the a job limiter for the GCP Batch adapter for remote tasks. It uses a private API for checking the availability of the regions for scheduling jobs, if all of them are loaded, the tasks are returned as unscheduled tasks and sent back to the queue. Signed-off-by: Javan Lacerda <javanlacerda@google.com>
1 parent 7375211 commit 8e17f2b

File tree

4 files changed

+294
-22
lines changed

4 files changed

+294
-22
lines changed

configs/test/batch/batch.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ mapping:
7676
name: east4-network2
7777
weight: 1
7878
project: 'test-clusterfuzz'
79+
queue_check_regions:
80+
- us-central1
81+
- us-east4
7982
subconfigs:
8083
central1-network1:
8184
region: 'us-central1'

src/clusterfuzz/_internal/base/memoize.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818
import json
1919
import threading
20+
import time
2021

2122
from clusterfuzz._internal.base import persistent_cache
2223
from clusterfuzz._internal.metrics import logs
@@ -89,6 +90,30 @@ def get_key(self, func, args, kwargs):
8990
return _default_key(func, args, kwargs)
9091

9192

93+
class InMemory(FifoInMemory):
94+
"""In-memory caching engine with TTL."""
95+
96+
def __init__(self, ttl_in_seconds, capacity=1000):
97+
super().__init__(capacity)
98+
self.ttl_in_seconds = ttl_in_seconds
99+
100+
def put(self, key, value):
101+
"""Put (key, value) into cache."""
102+
super().put(key, (value, time.time() + self.ttl_in_seconds))
103+
104+
def get(self, key):
105+
"""Get the value from cache."""
106+
entry = super().get(key)
107+
if entry is None:
108+
return None
109+
110+
value, expiry = entry
111+
if expiry < time.time():
112+
return None
113+
114+
return value
115+
116+
92117
class FifoOnDisk:
93118
"""On-disk caching engine."""
94119

src/clusterfuzz/_internal/batch/service.py

Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright 2025 Google LLC
22
#
3-
# Licensed under the Apache License, Version 2.0
4-
# (the "License"); you may not use this file except in compliance with
5-
# the License. You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
@@ -18,14 +18,19 @@
1818
and provides a simple interface for scheduling ClusterFuzz tasks.
1919
"""
2020
import collections
21+
import json
22+
import random
2123
import threading
2224
from typing import Dict
2325
from typing import List
2426
from typing import Tuple
27+
import urllib.request
2528
import uuid
2629

30+
import google.auth.transport.requests
2731
from google.cloud import batch_v1 as batch
2832

33+
from clusterfuzz._internal.base import memoize
2934
from clusterfuzz._internal.base import retry
3035
from clusterfuzz._internal.base import tasks
3136
from clusterfuzz._internal.base import utils
@@ -65,6 +70,13 @@
6570
# See https://cloud.google.com/batch/quotas#job_limits
6671
MAX_CONCURRENT_VMS_PER_JOB = 1000
6772

73+
MAX_QUEUE_SIZE = 100
74+
75+
76+
class AllRegionsOverloadedError(Exception):
77+
"""Raised when all batch regions are overloaded."""
78+
79+
6880
_local = threading.local()
6981

7082
DEFAULT_RETRY_COUNT = 0
@@ -184,14 +196,58 @@ def count_queued_or_scheduled_tasks(project: str,
184196
return (queued, scheduled)
185197

186198

199+
@memoize.wrap(memoize.InMemory(60))
200+
def get_region_load(project: str, region: str) -> int:
201+
"""Gets the current load (queued and scheduled jobs) for a region."""
202+
creds, _ = credentials.get_default()
203+
if not creds.valid:
204+
creds.refresh(google.auth.transport.requests.Request())
205+
206+
headers = {
207+
'Authorization': f'Bearer {creds.token}',
208+
'Content-Type': 'application/json'
209+
}
210+
211+
try:
212+
url = (f'https://batch.googleapis.com/v1alpha/projects/{project}/locations/'
213+
f'{region}/jobs:countByState?states=QUEUED')
214+
req = urllib.request.Request(url, headers=headers)
215+
with urllib.request.urlopen(req) as response:
216+
if response.status != 200:
217+
logs.error(
218+
f'Batch countByState failed: {response.status} {response.read()}')
219+
return 0
220+
221+
data = json.loads(response.read())
222+
logs.info(f'Batch countByState response for {region}: {data}')
223+
# The API returns a list of state counts.
224+
# Example: { "jobCounts": { "QUEUED": "10" } }
225+
total = 0
226+
227+
# Log data for debugging first few times if needed, or just rely on structure.
228+
# We'll assume the structure is standard for Google APIs.
229+
job_counts = data.get('jobCounts', {})
230+
for state, count in job_counts.items():
231+
count = int(count)
232+
if state == 'QUEUED':
233+
total += count
234+
else:
235+
logs.error(f'Unknown state: {state}')
236+
237+
return total
238+
except Exception as e:
239+
logs.error(f'Failed to get region load for {region}: {e}')
240+
return 0
241+
242+
187243
def _get_batch_config():
188244
"""Returns the batch config. This function was made to make mocking easier."""
189245
return local_config.BatchConfig()
190246

191247

192248
def is_remote_task(command: str, job_name: str) -> bool:
193249
"""Returns whether a task is configured to run remotely on GCP Batch.
194-
250+
195251
This is determined by checking if a valid batch workload specification can
196252
be found for the given command and job type.
197253
"""
@@ -242,15 +298,46 @@ def _get_config_names(batch_tasks: List[remote_task_types.RemoteTask]):
242298

243299

244300
def _get_subconfig(batch_config, instance_spec):
245-
# TODO(metzman): Make this pick one at random or based on conditions.
246301
all_subconfigs = batch_config.get('subconfigs', {})
247302
instance_subconfigs = instance_spec['subconfigs']
248-
weighted_subconfigs = [
249-
WeightedSubconfig(subconfig['name'], subconfig['weight'])
250-
for subconfig in instance_subconfigs
251-
]
252-
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
253-
return all_subconfigs[weighted_subconfig.name]
303+
304+
queue_check_regions = batch_config.get('queue_check_regions')
305+
if not queue_check_regions:
306+
logs.info(
307+
'Skipping batch load check because queue_check_regions is not configured.'
308+
)
309+
weighted_subconfigs = [
310+
WeightedSubconfig(subconfig['name'], subconfig['weight'])
311+
for subconfig in instance_subconfigs
312+
]
313+
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
314+
return all_subconfigs[weighted_subconfig.name]
315+
316+
# Check load for configured regions.
317+
healthy_subconfigs = []
318+
project = batch_config.get('project')
319+
320+
for subconfig in instance_subconfigs:
321+
name = subconfig['name']
322+
conf = all_subconfigs[name]
323+
region = conf['region']
324+
325+
if region in queue_check_regions:
326+
load = get_region_load(project, region)
327+
logs.info(f'Region {region} has {load} queued jobs.')
328+
if load >= MAX_QUEUE_SIZE:
329+
logs.info(f'Region {region} overloaded (load={load}). Skipping.')
330+
continue
331+
332+
healthy_subconfigs.append(name)
333+
334+
if not healthy_subconfigs:
335+
logs.error('All candidate regions are overloaded.')
336+
raise AllRegionsOverloadedError('All candidate regions are overloaded.')
337+
338+
# Randomly pick one from healthy regions to avoid thundering herd.
339+
chosen_name = random.choice(healthy_subconfigs)
340+
return all_subconfigs[chosen_name]
254341

255342

256343
def _get_specs_from_config(
@@ -277,7 +364,6 @@ def _get_specs_from_config(
277364
versioned_images_map = instance_spec.get('versioned_docker_images')
278365
if (base_os_version and versioned_images_map and
279366
base_os_version in versioned_images_map):
280-
# New path: Use the versioned image if specified and available.
281367
docker_image_uri = versioned_images_map[base_os_version]
282368
else:
283369
# Fallback/legacy path: Use the original docker_image key.
@@ -324,7 +410,7 @@ def _get_specs_from_config(
324410

325411
class GcpBatchService(remote_task_types.RemoteTaskInterface):
326412
"""A high-level service for creating and managing remote tasks.
327-
413+
328414
This service provides a simple interface for scheduling ClusterFuzz tasks on
329415
GCP Batch. It handles the details of creating batch jobs and tasks, and
330416
provides a way to check if a task is configured to run remotely.
@@ -383,20 +469,27 @@ def create_utask_main_job(self, module: str, job_type: str,
383469
def create_utask_main_jobs(self,
384470
remote_tasks: List[remote_task_types.RemoteTask]):
385471
"""Creates a batch job for a list of uworker main tasks.
386-
472+
387473
This method groups the tasks by their workload specification and creates a
388474
separate batch job for each group. This allows tasks with similar
389475
requirements to be processed together, which can improve efficiency.
390476
"""
391477
job_specs = collections.defaultdict(list)
392-
specs = _get_specs_from_config(remote_tasks)
478+
try:
479+
specs = _get_specs_from_config(remote_tasks)
480+
481+
# Return the remote tasks as uncreated task
482+
# if all regions are overloaded
483+
except AllRegionsOverloadedError:
484+
return remote_tasks
485+
393486
for remote_task in remote_tasks:
394487
logs.info(f'Scheduling {remote_task.command}, {remote_task.job_type}.')
395488
spec = specs[(remote_task.command, remote_task.job_type)]
396489
job_specs[spec].append(remote_task.input_download_url)
397490

398491
logs.info('Creating batch jobs.')
399-
logs.info('Batching utask_mains.')
492+
400493
for spec, input_urls in job_specs.items():
401494
for input_urls_portion in utils.batched(input_urls,
402495
MAX_CONCURRENT_VMS_PER_JOB - 1):

0 commit comments

Comments
 (0)