diff --git a/src/web_app_handler.py b/src/web_app_handler.py index 5b96c07..69bec09 100644 --- a/src/web_app_handler.py +++ b/src/web_app_handler.py @@ -6,9 +6,11 @@ import os from typing import NamedTuple +import requests + from .github_app import GithubAppToken from .github_sdk import GithubClient -from src.sentry_config import fetch_dsn_for_github_org +from .workflow_job_collector import WorkflowJobCollector LOGGING_LEVEL = os.environ.get("LOGGING_LEVEL", logging.INFO) logger = logging.getLogger(__name__) @@ -16,49 +18,149 @@ class WebAppHandler: + """ + Handles GitHub webhook events for workflow job completion. + + Supports both hierarchical workflow tracing (new) and individual job tracing (legacy). + The mode is controlled by the ENABLE_HIERARCHICAL_TRACING environment variable. + """ + def __init__(self, dry_run=False): + """ + Initialize the WebAppHandler. + + Args: + dry_run: If True, simulates operations without sending traces + """ self.config = init_config() self.dry_run = dry_run + self.job_collectors = {} # org -> WorkflowJobCollector + + def _get_job_collector(self, org: str, token: str, dsn: str) -> WorkflowJobCollector: + """ + Get or create a job collector for the organization. + + Args: + org: GitHub organization name + token: GitHub API token + dsn: Sentry DSN for trace submission + + Returns: + WorkflowJobCollector instance for the organization + """ + if org not in self.job_collectors: + self.job_collectors[org] = WorkflowJobCollector(dsn, token, self.dry_run) + return self.job_collectors[org] + + def _send_legacy_trace(self, data: dict, org: str, token: str, dsn: str) -> None: + """ + Send individual job trace (legacy behavior). + + Args: + data: GitHub webhook job payload + org: GitHub organization name + token: GitHub API token + dsn: Sentry DSN for trace submission + """ + logger.info(f"Using legacy individual job tracing for org '{org}'") + github_client = GithubClient(token, dsn, self.dry_run) + github_client.send_trace(data) def handle_event(self, data, headers): - # We return 200 to make webhook not turn red since everything got processed well + """ + Handle GitHub webhook events. + + Supports both hierarchical workflow tracing (new) and individual job tracing (legacy). + The mode is determined by feature flags and organization settings. + + Args: + data: GitHub webhook payload + headers: HTTP headers from the webhook request + + Returns: + Tuple of (reason, http_code) + """ http_code = 200 reason = "OK" - if headers["X-GitHub-Event"] != "workflow_job": + # Flask normalizes headers - try both original and normalized names + github_event = headers.get("X-GitHub-Event") or headers.get("X-GITHUB-EVENT") or headers.get("HTTP_X_GITHUB_EVENT") + if not github_event: + reason = "Missing X-GitHub-Event header." + http_code = 400 + logger.warning("Missing X-GitHub-Event header") + elif github_event != "workflow_job": reason = "Event not supported." - elif data["action"] != "completed": + logger.info(f"Event '{github_event}' not supported, only 'workflow_job' is supported") + elif data.get("action") != "completed": reason = "We cannot do anything with this workflow state." + logger.info(f"Action '{data.get('action')}' not supported, only 'completed' is supported") else: - # For now, this simplifies testing + # Log webhook received + workflow_job = data.get("workflow_job", {}) + run_id = workflow_job.get("run_id") + job_id = workflow_job.get("id") + job_name = workflow_job.get("name") + logger.info(f"Received webhook for workflow run {run_id}, job '{job_name}' (ID: {job_id})") if self.dry_run: return reason, http_code - installation_id = data["installation"]["id"] + # Handle missing installation field (for webhook testing) + installation_id = data.get("installation", {}).get("id") org = data["repository"]["owner"]["login"] - # We are executing in Github App mode - if self.config.gh_app: - with GithubAppToken(**self.config.gh_app._asdict()).get_token( - installation_id - ) as token: - # Once the Sentry org has a .sentry repo we can remove the DSN from the deployment - dsn = fetch_dsn_for_github_org(org, token) - client = GithubClient( - token=token, - dsn=dsn, - dry_run=self.dry_run, - ) - client.send_trace(data["workflow_job"]) + # For webhook testing, use the DSN directly from environment + dsn = os.environ.get("APP_DSN") + if not dsn: + reason = "No DSN configured for webhook testing" + http_code = 500 else: - # Once the Sentry org has a .sentry repo we can remove the DSN from the deployment - dsn = fetch_dsn_for_github_org(org, token) - client = GithubClient( - token=self.config.gh.token, - dsn=dsn, - dry_run=self.dry_run, - ) - client.send_trace(data["workflow_job"]) + # Get GitHub App installation token if available, otherwise fall back to PAT or None + token = None + if installation_id and self.config.gh_app: + try: + # Generate installation token from GitHub App + github_app_token = GithubAppToken( + self.config.gh_app.private_key, + self.config.gh_app.app_id + ) + # Note: We use the token immediately, not in a context manager + # because we need it to persist for the WorkflowJobCollector + # The token expires after 1 hour, which is fine for our use case + token_response = requests.post( + url=f"https://api.github.com/app/installations/{installation_id}/access_tokens", + headers=github_app_token.headers, + ) + token_response.raise_for_status() + token = token_response.json()["token"] + logger.debug(f"Generated GitHub App installation token for org '{org}'") + except Exception as e: + logger.warning( + f"Failed to generate GitHub App installation token for org '{org}': {e}. " + "Falling back to timeout-based job detection." + ) + # Fall back to PAT if available + token = self.config.gh.token + elif self.config.gh.token: + # Use PAT if GitHub App is not configured + token = self.config.gh.token + logger.debug(f"Using PAT token for org '{org}'") + else: + logger.debug( + f"No token available for org '{org}'. " + "Will use timeout-based job detection." + ) + + # Get job collector and check if hierarchical tracing is enabled + collector = self._get_job_collector(org, token, dsn) + + if collector.is_hierarchical_tracing_enabled(org): + # Use new hierarchical workflow tracing + logger.debug(f"Using hierarchical workflow tracing for org '{org}'") + collector.add_job(data) + else: + # Fall back to legacy individual job tracing + self._send_legacy_trace(data, org, token, dsn) return reason, http_code @@ -66,7 +168,11 @@ def valid_signature(self, body, headers): if not self.config.gh.webhook_secret: return True else: - signature = headers["X-Hub-Signature-256"].replace("sha256=", "") + # Flask normalizes headers - try both original and normalized names + signature_header = headers.get("X-Hub-Signature-256") or headers.get("X-HUB-SIGNATURE-256") or headers.get("HTTP_X_HUB_SIGNATURE_256") + if not signature_header: + return False + signature = signature_header.replace("sha256=", "") body_signature = hmac.new( self.config.gh.webhook_secret.encode(), msg=body, diff --git a/src/workflow_job_collector.py b/src/workflow_job_collector.py new file mode 100644 index 0000000..6233ee6 --- /dev/null +++ b/src/workflow_job_collector.py @@ -0,0 +1,557 @@ +""" +Workflow Job Collector Module + +This module handles the collection and aggregation of GitHub workflow jobs, +determining when a workflow is complete and triggering trace submissions. +""" + +from __future__ import annotations + +import logging +import os +import threading +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import requests + +from .workflow_tracer import WorkflowTracer + +# Configuration Constants +LOGGING_LEVEL = os.environ.get("LOGGING_LEVEL", logging.INFO) + +# Timing Constants (in seconds) +SMALL_WORKFLOW_PROCESSING_DELAY = 2.0 # Delay for workflows with few jobs +NO_NEW_JOBS_TIMEOUT = 7.0 # Time to wait before assuming no more jobs +MAX_WORKFLOW_WAIT_TIME = 300.0 # Maximum time to wait for workflow completion (5 minutes) + +# Feature Flags +ENABLE_HIERARCHICAL_TRACING = os.environ.get("ENABLE_HIERARCHICAL_TRACING", "true").lower() == "true" +SENTRY_ORG_ONLY = os.environ.get("HIERARCHICAL_TRACING_SENTRY_ORG_ONLY", "false").lower() == "true" + +logger = logging.getLogger(__name__) +logger.setLevel(LOGGING_LEVEL) + + +class WorkflowJobCollector: + """ + Collects jobs from a workflow run and sends workflow-level transactions. + + This class aggregates jobs from GitHub workflow runs and determines when + a workflow is complete. Once complete, it sends a single hierarchical + trace containing all jobs and steps as spans. + + Attributes: + dsn: Sentry DSN for trace submission + token: GitHub API token + dry_run: If True, no traces are sent + workflow_jobs: Maps run_id to list of jobs + processed_jobs: Set of job IDs already processed + workflow_timers: Active timers for workflow processing + processed_workflows: Set of workflow run IDs already sent + job_arrival_times: Tracks when jobs arrive for smart detection + """ + + def __init__(self, dsn: str, token: Optional[str], dry_run: bool = False): + """ + Initialize the WorkflowJobCollector. + + Args: + dsn: Sentry DSN for trace submission + token: GitHub API token for fetching workflow data + dry_run: If True, simulates operations without sending traces + """ + self.dsn = dsn + self.token = token + self.dry_run = dry_run + + # State tracking + self.workflow_jobs: Dict[int, List[Dict[str, Any]]] = defaultdict(list) + self.workflow_repositories: Dict[int, Dict[str, Any]] = {} # run_id -> repository info + self.processed_jobs: set[int] = set() + self.processed_workflows: set[int] = set() + self.job_arrival_times: Dict[int, List[float]] = defaultdict(list) + self.workflow_timers: Dict[int, threading.Timer] = {} + self.workflow_total_jobs: Dict[int, Optional[int]] = {} # run_id -> total_count (None = not fetched yet) + + # Thread safety + self._lock = threading.Lock() + + # Initialize workflow tracer + self.workflow_tracer = WorkflowTracer(token, dsn, dry_run) + + logger.info( + f"WorkflowJobCollector initialized (dry_run={dry_run}, " + f"hierarchical_tracing={ENABLE_HIERARCHICAL_TRACING})" + ) + + def add_job(self, job_data: Dict[str, Any]) -> None: + """ + Add a job to the collector and check if workflow is complete. + + Args: + job_data: GitHub webhook job payload + """ + job = job_data["workflow_job"] + run_id = job["run_id"] + job_id = job["id"] + + with self._lock: + if self._is_job_already_processed(job_id): + return + + # Store repository info from webhook payload if available + if run_id not in self.workflow_repositories and "repository" in job_data: + self.workflow_repositories[run_id] = job_data["repository"] + + # Fetch total job count from GitHub API when first job arrives + if run_id not in self.workflow_total_jobs: + repo_full_name = "unknown/unknown" + if run_id in self.workflow_repositories: + repo_full_name = self.workflow_repositories[run_id].get("full_name", repo_full_name) + elif "repository" in job_data: + repo_full_name = job_data["repository"].get("full_name", repo_full_name) + + total_count = self._fetch_total_job_count(run_id, repo_full_name) + self.workflow_total_jobs[run_id] = total_count + + self._record_job(run_id, job_id, job) + + # Check if we should process now or schedule for later + if self._should_process_workflow_now(run_id): + self._schedule_workflow_processing(run_id) + elif self._should_schedule_timeout_check(run_id): + # All jobs complete but timeout hasn't elapsed - schedule a check after timeout + self._schedule_timeout_check(run_id) + + def _is_job_already_processed(self, job_id: int) -> bool: + """Check if a job has already been processed.""" + if job_id in self.processed_jobs: + logger.debug(f"Job {job_id} already processed, skipping") + return True + return False + + def _record_job(self, run_id: int, job_id: int, job: Dict[str, Any]) -> None: + """Record a new job arrival.""" + self.processed_jobs.add(job_id) + self.workflow_jobs[run_id].append(job) + self.job_arrival_times[run_id].append(time.time()) + + logger.info( + f"Added job '{job['name']}' (ID: {job_id}) to workflow run {run_id} " + f"(total jobs: {len(self.workflow_jobs[run_id])})" + ) + + def _fetch_total_job_count(self, run_id: int, repo_full_name: str) -> Optional[int]: + """ + Fetch the total number of jobs for a workflow run from GitHub API. + + Args: + run_id: Workflow run ID + repo_full_name: Repository full name (e.g., "owner/repo") + + Returns: + Total number of jobs, or None if fetch fails or token is missing + """ + if not self.token: + logger.debug(f"No GitHub token available, cannot fetch total job count for run {run_id}") + return None + + try: + # GitHub API endpoint: GET /repos/{owner}/{repo}/actions/runs/{run_id}/jobs + api_url = f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/jobs" + headers = {"Authorization": f"token {self.token}"} + + logger.debug(f"Fetching total job count for workflow run {run_id} from GitHub API") + response = requests.get(api_url, headers=headers, timeout=5) + response.raise_for_status() + + data = response.json() + total_count = data.get("total_count") + + if total_count is not None: + logger.info( + f"Fetched total job count for workflow run {run_id}: {total_count} jobs" + ) + return total_count + else: + logger.warning( + f"GitHub API response missing 'total_count' for workflow run {run_id}" + ) + return None + + except requests.exceptions.RequestException as e: + logger.warning( + f"Failed to fetch total job count for workflow run {run_id}: {e}. " + "Will fall back to timeout-based detection." + ) + return None + except Exception as e: + logger.error( + f"Unexpected error fetching total job count for workflow run {run_id}: {e}", + exc_info=True + ) + return None + + def _should_process_workflow_now(self, run_id: int) -> bool: + """ + Determine if a workflow should be processed now. + + Uses smart detection based on: + - Job completion status + - Number of jobs + - Time since last job arrival + - Maximum wait time + + Args: + run_id: Workflow run ID + + Returns: + True if workflow should be processed, False otherwise + """ + if run_id in self.processed_workflows: + logger.debug(f"Workflow {run_id} already processed") + return False + + jobs = self.workflow_jobs[run_id] + jobs_count = len(jobs) + + # Check if all jobs are completed + if not self._all_jobs_completed(jobs): + incomplete_jobs = [j.get("name") for j in jobs if j.get("conclusion") is None] + logger.debug(f"Workflow {run_id} has {jobs_count} jobs, {len(incomplete_jobs)} incomplete: {incomplete_jobs}") + return False + + # Check if we know the total job count and if we have all jobs + total_count = self.workflow_total_jobs.get(run_id) + if total_count is not None: + # We have the total count from API, check if we've received all jobs + if jobs_count >= total_count: + logger.info( + f"Workflow {run_id} has all {total_count} jobs (received {jobs_count}), " + "ready to process" + ) + # Still need to wait for timeout to ensure all jobs are complete + return self._meets_processing_threshold(run_id, jobs_count) + else: + logger.debug( + f"Workflow {run_id} has {jobs_count}/{total_count} jobs, " + "waiting for more jobs to arrive" + ) + return False + # If total_count is None, fall back to timeout-based detection + + # Check if maximum wait time exceeded + if self._is_workflow_timeout_exceeded(run_id): + logger.warning( + f"Workflow {run_id} exceeded maximum wait time " + f"({MAX_WORKFLOW_WAIT_TIME}s), processing with {jobs_count} jobs" + ) + return True + + # Determine based on workflow size + return self._meets_processing_threshold(run_id, jobs_count) + + def _all_jobs_completed(self, jobs: List[Dict[str, Any]]) -> bool: + """Check if all jobs in a list are completed.""" + return all(job.get("conclusion") is not None for job in jobs) + + def _is_workflow_timeout_exceeded(self, run_id: int) -> bool: + """Check if the workflow has exceeded the maximum wait time.""" + arrival_times = self.job_arrival_times[run_id] + if not arrival_times: + return False + + first_arrival = arrival_times[0] + time_elapsed = time.time() - first_arrival + return time_elapsed > MAX_WORKFLOW_WAIT_TIME + + def _meets_processing_threshold(self, run_id: int, jobs_count: int) -> bool: + """ + Check if the workflow meets the processing threshold. + + Always waits for the timeout to ensure all jobs are collected before processing. + This prevents sending incomplete traces when jobs arrive at different times. + + Args: + run_id: Workflow run ID + jobs_count: Number of jobs in the workflow + + Returns: + True if timeout elapsed and we should process, False otherwise + """ + # Always wait for timeout to ensure we collect all jobs + # Jobs can arrive at different times, so we need to wait for the timeout + # before assuming no more jobs will arrive + return self._has_job_arrival_timeout_elapsed(run_id, jobs_count) + + def _has_job_arrival_timeout_elapsed(self, run_id: int, jobs_count: int) -> bool: + """ + Check if enough time has passed since the last job arrival. + + Args: + run_id: Workflow run ID + jobs_count: Number of jobs in the workflow + + Returns: + True if timeout elapsed, False otherwise + """ + arrival_times = self.job_arrival_times[run_id] + + # Always wait for timeout to ensure we collect all jobs + # Don't immediately process single-job workflows as they might be multi-job + if len(arrival_times) >= 1: + time_since_last_job = time.time() - arrival_times[-1] + if time_since_last_job > NO_NEW_JOBS_TIMEOUT: + logger.info( + f"No new jobs for {time_since_last_job:.1f}s " + f"(threshold: {NO_NEW_JOBS_TIMEOUT}s), processing {jobs_count} job(s)" + ) + return True + + return False + + def _should_schedule_timeout_check(self, run_id: int) -> bool: + """ + Check if we should schedule a timeout check. + + This is called when all jobs are complete but the timeout hasn't elapsed yet. + We need to schedule a timer to check again after the timeout period. + + Args: + run_id: Workflow run ID + + Returns: + True if we should schedule a timeout check, False otherwise + """ + if run_id in self.processed_workflows: + return False + + jobs = self.workflow_jobs[run_id] + if not self._all_jobs_completed(jobs): + return False + + # Check if we know the total job count and if we have all jobs + total_count = self.workflow_total_jobs.get(run_id) + if total_count is not None: + jobs_count = len(jobs) + if jobs_count < total_count: + # We haven't received all jobs yet, don't schedule timeout check + logger.debug( + f"Workflow {run_id} has {jobs_count}/{total_count} jobs, " + "not scheduling timeout check yet" + ) + return False + # We have all jobs, proceed with timeout check scheduling + + # Check if timeout hasn't elapsed yet + arrival_times = self.job_arrival_times[run_id] + if len(arrival_times) >= 1: + time_since_last_job = time.time() - arrival_times[-1] + if time_since_last_job <= NO_NEW_JOBS_TIMEOUT: + # Timeout hasn't elapsed, and we don't already have a timer scheduled + if run_id not in self.workflow_timers: + return True + + return False + + def _schedule_timeout_check(self, run_id: int) -> None: + """ + Schedule a check after the timeout period to process the workflow. + + Args: + run_id: Workflow run ID + """ + arrival_times = self.job_arrival_times[run_id] + if not arrival_times: + return + + time_since_last_job = time.time() - arrival_times[-1] + remaining_time = NO_NEW_JOBS_TIMEOUT - time_since_last_job + + if remaining_time > 0: + logger.info( + f"Scheduling timeout check for workflow {run_id} " + f"in {remaining_time:.1f}s (timeout: {NO_NEW_JOBS_TIMEOUT}s)" + ) + + timer = threading.Timer( + remaining_time, + self._process_workflow_immediately, + args=[run_id] + ) + self.workflow_timers[run_id] = timer + timer.start() + + def _schedule_workflow_processing(self, run_id: int) -> None: + """Schedule workflow processing with a short delay to collect all jobs.""" + jobs_count = len(self.workflow_jobs[run_id]) + + logger.info( + f"Scheduling workflow {run_id} for processing " + f"(jobs: {jobs_count}, delay: {SMALL_WORKFLOW_PROCESSING_DELAY}s)" + ) + + timer = threading.Timer( + SMALL_WORKFLOW_PROCESSING_DELAY, + self._process_workflow_immediately, + args=[run_id] + ) + self.workflow_timers[run_id] = timer + timer.start() + + def _process_workflow_immediately(self, run_id: int) -> None: + """ + Process workflow immediately when triggered by timer. + + This method is called by a timer thread and includes exception + handling to prevent silent failures and resource leaks. + + Args: + run_id: Workflow run ID to process + """ + try: + with self._lock: + if not self._should_process_workflow(run_id): + return + + jobs = self.workflow_jobs[run_id] + logger.info( + f"Processing workflow run {run_id} with {len(jobs)} jobs" + ) + + if self._all_jobs_completed(jobs): + logger.info( + f"All {len(jobs)} jobs complete for workflow {run_id}, sending trace" + ) + self._send_workflow_trace(run_id) + else: + logger.warning( + f"Not all jobs complete for workflow {run_id}, skipping" + ) + except Exception as e: + logger.error( + f"Error processing workflow run {run_id}: {e}", + exc_info=True + ) + # Ensure cleanup happens even if there's an exception + self._cleanup_workflow_run(run_id) + + def _should_process_workflow(self, run_id: int) -> bool: + """Check if a workflow should be processed (not already processed).""" + if run_id in self.processed_workflows: + logger.debug(f"Workflow run {run_id} already processed, skipping") + return False + + if not self.workflow_jobs[run_id]: + logger.warning(f"No jobs found for workflow run {run_id}") + return False + + return True + + def _send_workflow_trace(self, run_id: int) -> None: + """ + Send workflow-level trace for all jobs in the run. + + Args: + run_id: Workflow run ID to send trace for + """ + if run_id in self.processed_workflows: + logger.warning( + f"Workflow run {run_id} already processed, " + "skipping to prevent duplicates" + ) + return + + jobs = self.workflow_jobs[run_id] + if not jobs: + logger.warning(f"No jobs found for workflow run {run_id}") + return + + logger.info( + f"Sending workflow trace for run {run_id} with {len(jobs)} jobs" + ) + + try: + base_job = jobs[0] + # Pass repository info if available + repository_info = self.workflow_repositories.get(run_id) + self.workflow_tracer.send_workflow_trace(base_job, jobs, repository_info=repository_info) + logger.info( + f"Successfully sent workflow trace for run {run_id}" + ) + except Exception as e: + logger.error( + f"Failed to send workflow trace for run {run_id}: {e}", + exc_info=True + ) + logger.warning( + "NOT falling back to individual traces to prevent duplicates" + ) + finally: + self._cleanup_workflow_run(run_id) + + def _cleanup_workflow_run(self, run_id: int) -> None: + """ + Clean up workflow run data to prevent resource leaks. + + This method should always be called after processing a workflow, + whether successful or not, to ensure proper cleanup. + + Args: + run_id: Workflow run ID to clean up + """ + try: + with self._lock: + self.processed_workflows.add(run_id) + + if run_id in self.workflow_jobs: + del self.workflow_jobs[run_id] + + if run_id in self.workflow_timers: + self.workflow_timers[run_id].cancel() + del self.workflow_timers[run_id] + + if run_id in self.job_arrival_times: + del self.job_arrival_times[run_id] + + if run_id in self.workflow_repositories: + del self.workflow_repositories[run_id] + + if run_id in self.workflow_total_jobs: + del self.workflow_total_jobs[run_id] + + logger.debug(f"Cleaned up workflow run {run_id}") + except Exception as cleanup_error: + logger.error( + f"Error during cleanup of workflow run {run_id}: {cleanup_error}", + exc_info=True + ) + + def is_hierarchical_tracing_enabled(self, org: str) -> bool: + """ + Check if hierarchical tracing is enabled for an organization. + + Args: + org: GitHub organization name + + Returns: + True if hierarchical tracing should be used, False otherwise + """ + if not ENABLE_HIERARCHICAL_TRACING: + return False + + if SENTRY_ORG_ONLY and org.lower() != "getsentry": + logger.debug( + f"Hierarchical tracing disabled for org '{org}' " + "(SENTRY_ORG_ONLY mode enabled)" + ) + return False + + return True + + + + + diff --git a/src/workflow_tracer.py b/src/workflow_tracer.py new file mode 100644 index 0000000..2686a4e --- /dev/null +++ b/src/workflow_tracer.py @@ -0,0 +1,454 @@ +""" +Enhanced workflow tracing that creates a parent workflow transaction +to encapsulate all jobs and provide total workflow duration +""" + +import json +import logging +import uuid +import hashlib +from datetime import datetime +from typing import Dict, List, Any, Optional +import requests +try: + from sentry_sdk.envelope import Envelope + from sentry_sdk.utils import format_timestamp +except ImportError: + # Fallback for testing + class Envelope: + def add_transaction(self, transaction): pass + def serialize_into(self, f): pass + + def format_timestamp(dt): + return dt.isoformat() + "Z" + + +def get_uuid(): + return uuid.uuid4().hex + + +def get_uuid_from_string(input_string): + hash_object = hashlib.sha256(input_string.encode()) + hash_value = hash_object.hexdigest() + return uuid.UUID(hash_value[:32]).hex + + +class WorkflowTracer: + """Enhanced tracer that creates workflow-level transactions""" + + def __init__(self, token: Optional[str], dsn: str, dry_run: bool = False): + self.token = token + self.dsn = dsn + self.dry_run = dry_run + self.workflow_cache = {} # Cache workflow runs to avoid duplicate API calls + + if dsn: + # Parse DSN: https://key@host/project_id + dsn_parts = dsn.split("@") + if len(dsn_parts) != 2: + raise ValueError(f"Invalid DSN format: {dsn}") + + sentry_key = dsn_parts[0].split("//")[1] + host_and_project = dsn_parts[1] + + # Split host and project_id + host_parts = host_and_project.split("/") + if len(host_parts) != 2: + raise ValueError(f"Invalid DSN format: {dsn}") + + host = host_parts[0] + project_id = host_parts[1] + + self.sentry_key = sentry_key + self.sentry_project_url = f"https://{host}/api/{project_id}/envelope/" + + def _fetch_github(self, url: str) -> requests.Response: + """Fetch data from GitHub API""" + headers = {"Authorization": f"token {self.token}"} + req = requests.get(url, headers=headers) + req.raise_for_status() + return req + + def _get_workflow_run_data(self, job: Dict[str, Any], repository_info: Dict[str, Any] = None) -> Dict[str, Any]: + """Get workflow run data, with caching""" + run_id = job["run_id"] + + if run_id not in self.workflow_cache: + # Extract repository info from webhook payload or use defaults + repo_full_name = "unknown/unknown" + if repository_info: + repo_full_name = repository_info.get("full_name", repo_full_name) + elif "repository" in job: + repo_full_name = job["repository"].get("full_name", repo_full_name) + + # Extract workflow info from job or use defaults + workflow_name = job.get("workflow_name") or job.get("name", "Unknown Workflow") + workflow_path = job.get("workflow_path", ".github/workflows/workflow.yml") + + # Extract commit author info if available + author_name = "GitHub Actions" + author_email = "actions@github.com" + if "head_commit" in job and "author" in job.get("head_commit", {}): + author_name = job["head_commit"]["author"].get("name", author_name) + author_email = job["head_commit"]["author"].get("email", author_email) + + # Fetch workflow run details from GitHub API to get created_at and updated_at + workflow_run_created_at = None + workflow_run_updated_at = None + if self.token and repo_full_name != "unknown/unknown": + try: + # GitHub API endpoint: GET /repos/{owner}/{repo}/actions/runs/{run_id} + api_url = f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}" + headers = {"Authorization": f"token {self.token}"} + + logging.debug(f"Fetching workflow run details for run {run_id} from GitHub API") + response = self._fetch_github(api_url) + run_data = response.json() + + workflow_run_created_at = run_data.get("created_at") + workflow_run_updated_at = run_data.get("updated_at") + + if workflow_run_created_at and workflow_run_updated_at: + logging.debug( + f"Fetched workflow run timestamps: created_at={workflow_run_created_at}, " + f"updated_at={workflow_run_updated_at}" + ) + else: + logging.warning( + f"Workflow run API response missing timestamps for run {run_id}, " + "will fall back to job timestamps" + ) + except Exception as e: + logging.warning( + f"Failed to fetch workflow run details for run {run_id}: {e}. " + "Will fall back to job timestamps." + ) + + self.workflow_cache[run_id] = { + "runs": { + "head_commit": { + "author": {"name": author_name, "email": author_email} + }, + "head_branch": job.get("head_branch", "main"), + "head_sha": job.get("head_sha", "unknown"), + "run_attempt": job.get("run_attempt", 1), + "html_url": f"https://github.com/{repo_full_name}/actions/runs/{run_id}", + "repository": {"full_name": repo_full_name}, + "created_at": workflow_run_created_at, + "updated_at": workflow_run_updated_at + }, + "workflow": { + "name": workflow_name, + "path": workflow_path + }, + "repo": repo_full_name + } + + return self.workflow_cache[run_id] + + def _create_workflow_transaction(self, job: Dict[str, Any], all_jobs: List[Dict[str, Any]], repository_info: Dict[str, Any] = None) -> Dict[str, Any]: + """Create a single workflow transaction with job spans""" + workflow_data = self._get_workflow_run_data(job, repository_info) + runs = workflow_data["runs"] + workflow = workflow_data["workflow"] + repo = workflow_data["repo"] + + # Determine overall workflow status + job_conclusions = [j.get("conclusion") for j in all_jobs] + if "failure" in job_conclusions: + workflow_status = "internal_error" + elif "cancelled" in job_conclusions: + workflow_status = "cancelled" + elif "skipped" in job_conclusions: + workflow_status = "skipped" + else: + workflow_status = "ok" + + # Prepare data and tags similar to github_sdk.py structure + workflow_data_dict = { + "workflow_url": runs["html_url"], + } + workflow_tags = { + "branch": runs["head_branch"], + "commit": runs["head_sha"], + "repo": repo, + "run_attempt": runs["run_attempt"], + "workflow": workflow["path"].rsplit("/")[-1], + } + + # Add PR info if available + if runs.get("pull_requests"): + pr_number = runs["pull_requests"][0]["number"] + workflow_data_dict["pr"] = f"https://github.com/{repo}/pull/{pr_number}" + workflow_tags["pull_request"] = pr_number + + # Re-fetch workflow run data to get the final updated_at timestamp + # This ensures we have the most up-to-date completion time + workflow_run_created_at = runs.get("created_at") + workflow_run_updated_at = runs.get("updated_at") + + # Re-fetch workflow run details right before sending to get final updated_at + repo_full_name = repo + if self.token and repo_full_name != "unknown/unknown" and job.get("run_id"): + try: + run_id = job["run_id"] + api_url = f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}" + logging.debug(f"Re-fetching workflow run details for run {run_id} to get final updated_at") + response = self._fetch_github(api_url) + run_data = response.json() + + # Update with the latest timestamps + workflow_run_created_at = run_data.get("created_at") + workflow_run_updated_at = run_data.get("updated_at") + + # Update cache with latest data + if run_id in self.workflow_cache: + self.workflow_cache[run_id]["runs"]["created_at"] = workflow_run_created_at + self.workflow_cache[run_id]["runs"]["updated_at"] = workflow_run_updated_at + + logging.debug( + f"Re-fetched workflow run timestamps: created_at={workflow_run_created_at}, " + f"updated_at={workflow_run_updated_at}" + ) + except Exception as e: + logging.warning( + f"Failed to re-fetch workflow run details for run {job.get('run_id')}: {e}. " + "Using cached timestamps." + ) + + # Calculate workflow timestamps - prefer workflow run created_at/updated_at to match GitHub's duration + # Fall back to job timestamps if workflow run timestamps are not available + if workflow_run_created_at and workflow_run_updated_at: + # Use workflow run timestamps (matches GitHub's duration calculation) + workflow_start_str = workflow_run_created_at + workflow_end_str = workflow_run_updated_at + logging.debug( + f"Using workflow run timestamps: {workflow_start_str} -> {workflow_end_str} " + f"(matches GitHub's duration calculation)" + ) + else: + # Fall back to job timestamps (earliest job start to latest job end) + workflow_start_str = min([j["started_at"] for j in all_jobs if j.get("started_at")], default=all_jobs[0]["started_at"]) + workflow_end_str = max([j["completed_at"] for j in all_jobs if j.get("completed_at")], default=all_jobs[0]["completed_at"]) + logging.debug( + f"Using job timestamps (fallback): {workflow_start_str} -> {workflow_end_str}" + ) + + # Create workflow transaction matching github_sdk.py structure + workflow_transaction = { + "event_id": get_uuid(), + "type": "transaction", + "transaction": f"workflow: {workflow['name']}", + "contexts": { + "trace": { + "span_id": get_uuid()[:16], + "trace_id": get_uuid_from_string( + f"workflow_run_id:{job['run_id']}_run_attempt:{job['run_attempt']}" + ), + "type": "trace", + "op": f"workflow: {workflow['name']}", + "description": f"workflow: {workflow['name']}", + "status": workflow_status, + "data": workflow_data_dict + } + }, + "user": runs["head_commit"]["author"], + "start_timestamp": workflow_start_str, + "timestamp": workflow_end_str, + "tags": workflow_tags, + "spans": [] + } + + # Calculate cleanup/teardown time (delta between last job completion and workflow updated_at) + cleanup_duration_seconds = 0 + if workflow_run_created_at and workflow_run_updated_at: + # Find the latest job completion time + latest_job_completion = max([j["completed_at"] for j in all_jobs if j.get("completed_at")], default=None) + if latest_job_completion: + try: + from datetime import datetime + latest_job_dt = datetime.fromisoformat(latest_job_completion.replace("Z", "+00:00")) + workflow_end_dt = datetime.fromisoformat(workflow_run_updated_at.replace("Z", "+00:00")) + cleanup_duration_seconds = (workflow_end_dt - latest_job_dt).total_seconds() + + if cleanup_duration_seconds > 0: + logging.debug( + f"Cleanup/teardown time: {cleanup_duration_seconds:.1f}s " + f"(between last job completion and workflow completion)" + ) + except Exception as e: + logging.warning(f"Failed to calculate cleanup duration: {e}") + + # Add job spans to the workflow transaction + workflow_span_id = workflow_transaction["contexts"]["trace"]["span_id"] + workflow_trace_id = workflow_transaction["contexts"]["trace"]["trace_id"] + + for job_data in all_jobs: + # Create job span matching github_sdk.py format + job_span = { + "op": job_data["name"], + "name": job_data["name"], + "parent_span_id": workflow_span_id, + "span_id": get_uuid()[:16], + "start_timestamp": job_data["started_at"], + "timestamp": job_data["completed_at"], + "trace_id": workflow_trace_id, + } + workflow_transaction["spans"].append(job_span) + + # Add step spans as children of job span, matching github_sdk.py format + for step in job_data.get("steps", []): + try: + step_span = { + "op": step["name"], + "name": step["name"], + "parent_span_id": job_span["span_id"], + "span_id": get_uuid()[:16], + "start_timestamp": step["started_at"], + "timestamp": step["completed_at"], + "trace_id": workflow_trace_id, + } + workflow_transaction["spans"].append(step_span) + except Exception as e: + logging.exception(e) + + # Add cleanup/teardown span if there's a delta between last job and workflow completion + if cleanup_duration_seconds > 0: + # Find the latest job completion timestamp to start cleanup span from + latest_job_completion = max([j["completed_at"] for j in all_jobs if j.get("completed_at")], default=None) + if latest_job_completion: + try: + from datetime import datetime, timedelta + cleanup_start_dt = datetime.fromisoformat(latest_job_completion.replace("Z", "+00:00")) + cleanup_end_dt = cleanup_start_dt + timedelta(seconds=cleanup_duration_seconds) + + cleanup_span = { + "op": "workflow.cleanup", + "name": "Workflow cleanup and teardown", + "parent_span_id": workflow_span_id, + "span_id": get_uuid()[:16], + "start_timestamp": cleanup_start_dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "timestamp": cleanup_end_dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "trace_id": workflow_trace_id, + } + workflow_transaction["spans"].append(cleanup_span) + logging.debug( + f"Added cleanup span: {cleanup_duration_seconds:.1f}s " + f"({cleanup_start_dt.strftime('%H:%M:%S')} -> {cleanup_end_dt.strftime('%H:%M:%S')})" + ) + except Exception as e: + logging.warning(f"Failed to create cleanup span: {e}") + + return workflow_transaction + + + def send_workflow_trace(self, job: Dict[str, Any], all_jobs: List[Dict[str, Any]] = None, repository_info: Dict[str, Any] = None): + """Send a single workflow transaction with all job and step spans""" + if self.dry_run: + logging.info(f"Dry run: Would send workflow trace for {job['name']}") + return + + if all_jobs is None: + all_jobs = [job] + + try: + logging.info(f"Creating workflow transaction for {len(all_jobs)} jobs") + logging.info(f"Job names: {[j['name'] for j in all_jobs]}") + + # Create single workflow transaction with all spans + workflow_transaction = self._create_workflow_transaction(job, all_jobs, repository_info) + workflow_trace_id = workflow_transaction["contexts"]["trace"]["trace_id"] + + # Log detailed transaction info + logging.info(f"Workflow transaction details:") + logging.info(f" - Trace ID: {workflow_trace_id}") + logging.info(f" - Transaction name: {workflow_transaction['transaction']}") + logging.info(f" - Total spans: {len(workflow_transaction['spans'])}") + logging.info(f" - Trace version: {workflow_transaction.get('tags', {}).get('trace_version', 'N/A')}") + logging.info(f" - Workflow status: {workflow_transaction['contexts']['trace']['status']}") + + # Log span details + job_spans = [s for s in workflow_transaction['spans'] if s['op'] == 'job'] + step_spans = [s for s in workflow_transaction['spans'] if s['op'] == 'step'] + logging.info(f" - Job spans: {len(job_spans)}") + logging.info(f" - Step spans: {len(step_spans)}") + + logging.info(f"Sending workflow transaction with trace_id: {workflow_trace_id}") + # Send single workflow transaction + self._send_envelope(workflow_transaction) + + logging.info(f"Successfully sent workflow trace with {len(all_jobs)} jobs") + + except Exception as e: + logging.error(f"Error in send_workflow_trace: {e}", exc_info=True) + raise + + def _send_envelope(self, transaction: Dict[str, Any]): + """Send transaction to Sentry""" + if self.dry_run: + return + + # Save transaction payload for Postman testing + trace_id = transaction.get('contexts', {}).get('trace', {}).get('trace_id', 'unknown') + filename = f"transaction_payload_{trace_id}.json" + + with open(filename, 'w') as f: + json.dump(transaction, f, indent=2) + + logging.info(f"💾 Transaction payload saved to: {filename}") + logging.info(f"📋 Transaction details:") + logging.info(f" - Trace ID: {trace_id}") + logging.info(f" - Transaction: {transaction.get('transaction')}") + logging.info(f" - Total spans: {len(transaction.get('spans', []))}") + logging.info(f" - Trace version: {transaction.get('tags', {}).get('trace_version')}") + + logging.info(f"Sending envelope to Sentry: {self.sentry_project_url}") + logging.info(f"Transaction type: {transaction.get('type')}") + logging.info(f"Transaction name: {transaction.get('transaction')}") + logging.info(f"Trace ID: {transaction.get('contexts', {}).get('trace', {}).get('trace_id')}") + logging.info(f"Event ID: {transaction.get('event_id')}") + + # Send transaction as-is (event_id is included, matching github_sdk.py) + envelope = Envelope() + envelope.add_transaction(transaction) + now = datetime.utcnow() + + headers = { + "event_id": get_uuid(), + "sent_at": format_timestamp(now), + "Content-Type": "application/x-sentry-envelope", + "Content-Encoding": "gzip", + "X-Sentry-Auth": f"Sentry sentry_key={self.sentry_key}," + + f"sentry_client=gha-sentry-workflow/0.0.1,sentry_timestamp={now}," + + "sentry_version=7", + } + + import io + import gzip + + body = io.BytesIO() + with gzip.GzipFile(fileobj=body, mode="w") as f: + envelope.serialize_into(f) + + logging.info(f"Envelope size: {len(body.getvalue())} bytes") + + try: + req = requests.post( + self.sentry_project_url, + data=body.getvalue(), + headers=headers, + ) + + logging.info(f"✅ Sentry response: {req.status_code}") + if req.status_code != 200: + logging.error(f"❌ Sentry rejected transaction: {req.status_code} - {req.text[:500]}") + else: + logging.info(f"✅ Transaction successfully sent to Sentry") + + req.raise_for_status() + return req + except requests.exceptions.RequestException as e: + logging.error(f"❌ Failed to send transaction to Sentry: {e}") + logging.error(f" URL: {self.sentry_project_url}") + logging.error(f" Trace ID: {trace_id}") + raise diff --git a/tests/test_web_app_handler_refactored.py b/tests/test_web_app_handler_refactored.py new file mode 100644 index 0000000..f0e33ad --- /dev/null +++ b/tests/test_web_app_handler_refactored.py @@ -0,0 +1,338 @@ +""" +Tests for refactored WebAppHandler with hierarchical workflow tracing. + +This module tests the integration between WebAppHandler and WorkflowJobCollector, +including feature flag behavior and backward compatibility. +""" + +import os +import unittest +from datetime import datetime +from unittest.mock import Mock, patch, MagicMock + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.web_app_handler import WebAppHandler + + +class TestWebAppHandlerRefactored(unittest.TestCase): + """Test suite for refactored WebAppHandler.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the config initialization + with patch('src.web_app_handler.init_config'): + self.handler = WebAppHandler(dry_run=True) + + # Create sample webhook data + self.sample_job_data = { + "action": "completed", + "workflow_job": { + "id": 123, + "run_id": 456, + "name": "test-job", + "conclusion": "success", + "started_at": datetime.utcnow().isoformat() + "Z", + "completed_at": datetime.utcnow().isoformat() + "Z", + "html_url": "https://github.com/test/repo/actions/runs/456/jobs/123", + "steps": [], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "workflow_name": "Test Workflow" + }, + "repository": { + "owner": { + "login": "test-org" + } + }, + "installation": { + "id": 789 + } + } + + self.sample_headers = { + "X-GitHub-Event": "workflow_job" + } + + def test_init(self): + """Test WebAppHandler initialization.""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=True) + + self.assertTrue(handler.dry_run) + self.assertEqual(len(handler.job_collectors), 0) + + def test_handle_event_unsupported_event(self): + """Test handling of unsupported event types.""" + headers = {"X-GitHub-Event": "push"} + + reason, http_code = self.handler.handle_event(self.sample_job_data, headers) + + self.assertEqual(http_code, 200) + self.assertEqual(reason, "Event not supported.") + + def test_handle_event_unsupported_action(self): + """Test handling of unsupported actions.""" + data = self.sample_job_data.copy() + data["action"] = "in_progress" + + reason, http_code = self.handler.handle_event(data, self.sample_headers) + + self.assertEqual(http_code, 200) + self.assertEqual(reason, "We cannot do anything with this workflow state.") + + def test_handle_event_dry_run(self): + """Test that dry run mode returns early.""" + reason, http_code = self.handler.handle_event( + self.sample_job_data, + self.sample_headers + ) + + # In dry run, should return OK without processing + self.assertEqual(http_code, 200) + self.assertEqual(reason, "OK") + + @patch.dict(os.environ, {"APP_DSN": "https://test@sentry.io/123"}) + def test_handle_event_hierarchical_tracing(self): + """Test event handling with hierarchical tracing enabled.""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + # Mock the job collector + mock_collector = Mock() + mock_collector.is_hierarchical_tracing_enabled.return_value = True + + with patch.object(handler, '_get_job_collector', return_value=mock_collector): + reason, http_code = handler.handle_event( + self.sample_job_data, + self.sample_headers + ) + + # Verify hierarchical tracing was used + mock_collector.add_job.assert_called_once_with(self.sample_job_data) + self.assertEqual(http_code, 200) + self.assertEqual(reason, "OK") + + @patch.dict(os.environ, {"APP_DSN": "https://test@sentry.io/123"}) + def test_handle_event_legacy_tracing(self): + """Test event handling with legacy tracing (hierarchical disabled).""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + # Mock the job collector + mock_collector = Mock() + mock_collector.is_hierarchical_tracing_enabled.return_value = False + + with patch.object(handler, '_get_job_collector', return_value=mock_collector): + with patch.object(handler, '_send_legacy_trace') as mock_legacy: + reason, http_code = handler.handle_event( + self.sample_job_data, + self.sample_headers + ) + + # Verify legacy tracing was used + mock_collector.add_job.assert_not_called() + mock_legacy.assert_called_once() + self.assertEqual(http_code, 200) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_event_no_dsn(self): + """Test event handling when DSN is not configured.""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + reason, http_code = handler.handle_event( + self.sample_job_data, + self.sample_headers + ) + + self.assertEqual(http_code, 500) + self.assertEqual(reason, "No DSN configured for webhook testing") + + def test_get_job_collector_creates_new(self): + """Test that _get_job_collector creates a new collector if needed.""" + org = "test-org" + token = "test-token" + dsn = "https://test@sentry.io/123" + + with patch('src.web_app_handler.WorkflowJobCollector') as mock_collector_class: + collector1 = self.handler._get_job_collector(org, token, dsn) + + # Verify collector was created + mock_collector_class.assert_called_once_with(dsn, token, True) + + def test_get_job_collector_reuses_existing(self): + """Test that _get_job_collector reuses existing collector.""" + org = "test-org" + token = "test-token" + dsn = "https://test@sentry.io/123" + + with patch('src.web_app_handler.WorkflowJobCollector') as mock_collector_class: + collector1 = self.handler._get_job_collector(org, token, dsn) + collector2 = self.handler._get_job_collector(org, token, dsn) + + # Verify collector was created only once + mock_collector_class.assert_called_once() + + # Verify same instance is returned + self.assertIs(collector1, collector2) + + def test_get_job_collector_separate_orgs(self): + """Test that different orgs get separate collectors.""" + token = "test-token" + dsn = "https://test@sentry.io/123" + + with patch('src.web_app_handler.WorkflowJobCollector') as mock_collector_class: + collector1 = self.handler._get_job_collector("org1", token, dsn) + collector2 = self.handler._get_job_collector("org2", token, dsn) + + # Verify two collectors were created + self.assertEqual(mock_collector_class.call_count, 2) + + @patch.dict(os.environ, {"APP_DSN": "https://test@sentry.io/123"}) + def test_send_legacy_trace(self): + """Test legacy trace sending.""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + with patch('src.web_app_handler.GithubClient') as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + handler._send_legacy_trace( + self.sample_job_data, + "test-org", + "test-token", + "https://test@sentry.io/123" + ) + + # Verify GithubClient was created and send_trace called + mock_client_class.assert_called_once_with( + "test-token", + "https://test@sentry.io/123", + False + ) + mock_client.send_trace.assert_called_once_with(self.sample_job_data) + + def test_handle_event_missing_installation(self): + """Test handling event with missing installation field.""" + data = self.sample_job_data.copy() + del data["installation"] + + with patch.dict(os.environ, {"APP_DSN": "https://test@sentry.io/123"}): + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + mock_collector = Mock() + mock_collector.is_hierarchical_tracing_enabled.return_value = True + + with patch.object(handler, '_get_job_collector', return_value=mock_collector): + reason, http_code = handler.handle_event(data, self.sample_headers) + + # Should still process with default installation_id + self.assertEqual(http_code, 200) + mock_collector.add_job.assert_called_once() + + def test_valid_signature_no_secret(self): + """Test signature validation when no secret is configured.""" + self.handler.config = Mock() + self.handler.config.gh = Mock() + self.handler.config.gh.webhook_secret = None + + result = self.handler.valid_signature(b"test body", {}) + self.assertTrue(result) + + def test_valid_signature_with_secret(self): + """Test signature validation with webhook secret.""" + import hmac + + secret = "test-secret" + body = b"test body" + + # Generate valid signature + signature = hmac.new( + secret.encode(), + msg=body, + digestmod="sha256" + ).hexdigest() + + self.handler.config = Mock() + self.handler.config.gh = Mock() + self.handler.config.gh.webhook_secret = secret + + headers = {"X-Hub-Signature-256": f"sha256={signature}"} + + result = self.handler.valid_signature(body, headers) + self.assertTrue(result) + + def test_valid_signature_invalid(self): + """Test signature validation with invalid signature.""" + self.handler.config = Mock() + self.handler.config.gh = Mock() + self.handler.config.gh.webhook_secret = "test-secret" + + headers = {"X-Hub-Signature-256": "sha256=invalid"} + + result = self.handler.valid_signature(b"test body", headers) + self.assertFalse(result) + + +class TestWebAppHandlerIntegration(unittest.TestCase): + """Integration tests for WebAppHandler with real WorkflowJobCollector.""" + + @patch.dict(os.environ, {"APP_DSN": "https://test@sentry.io/123"}) + @patch('src.workflow_job_collector.ENABLE_HIERARCHICAL_TRACING', True) + def test_full_workflow_with_hierarchical_tracing(self): + """Test full workflow processing with hierarchical tracing enabled.""" + with patch('src.web_app_handler.init_config'): + handler = WebAppHandler(dry_run=False) + + # Create multiple jobs for same workflow + run_id = 12345 + jobs = [ + { + "action": "completed", + "workflow_job": { + "id": i, + "run_id": run_id, + "name": f"job-{i}", + "conclusion": "success", + "started_at": datetime.utcnow().isoformat() + "Z", + "completed_at": datetime.utcnow().isoformat() + "Z", + "html_url": f"https://github.com/test/repo/runs/{run_id}/jobs/{i}", + "steps": [], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "workflow_name": "Test" + }, + "repository": { + "owner": {"login": "test-org"} + }, + "installation": {"id": 789} + } + for i in range(3) + ] + + headers = {"X-GitHub-Event": "workflow_job"} + + # Mock the WorkflowTracer to avoid actual API calls + with patch('src.workflow_job_collector.WorkflowTracer'): + # Process all jobs + for job_data in jobs: + reason, http_code = handler.handle_event(job_data, headers) + self.assertEqual(http_code, 200) + + # Verify collector was created for org + self.assertIn("test-org", handler.job_collectors) + + # Verify all jobs were collected + collector = handler.job_collectors["test-org"] + self.assertEqual(len(collector.workflow_jobs[run_id]), 3) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_web_app_handler_token_generation.py b/tests/test_web_app_handler_token_generation.py new file mode 100644 index 0000000..99e9b98 --- /dev/null +++ b/tests/test_web_app_handler_token_generation.py @@ -0,0 +1,254 @@ +""" +Tests for GitHub App token generation in WebAppHandler. +""" + +import os +import unittest +from unittest.mock import Mock, patch, MagicMock + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.web_app_handler import WebAppHandler, GithubAppConfig, GitHubConfig, Config + + +class TestWebAppHandlerTokenGeneration(unittest.TestCase): + """Test suite for GitHub App token generation.""" + + def setUp(self): + """Set up test fixtures.""" + self.webhook_payload = { + "action": "completed", + "workflow_job": { + "id": 12345, + "run_id": 67890, + "name": "test-job", + "conclusion": "success", + "started_at": "2025-11-17T17:36:10Z", + "completed_at": "2025-11-17T17:36:45Z", + }, + "repository": { + "full_name": "test-org/test-repo", + "owner": {"login": "test-org"} + }, + "installation": { + "id": 123456 + } + } + self.headers = { + "X-GitHub-Event": "workflow_job", + "X-Hub-Signature-256": "sha256=test" + } + + def test_generate_github_app_token_success(self): + """Test successful GitHub App token generation.""" + # Mock config with GitHub App + mock_config = Config( + gh_app=GithubAppConfig( + app_id=12345, + private_key=b"test_private_key" + ), + gh=GitHubConfig( + webhook_secret="test_secret", + token=None + ) + ) + + # Mock token response + mock_token_response = Mock() + mock_token_response.json.return_value = {"token": "ghs_test_token_12345"} + mock_token_response.raise_for_status = Mock() + + # Mock GithubAppToken headers + mock_app_token = Mock() + mock_app_token.headers = {"Authorization": "Bearer test_jwt"} + + handler = WebAppHandler(dry_run=False) # Need dry_run=False to test token generation + handler.config = mock_config + + # Set APP_DSN for the handler + with patch.dict('os.environ', {'APP_DSN': 'https://test@sentry.io/123'}): + with patch('src.web_app_handler.GithubAppToken', return_value=mock_app_token): + with patch('requests.post', return_value=mock_token_response) as mock_post: + with patch('src.web_app_handler.WorkflowJobCollector'): + reason, http_code = handler.handle_event(self.webhook_payload, self.headers) + + # Verify token generation was attempted + mock_post.assert_called() + # Verify webhook was processed + self.assertEqual(http_code, 200) + + def test_fallback_to_pat_token(self): + """Test fallback to PAT token when GitHub App fails.""" + # Mock config with GitHub App and PAT + mock_config = Config( + gh_app=GithubAppConfig( + app_id=12345, + private_key=b"test_private_key" + ), + gh=GitHubConfig( + webhook_secret="test_secret", + token="ghp_pat_token_12345" + ) + ) + + handler = WebAppHandler(dry_run=True) + handler.config = mock_config + + # Mock GitHub App token generation failure + with patch('requests.post', side_effect=Exception("App token generation failed")): + reason, http_code = handler.handle_event(self.webhook_payload, self.headers) + + # Should still process (falls back to PAT) + self.assertEqual(http_code, 200) + + def test_no_token_available(self): + """Test handling when no token is available.""" + # Mock config without GitHub App or PAT + mock_config = Config( + gh_app=None, + gh=GitHubConfig( + webhook_secret="test_secret", + token=None + ) + ) + + handler = WebAppHandler(dry_run=True) + handler.config = mock_config + + reason, http_code = handler.handle_event(self.webhook_payload, self.headers) + + # Should still process (will use timeout-based detection) + self.assertEqual(http_code, 200) + + def test_missing_installation_id(self): + """Test handling when installation_id is missing.""" + payload_no_installation = self.webhook_payload.copy() + del payload_no_installation["installation"] + + mock_config = Config( + gh_app=None, + gh=GitHubConfig( + webhook_secret="test_secret", + token="ghp_pat_token" + ) + ) + + handler = WebAppHandler(dry_run=True) + handler.config = mock_config + + reason, http_code = handler.handle_event(payload_no_installation, self.headers) + + # Should use PAT token + self.assertEqual(http_code, 200) + + def test_github_app_token_expires_after_one_hour(self): + """Test that tokens are generated per request (they expire in 1 hour).""" + mock_config = Config( + gh_app=GithubAppConfig( + app_id=12345, + private_key=b"test_private_key" + ), + gh=GitHubConfig( + webhook_secret="test_secret", + token=None + ) + ) + + mock_token_response = Mock() + mock_token_response.json.return_value = {"token": "ghs_test_token_12345"} + mock_token_response.raise_for_status = Mock() + + # Mock GithubAppToken headers + mock_app_token = Mock() + mock_app_token.headers = {"Authorization": "Bearer test_jwt"} + + handler = WebAppHandler(dry_run=False) # Need dry_run=False to test token generation + handler.config = mock_config + + # Set APP_DSN for the handler + with patch.dict('os.environ', {'APP_DSN': 'https://test@sentry.io/123'}): + with patch('src.web_app_handler.GithubAppToken', return_value=mock_app_token): + with patch('requests.post', return_value=mock_token_response) as mock_post: + with patch('src.web_app_handler.WorkflowJobCollector'): + # Process multiple webhooks + handler.handle_event(self.webhook_payload, self.headers) + handler.handle_event(self.webhook_payload, self.headers) + + # Each webhook should generate a new token (or reuse if cached) + # Note: Current implementation generates per webhook, which is fine + self.assertGreaterEqual(mock_post.call_count, 1) + + +class TestWebAppHandlerHeaderHandling(unittest.TestCase): + """Test suite for header handling improvements.""" + + def setUp(self): + """Set up test fixtures.""" + self.webhook_payload = { + "action": "completed", + "workflow_job": { + "id": 12345, + "run_id": 67890, + "name": "test-job", + }, + "repository": { + "owner": {"login": "test-org"} + } + } + + def test_flask_normalized_headers(self): + """Test that Flask-normalized headers are handled correctly.""" + handler = WebAppHandler(dry_run=True) + + # Test various header formats Flask might use + headers_variants = [ + {"X-GitHub-Event": "workflow_job"}, + {"X-GITHUB-EVENT": "workflow_job"}, + {"HTTP_X_GITHUB_EVENT": "workflow_job"}, + ] + + for headers in headers_variants: + reason, http_code = handler.handle_event(self.webhook_payload, headers) + self.assertEqual(http_code, 200, f"Failed with headers: {headers}") + + def test_missing_github_event_header(self): + """Test handling when X-GitHub-Event header is missing.""" + handler = WebAppHandler(dry_run=True) + + headers = {} # No GitHub event header + + reason, http_code = handler.handle_event(self.webhook_payload, headers) + + self.assertEqual(http_code, 400) + self.assertIn("Missing", reason) + + def test_unsupported_event_type(self): + """Test handling of unsupported event types.""" + handler = WebAppHandler(dry_run=True) + + headers = {"X-GitHub-Event": "push"} # Unsupported event + + reason, http_code = handler.handle_event(self.webhook_payload, headers) + + self.assertEqual(http_code, 200) # Returns 200 but doesn't process + self.assertIn("not supported", reason) + + def test_unsupported_action(self): + """Test handling of unsupported actions.""" + handler = WebAppHandler(dry_run=True) + + payload = self.webhook_payload.copy() + payload["action"] = "queued" # Unsupported action + + headers = {"X-GitHub-Event": "workflow_job"} + + reason, http_code = handler.handle_event(payload, headers) + + self.assertEqual(http_code, 200) # Returns 200 but doesn't process + self.assertIn("cannot do anything", reason.lower()) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_workflow_job_collector.py b/tests/test_workflow_job_collector.py new file mode 100644 index 0000000..004b44a --- /dev/null +++ b/tests/test_workflow_job_collector.py @@ -0,0 +1,476 @@ +""" +Tests for WorkflowJobCollector module. + +This module tests the collection and aggregation of GitHub workflow jobs, +including workflow completion detection, timeout handling, and thread safety. +""" + +import os +import time +import unittest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.workflow_job_collector import ( + WorkflowJobCollector, + SMALL_WORKFLOW_PROCESSING_DELAY, + NO_NEW_JOBS_TIMEOUT, + MAX_WORKFLOW_WAIT_TIME, + LARGE_WORKFLOW_THRESHOLD, + MEDIUM_WORKFLOW_THRESHOLD, + SMALL_WORKFLOW_THRESHOLD, +) + + +class TestWorkflowJobCollector(unittest.TestCase): + """Test suite for WorkflowJobCollector class.""" + + def setUp(self): + """Set up test fixtures.""" + self.dsn = "https://test@sentry.io/123" + self.token = "test_token" + self.test_run_id = 12345 + + # Mock the WorkflowTracer to avoid actual API calls + with patch('src.workflow_job_collector.WorkflowTracer'): + self.collector = WorkflowJobCollector( + dsn=self.dsn, + token=self.token, + dry_run=True + ) + + def tearDown(self): + """Clean up after tests.""" + # Cancel any active timers + for timer in self.collector.workflow_timers.values(): + timer.cancel() + + def _create_job_data(self, job_id: int, run_id: int, name: str, + conclusion: str = "success") -> dict: + """ + Create a mock job data payload. + + Args: + job_id: Unique job identifier + run_id: Workflow run identifier + name: Job name + conclusion: Job conclusion (success, failure, cancelled, etc.) + + Returns: + Mock job data dictionary + """ + now = datetime.utcnow() + return { + "workflow_job": { + "id": job_id, + "run_id": run_id, + "name": name, + "conclusion": conclusion, + "started_at": (now - timedelta(minutes=5)).isoformat() + "Z", + "completed_at": now.isoformat() + "Z", + "html_url": f"https://github.com/test/repo/actions/runs/{run_id}/jobs/{job_id}", + "steps": [ + { + "name": "Setup", + "number": 1, + "conclusion": "success", + "started_at": (now - timedelta(minutes=5)).isoformat() + "Z", + "completed_at": (now - timedelta(minutes=4)).isoformat() + "Z", + } + ], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "workflow_name": "Test Workflow" + } + } + + def test_init(self): + """Test WorkflowJobCollector initialization.""" + self.assertEqual(self.collector.dsn, self.dsn) + self.assertEqual(self.collector.token, self.token) + self.assertTrue(self.collector.dry_run) + self.assertEqual(len(self.collector.workflow_jobs), 0) + self.assertEqual(len(self.collector.processed_jobs), 0) + self.assertEqual(len(self.collector.processed_workflows), 0) + + def test_add_job_single(self): + """Test adding a single job.""" + job_data = self._create_job_data(1, self.test_run_id, "test-job") + + with patch.object(self.collector, '_schedule_workflow_processing') as mock_schedule: + self.collector.add_job(job_data) + + # Verify job was added + self.assertIn(self.test_run_id, self.collector.workflow_jobs) + self.assertEqual(len(self.collector.workflow_jobs[self.test_run_id]), 1) + self.assertIn(1, self.collector.processed_jobs) + + # Verify scheduling was called for completed job + mock_schedule.assert_called_once_with(self.test_run_id) + + def test_add_job_duplicate(self): + """Test that duplicate jobs are ignored.""" + job_data = self._create_job_data(1, self.test_run_id, "test-job") + + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + initial_count = len(self.collector.workflow_jobs[self.test_run_id]) + + # Add same job again + self.collector.add_job(job_data) + final_count = len(self.collector.workflow_jobs[self.test_run_id]) + + # Count should not increase + self.assertEqual(initial_count, final_count) + + def test_add_multiple_jobs(self): + """Test adding multiple jobs to the same workflow run.""" + jobs = [ + self._create_job_data(1, self.test_run_id, "job-1"), + self._create_job_data(2, self.test_run_id, "job-2"), + self._create_job_data(3, self.test_run_id, "job-3"), + ] + + with patch.object(self.collector, '_schedule_workflow_processing'): + for job_data in jobs: + self.collector.add_job(job_data) + + # Verify all jobs were added + self.assertEqual(len(self.collector.workflow_jobs[self.test_run_id]), 3) + self.assertEqual(len(self.collector.processed_jobs), 3) + + def test_all_jobs_completed(self): + """Test detection of all jobs being completed.""" + completed_jobs = [ + {"conclusion": "success"}, + {"conclusion": "failure"}, + {"conclusion": "cancelled"}, + ] + self.assertTrue(self.collector._all_jobs_completed(completed_jobs)) + + incomplete_jobs = [ + {"conclusion": "success"}, + {"conclusion": None}, # Not completed + ] + self.assertFalse(self.collector._all_jobs_completed(incomplete_jobs)) + + def test_meets_processing_threshold_large_workflow(self): + """Test threshold detection for large workflows (10+ jobs).""" + # Add 10 completed jobs + for i in range(LARGE_WORKFLOW_THRESHOLD): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append( + job["workflow_job"] + ) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._meets_processing_threshold( + self.test_run_id, + LARGE_WORKFLOW_THRESHOLD + ) + self.assertTrue(result) + + def test_meets_processing_threshold_medium_workflow(self): + """Test threshold detection for medium workflows (5-9 jobs).""" + # Add 5 completed jobs + for i in range(MEDIUM_WORKFLOW_THRESHOLD): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append( + job["workflow_job"] + ) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._meets_processing_threshold( + self.test_run_id, + MEDIUM_WORKFLOW_THRESHOLD + ) + self.assertTrue(result) + + def test_meets_processing_threshold_small_workflow(self): + """Test threshold detection for small workflows (3-4 jobs).""" + # Add 3 completed jobs + for i in range(SMALL_WORKFLOW_THRESHOLD): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append( + job["workflow_job"] + ) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._meets_processing_threshold( + self.test_run_id, + SMALL_WORKFLOW_THRESHOLD + ) + self.assertTrue(result) + + def test_meets_processing_threshold_single_job(self): + """Test threshold detection for single job workflows.""" + job = self._create_job_data(1, self.test_run_id, "single-job") + self.collector.workflow_jobs[self.test_run_id].append( + job["workflow_job"] + ) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._meets_processing_threshold(self.test_run_id, 1) + self.assertTrue(result) + + def test_workflow_timeout_exceeded(self): + """Test detection of workflow timeout.""" + # Set arrival time to past the timeout + past_time = time.time() - (MAX_WORKFLOW_WAIT_TIME + 10) + self.collector.job_arrival_times[self.test_run_id].append(past_time) + + result = self.collector._is_workflow_timeout_exceeded(self.test_run_id) + self.assertTrue(result) + + def test_workflow_timeout_not_exceeded(self): + """Test that workflow timeout is not triggered prematurely.""" + # Set arrival time to recent + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._is_workflow_timeout_exceeded(self.test_run_id) + self.assertFalse(result) + + def test_cleanup_workflow_run(self): + """Test workflow run cleanup.""" + # Add some data + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + # Create a mock timer + mock_timer = Mock() + self.collector.workflow_timers[self.test_run_id] = mock_timer + + # Cleanup + self.collector._cleanup_workflow_run(self.test_run_id) + + # Verify cleanup + self.assertNotIn(self.test_run_id, self.collector.workflow_jobs) + self.assertNotIn(self.test_run_id, self.collector.job_arrival_times) + self.assertNotIn(self.test_run_id, self.collector.workflow_timers) + self.assertIn(self.test_run_id, self.collector.processed_workflows) + mock_timer.cancel.assert_called_once() + + def test_send_workflow_trace_success(self): + """Test successful workflow trace sending.""" + # Add a completed job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + + # Mock the workflow tracer + with patch.object(self.collector.workflow_tracer, 'send_workflow_trace') as mock_send: + self.collector._send_workflow_trace(self.test_run_id) + + # Verify trace was sent + mock_send.assert_called_once() + + # Verify cleanup + self.assertNotIn(self.test_run_id, self.collector.workflow_jobs) + self.assertIn(self.test_run_id, self.collector.processed_workflows) + + def test_send_workflow_trace_already_processed(self): + """Test that already processed workflows are not sent again.""" + # Mark as already processed + self.collector.processed_workflows.add(self.test_run_id) + + # Add a job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + + # Mock the workflow tracer + with patch.object(self.collector.workflow_tracer, 'send_workflow_trace') as mock_send: + self.collector._send_workflow_trace(self.test_run_id) + + # Verify trace was NOT sent + mock_send.assert_not_called() + + def test_send_workflow_trace_error_handling(self): + """Test error handling during trace sending.""" + # Add a completed job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + + # Mock the workflow tracer to raise an exception + with patch.object(self.collector.workflow_tracer, 'send_workflow_trace') as mock_send: + mock_send.side_effect = Exception("Test error") + + # Should not raise, but should log error + self.collector._send_workflow_trace(self.test_run_id) + + # Verify cleanup still happened + self.assertNotIn(self.test_run_id, self.collector.workflow_jobs) + self.assertIn(self.test_run_id, self.collector.processed_workflows) + + def test_process_workflow_immediately_with_complete_jobs(self): + """Test immediate processing of workflow with all jobs complete.""" + # Add completed jobs + for i in range(3): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append( + job["workflow_job"] + ) + + # Mock the send method + with patch.object(self.collector, '_send_workflow_trace') as mock_send: + self.collector._process_workflow_immediately(self.test_run_id) + + # Verify workflow was sent + mock_send.assert_called_once_with(self.test_run_id) + + def test_process_workflow_immediately_with_incomplete_jobs(self): + """Test that workflows with incomplete jobs are not processed.""" + # Add a job without conclusion (incomplete) + job_data = self._create_job_data(1, self.test_run_id, "incomplete-job") + job_data["workflow_job"]["conclusion"] = None + self.collector.workflow_jobs[self.test_run_id].append( + job_data["workflow_job"] + ) + + # Mock the send method + with patch.object(self.collector, '_send_workflow_trace') as mock_send: + self.collector._process_workflow_immediately(self.test_run_id) + + # Verify workflow was NOT sent + mock_send.assert_not_called() + + def test_process_workflow_immediately_exception_handling(self): + """Test exception handling in immediate processing.""" + # Add a job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + + # Mock _send_workflow_trace to raise exception + with patch.object(self.collector, '_send_workflow_trace') as mock_send: + mock_send.side_effect = Exception("Test error") + + # Should handle exception gracefully + self.collector._process_workflow_immediately(self.test_run_id) + + # Verify cleanup happened via exception handler + self.assertIn(self.test_run_id, self.collector.processed_workflows) + + @patch('src.workflow_job_collector.ENABLE_HIERARCHICAL_TRACING', True) + def test_hierarchical_tracing_enabled(self): + """Test hierarchical tracing is enabled when feature flag is true.""" + with patch('src.workflow_job_collector.WorkflowTracer'): + collector = WorkflowJobCollector(self.dsn, self.token, dry_run=True) + self.assertTrue(collector.is_hierarchical_tracing_enabled("test-org")) + + @patch('src.workflow_job_collector.ENABLE_HIERARCHICAL_TRACING', False) + def test_hierarchical_tracing_disabled(self): + """Test hierarchical tracing is disabled when feature flag is false.""" + with patch('src.workflow_job_collector.WorkflowTracer'): + collector = WorkflowJobCollector(self.dsn, self.token, dry_run=True) + self.assertFalse(collector.is_hierarchical_tracing_enabled("test-org")) + + @patch('src.workflow_job_collector.ENABLE_HIERARCHICAL_TRACING', True) + @patch('src.workflow_job_collector.SENTRY_ORG_ONLY', True) + def test_hierarchical_tracing_sentry_org_only(self): + """Test hierarchical tracing restricted to Sentry org.""" + with patch('src.workflow_job_collector.WorkflowTracer'): + collector = WorkflowJobCollector(self.dsn, self.token, dry_run=True) + + # Should be enabled for getsentry + self.assertTrue(collector.is_hierarchical_tracing_enabled("getsentry")) + + # Should be disabled for other orgs + self.assertFalse(collector.is_hierarchical_tracing_enabled("test-org")) + + def test_thread_safety_concurrent_job_additions(self): + """Test thread safety when adding jobs concurrently.""" + import threading + + def add_jobs(start_id): + for i in range(10): + job_data = self._create_job_data( + start_id + i, + self.test_run_id, + f"job-{start_id + i}" + ) + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Create multiple threads adding jobs + threads = [] + for i in range(3): + thread = threading.Thread(target=add_jobs, args=(i * 10,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all 30 jobs were added without race conditions + self.assertEqual(len(self.collector.processed_jobs), 30) + + +class TestWorkflowJobCollectorIntegration(unittest.TestCase): + """Integration tests for WorkflowJobCollector.""" + + def setUp(self): + """Set up test fixtures.""" + self.dsn = "https://test@sentry.io/123" + self.token = "test_token" + + with patch('src.workflow_job_collector.WorkflowTracer'): + self.collector = WorkflowJobCollector( + dsn=self.dsn, + token=self.token, + dry_run=True + ) + + def tearDown(self): + """Clean up after tests.""" + for timer in self.collector.workflow_timers.values(): + timer.cancel() + + def test_full_workflow_lifecycle(self): + """Test complete workflow lifecycle from job addition to trace sending.""" + run_id = 99999 + + # Create multiple jobs + jobs = [ + { + "workflow_job": { + "id": i, + "run_id": run_id, + "name": f"job-{i}", + "conclusion": "success", + "started_at": datetime.utcnow().isoformat() + "Z", + "completed_at": datetime.utcnow().isoformat() + "Z", + "html_url": f"https://github.com/test/repo/runs/{run_id}/jobs/{i}", + "steps": [], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "workflow_name": "Test" + } + } + for i in range(5) + ] + + # Mock the workflow tracer + with patch.object(self.collector.workflow_tracer, 'send_workflow_trace') as mock_send: + # Add all jobs + for job_data in jobs: + self.collector.add_job(job_data) + + # Wait for timer to fire (plus a bit extra) + time.sleep(SMALL_WORKFLOW_PROCESSING_DELAY + 0.5) + + # Verify trace was sent + mock_send.assert_called_once() + + # Verify cleanup + self.assertNotIn(run_id, self.collector.workflow_jobs) + self.assertIn(run_id, self.collector.processed_workflows) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_workflow_job_collector_new_features.py b/tests/test_workflow_job_collector_new_features.py new file mode 100644 index 0000000..e551ad6 --- /dev/null +++ b/tests/test_workflow_job_collector_new_features.py @@ -0,0 +1,804 @@ +""" +Tests for new WorkflowJobCollector features: +- API-based total job count fetching +- Timeout scheduling +- Workflow completion detection improvements +- Edge cases and error handling +""" + +import os +import time +import unittest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock, call +import requests + +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.workflow_job_collector import ( + WorkflowJobCollector, + NO_NEW_JOBS_TIMEOUT, + MAX_WORKFLOW_WAIT_TIME, +) + + +class TestWorkflowJobCollectorAPIFeatures(unittest.TestCase): + """Test suite for API-based features in WorkflowJobCollector.""" + + def setUp(self): + """Set up test fixtures.""" + self.dsn = "https://test@sentry.io/123" + self.token = "test_token" + self.test_run_id = 12345 + self.repo_full_name = "test-org/test-repo" + + with patch('src.workflow_job_collector.WorkflowTracer'): + self.collector = WorkflowJobCollector( + dsn=self.dsn, + token=self.token, + dry_run=True + ) + + def tearDown(self): + """Clean up after tests.""" + for timer in self.collector.workflow_timers.values(): + timer.cancel() + + def _create_job_data(self, job_id: int, run_id: int, name: str, + conclusion: str = "success", repository: dict = None) -> dict: + """Create a mock job data payload.""" + now = datetime.utcnow() + job_data = { + "workflow_job": { + "id": job_id, + "run_id": run_id, + "name": name, + "conclusion": conclusion, + "started_at": (now - timedelta(minutes=5)).isoformat() + "Z", + "completed_at": now.isoformat() + "Z", + "html_url": f"https://github.com/test/repo/actions/runs/{run_id}/jobs/{job_id}", + "steps": [], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + } + } + if repository: + job_data["repository"] = repository + return job_data + + def test_fetch_total_job_count_success(self): + """Test successful fetching of total job count from GitHub API.""" + mock_response = Mock() + mock_response.json.return_value = {"total_count": 8} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response) as mock_get: + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertEqual(result, 8) + mock_get.assert_called_once() + call_args = mock_get.call_args + self.assertIn(f"/actions/runs/{self.test_run_id}/jobs", call_args[0][0]) + self.assertIn("Authorization", call_args[1]["headers"]) + + def test_fetch_total_job_count_no_token(self): + """Test that fetch returns None when no token is available.""" + self.collector.token = None + + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertIsNone(result) + + def test_fetch_total_job_count_api_failure(self): + """Test that API failures are handled gracefully.""" + with patch('requests.get', side_effect=requests.exceptions.RequestException("API Error")): + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertIsNone(result) + + def test_fetch_total_job_count_missing_total_count(self): + """Test handling when API response is missing total_count field.""" + mock_response = Mock() + mock_response.json.return_value = {} # Missing total_count + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response): + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertIsNone(result) + + def test_fetch_total_job_count_timeout(self): + """Test that API timeout is handled.""" + with patch('requests.get', side_effect=requests.exceptions.Timeout("Timeout")): + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertIsNone(result) + + def test_fetch_total_job_count_http_error(self): + """Test handling of HTTP errors (404, 403, etc.).""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found") + + with patch('requests.get', return_value=mock_response): + result = self.collector._fetch_total_job_count( + self.test_run_id, + self.repo_full_name + ) + + self.assertIsNone(result) + + def test_add_job_fetches_total_count_on_first_job(self): + """Test that total job count is fetched when first job arrives.""" + repository = {"full_name": self.repo_full_name} + job_data = self._create_job_data(1, self.test_run_id, "first-job", repository=repository) + + mock_response = Mock() + mock_response.json.return_value = {"total_count": 5} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response) as mock_get: + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Verify total count was fetched and cached + self.assertEqual(self.collector.workflow_total_jobs[self.test_run_id], 5) + # Verify API was called + self.assertEqual(mock_get.call_count, 1) + + def test_add_job_does_not_refetch_total_count(self): + """Test that total job count is only fetched once per workflow run.""" + repository = {"full_name": self.repo_full_name} + job1 = self._create_job_data(1, self.test_run_id, "job-1", repository=repository) + job2 = self._create_job_data(2, self.test_run_id, "job-2", repository=repository) + + mock_response = Mock() + mock_response.json.return_value = {"total_count": 3} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response) as mock_get: + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job1) + self.collector.add_job(job2) + + # API should only be called once + self.assertEqual(mock_get.call_count, 1) + self.assertEqual(self.collector.workflow_total_jobs[self.test_run_id], 3) + + def test_should_process_workflow_now_with_total_count(self): + """Test processing logic when total job count is known.""" + # Set up: We know there are 5 jobs total + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + # Add 5 completed jobs + for i in range(5): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + # Wait for timeout to elapse + time.sleep(NO_NEW_JOBS_TIMEOUT + 0.1) + + result = self.collector._should_process_workflow_now(self.test_run_id) + self.assertTrue(result) + + def test_should_process_workflow_now_waiting_for_more_jobs(self): + """Test that processing waits when not all jobs have arrived.""" + # Set up: We know there are 5 jobs total, but only 3 have arrived + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + # Add only 3 completed jobs + for i in range(3): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._should_process_workflow_now(self.test_run_id) + self.assertFalse(result) # Should wait for more jobs + + def test_should_process_workflow_now_fallback_to_timeout(self): + """Test that processing falls back to timeout when total count is unknown.""" + # Don't set total count (simulates API failure or no token) + self.collector.workflow_total_jobs[self.test_run_id] = None + + # Add 3 completed jobs + for i in range(3): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + # Wait for timeout + time.sleep(NO_NEW_JOBS_TIMEOUT + 0.1) + + result = self.collector._should_process_workflow_now(self.test_run_id) + self.assertTrue(result) # Should process after timeout + + def test_should_schedule_timeout_check_with_total_count(self): + """Test timeout check scheduling when total count is known.""" + # Set up: 5 jobs total, all 5 have arrived and completed + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + for i in range(5): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._should_schedule_timeout_check(self.test_run_id) + self.assertTrue(result) + + def test_should_schedule_timeout_check_waiting_for_jobs(self): + """Test that timeout check is not scheduled when jobs are missing.""" + # Set up: 5 jobs total, but only 3 have arrived + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + for i in range(3): + job = self._create_job_data(i, self.test_run_id, f"job-{i}") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + result = self.collector._should_schedule_timeout_check(self.test_run_id) + self.assertFalse(result) # Should not schedule, waiting for more jobs + + def test_schedule_timeout_check_creates_timer(self): + """Test that timeout check scheduling creates a timer.""" + # Add a completed job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + with patch('threading.Timer') as mock_timer_class: + self.collector._schedule_timeout_check(self.test_run_id) + + # Verify timer was created + mock_timer_class.assert_called_once() + # Verify timer was started + mock_timer_instance = mock_timer_class.return_value + mock_timer_instance.start.assert_called_once() + + def test_schedule_timeout_check_calculates_remaining_time(self): + """Test that timeout check calculates remaining time correctly.""" + # Add a job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.job_arrival_times[self.test_run_id].append(time.time()) + + # Wait a bit + time.sleep(1.0) + + with patch('threading.Timer') as mock_timer_class: + self.collector._schedule_timeout_check(self.test_run_id) + + # Verify timer was called with remaining time (should be ~6s since we waited 1s) + call_args = mock_timer_class.call_args + timer_delay = call_args[0][0] + self.assertGreater(timer_delay, 0) + self.assertLess(timer_delay, NO_NEW_JOBS_TIMEOUT) + + def test_cleanup_removes_total_jobs_cache(self): + """Test that cleanup removes total jobs cache entry.""" + # Set up workflow data + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + # Cleanup + self.collector._cleanup_workflow_run(self.test_run_id) + + # Verify total jobs cache was cleaned up + self.assertNotIn(self.test_run_id, self.collector.workflow_total_jobs) + + def test_add_job_stores_repository_info(self): + """Test that repository info is stored from webhook payload.""" + repository = {"full_name": self.repo_full_name, "id": 12345} + job_data = self._create_job_data(1, self.test_run_id, "test-job", repository=repository) + + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Verify repository info was stored + self.assertIn(self.test_run_id, self.collector.workflow_repositories) + self.assertEqual( + self.collector.workflow_repositories[self.test_run_id]["full_name"], + self.repo_full_name + ) + + def test_add_job_uses_repository_info_for_api_call(self): + """Test that repository info is used when fetching total job count.""" + repository = {"full_name": self.repo_full_name} + job_data = self._create_job_data(1, self.test_run_id, "test-job", repository=repository) + + mock_response = Mock() + mock_response.json.return_value = {"total_count": 3} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response) as mock_get: + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Verify API was called with correct repository name + call_args = mock_get.call_args + self.assertIn(self.repo_full_name, call_args[0][0]) + + +class TestWorkflowJobCollectorEdgeCases(unittest.TestCase): + """Test suite for edge cases and error scenarios.""" + + def setUp(self): + """Set up test fixtures.""" + self.dsn = "https://test@sentry.io/123" + self.token = "test_token" + self.test_run_id = 12345 + + with patch('src.workflow_job_collector.WorkflowTracer'): + self.collector = WorkflowJobCollector( + dsn=self.dsn, + token=self.token, + dry_run=True + ) + + def tearDown(self): + """Clean up after tests.""" + for timer in self.collector.workflow_timers.values(): + timer.cancel() + + def _create_job_data(self, job_id: int, run_id: int, name: str, + conclusion: str = "success") -> dict: + """Create a mock job data payload.""" + now = datetime.utcnow() + return { + "workflow_job": { + "id": job_id, + "run_id": run_id, + "name": name, + "conclusion": conclusion, + "started_at": (now - timedelta(minutes=5)).isoformat() + "Z", + "completed_at": now.isoformat() + "Z", + "html_url": f"https://github.com/test/repo/actions/runs/{run_id}/jobs/{job_id}", + "steps": [], + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + }, + "repository": { + "full_name": "test-org/test-repo" + } + } + + def test_jobs_arrive_after_timeout(self): + """Test scenario where jobs arrive after timeout has elapsed.""" + # Set total count to 5 + self.collector.workflow_total_jobs[self.test_run_id] = 5 + + # Add 3 jobs + for i in range(3): + job_data = self._create_job_data(i, self.test_run_id, f"job-{i}") + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Wait for timeout + time.sleep(NO_NEW_JOBS_TIMEOUT + 0.1) + + # Add 2 more jobs after timeout + for i in range(3, 5): + job_data = self._create_job_data(i, self.test_run_id, f"job-{i}") + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Should have all 5 jobs + self.assertEqual(len(self.collector.workflow_jobs[self.test_run_id]), 5) + + def test_max_workflow_wait_time_exceeded(self): + """Test that workflows exceeding max wait time are processed.""" + # Set arrival time to past max wait time + past_time = time.time() - (MAX_WORKFLOW_WAIT_TIME + 10) + self.collector.job_arrival_times[self.test_run_id].append(past_time) + + # Add a completed job + job = self._create_job_data(1, self.test_run_id, "test-job") + self.collector.workflow_jobs[self.test_run_id].append(job["workflow_job"]) + + result = self.collector._should_process_workflow_now(self.test_run_id) + self.assertTrue(result) + + def test_incomplete_jobs_prevent_processing(self): + """Test that workflows with incomplete jobs are not processed.""" + # Add a job without conclusion + incomplete_job = self._create_job_data(1, self.test_run_id, "incomplete-job") + incomplete_job["workflow_job"]["conclusion"] = None + self.collector.workflow_jobs[self.test_run_id].append(incomplete_job["workflow_job"]) + + result = self.collector._should_process_workflow_now(self.test_run_id) + self.assertFalse(result) + + def test_multiple_workflows_concurrent(self): + """Test handling multiple workflow runs concurrently.""" + run_id_1 = 11111 + run_id_2 = 22222 + + # Add jobs to different workflows + job1 = self._create_job_data(1, run_id_1, "job-1") + job2 = self._create_job_data(2, run_id_2, "job-2") + + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job1) + self.collector.add_job(job2) + + # Verify both workflows are tracked separately + self.assertIn(run_id_1, self.collector.workflow_jobs) + self.assertIn(run_id_2, self.collector.workflow_jobs) + self.assertEqual(len(self.collector.workflow_jobs[run_id_1]), 1) + self.assertEqual(len(self.collector.workflow_jobs[run_id_2]), 1) + + def test_api_failure_fallback_to_timeout(self): + """Test that API failure falls back to timeout-based detection.""" + repository = {"full_name": "test-org/test-repo"} + job_data = self._create_job_data(1, self.test_run_id, "test-job") + job_data["repository"] = repository + + # Mock API failure + with patch('requests.get', side_effect=requests.exceptions.RequestException("API Error")): + with patch.object(self.collector, '_schedule_timeout_check') as mock_schedule: + self.collector.add_job(job_data) + + # Should have None in cache (API failed) + self.assertIsNone(self.collector.workflow_total_jobs.get(self.test_run_id)) + # Should still schedule timeout check (fallback behavior) + # Note: This depends on job completion status + + def test_total_count_zero(self): + """Test handling when API returns total_count of 0.""" + repository = {"full_name": "test-org/test-repo"} + job_data = self._create_job_data(1, self.test_run_id, "test-job") + job_data["repository"] = repository + + mock_response = Mock() + mock_response.json.return_value = {"total_count": 0} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response): + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Should handle 0 count gracefully + self.assertEqual(self.collector.workflow_total_jobs[self.test_run_id], 0) + + def test_missing_repository_info(self): + """Test handling when repository info is missing from webhook.""" + job_data = self._create_job_data(1, self.test_run_id, "test-job") + del job_data["repository"] # Remove repository info + + mock_response = Mock() + mock_response.json.return_value = {"total_count": 3} + mock_response.raise_for_status = Mock() + + with patch('requests.get', return_value=mock_response): + with patch.object(self.collector, '_schedule_workflow_processing'): + self.collector.add_job(job_data) + + # Should still work, using "unknown/unknown" as fallback + # API call might fail, but shouldn't crash + + def test_multiple_workflows_concurrent_different_job_ids(self): + """Test that multiple workflow runs with different job IDs are handled correctly.""" + run_id_1 = 11111 + run_id_2 = 22222 + job_id_1 = 99999 + job_id_2 = 88888 # Different job ID for different run + + job1 = self._create_job_data(job_id_1, run_id_1, "job-1") + job2 = self._create_job_data(job_id_2, run_id_2, "job-2") + + # Mock API calls for both runs - return different total counts + def mock_get_side_effect(url, **kwargs): + mock_resp = Mock() + if str(run_id_1) in url: + mock_resp.json.return_value = {"total_count": 1} # Run 1 has 1 job + else: + mock_resp.json.return_value = {"total_count": 1} # Run 2 has 1 job + mock_resp.raise_for_status = Mock() + return mock_resp + + with patch('requests.get', side_effect=mock_get_side_effect): + with patch.object(self.collector, '_schedule_workflow_processing'): + with patch.object(self.collector, '_schedule_timeout_check'): + with patch.object(self.collector, '_should_process_workflow_now', return_value=False): + with patch.object(self.collector, '_process_workflow_immediately'): + # Add jobs to different workflow runs + self.collector.add_job(job1) + self.collector.add_job(job2) + + # Verify both jobs were added to processed_jobs + self.assertIn(job_id_1, self.collector.processed_jobs) + self.assertIn(job_id_2, self.collector.processed_jobs) + # Verify both workflows have their jobs + self.assertEqual(len(self.collector.workflow_jobs[run_id_1]), 1) + self.assertEqual(len(self.collector.workflow_jobs[run_id_2]), 1) + + +class TestWorkflowTracerNewFeatures(unittest.TestCase): + """Test suite for new WorkflowTracer features.""" + + def setUp(self): + """Set up test fixtures.""" + self.dsn = "https://test@sentry.io/123" + self.token = "test_token" + + with patch('src.workflow_tracer.Envelope'): + from src.workflow_tracer import WorkflowTracer + self.tracer = WorkflowTracer( + token=self.token, + dsn=self.dsn, + dry_run=True + ) + + def _create_job(self, job_id: int, run_id: int, name: str, + started_at: str = None, completed_at: str = None) -> dict: + """Create a mock job.""" + now = datetime.utcnow() + if not started_at: + started_at = (now - timedelta(minutes=5)).isoformat() + "Z" + if not completed_at: + completed_at = now.isoformat() + "Z" + + return { + "id": job_id, + "run_id": run_id, + "name": name, + "conclusion": "success", + "started_at": started_at, + "completed_at": completed_at, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "steps": [] + } + + def test_fetch_workflow_run_timestamps_success(self): + """Test successful fetching of workflow run timestamps.""" + run_id = 12345 + repo_full_name = "test-org/test-repo" + + mock_response = Mock() + mock_response.json.return_value = { + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" + } + mock_response.raise_for_status = Mock() + + job = self._create_job(1, run_id, "test-job") + + with patch.object(self.tracer, '_fetch_github', return_value=mock_response): + workflow_data = self.tracer._get_workflow_run_data( + job, + repository_info={"full_name": repo_full_name} + ) + + self.assertEqual(workflow_data["runs"]["created_at"], "2025-11-17T17:36:13Z") + self.assertEqual(workflow_data["runs"]["updated_at"], "2025-11-17T17:36:52Z") + + def test_fetch_workflow_run_timestamps_no_token(self): + """Test that workflow run timestamps are None when no token.""" + self.tracer.token = None + + job = self._create_job(1, 12345, "test-job") + workflow_data = self.tracer._get_workflow_run_data(job) + + self.assertIsNone(workflow_data["runs"].get("created_at")) + self.assertIsNone(workflow_data["runs"].get("updated_at")) + + def test_fetch_workflow_run_timestamps_api_failure(self): + """Test that API failures are handled gracefully.""" + job = self._create_job(1, 12345, "test-job") + + with patch.object(self.tracer, '_fetch_github', side_effect=Exception("API Error")): + workflow_data = self.tracer._get_workflow_run_data( + job, + repository_info={"full_name": "test-org/test-repo"} + ) + + # Should fall back to None timestamps + self.assertIsNone(workflow_data["runs"].get("created_at")) + self.assertIsNone(workflow_data["runs"].get("updated_at")) + + def test_create_workflow_transaction_uses_run_timestamps(self): + """Test that workflow transaction uses run timestamps when available.""" + job = self._create_job(1, 12345, "test-job") + all_jobs = [job] + + # Mock workflow run data with timestamps + mock_workflow_data = { + "runs": { + "head_commit": {"author": {"name": "Test", "email": "test@test.com"}}, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "html_url": "https://github.com/test/repo/actions/runs/12345", + "repository": {"full_name": "test-org/test-repo"}, + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" + }, + "workflow": {"name": "Test Workflow", "path": ".github/workflows/test.yml"}, + "repo": "test-org/test-repo" + } + + with patch.object(self.tracer, '_get_workflow_run_data', return_value=mock_workflow_data): + transaction = self.tracer._create_workflow_transaction(job, all_jobs) + + self.assertEqual(transaction["start_timestamp"], "2025-11-17T17:36:13Z") + self.assertEqual(transaction["timestamp"], "2025-11-17T17:36:52Z") + + def test_create_workflow_transaction_fallback_to_job_timestamps(self): + """Test that workflow transaction falls back to job timestamps.""" + job = self._create_job( + 1, 12345, "test-job", + started_at="2025-11-17T17:36:10Z", + completed_at="2025-11-17T17:36:45Z" + ) + all_jobs = [job] + + # Mock workflow run data without timestamps + mock_workflow_data = { + "runs": { + "head_commit": {"author": {"name": "Test", "email": "test@test.com"}}, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "html_url": "https://github.com/test/repo/actions/runs/12345", + "repository": {"full_name": "test-org/test-repo"}, + "created_at": None, + "updated_at": None + }, + "workflow": {"name": "Test Workflow", "path": ".github/workflows/test.yml"}, + "repo": "test-org/test-repo" + } + + with patch.object(self.tracer, '_get_workflow_run_data', return_value=mock_workflow_data): + transaction = self.tracer._create_workflow_transaction(job, all_jobs) + + # Should use job timestamps + self.assertEqual(transaction["start_timestamp"], "2025-11-17T17:36:10Z") + self.assertEqual(transaction["timestamp"], "2025-11-17T17:36:45Z") + + def test_create_workflow_transaction_adds_cleanup_span(self): + """Test that cleanup span is added when there's a delta.""" + job = self._create_job( + 1, 12345, "test-job", + started_at="2025-11-17T17:36:10Z", + completed_at="2025-11-17T17:36:48Z" # Job completes at 48s + ) + all_jobs = [job] + + # Mock workflow run data with updated_at later than job completion + mock_workflow_data = { + "runs": { + "head_commit": {"author": {"name": "Test", "email": "test@test.com"}}, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "html_url": "https://github.com/test/repo/actions/runs/12345", + "repository": {"full_name": "test-org/test-repo"}, + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" # Workflow completes at 52s (4s later) + }, + "workflow": {"name": "Test Workflow", "path": ".github/workflows/test.yml"}, + "repo": "test-org/test-repo" + } + + with patch.object(self.tracer, '_get_workflow_run_data', return_value=mock_workflow_data): + # Mock re-fetch to return same data + with patch.object(self.tracer, '_fetch_github') as mock_fetch: + mock_response = Mock() + mock_response.json.return_value = { + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" + } + mock_fetch.return_value = mock_response + + transaction = self.tracer._create_workflow_transaction(job, all_jobs) + + # Should have cleanup span + cleanup_spans = [s for s in transaction["spans"] if s["op"] == "workflow.cleanup"] + self.assertEqual(len(cleanup_spans), 1) + self.assertEqual(cleanup_spans[0]["name"], "Workflow cleanup and teardown") + + def test_create_workflow_transaction_no_cleanup_span_when_no_delta(self): + """Test that cleanup span is not added when there's no delta.""" + job = self._create_job( + 1, 12345, "test-job", + started_at="2025-11-17T17:36:10Z", + completed_at="2025-11-17T17:36:52Z" # Job completes at same time as workflow + ) + all_jobs = [job] + + # Mock workflow run data with updated_at same as job completion + mock_workflow_data = { + "runs": { + "head_commit": {"author": {"name": "Test", "email": "test@test.com"}}, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "html_url": "https://github.com/test/repo/actions/runs/12345", + "repository": {"full_name": "test-org/test-repo"}, + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" # Same as job completion + }, + "workflow": {"name": "Test Workflow", "path": ".github/workflows/test.yml"}, + "repo": "test-org/test-repo" + } + + with patch.object(self.tracer, '_get_workflow_run_data', return_value=mock_workflow_data): + # Mock re-fetch to return same data + with patch.object(self.tracer, '_fetch_github') as mock_fetch: + mock_response = Mock() + mock_response.json.return_value = { + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" + } + mock_fetch.return_value = mock_response + + transaction = self.tracer._create_workflow_transaction(job, all_jobs) + + # Should NOT have cleanup span + cleanup_spans = [s for s in transaction["spans"] if s["op"] == "workflow.cleanup"] + self.assertEqual(len(cleanup_spans), 0) + + def test_re_fetch_workflow_run_before_sending(self): + """Test that workflow run is re-fetched before sending trace.""" + job = self._create_job(1, 12345, "test-job") + all_jobs = [job] + + # First fetch returns old updated_at + first_fetch_data = { + "runs": { + "head_commit": {"author": {"name": "Test", "email": "test@test.com"}}, + "head_branch": "main", + "head_sha": "abc123", + "run_attempt": 1, + "html_url": "https://github.com/test/repo/actions/runs/12345", + "repository": {"full_name": "test-org/test-repo"}, + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:48Z" # Old timestamp + }, + "workflow": {"name": "Test Workflow", "path": ".github/workflows/test.yml"}, + "repo": "test-org/test-repo" + } + + # Second fetch (re-fetch) returns new updated_at + second_fetch_response = Mock() + second_fetch_response.json.return_value = { + "created_at": "2025-11-17T17:36:13Z", + "updated_at": "2025-11-17T17:36:52Z" # New timestamp + } + + with patch.object(self.tracer, '_get_workflow_run_data', return_value=first_fetch_data): + with patch.object(self.tracer, '_fetch_github', return_value=second_fetch_response) as mock_fetch: + transaction = self.tracer._create_workflow_transaction( + job, all_jobs, + repository_info={"full_name": "test-org/test-repo"} + ) + + # Verify re-fetch was called + mock_fetch.assert_called() + # Verify transaction uses new timestamp + self.assertEqual(transaction["timestamp"], "2025-11-17T17:36:52Z") + + +if __name__ == '__main__': + unittest.main() +