11# Copyright 2025 Google LLC
22#
3- # Licensed under the Apache License, Version 2.0
4- # (the "License"); you may not use this file except in compliance with
5- # the License. You may obtain a copy of the License at
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
66#
77# http://www.apache.org/licenses/LICENSE-2.0
88#
1818and provides a simple interface for scheduling ClusterFuzz tasks.
1919"""
2020import collections
21+ import json
22+ import random
2123import threading
2224from typing import Dict
2325from typing import List
2426from typing import Tuple
27+ import urllib .request
2528import uuid
2629
30+ import google .auth .transport .requests
2731from google .cloud import batch_v1 as batch
2832
33+ from clusterfuzz ._internal .base import memoize
2934from clusterfuzz ._internal .base import retry
3035from clusterfuzz ._internal .base import tasks
3136from clusterfuzz ._internal .base import utils
6570# See https://cloud.google.com/batch/quotas#job_limits
6671MAX_CONCURRENT_VMS_PER_JOB = 1000
6772
73+ MAX_QUEUE_SIZE = 100
74+
75+
76+ class AllRegionsOverloadedError (Exception ):
77+ """Raised when all batch regions are overloaded."""
78+
79+
6880_local = threading .local ()
6981
7082DEFAULT_RETRY_COUNT = 0
@@ -184,14 +196,58 @@ def count_queued_or_scheduled_tasks(project: str,
184196 return (queued , scheduled )
185197
186198
199+ @memoize .wrap (memoize .InMemory (60 ))
200+ def get_region_load (project : str , region : str ) -> int :
201+ """Gets the current load (queued and scheduled jobs) for a region."""
202+ creds , _ = credentials .get_default ()
203+ if not creds .valid :
204+ creds .refresh (google .auth .transport .requests .Request ())
205+
206+ headers = {
207+ 'Authorization' : f'Bearer { creds .token } ' ,
208+ 'Content-Type' : 'application/json'
209+ }
210+
211+ try :
212+ url = (f'https://batch.googleapis.com/v1alpha/projects/{ project } /locations/'
213+ f'{ region } /jobs:countByState?states=QUEUED' )
214+ req = urllib .request .Request (url , headers = headers )
215+ with urllib .request .urlopen (req ) as response :
216+ if response .status != 200 :
217+ logs .error (
218+ f'Batch countByState failed: { response .status } { response .read ()} ' )
219+ return 0
220+
221+ data = json .loads (response .read ())
222+ logs .info (f'Batch countByState response for { region } : { data } ' )
223+ # The API returns a list of state counts.
224+ # Example: { "jobCounts": { "QUEUED": "10" } }
225+ total = 0
226+
227+ # Log data for debugging first few times if needed, or just rely on structure.
228+ # We'll assume the structure is standard for Google APIs.
229+ job_counts = data .get ('jobCounts' , {})
230+ for state , count in job_counts .items ():
231+ count = int (count )
232+ if state == 'QUEUED' :
233+ total += count
234+ else :
235+ logs .error (f'Unknown state: { state } ' )
236+
237+ return total
238+ except Exception as e :
239+ logs .error (f'Failed to get region load for { region } : { e } ' )
240+ return 0
241+
242+
187243def _get_batch_config ():
188244 """Returns the batch config. This function was made to make mocking easier."""
189245 return local_config .BatchConfig ()
190246
191247
192248def is_remote_task (command : str , job_name : str ) -> bool :
193249 """Returns whether a task is configured to run remotely on GCP Batch.
194-
250+
195251 This is determined by checking if a valid batch workload specification can
196252 be found for the given command and job type.
197253 """
@@ -242,15 +298,46 @@ def _get_config_names(batch_tasks: List[remote_task_types.RemoteTask]):
242298
243299
244300def _get_subconfig (batch_config , instance_spec ):
245- # TODO(metzman): Make this pick one at random or based on conditions.
246301 all_subconfigs = batch_config .get ('subconfigs' , {})
247302 instance_subconfigs = instance_spec ['subconfigs' ]
248- weighted_subconfigs = [
249- WeightedSubconfig (subconfig ['name' ], subconfig ['weight' ])
250- for subconfig in instance_subconfigs
251- ]
252- weighted_subconfig = utils .random_weighted_choice (weighted_subconfigs )
253- return all_subconfigs [weighted_subconfig .name ]
303+
304+ queue_check_regions = batch_config .get ('queue_check_regions' )
305+ if not queue_check_regions :
306+ logs .info (
307+ 'Skipping batch load check because queue_check_regions is not configured.'
308+ )
309+ weighted_subconfigs = [
310+ WeightedSubconfig (subconfig ['name' ], subconfig ['weight' ])
311+ for subconfig in instance_subconfigs
312+ ]
313+ weighted_subconfig = utils .random_weighted_choice (weighted_subconfigs )
314+ return all_subconfigs [weighted_subconfig .name ]
315+
316+ # Check load for configured regions.
317+ healthy_subconfigs = []
318+ project = batch_config .get ('project' )
319+
320+ for subconfig in instance_subconfigs :
321+ name = subconfig ['name' ]
322+ conf = all_subconfigs [name ]
323+ region = conf ['region' ]
324+
325+ if region in queue_check_regions :
326+ load = get_region_load (project , region )
327+ logs .info (f'Region { region } has { load } queued jobs.' )
328+ if load >= MAX_QUEUE_SIZE :
329+ logs .info (f'Region { region } overloaded (load={ load } ). Skipping.' )
330+ continue
331+
332+ healthy_subconfigs .append (name )
333+
334+ if not healthy_subconfigs :
335+ logs .error ('All candidate regions are overloaded.' )
336+ raise AllRegionsOverloadedError ('All candidate regions are overloaded.' )
337+
338+ # Randomly pick one from healthy regions to avoid thundering herd.
339+ chosen_name = random .choice (healthy_subconfigs )
340+ return all_subconfigs [chosen_name ]
254341
255342
256343def _get_specs_from_config (
@@ -277,7 +364,6 @@ def _get_specs_from_config(
277364 versioned_images_map = instance_spec .get ('versioned_docker_images' )
278365 if (base_os_version and versioned_images_map and
279366 base_os_version in versioned_images_map ):
280- # New path: Use the versioned image if specified and available.
281367 docker_image_uri = versioned_images_map [base_os_version ]
282368 else :
283369 # Fallback/legacy path: Use the original docker_image key.
@@ -324,7 +410,7 @@ def _get_specs_from_config(
324410
325411class GcpBatchService (remote_task_types .RemoteTaskInterface ):
326412 """A high-level service for creating and managing remote tasks.
327-
413+
328414 This service provides a simple interface for scheduling ClusterFuzz tasks on
329415 GCP Batch. It handles the details of creating batch jobs and tasks, and
330416 provides a way to check if a task is configured to run remotely.
@@ -383,20 +469,27 @@ def create_utask_main_job(self, module: str, job_type: str,
383469 def create_utask_main_jobs (self ,
384470 remote_tasks : List [remote_task_types .RemoteTask ]):
385471 """Creates a batch job for a list of uworker main tasks.
386-
472+
387473 This method groups the tasks by their workload specification and creates a
388474 separate batch job for each group. This allows tasks with similar
389475 requirements to be processed together, which can improve efficiency.
390476 """
391477 job_specs = collections .defaultdict (list )
392- specs = _get_specs_from_config (remote_tasks )
478+ try :
479+ specs = _get_specs_from_config (remote_tasks )
480+
481+ # Return the remote tasks as uncreated task
482+ # if all regions are overloaded
483+ except AllRegionsOverloadedError :
484+ return remote_tasks
485+
393486 for remote_task in remote_tasks :
394487 logs .info (f'Scheduling { remote_task .command } , { remote_task .job_type } .' )
395488 spec = specs [(remote_task .command , remote_task .job_type )]
396489 job_specs [spec ].append (remote_task .input_download_url )
397490
398491 logs .info ('Creating batch jobs.' )
399- logs . info ( 'Batching utask_mains.' )
492+
400493 for spec , input_urls in job_specs .items ():
401494 for input_urls_portion in utils .batched (input_urls ,
402495 MAX_CONCURRENT_VMS_PER_JOB - 1 ):
0 commit comments