diff --git a/src/webapp/config.py b/src/webapp/config.py index ab2b269f..29c38d71 100644 --- a/src/webapp/config.py +++ b/src/webapp/config.py @@ -43,6 +43,9 @@ "GCP_SERVICE_ACCOUNT_EMAIL": "", } +# ENV -> Databricks volume schema (used for /Volumes/{schema}/... paths). +ENV_TO_VOLUME_SCHEMA = {"DEV": "dev_sst_02", "STAGING": "staging_sst_01"} + # databricks vars needed for databricks integration databricks_vars = { # SECRET. diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 5820415e..98ec6974 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -14,7 +14,7 @@ from google.cloud import storage from google.api_core import exceptions as gcs_errors from .validation_extension import generate_extension_schema -from .config import databricks_vars, gcs_vars +from .config import ENV_TO_VOLUME_SCHEMA, databricks_vars, env_vars, gcs_vars from .utilities import databricksify_inst_name, SchemaType from typing import List, Any, Dict, Optional from fastapi import HTTPException @@ -48,6 +48,8 @@ class DatabricksInferenceRunRequest(BaseModel): # The email where notifications will get sent. email: str gcp_external_bucket_name: str + # Optional term filter (e.g. cohort labels); serialized as JSON for job params when set. Used for cohort/graduation models. + term_filter: list[str] | None = None class DatabricksInferenceRunResponse(BaseModel): @@ -83,6 +85,78 @@ def _sha256_json(obj: Any) -> str: ).hexdigest() +def _parse_config_toml_to_selection(raw: bytes) -> dict | None: + """Parse TOML bytes and return the [preprocessing.selection] section, or None.""" + try: + try: + import tomllib + + try: + data = tomllib.loads(raw) + except TypeError: + data = tomllib.loads(raw.decode("utf-8")) + except ImportError: + import tomli as tomllib + + data = tomllib.loads(raw.decode("utf-8")) + except (Exception, TypeError): + return None + if not isinstance(data, dict): + return None + preprocessing = data.get("preprocessing") + if not isinstance(preprocessing, dict): + return None + selection = preprocessing.get("selection") + if not isinstance(selection, dict): + return None + return selection + + +def _find_selection_in_toml_under( + w: WorkspaceClient, + directory_path: str, + inst_name: str, +) -> dict | None: + """List directory recursively; find first .toml file with [preprocessing.selection].""" + try: + entries = list(w.files.list_directory_contents(directory_path)) + except Exception as e: + LOGGER.debug( + "read_volume_training_config: could not list %s for %s: %s", + directory_path, + inst_name, + e, + ) + return None + for entry in entries: + if not entry.path: + continue + if entry.is_directory: + selection = _find_selection_in_toml_under(w, entry.path, inst_name) + if selection is not None: + return selection + continue + if entry.name and entry.name.lower().endswith(".toml"): + try: + response = w.files.download(entry.path) + if response.contents is None: + continue + raw = response.contents.read() + except Exception as e: + LOGGER.debug( + "read_volume_training_config: could not read %s for %s: %s", + entry.path, + inst_name, + e, + ) + continue + raw_bytes = raw if isinstance(raw, bytes) else raw.encode("utf-8") + selection = _parse_config_toml_to_selection(raw_bytes) + if selection is not None: + return selection + return None + + L1_RESP_CACHE_TTL = int("600") # seconds L1_VER_CACHE_TTL = int("3600") # seconds L1_RESP_CACHE: Any = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL) @@ -225,44 +299,44 @@ def run_pdp_inference( f"run_pdp_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." ) job_id = job.job_id - LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") + LOGGER.info("Resolved job ID for '%s': %s", pipeline_type, job_id) except Exception as e: - LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.") + LOGGER.exception("Job lookup failed for '%s': %s", pipeline_type, e) raise ValueError(f"run_pdp_inference(): Failed to find job: {e}") + job_params: Dict[str, str] = { + "cohort_file_name": get_filepath_of_filetype( + req.filepath_to_type, SchemaType.STUDENT + ), + "course_file_name": get_filepath_of_filetype( + req.filepath_to_type, SchemaType.COURSE + ), + "databricks_institution_name": db_inst_name, + "DB_workspace": databricks_vars[ + "DATABRICKS_WORKSPACE" + ], # is this value the same PER environ? dev/staging/prod + "gcp_bucket_name": req.gcp_external_bucket_name, + "model_name": req.model_name, + "notification_email": req.email, + } + if req.term_filter is not None: + job_params["term_filter"] = json.dumps(req.term_filter) try: run_job: Any = w.jobs.run_now( job_id, - job_parameters={ - "cohort_file_name": get_filepath_of_filetype( - req.filepath_to_type, SchemaType.STUDENT - ), - "course_file_name": get_filepath_of_filetype( - req.filepath_to_type, SchemaType.COURSE - ), - "databricks_institution_name": db_inst_name, - "DB_workspace": databricks_vars[ - "DATABRICKS_WORKSPACE" - ], # is this value the same PER environ? dev/staging/prod - "gcp_bucket_name": req.gcp_external_bucket_name, - "model_name": req.model_name, - "notification_email": req.email, - }, + job_parameters=job_params, ) LOGGER.info( - f"Successfully triggered job run. Run ID: {run_job.response.run_id}" + "Successfully triggered job run. Run ID: %s", run_job.response.run_id ) except Exception as e: - LOGGER.exception("Failed to run the PDP inference job.") + LOGGER.exception("Failed to run the PDP inference job: %s", e) raise ValueError(f"run_pdp_inference(): Job could not be run: {e}") if not run_job.response or run_job.response.run_id is None: raise ValueError("run_pdp_inference(): Job did not return a valid run_id.") - run_id = run_job.response.run_id - LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") - - return DatabricksInferenceRunResponse(job_run_id=run_id) + return DatabricksInferenceRunResponse(job_run_id=run_job.response.run_id) def delete_inst(self, inst_name: str) -> None: """Cleanup tasks required on the Databricks side to delete an institution.""" @@ -530,6 +604,74 @@ def fetch_model_version( return latest_version + def read_volume_training_config( + self, inst_name: str, model_run_id: str + ) -> dict | None: + """Read training/preprocessing config from the model's training run in the silver volume. + + Looks for any .toml file under: + /Volumes/{env_schema}/{slug}_silver/silver_volume/{model_run_id}/ + + Uses the first .toml found (any subfolder) that contains [preprocessing.selection]. + inst_name is the institution display name (e.g. from inst.name). The path slug is + derived with databricksify_inst_name(inst_name). model_run_id is the training run + identifier (e.g. 0b2e206732ce48f6b644149090c9614a). env_schema is derived from ENV + (see config.startup_env_vars and ENV_FILE_PATH). Allowed values: DEV -> dev_sst_02, + STAGING -> staging_sst_01; other values (e.g. LOCAL, PROD) return None. + Returns the [preprocessing.selection] section only (dict with student_criteria, etc.) + for use by latest-inference-cohort and related logic. + + Returns that section dict, or None if no suitable file or section is found. + """ + if not inst_name or not str(inst_name).strip(): + LOGGER.warning( + "read_volume_training_config: empty inst_name; cannot build volume path.", + ) + return None + model_run_id_clean = str(model_run_id).strip() if model_run_id else "" + if not model_run_id_clean: + LOGGER.warning( + "read_volume_training_config: empty model_run_id; cannot build volume path.", + ) + return None + env = str(env_vars.get("ENV", "")).strip().upper() + if env not in ENV_TO_VOLUME_SCHEMA: + LOGGER.warning( + "read_volume_training_config: ENV %r not in %s; cannot read config for %s", + env_vars.get("ENV"), + list(ENV_TO_VOLUME_SCHEMA), + inst_name, + ) + return None + env_schema = ENV_TO_VOLUME_SCHEMA[env] + try: + db_inst_name = databricksify_inst_name(inst_name) + except ValueError as e: + LOGGER.warning( + "read_volume_training_config: cannot databricksify inst_name %r: %s", + inst_name, + e, + ) + return None + directory_path = ( + f"/Volumes/{env_schema}/{db_inst_name}_silver/silver_volume/" + f"{model_run_id_clean}" + ) + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + except Exception as e: + LOGGER.exception( + "read_volume_training_config: WorkspaceClient failed for %s: %s", + inst_name, + e, + ) + return None + selection = _find_selection_in_toml_under(w, directory_path, inst_name) + return selection + def delete_model(self, catalog_name: str, inst_name: str, model_name: str) -> None: schema = databricksify_inst_name(inst_name) model_name_path = f"{catalog_name}.{schema}_gold.{model_name}" diff --git a/src/webapp/databricks_test.py b/src/webapp/databricks_test.py index 37fa2855..5da23e88 100644 --- a/src/webapp/databricks_test.py +++ b/src/webapp/databricks_test.py @@ -1,6 +1,16 @@ +import json import pytest +from unittest import mock -from .databricks import DatabricksControl +from databricks.sdk.service.files import DirectoryEntry + +from . import databricks as databricks_module +from .databricks import ( + DatabricksControl, + DatabricksInferenceRunRequest, + _parse_config_toml_to_selection, +) +from .utilities import SchemaType @pytest.fixture @@ -51,3 +61,248 @@ def test_invalid_regex_is_ignored(ctrl): def test_returns_none_when_no_match(ctrl): mapping = {"student": "student.csv"} assert ctrl.get_key_for_file(mapping, "unknown.csv") is None + + +def test_parse_config_toml_to_selection_returns_preprocessing_selection(): + """_parse_config_toml_to_selection parses TOML bytes and returns [preprocessing.selection] only.""" + toml_bytes = ( + b"[preprocessing]\nsplits = { train = 0.6, test = 0.2, validate = 0.2 }\n" + b"[preprocessing.selection]\n" + b'student_criteria = { enrollment_type = "FIRST-TIME", cohort_term = ["FALL", "SPRING"] }\n' + ) + result = _parse_config_toml_to_selection(toml_bytes) + assert result is not None + assert result == { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + } + } + + +def test_parse_config_toml_to_selection_returns_none_for_invalid_or_missing_section(): + """_parse_config_toml_to_selection returns None when TOML is invalid or section missing.""" + assert _parse_config_toml_to_selection(b"not valid toml {{{") is None + assert _parse_config_toml_to_selection(b"[other]\nx = 1\n") is None + assert _parse_config_toml_to_selection(b"[preprocessing]\nx = 1\n") is None + + +MODEL_RUN_ID_TEST = "0b2e206732ce48f6b644149090c9614a" + + +def test_read_volume_training_config_returns_none_for_empty_inst_name(ctrl): + """read_volume_training_config returns None when inst_name is empty.""" + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + assert ctrl.read_volume_training_config("", MODEL_RUN_ID_TEST) is None + assert ctrl.read_volume_training_config(" ", MODEL_RUN_ID_TEST) is None + + +def test_read_volume_training_config_returns_none_for_empty_model_run_id(ctrl): + """read_volume_training_config returns None when model_run_id is empty.""" + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + assert ctrl.read_volume_training_config("Some University", "") is None + assert ctrl.read_volume_training_config("Some University", " ") is None + + +def test_read_volume_training_config_returns_none_when_env_not_dev_or_staging(ctrl): + """read_volume_training_config returns None when ENV is LOCAL (no volume schema).""" + with mock.patch.dict(databricks_module.env_vars, {"ENV": "LOCAL"}): + result = ctrl.read_volume_training_config("Some University", MODEL_RUN_ID_TEST) + assert result is None + + +def test_read_volume_training_config_returns_none_when_databricksify_raises(ctrl): + """read_volume_training_config returns None when databricksify_inst_name raises ValueError.""" + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, + "databricksify_inst_name", + side_effect=ValueError("invalid chars"), + ): + result = ctrl.read_volume_training_config( + "Bad/Name\\Here", MODEL_RUN_ID_TEST + ) + assert result is None + + +def test_read_volume_training_config_returns_none_when_workspace_client_raises(ctrl): + """read_volume_training_config returns None when WorkspaceClient construction fails.""" + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, + "WorkspaceClient", + side_effect=Exception("connection refused"), + ): + result = ctrl.read_volume_training_config( + "Some University", MODEL_RUN_ID_TEST + ) + assert result is None + + +def _one_toml_entry( + path: str = "/Volumes/dev_sst_02/some_uni_silver/silver_volume/run_id/training.toml", + name: str = "training.toml", +) -> list[DirectoryEntry]: + """Single .toml file entry as returned by list_directory_contents (any .toml name).""" + return [ + DirectoryEntry(path=path, name=name, is_directory=False), + ] + + +def test_read_volume_training_config_returns_none_when_list_raises(ctrl): + """read_volume_training_config returns None when list_directory_contents raises.""" + mock_client = mock.Mock() + mock_client.files.list_directory_contents.side_effect = Exception("not found") + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, "WorkspaceClient", return_value=mock_client + ): + result = ctrl.read_volume_training_config( + "Some University", MODEL_RUN_ID_TEST + ) + assert result is None + + +def test_read_volume_training_config_returns_none_when_download_raises(ctrl): + """read_volume_training_config returns None when files.download raises.""" + mock_client = mock.Mock() + mock_client.files.list_directory_contents.return_value = iter(_one_toml_entry()) + mock_client.files.download.side_effect = Exception("file not found") + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, "WorkspaceClient", return_value=mock_client + ): + result = ctrl.read_volume_training_config( + "Some University", MODEL_RUN_ID_TEST + ) + assert result is None + + +def test_read_volume_training_config_returns_none_when_toml_missing_selection_section( + ctrl, +): + """read_volume_training_config returns None when config file has no [preprocessing.selection].""" + mock_response = mock.Mock() + mock_response.contents.read.return_value = b"[other]\nx = 1\n" + mock_client = mock.Mock() + mock_client.files.list_directory_contents.return_value = iter(_one_toml_entry()) + mock_client.files.download.return_value = mock_response + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, "WorkspaceClient", return_value=mock_client + ): + result = ctrl.read_volume_training_config( + "Some University", MODEL_RUN_ID_TEST + ) + assert result is None + + +def test_read_volume_training_config_returns_selection_when_toml_found_under_run_dir( + ctrl, +): + """read_volume_training_config returns [preprocessing.selection] when any .toml under run dir has it.""" + toml_bytes = ( + b"[preprocessing]\n[preprocessing.selection]\n" + b'student_criteria = { enrollment_type = "FIRST-TIME" }\n' + ) + mock_response = mock.Mock() + mock_response.contents.read.return_value = toml_bytes + mock_client = mock.Mock() + # Any .toml name is accepted (e.g. training.toml, config.toml, preprocessing.toml) + mock_client.files.list_directory_contents.return_value = iter( + _one_toml_entry( + "/Volumes/dev_sst_02/some_uni_silver/silver_volume/run_id/training.toml", + name="training.toml", + ) + ) + mock_client.files.download.return_value = mock_response + with mock.patch.dict(databricks_module.env_vars, {"ENV": "DEV"}): + with mock.patch.object( + databricks_module, "WorkspaceClient", return_value=mock_client + ): + result = ctrl.read_volume_training_config( + "Some University", MODEL_RUN_ID_TEST + ) + assert result is not None + assert result.get("student_criteria") == {"enrollment_type": "FIRST-TIME"} + + +def _minimal_inference_request(term_filter=None): + """Minimal DatabricksInferenceRunRequest with STUDENT and COURSE file types.""" + return DatabricksInferenceRunRequest( + inst_name="Test Inst", + filepath_to_type={ + "/path/cohort.csv": [SchemaType.STUDENT], + "/path/course.csv": [SchemaType.COURSE], + }, + model_name="test_model", + email="test@example.com", + gcp_external_bucket_name="test-bucket", + term_filter=term_filter, + ) + + +def test_run_pdp_inference_omits_term_filter_from_job_params_when_none(ctrl): + """When term_filter is None, job_parameters passed to run_now do not contain term_filter key.""" + req = _minimal_inference_request(term_filter=None) + mock_job = mock.Mock() + mock_job.job_id = 12345 + mock_run_response = mock.Mock() + mock_run_response.response.run_id = 999 + mock_w = mock.Mock() + mock_w.jobs.list.return_value = iter([mock_job]) + mock_w.jobs.run_now.return_value = mock_run_response + with ( + mock.patch.object(databricks_module, "WorkspaceClient", return_value=mock_w), + mock.patch.object( + databricks_module, "databricksify_inst_name", return_value="test_inst" + ), + mock.patch.dict( + databricks_module.databricks_vars, + {"DATABRICKS_HOST_URL": "https://x", "DATABRICKS_WORKSPACE": "ws"}, + ), + mock.patch.dict( + databricks_module.gcs_vars, {"GCP_SERVICE_ACCOUNT_EMAIL": "a@b.com"} + ), + ): + result = ctrl.run_pdp_inference(req) + assert result.job_run_id == 999 + mock_w.jobs.run_now.assert_called_once() + call_kwargs = mock_w.jobs.run_now.call_args[1] + job_params = call_kwargs["job_parameters"] + assert "term_filter" not in job_params + + +def test_run_pdp_inference_includes_term_filter_in_job_params_when_set(ctrl): + """When term_filter is set, job_parameters include term_filter as JSON string.""" + req = _minimal_inference_request(term_filter=["fall 2024-25", "spring 2024-25"]) + mock_job = mock.Mock() + mock_job.job_id = 12345 + mock_run_response = mock.Mock() + mock_run_response.response.run_id = 888 + mock_w = mock.Mock() + mock_w.jobs.list.return_value = iter([mock_job]) + mock_w.jobs.run_now.return_value = mock_run_response + with ( + mock.patch.object(databricks_module, "WorkspaceClient", return_value=mock_w), + mock.patch.object( + databricks_module, "databricksify_inst_name", return_value="test_inst" + ), + mock.patch.dict( + databricks_module.databricks_vars, + {"DATABRICKS_HOST_URL": "https://x", "DATABRICKS_WORKSPACE": "ws"}, + ), + mock.patch.dict( + databricks_module.gcs_vars, {"GCP_SERVICE_ACCOUNT_EMAIL": "a@b.com"} + ), + ): + result = ctrl.run_pdp_inference(req) + assert result.job_run_id == 888 + mock_w.jobs.run_now.assert_called_once() + call_kwargs = mock_w.jobs.run_now.call_args[1] + job_params = call_kwargs["job_parameters"] + assert "term_filter" in job_params + assert json.loads(job_params["term_filter"]) == [ + "fall 2024-25", + "spring 2024-25", + ] diff --git a/src/webapp/routers/data.py b/src/webapp/routers/data.py index e1b24496..f463998b 100644 --- a/src/webapp/routers/data.py +++ b/src/webapp/routers/data.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime, date -from typing import Annotated, Any, Dict, List, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Optional, Tuple, Union, Literal from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, status, Response from sqlalchemy import and_, or_ @@ -28,6 +28,7 @@ DataSource, get_external_bucket_name, decode_url_piece, + SchemaType, ) from ..database import ( @@ -186,6 +187,16 @@ class DataOverview(BaseModel): files: list[DataInfo] +class LatestInferenceCohortResponse(BaseModel): + """Latest inference-ready cohort for an institution (Task 1: single cohort).""" + + cohort_label: Optional[str] = None + valid_student_count: int = 0 + status: Literal["valid", "invalid"] + batch_name: Optional[str] = None + reason: Optional[str] = None + + # Data related operations. Input files mean files sourced from the institution. Output files are generated by SST. @@ -331,6 +342,499 @@ def read_inst_all_output_files( } +def _latest_inference_cohort_validation_response( + inst_id: str, + df_student: pd.DataFrame, + order_error: Optional[str], + default_label: str, + batch_name: Optional[str], +) -> Optional[LatestInferenceCohortResponse]: + """Return an invalid LatestInferenceCohortResponse if order_error or missing id col; else None.""" + if order_error is not None: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", + inst_id, + order_error, + ) + return LatestInferenceCohortResponse( + cohort_label=default_label, + valid_student_count=0, + status="invalid", + batch_name=batch_name, + reason=order_error, + ) + id_col = "student_id" if "student_id" in df_student.columns else "study_id" + if id_col not in df_student.columns: + reason = "Student file missing student id column." + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", inst_id, reason + ) + return LatestInferenceCohortResponse( + cohort_label=default_label, + valid_student_count=0, + status="invalid", + batch_name=batch_name, + reason=reason, + ) + return None + + +def _first_valid_cohort_from_ordered_terms( + ordered_cohort_terms: List[Dict[str, Any]], + df_student: pd.DataFrame, + df_course: pd.DataFrame, + selection_config: Dict[str, Any], + id_col: str, + course_id_col: str, + default_label: str, + batch_name: Optional[str], + inst_id: str, +) -> LatestInferenceCohortResponse: + """Try each cohort term in order; return first valid, or invalid if none have course data.""" + for candidate in ordered_cohort_terms: + df_student_cohort, cohort_label = _filter_students_to_cohort_term( + df_student, candidate + ) + if df_student_cohort.empty: + continue + cohort_student_ids = set(df_student_cohort[id_col].drop_duplicates()) + if course_id_col in df_course.columns: + df_course_cohort = df_course[ + df_course[course_id_col].isin(cohort_student_ids) + ] + else: + df_course_cohort = df_course + if df_course_cohort.empty: + continue + valid_student_count, criteria_error = apply_student_criteria_count( + df_student_cohort, selection_config, df_course=df_course_cohort + ) + if criteria_error: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", + inst_id, + criteria_error, + ) + return LatestInferenceCohortResponse( + cohort_label=cohort_label, + valid_student_count=0, + status="invalid", + batch_name=batch_name, + reason=criteria_error, + ) + return LatestInferenceCohortResponse( + cohort_label=cohort_label, + valid_student_count=valid_student_count, + status="valid", + batch_name=batch_name, + reason=None, + ) + reason = "No cohort term in the batch has course data for those students." + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", inst_id, reason + ) + return LatestInferenceCohortResponse( + cohort_label=default_label, + valid_student_count=0, + status="invalid", + batch_name=batch_name, + reason=reason, + ) + + +def _resolve_latest_inference_cohort_from_dataframes( + inst_id: str, + df_student: pd.DataFrame, + df_course: pd.DataFrame, + selection_config: Dict[str, Any], + batch_name: Optional[str], +) -> LatestInferenceCohortResponse: + """Resolve latest inference-ready cohort from loaded student and course DataFrames. + + Tries cohort terms from most recent to oldest; returns the first that has course data + and meets selection criteria. Logs with inst_id when returning invalid. + + Args: + inst_id: Institution id (for logging). + df_student: Non-empty STUDENT DataFrame with cohort_term column. + df_course: Non-empty COURSE DataFrame. + selection_config: preprocessing.selection config (student_criteria, etc.). + batch_name: Batch name for response labels. + + Returns: + LatestInferenceCohortResponse with status valid or invalid and reason. + """ + default_label = batch_name or "Latest cohort" + allowed_cohort_terms = _allowed_cohort_terms_from_config(selection_config) + ordered_cohort_terms, order_error = get_ordered_cohort_terms( + df_student, allowed_cohort_terms=allowed_cohort_terms + ) + validation_response = _latest_inference_cohort_validation_response( + inst_id, df_student, order_error, default_label, batch_name + ) + if validation_response is not None: + return validation_response + + id_col = "student_id" if "student_id" in df_student.columns else "study_id" + course_id_col = "student_id" if "student_id" in df_course.columns else "study_id" + + if ordered_cohort_terms: + return _first_valid_cohort_from_ordered_terms( + ordered_cohort_terms, + df_student, + df_course, + selection_config, + id_col, + course_id_col, + default_label, + batch_name, + inst_id, + ) + + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", + inst_id, + COHORT_TERM_REQUIRED_MSG, + ) + return LatestInferenceCohortResponse( + cohort_label=default_label, + valid_student_count=0, + status="invalid", + batch_name=batch_name, + reason=COHORT_TERM_REQUIRED_MSG, + ) + + +@router.get( + "/{inst_id}/latest-inference-cohort", + response_model=LatestInferenceCohortResponse, +) +def get_latest_inference_cohort( + inst_id: str, + current_user: Annotated[BaseUser, Depends(get_current_active_user)], + sql_session: Annotated[Session, Depends(get_session)], + storage_control: Annotated[StorageControl, Depends(StorageControl)], + databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)], + model_name: Optional[str] = None, + batch_name: Optional[str] = None, +) -> LatestInferenceCohortResponse: + """Return the latest inference-ready cohort for the institution. + + Uses the given batch (or the most recent with student + course files if + batch_name omitted), the given model (or the single registered model), + and the model's config to resolve the most recent cohort term that has + course data and meets criteria. Returns valid student count and status so + the user can confirm running inference on that cohort. Always returns 200; + status='invalid' with reason when no batch, missing config, or no cohort + has course data. + + Config is read from the model's training run in the silver volume. model_run_id + is derived from the latest model version for the chosen model. + + Args: + inst_id: Institution UUID from path. + model_name: Optional. Model name; used to find config. If omitted, the + institution must have exactly one registered (valid, non-deleted) model. + batch_name: Optional. Batch name; when provided, cohort selection uses + this batch's data. When omitted, uses the most recent batch with + student + course files. + current_user: Injected; must have access to inst_id. + sql_session: Injected DB session. + storage_control: Injected storage for batch files. + databricks_control: Injected for model version and training config. + + Returns: + LatestInferenceCohortResponse with status valid or invalid, cohort_label, + valid_student_count, batch_name, and reason when invalid. + """ + has_access_to_inst_or_err(inst_id, current_user) + local_session.set(sql_session) + session = local_session.get() + + inst_rows = session.execute( + select(InstTable).where(InstTable.id == str_to_uuid(inst_id)) + ).all() + if not inst_rows or len(inst_rows) != 1: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: Institution not found.", + inst_id, + ) + return LatestInferenceCohortResponse( + status="invalid", reason="Institution not found." + ) + inst = inst_rows[0][0] + + resolved_model_name, model_run_id, error_response = ( + _resolve_model_name_and_run_id_for_inference( + session, inst_id, model_name, inst.name, databricks_control + ) + ) + if error_response is not None: + return error_response + assert model_run_id is not None # guaranteed when error_response is None + + df_student, df_course, selection_config, batch_name_for_response, error_response = ( + _load_batch_and_dataframes_for_inference( + session, + inst_id, + batch_name, + inst.name, + model_run_id, + databricks_control, + storage_control, + ) + ) + if error_response is not None: + return error_response + assert selection_config is not None # guaranteed when error_response is None + + return _resolve_latest_inference_cohort_from_dataframes( + inst_id=inst_id, + df_student=df_student, + df_course=df_course, + selection_config=selection_config, + batch_name=batch_name_for_response, + ) + + +def _resolve_model_name_and_run_id_for_inference( + session: Session, + inst_id: str, + model_name: Optional[str], + inst_name: str, + databricks_control: DatabricksControl, +) -> Tuple[ + Optional[str], + Optional[str], + Optional[LatestInferenceCohortResponse], +]: + """ + Resolve model name (from param or single registered), validate it exists, and fetch model_run_id. + + Returns (resolved_model_name, model_run_id, None) on success, or + (None, None, error_response) when invalid or Databricks lookup fails. + """ + if model_name and str(model_name).strip(): + resolved_model_name = decode_url_piece(str(model_name).strip()) + else: + models_result = session.execute( + select(ModelTable).where( + and_( + ModelTable.inst_id == str_to_uuid(inst_id), + ModelTable.valid == True, # noqa: E712 + or_( + ModelTable.deleted.is_(None), + ModelTable.deleted == False, # noqa: E712 + ), + ) + ) + ).all() + if not models_result: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: No registered model.", + inst_id, + ) + return ( + None, + None, + LatestInferenceCohortResponse( + status="invalid", + reason="No registered model for this institution; approve a model for use or specify model_name.", + ), + ) + if len(models_result) > 1: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: Multiple registered models; model_name required.", + inst_id, + ) + return ( + None, + None, + LatestInferenceCohortResponse( + status="invalid", + reason="Multiple registered models; specify model_name to select which model's config to use.", + ), + ) + resolved_model_name = models_result[0][0].name + + model_rows = session.execute( + select(ModelTable).where( + and_( + ModelTable.inst_id == str_to_uuid(inst_id), + ModelTable.name == resolved_model_name, + ) + ) + ).all() + if not model_rows or len(model_rows) != 1: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s model_name=%s: Model not found for institution.", + inst_id, + resolved_model_name, + ) + return ( + None, + None, + LatestInferenceCohortResponse( + status="invalid", + reason="Model not found for this institution.", + ), + ) + + try: + latest_model_version = databricks_control.fetch_model_version( + catalog_name=str(env_vars["CATALOG_NAME"]), + inst_name=inst_name, + model_name=resolved_model_name, + ) + return (resolved_model_name, str(latest_model_version.run_id), None) + except ValueError as e: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s model_name=%s: %s", + inst_id, + resolved_model_name, + e, + ) + return ( + None, + None, + LatestInferenceCohortResponse( + status="invalid", + reason="Could not get model version for config; no versions or lookup failed.", + ), + ) + + +def _load_batch_and_dataframes_for_inference( + session: Session, + inst_id: str, + batch_name: Optional[str], + inst_name: str, + model_run_id: str, + databricks_control: DatabricksControl, + storage_control: StorageControl, +) -> Tuple[ + Optional[pd.DataFrame], + Optional[pd.DataFrame], + Optional[Dict[str, Any]], + Optional[str], + Optional[LatestInferenceCohortResponse], +]: + """ + Resolve batch, load config, read batch files, and validate STUDENT/COURSE dataframes. + + Returns (df_student, df_course, selection_config, batch_name_for_response, None) on success, + or (None, None, None, optional_batch_name, error_response) on failure. + """ + batch, batch_label = get_batch_for_inference(session, inst_id, batch_name) + if not batch: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: Batch not found or missing student and course files.", + inst_id, + ) + reason = ( + "Batch not found or does not have both student and course files." + if batch_label + else "No completed batch with student and course files." + ) + return ( + None, + None, + None, + batch_label, + LatestInferenceCohortResponse( + status="invalid", + batch_name=batch_label, + reason=reason, + ), + ) + + selection_config = databricks_control.read_volume_training_config( + inst_name, model_run_id + ) + if not selection_config: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: Missing config in Databricks.", + inst_id, + ) + return ( + None, + None, + None, + batch.name, + LatestInferenceCohortResponse( + cohort_label=batch.name or "Latest cohort", + valid_student_count=0, + status="invalid", + batch_name=batch.name, + reason="Missing config in Databricks.", + ), + ) + + try: + file_dataframes = read_batch_files_as_dataframes( + inst_id, batch.files, storage_control + ) + except HTTPException as e: + detail = ( + e.detail if isinstance(e.detail, str) else "Could not load batch files." + ) + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: %s", inst_id, detail + ) + return ( + None, + None, + None, + batch.name, + LatestInferenceCohortResponse( + cohort_label=batch.name or "Latest cohort", + status="invalid", + batch_name=batch.name, + reason=detail, + ), + ) + + df_student = file_dataframes.get("STUDENT") + if df_student is None or df_student.empty: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: No STUDENT file in batch.", + inst_id, + ) + return ( + None, + None, + None, + batch.name, + LatestInferenceCohortResponse( + cohort_label=batch.name or "Latest cohort", + status="invalid", + batch_name=batch.name, + reason="No STUDENT file in batch.", + ), + ) + + df_course = file_dataframes.get("COURSE") + if df_course is None or df_course.empty: + LOGGER.warning( + "Latest inference cohort invalid for inst_id=%s: No COURSE file in batch.", + inst_id, + ) + return ( + None, + None, + None, + batch.name, + LatestInferenceCohortResponse( + cohort_label=batch.name or "Latest cohort", + status="invalid", + batch_name=batch.name, + reason="No COURSE file in batch.", + ), + ) + + return (df_student, df_course, selection_config, batch.name, None) + + # TODO: rename this function to better reflect its behavior. @router.post("/{inst_id}/update-data") def update_data( @@ -656,6 +1160,437 @@ def read_batch_files_as_dataframes( return result +# Term order within academic year for "most recent cohort term" (FALL < ... < SUMMER) +_COHORT_TERM_ORDER = {"FALL": 1, "WINTER": 2, "SPRING": 3, "SUMMER": 4} +COHORT_TERM_REQUIRED_MSG = ( + "Student file must have cohort_term column for latest inference cohort." +) + + +def _four_digit_year(y: int) -> int: + """Normalize 2-digit year to 4-digit (e.g. 25 -> 2025).""" + return y if y >= 100 else 2000 + y + + +def _cohort_year_end(cohort_str: str) -> Optional[int]: + """Parse cohort string to end year for ordering. E.g. '2024-25' -> 2025, '2021' -> 2021.""" + s = str(cohort_str).strip() + if not s: + return None + if "-" in s: + parts = s.split("-") + if len(parts) == 2 and parts[1].isdigit(): + return _four_digit_year(int(parts[1])) + if parts[0].isdigit(): + return _four_digit_year(int(parts[0])) + if s.isdigit(): + return _four_digit_year(int(s)) + return None + + +def _cohort_term_display_year(cohort_str: str, term_upper: str) -> Optional[int]: + """Year to show in label 'Fall 2024' / 'Spring 2025'. For 2024-25, Fall->2024, Spring->2025.""" + s = str(cohort_str).strip() + if "-" in s: + parts = s.split("-") + if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit(): + y0, y1 = _four_digit_year(int(parts[0])), _four_digit_year(int(parts[1])) + if term_upper in ("FALL", "WINTER"): + return y0 + return y1 + if parts[0].isdigit(): + return _four_digit_year(int(parts[0])) + if s.isdigit(): + return _four_digit_year(int(s)) + return None + + +def _allowed_cohort_terms_from_config( + selection_config: Optional[Dict[str, Any]], +) -> Optional[List[str]]: + """Extract allowed cohort terms from preprocessing.selection student_criteria.cohort_term. + Returns list of uppercase terms (e.g. ['FALL', 'SPRING']) or None if not specified. + """ + if not selection_config: + return None + criteria = selection_config.get("student_criteria") + if not isinstance(criteria, dict): + return None + raw = criteria.get("cohort_term") + if raw is None: + return None + if isinstance(raw, list): + return [str(t).strip().upper() for t in raw if t is not None and str(t).strip()] + return [str(raw).strip().upper()] + + +def _build_sorted_cohort_term_candidates( + candidates: pd.DataFrame, + allowed_set: Optional[set], +) -> List[Dict[str, Any]]: + """Filter candidates by allowed_set, sort by (year, term) descending, return list of candidate dicts. + + Args: + candidates: DataFrame with _year_sort, _term_upper, and either _cohort_str or entry_year. + allowed_set: Uppercase allowed term names (e.g. {'FALL', 'SPRING'}) or None for all. + + Returns: + List of dicts with cohort_str/entry_year and term_upper, most recent first. + """ + work = candidates.copy() + if allowed_set is not None: + work = work[work["_term_upper"].isin(allowed_set)] + work["_term_rank"] = work["_term_upper"].map( + lambda t: _COHORT_TERM_ORDER.get(t, 99) + ) + work["_sort_key"] = list(zip(work["_year_sort"], work["_term_rank"])) + work = work.sort_values("_sort_key", ascending=False) + if "_cohort_str" in work.columns: + return [ + {"cohort_str": row["_cohort_str"], "term_upper": row["_term_upper"]} + for _, row in work.iterrows() + ] + return [ + {"entry_year": int(row["entry_year"]), "term_upper": row["_term_upper"]} + for _, row in work.iterrows() + ] + + +def _no_allowed_term_error(allowed_set: Optional[set]) -> str: + """Build error message when no cohort term in config matches student data.""" + allowed_str = ", ".join(sorted(allowed_set)) if allowed_set else "" + return f"No cohort term in preprocessing.selection cohort_term ({allowed_str}) found in student data." + + +def get_ordered_cohort_terms( + df_student: pd.DataFrame, + allowed_cohort_terms: Optional[List[str]] = None, +) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + """Return cohort terms from most recent to oldest for looping (e.g. try until one has course data). + + Requires cohort_term column; supports cohort+cohort_term or entry_year+cohort_term. + Each item is a dict with cohort_str/entry_year and term_upper. + + Args: + df_student: Student DataFrame with cohort_term and cohort or entry_year. + allowed_cohort_terms: Uppercase allowed terms (e.g. ['FALL', 'SPRING']) or None for all. + + Returns: + (list of candidate dicts, None) on success, or (None, error_msg) on failure. + """ + if df_student.empty: + return None, "No student data." + if "cohort_term" not in df_student.columns: + return None, COHORT_TERM_REQUIRED_MSG + allowed_set: Optional[set] = None + if allowed_cohort_terms: + allowed_set = {str(t).strip().upper() for t in allowed_cohort_terms if t} + + if "cohort" in df_student.columns: + df_clean = df_student[ + df_student["cohort"].notna() & df_student["cohort_term"].notna() + ].copy() + if df_clean.empty: + return None, "Student file has no cohort and cohort_term values." + df_clean["_cohort_str"] = df_clean["cohort"].astype(str) + df_clean["_term_upper"] = df_clean["cohort_term"].astype(str).str.upper() + df_clean["_year_end"] = df_clean["_cohort_str"].map(_cohort_year_end) + if df_clean["_year_end"].isna().all(): + return None, "Student file cohort values could not be parsed as years." + candidates = df_clean.groupby(["_cohort_str", "_term_upper"], as_index=False)[ + "_year_end" + ].first() + candidates = candidates.rename(columns={"_year_end": "_year_sort"}) + result = _build_sorted_cohort_term_candidates(candidates, allowed_set) + if not result: + return None, _no_allowed_term_error(allowed_set) + return result, None + + if "entry_year" in df_student.columns: + df_clean = df_student[ + df_student["entry_year"].notna() & df_student["cohort_term"].notna() + ].copy() + if df_clean.empty: + return None, "Student file has no entry_year and cohort_term values." + df_clean["_term_upper"] = df_clean["cohort_term"].astype(str).str.upper() + candidates = df_clean[["entry_year", "_term_upper"]].drop_duplicates().copy() + candidates["_year_sort"] = candidates["entry_year"] + result = _build_sorted_cohort_term_candidates(candidates, allowed_set) + if not result: + return None, _no_allowed_term_error(allowed_set) + return result, None + + return ( + None, + "Student file must have cohort (or entry_year) and cohort_term columns for latest inference cohort.", + ) + + +def _filter_students_to_cohort_term( + df_student: pd.DataFrame, candidate: Dict[str, Any] +) -> Tuple[pd.DataFrame, str]: + """Filter student DataFrame to the cohort term described by candidate. Returns (filtered_df, label).""" + if "cohort_str" in candidate and "term_upper" in candidate: + cohort_val = candidate["cohort_str"] + term_val = candidate["term_upper"] + filtered = df_student[ + (df_student["cohort"].astype(str) == cohort_val) + & (df_student["cohort_term"].astype(str).str.upper() == term_val) + ].copy() + display_year = _cohort_term_display_year(cohort_val, term_val) + label = ( + f"{term_val.capitalize()} {display_year}" + if display_year is not None + else f"{term_val.capitalize()} {cohort_val}" + ) + return filtered, label + if "entry_year" in candidate and "term_upper" in candidate: + year_val = candidate["entry_year"] + term_val = candidate["term_upper"] + filtered = df_student[ + (df_student["entry_year"] == year_val) + & (df_student["cohort_term"].astype(str).str.upper() == term_val) + ].copy() + label = f"{term_val.capitalize()} {year_val}" + return filtered, label + raise ValueError("Invalid cohort term candidate.") + + +def filter_to_most_recent_cohort( + df_student: pd.DataFrame, + allowed_cohort_terms: Optional[List[str]] = None, +) -> Tuple[pd.DataFrame, Optional[str], Optional[str]]: + """Restrict student DataFrame to the single most recent cohort term. + + Requires cohort_term column; uses get_ordered_cohort_terms and takes the first candidate. + + Args: + df_student: Student DataFrame with cohort_term and cohort or entry_year. + allowed_cohort_terms: Uppercase allowed terms (e.g. ['FALL', 'SPRING']) or None for all. + + Returns: + (filtered_df, cohort_label_for_display, None) on success, + or (empty DataFrame, None, error_message) on failure. + """ + if df_student.empty: + return df_student, None, None + if "cohort_term" not in df_student.columns: + return pd.DataFrame(), None, COHORT_TERM_REQUIRED_MSG + + ordered, err = get_ordered_cohort_terms( + df_student, allowed_cohort_terms=allowed_cohort_terms + ) + if err is not None: + return pd.DataFrame(), None, err + if not ordered: + return pd.DataFrame(), None, COHORT_TERM_REQUIRED_MSG + filtered, label = _filter_students_to_cohort_term(df_student, ordered[0]) + return filtered, label, None + + +def _batch_has_student_and_course(batch: Any) -> bool: + """Return True if batch has at least one STUDENT and one COURSE file.""" + all_schemas: set[str] = set() + for f in batch.files: + if getattr(f, "schemas", None): + for s in f.schemas: + all_schemas.add(str(s)) + return SchemaType.STUDENT in all_schemas and SchemaType.COURSE in all_schemas + + +def get_batch_by_name_with_student_and_course( + session: Session, inst_id: str, batch_name: str +) -> Optional[Any]: + """Return the batch for the institution with the given name that has at least + one STUDENT and one COURSE file, or None if not found or missing file types. + Excludes deleted batches. + """ + from sqlalchemy.orm import selectinload + + stmt = ( + select(BatchTable) + .options(selectinload(BatchTable.files)) + .where( + and_( + BatchTable.inst_id == str_to_uuid(inst_id), + BatchTable.name == batch_name, + or_( + BatchTable.deleted.is_(None), + BatchTable.deleted == False, # noqa: E712 + ), + ) + ) + ) + rows = session.execute(stmt).all() + if not rows or len(rows) != 1: + return None + batch = rows[0][0] + return batch if _batch_has_student_and_course(batch) else None + + +def get_batch_for_inference( + session: Session, inst_id: str, batch_name: Optional[str] +) -> Tuple[Optional[Any], Optional[str]]: + """Return (batch, batch_name_for_response) for cohort resolution. + + When batch_name is provided and non-empty (after strip/decode), looks up that + batch; otherwise uses the most recent batch with STUDENT and COURSE files. + Excludes deleted batches. + + Args: + session: Database session. + inst_id: Institution UUID. + batch_name: Optional batch name from query; when omitted, latest batch is used. + + Returns: + (batch, batch_name_for_response). On success batch is the BatchTable instance + and batch_name_for_response is its name (or the decoded requested name). + On not found, (None, batch_name_for_response) with batch_name_for_response + set when a name was requested (for error responses). + """ + if batch_name and str(batch_name).strip(): + decoded = decode_url_piece(str(batch_name).strip()) + batch = get_batch_by_name_with_student_and_course(session, inst_id, decoded) + return (batch, decoded) if batch else (None, decoded) + batch = get_latest_batch_with_student_and_course(session, inst_id) + return (batch, batch.name if batch else None) if batch else (None, None) + + +def get_latest_batch_with_student_and_course( + session: Session, inst_id: str +) -> Optional[Any]: + """Return the most recent batch for the institution that has at least one + STUDENT and one COURSE file, or None if none exists. Does not require + batch.completed. Excludes deleted batches. + """ + from sqlalchemy.orm import selectinload + + stmt = ( + select(BatchTable) + .options(selectinload(BatchTable.files)) + .where( + and_( + BatchTable.inst_id == str_to_uuid(inst_id), + or_( + BatchTable.deleted.is_(None), + BatchTable.deleted == False, # noqa: E712 + ), + ) + ) + .order_by(BatchTable.created_at.desc()) + ) + rows = session.execute(stmt).all() + for row in rows: + batch = row[0] + if _batch_has_student_and_course(batch): + return batch + return None + + +def apply_student_criteria_count( + df: pd.DataFrame, + selection_config: Dict[str, Any], + student_id_col: str = "student_id", + df_course: Optional[pd.DataFrame] = None, +) -> Tuple[int, Optional[str]]: + """Filter student DataFrame by preprocessing.selection student_criteria and + return (distinct count of students, error_message). + selection_config is the [preprocessing.selection] dict. + If any criterion column is missing from the DataFrame, returns (0, reason) as failure. + If df_course is provided, only students that appear in the course data are counted. + """ + if df.empty: + return 0, None + id_col = student_id_col if student_id_col in df.columns else "study_id" + if id_col not in df.columns: + return 0, "Student file missing student id column." + criteria = selection_config.get("student_criteria") + if not isinstance(criteria, dict): + eligible_ids = df[id_col].drop_duplicates() + else: + missing = [col for col in criteria if col not in df.columns] + if missing: + return ( + 0, + f"Missing columns for selection criteria: {', '.join(sorted(missing))}.", + ) + filtered = df.copy() + for col, value in criteria.items(): + if isinstance(value, list): + filtered = filtered[filtered[col].isin(value)] + else: + filtered = filtered[filtered[col] == value] + eligible_ids = filtered[id_col].drop_duplicates() + + if df_course is not None: + if df_course.empty: + eligible_ids = eligible_ids.iloc[0:0] # empty, same index type + else: + course_id_col = ( + "student_id" if "student_id" in df_course.columns else "study_id" + ) + if course_id_col not in df_course.columns: + return 0, "Course file missing student id column." + students_with_course = set(df_course[course_id_col].drop_duplicates()) + eligible_ids = eligible_ids[eligible_ids.isin(students_with_course)] + + return int(eligible_ids.nunique()), None + + +def calculate_gpa_series( + df: pd.DataFrame, cohort_years: List[str], grouping_col: str, category_value: str +) -> List[float]: + """Calculate GPA data for one category across cohort years. + + Args: + df: DataFrame (cohort data) + cohort_years: List of cohort years + grouping_col: Column to filter by (e.g., 'enrollment_type') + category_value: Specific value to filter for (e.g., 'First-Time') + + Returns: + List of GPA values, one per cohort year + """ + + # Filter by category + filtered = df[df[grouping_col] == category_value] + + # Group by cohort and calculate mean GPA + gpa_by_cohort = ( + pd.to_numeric(filtered["gpa_group_year_1"], errors="coerce") + .groupby(filtered["cohort"]) + .mean() + ) + + # Convert to list aligned with cohort_years + data = [round(gpa_by_cohort.get(year, 0), 1) for year in cohort_years] + + return data + + +def get_term_counts( + df: pd.DataFrame, cohort_years: List[str], term_name: str +) -> List[int]: + """Get student counts for a specific term across cohort years. + + Args: + df: DataFrame (cohort or course data) + cohort_years: List of cohort years + term_name: Term name to filter for (e.g., 'FALL', 'WINTER') + + Returns: + List of student counts, one per cohort year + """ + result_series = ( + df[df["cohort_term"] == term_name] + .groupby("cohort") + .size() + .reindex(cohort_years, fill_value=0) + .astype(int) + ) + return [int(x) for x in result_series.tolist()] # Explicitly convert to List[int] + + @router.get("/{inst_id}/batch/{batch_id}/eda", response_model=EdaDataResponse) def get_eda_data( inst_id: str, diff --git a/src/webapp/routers/data_test.py b/src/webapp/routers/data_test.py index 8eb73fb1..d07071f9 100644 --- a/src/webapp/routers/data_test.py +++ b/src/webapp/routers/data_test.py @@ -23,14 +23,24 @@ FileTable, BatchTable, InstTable, + ModelTable, SchemaRegistryTable, DocType, Base, get_session, ) from ..utilities import uuid_to_str, get_current_active_user, SchemaType -from .data import router, DataOverview, DataInfo +from .data import ( + router, + DataOverview, + DataInfo, + filter_to_most_recent_cohort, + get_ordered_cohort_terms, + get_latest_batch_with_student_and_course, + apply_student_criteria_count, +) from ..gcsutil import StorageControl +from ..databricks import DatabricksControl MOCK_STORAGE = mock.Mock() @@ -127,7 +137,7 @@ def session_fixture(): updated_at=DATETIME_TESTING, sst_generated=False, valid=True, - schemas=[SchemaType.UNKNOWN], + schemas=[SchemaType.STUDENT], ) file_3 = FileTable( id=FILE_UUID_3, @@ -175,6 +185,7 @@ def session_fixture(): inst_id=USER_VALID_INST_UUID, name="file_input_two", source="PDP_SFTP", + batches={batch_1}, created_at=DATETIME_TESTING, updated_at=DATETIME_TESTING, sst_generated=False, @@ -183,6 +194,14 @@ def session_fixture(): ), file_3, file_4, + ModelTable( + id=uuid.UUID("b2c3d4e5-f6a7-8901-b234-567890123456"), + inst_id=USER_VALID_INST_UUID, + name="test_model", + created_by=CREATOR_UUID, + valid=False, + deleted=False, + ), ] ) session.commit() @@ -266,7 +285,7 @@ def test_read_inst_all_input_files(client: TestClient) -> Any: { "name": "file_input_two", "data_id": "cb02d06c2a59486a9bddd394a4fcb833", - "batch_ids": [], + "batch_ids": ["5b2420f3103546ab90eb74d5df97de43"], "inst_id": "1d7c75c33eda42949c6675ea8af97b55", "uploader": "", "source": "PDP_SFTP", @@ -381,6 +400,7 @@ def test_read_batch_info(client: TestClient) -> Any: "inst_id": "1d7c75c33eda42949c6675ea8af97b55", "file_names_to_ids": [ {"file_input_one": "f0bb3a206d924254afed6a72f43c562a"}, + {"file_input_two": "cb02d06c2a59486a9bddd394a4fcb833"}, {"file_output_three": "fbe67a2e50e040c7b7b807043cb813a5"}, ], "name": "batch_foo", @@ -420,6 +440,20 @@ def test_read_batch_info(client: TestClient) -> Any: "valid": True, "uploaded_date": "2024-12-24T20:22:20.132022", }, + { + "name": "file_input_two", + "data_id": "cb02d06c2a59486a9bddd394a4fcb833", + "batch_ids": ["5b2420f3103546ab90eb74d5df97de43"], + "inst_id": "1d7c75c33eda42949c6675ea8af97b55", + "uploader": "", + "source": "PDP_SFTP", + "deleted": False, + "deletion_request_time": None, + "retention_days": None, + "sst_generated": False, + "valid": False, + "uploaded_date": "2024-12-24T20:22:20.132022", + }, ], }, ) @@ -1439,3 +1473,969 @@ def test_validate_edvise_inst_not_found(edvise_client: TestClient) -> None: ) # Should fail - either 401 (unauthorized) or 404 (not found) assert response.status_code in [401, 404] + + +# ---- Latest inference cohort (Task 1) ---- + + +def test_apply_student_criteria_count_empty_df() -> None: + """apply_student_criteria_count returns (0, None) for empty DataFrame.""" + import pandas as pd + + df = pd.DataFrame(columns=["student_id", "enrollment_type"]) + count, err = apply_student_criteria_count(df, {}) + assert count == 0 + assert err is None + + +def test_apply_student_criteria_count_no_criteria() -> None: + """apply_student_criteria_count returns nunique when no student_criteria.""" + import pandas as pd + + df = pd.DataFrame({"student_id": ["a", "b", "a"], "x": [1, 2, 3]}) + count, err = apply_student_criteria_count(df, {}) + assert count == 2 + assert err is None + + +def test_apply_student_criteria_count_with_criteria() -> None: + """apply_student_criteria_count filters by student_criteria and returns count.""" + import pandas as pd + + df = pd.DataFrame( + { + "student_id": ["a", "b", "c", "d"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME", "Transfer", "FIRST-TIME"], + } + ) + config = {"student_criteria": {"enrollment_type": "FIRST-TIME"}} + count, err = apply_student_criteria_count(df, config) + assert count == 3 + assert err is None + + +def test_apply_student_criteria_count_list_value() -> None: + """apply_student_criteria_count supports list values (isin).""" + import pandas as pd + + df = pd.DataFrame( + { + "student_id": ["a", "b", "c"], + "cohort_term": ["FALL", "SPRING", "SUMMER"], + } + ) + config = {"student_criteria": {"cohort_term": ["FALL", "SPRING"]}} + count, err = apply_student_criteria_count(df, config) + assert count == 2 + assert err is None + + +def test_apply_student_criteria_count_missing_column_fails() -> None: + """apply_student_criteria_count returns error when criterion column is missing.""" + import pandas as pd + + df = pd.DataFrame( + {"student_id": ["a", "b"], "enrollment_type": ["FIRST-TIME", "Transfer"]} + ) + config = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "credential_type_sought_year_1": "BACHELOR'S DEGREE", + } + } + count, err = apply_student_criteria_count(df, config) + assert count == 0 + assert err is not None + assert "Missing columns for selection criteria" in err + assert "credential_type_sought_year_1" in err + + +def test_apply_student_criteria_count_restricts_to_students_with_course_data() -> None: + """apply_student_criteria_count only counts students that appear in course data.""" + import pandas as pd + + df_student = pd.DataFrame( + { + "study_id": ["a", "b", "c"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame( + {"study_id": ["a", "c"], "course_id": [1, 2]} + ) # b has no course rows + config = {"student_criteria": {"enrollment_type": "FIRST-TIME"}} + count, err = apply_student_criteria_count(df_student, config, df_course=df_course) + assert err is None + assert count == 2 # a and c only; b excluded for having no course data + + +def test_apply_student_criteria_count_empty_course_returns_zero() -> None: + """When df_course is provided but empty (no course data for cohort), count is 0.""" + import pandas as pd + + df_student = pd.DataFrame( + {"study_id": ["a", "b"], "enrollment_type": ["FIRST-TIME", "FIRST-TIME"]} + ) + df_course = pd.DataFrame(columns=["study_id", "course_id"]) # empty + config = {"student_criteria": {"enrollment_type": "FIRST-TIME"}} + count, err = apply_student_criteria_count(df_student, config, df_course=df_course) + assert err is None + assert count == 0 + + +def test_filter_to_most_recent_cohort_uses_cohort_term() -> None: + """filter_to_most_recent_cohort selects most recent cohort term (e.g. Spring 2025 > Fall 2024).""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b", "c", "d"], + "cohort": ["2024-25", "2024-25", "2024-25", "2023-24"], + "cohort_term": ["FALL", "FALL", "SPRING", "SPRING"], + } + ) + filtered, label, err = filter_to_most_recent_cohort(df) + assert err is None + assert label == "Spring 2025" + assert len(filtered) == 1 + assert filtered["study_id"].iloc[0] == "c" + + +def test_filter_to_most_recent_cohort_respects_allowed_cohort_terms() -> None: + """filter_to_most_recent_cohort only considers terms in preprocessing.selection cohort_term.""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b", "c"], + "cohort": ["2024-25", "2024-25", "2024-25"], + "cohort_term": ["FALL", "SPRING", "SUMMER"], + } + ) + # Allowed FALL, SPRING -> most recent among those is Spring 2025 (c excluded as SUMMER) + filtered, label, err = filter_to_most_recent_cohort( + df, allowed_cohort_terms=["FALL", "SPRING"] + ) + assert err is None + assert label == "Spring 2025" + assert len(filtered) == 1 + assert filtered["study_id"].iloc[0] == "b" + + +def test_filter_to_most_recent_cohort_allowed_terms_no_match() -> None: + """filter_to_most_recent_cohort returns error when no data matches allowed cohort_term.""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b"], + "cohort": ["2024-25", "2024-25"], + "cohort_term": ["SUMMER", "SUMMER"], + } + ) + filtered, label, err = filter_to_most_recent_cohort( + df, allowed_cohort_terms=["FALL", "SPRING"] + ) + assert err is not None + assert "preprocessing.selection" in err or "cohort_term" in err + assert filtered.empty + + +def test_filter_to_most_recent_cohort_requires_cohort_term() -> None: + """filter_to_most_recent_cohort returns error when cohort_term column is missing.""" + import pandas as pd + + df = pd.DataFrame( + {"study_id": ["a", "b", "c"], "cohort": ["2023-24", "2024-25", "2024-25"]} + ) + filtered, label, err = filter_to_most_recent_cohort(df) + assert err is not None + assert "cohort_term" in err.lower() + assert filtered.empty + + +def test_filter_to_most_recent_cohort_requires_cohort_term_with_entry_year() -> None: + """filter_to_most_recent_cohort returns error when only entry_year (no cohort_term).""" + import pandas as pd + + df = pd.DataFrame({"study_id": ["a", "b", "c"], "entry_year": [2022, 2024, 2024]}) + filtered, label, err = filter_to_most_recent_cohort(df) + assert err is not None + assert "cohort_term" in err.lower() + assert filtered.empty + + +def test_filter_to_most_recent_cohort_missing_column() -> None: + """filter_to_most_recent_cohort returns error when no cohort/entry_year column.""" + import pandas as pd + + df = pd.DataFrame({"study_id": ["a", "b"], "enrollment_type": ["FT", "PT"]}) + filtered, label, err = filter_to_most_recent_cohort(df) + assert err is not None + assert "cohort" in err.lower() or "entry_year" in err.lower() + assert filtered.empty + + +def test_get_ordered_cohort_terms_cohort_and_cohort_term() -> None: + """get_ordered_cohort_terms returns terms most recent first when using cohort+cohort_term.""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b", "c", "d"], + "cohort": ["2024-25", "2024-25", "2023-24", "2023-24"], + "cohort_term": ["SPRING", "FALL", "SPRING", "FALL"], + } + ) + ordered, err = get_ordered_cohort_terms(df) + assert err is None + assert ordered is not None + assert len(ordered) == 4 # Spring 2025, Fall 2024, Spring 2024, Fall 2023 + assert ( + ordered[0]["term_upper"] == "SPRING" and ordered[0]["cohort_str"] == "2024-25" + ) + assert ordered[1]["term_upper"] == "FALL" and ordered[1]["cohort_str"] == "2024-25" + + +def test_get_ordered_cohort_terms_entry_year_and_cohort_term() -> None: + """get_ordered_cohort_terms returns terms most recent first when using entry_year+cohort_term.""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b", "c"], + "entry_year": [2024, 2024, 2023], + "cohort_term": ["SPRING", "FALL", "FALL"], + } + ) + ordered, err = get_ordered_cohort_terms(df) + assert err is None + assert ordered is not None + assert len(ordered) == 3 # Spring 2024, Fall 2024, Fall 2023 + assert ordered[0]["term_upper"] == "SPRING" and ordered[0]["entry_year"] == 2024 + + +def test_get_ordered_cohort_terms_empty_dataframe() -> None: + """get_ordered_cohort_terms returns error when DataFrame is empty.""" + import pandas as pd + + df = pd.DataFrame(columns=["study_id", "cohort", "cohort_term"]) + ordered, err = get_ordered_cohort_terms(df) + assert err is not None + assert "no student data" in err.lower() + assert ordered is None + + +def test_get_ordered_cohort_terms_requires_cohort_term_column() -> None: + """get_ordered_cohort_terms returns error when cohort_term column is missing.""" + import pandas as pd + + df = pd.DataFrame({"study_id": ["a"], "cohort": ["2024-25"]}) + ordered, err = get_ordered_cohort_terms(df) + assert err is not None + assert "cohort_term" in err.lower() + assert ordered is None + + +def test_get_ordered_cohort_terms_respects_allowed_terms() -> None: + """get_ordered_cohort_terms only includes terms in allowed_cohort_terms.""" + import pandas as pd + + df = pd.DataFrame( + { + "study_id": ["a", "b", "c"], + "cohort": ["2024-25", "2024-25", "2024-25"], + "cohort_term": ["FALL", "SPRING", "SUMMER"], + } + ) + ordered, err = get_ordered_cohort_terms(df, allowed_cohort_terms=["FALL", "SPRING"]) + assert err is None + assert ordered is not None + assert len(ordered) == 2 # Spring 2025, Fall 2024 (SUMMER excluded) + terms = [c["term_upper"] for c in ordered] + assert "SUMMER" not in terms + + +def test_get_ordered_cohort_terms_no_cohort_or_entry_year() -> None: + """get_ordered_cohort_terms returns error when neither cohort nor entry_year present.""" + import pandas as pd + + df = pd.DataFrame({"study_id": ["a", "b"], "cohort_term": ["FALL", "FALL"]}) + ordered, err = get_ordered_cohort_terms(df) + assert err is not None + assert ( + "cohort" in err.lower() and "entry_year" in err.lower() + ) or "cohort_term" in err.lower() + assert ordered is None + + +def test_get_latest_batch_with_student_and_course_returns_batch( + session: sqlalchemy.orm.Session, +) -> None: + """get_latest_batch_with_student_and_course returns batch when it has STUDENT and COURSE.""" + batch = get_latest_batch_with_student_and_course( + session, uuid_to_str(USER_VALID_INST_UUID) + ) + # Session fixture: batch_1 has file_1 (STUDENT), file_2 (COURSE), file_3 (STUDENT, sst_generated) + assert batch is not None + assert batch.name == "batch_foo" + + +def test_get_latest_batch_with_student_and_course_no_batch( + session: sqlalchemy.orm.Session, +) -> None: + """get_latest_batch_with_student_and_course returns None for inst with no such batch.""" + batch = get_latest_batch_with_student_and_course(session, uuid_to_str(UUID_INVALID)) + assert batch is None + + +def test_get_latest_batch_with_student_and_course_returns_none_when_no_batch_has_both_student_and_course( + session: sqlalchemy.orm.Session, +) -> None: + """get_latest_batch_with_student_and_course returns None when inst has batches but none have both STUDENT and COURSE.""" + inst_only_student_uuid = uuid.UUID("a1b2c3d4-e5f6-4789-a012-345678901234") + batch_student_only = BatchTable( + inst_id=inst_only_student_uuid, + name="batch_student_only", + created_by=CREATOR_UUID, + created_at=DATETIME_TESTING, + updated_at=DATETIME_TESTING, + ) + file_student_only = FileTable( + inst_id=inst_only_student_uuid, + name="file_student_only", + source="MANUAL_UPLOAD", + batches={batch_student_only}, + created_at=DATETIME_TESTING, + updated_at=DATETIME_TESTING, + sst_generated=False, + valid=True, + schemas=[SchemaType.STUDENT], + ) + session.add_all( + [ + InstTable( + id=inst_only_student_uuid, + name="inst_student_only", + created_at=DATETIME_TESTING, + updated_at=DATETIME_TESTING, + ), + batch_student_only, + file_student_only, + ] + ) + session.commit() + batch = get_latest_batch_with_student_and_course( + session, uuid_to_str(inst_only_student_uuid) + ) + assert batch is None + + +def test_get_latest_inference_cohort_unauthorized(client: TestClient) -> None: + """GET latest-inference-cohort returns 401 for wrong institution.""" + response = client.get( + "/institutions/" + + uuid_to_str(UUID_INVALID) + + "/latest-inference-cohort?model_name=test_model" + ) + assert response.status_code == 401 + + +def test_get_latest_inference_cohort_invalid_when_no_model_name_and_no_models( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 invalid when model_name omitted and institution has no registered models.""" + # Session fixture has no ModelTable rows for USER_VALID_INST_UUID + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort" + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "registered" in (data["reason"] or "").lower() + + +def test_get_latest_inference_cohort_invalid_when_models_exist_but_none_registered( + client: TestClient, + session: sqlalchemy.orm.Session, +) -> None: + """GET latest-inference-cohort returns 200 invalid when model_name omitted and all models are invalid (not registered).""" + model_uuid = uuid.UUID("a0b1c2d3-e4f5-6789-a012-345678901235") + session.add( + ModelTable( + id=model_uuid, + inst_id=USER_VALID_INST_UUID, + name="unapproved_model", + created_by=CREATOR_UUID, + valid=False, + deleted=False, + ) + ) + session.commit() + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort" + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "registered" in (data["reason"] or "").lower() + + +def test_get_latest_inference_cohort_missing_cohort_column(client: TestClient) -> None: + """GET latest-inference-cohort returns 200 invalid when student file has no cohort column.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": {"enrollment_type": "FIRST-TIME"}, + } + # Student file without cohort or entry_year + df_student = pd.DataFrame({"study_id": ["s1"], "enrollment_type": ["FIRST-TIME"]}) + df_course = pd.DataFrame({"study_id": ["s1"], "x": [1]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "cohort_term" in (data["reason"] or "").lower() + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_missing_cohort_term(client: TestClient) -> None: + """GET latest-inference-cohort returns 200 invalid when student file has cohort but no cohort_term.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": {"enrollment_type": "FIRST-TIME"}, + } + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2"], + "cohort": ["2024-25", "2024-25"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame({"study_id": ["s1", "s2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "cohort_term" in (data["reason"] or "").lower() + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_model_version_lookup_fails( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 invalid when fetch_model_version raises (e.g. no versions).""" + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.side_effect = ValueError( + "No versions found for model" + ) + + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "model version" in (data["reason"] or "").lower() + + +def test_get_latest_inference_cohort_missing_config(client: TestClient) -> None: + """GET latest-inference-cohort returns 200 invalid when Databricks config is missing.""" + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = None + + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "Missing config" in (data["reason"] or "") + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_valid( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 valid when batch has STUDENT and config.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + }, + } + + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2", "s3"], + "cohort": ["2024-25", "2024-25", "2023-24"], + "cohort_term": ["FALL", "FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME", "Transfer"], + } + ) + # Course data for s1 and s2 in most recent cohort term (Fall 2024) + df_course = pd.DataFrame({"study_id": ["s1", "s2", "s3"], "x": [1, 2, 3]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "valid" + assert "batch_name" in data and data["batch_name"] == "batch_foo" + assert "cohort_label" in data and data["cohort_label"] == "Fall 2024" + assert ( + "valid_student_count" in data and data["valid_student_count"] == 2 + ) # s1, s2 in Fall 2024 cohort term matching FIRST-TIME + assert data.get("reason") is None + + +def test_get_latest_inference_cohort_valid_with_batch_name( + client: TestClient, +) -> None: + """GET latest-inference-cohort with batch_name uses that batch for cohort selection.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + }, + } + + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2"], + "cohort": ["2024-25", "2024-25"], + "cohort_term": ["FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame({"study_id": ["s1", "s2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model&batch_name=batch_foo" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "valid" + assert data["batch_name"] == "batch_foo" + assert data["cohort_label"] == "Fall 2024" + assert data["valid_student_count"] == 2 + + +def test_get_latest_inference_cohort_invalid_when_model_name_not_found( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns invalid when user provides a model name that does not exist for the institution.""" + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=nonexistent_model" + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "model" in (data["reason"] or "").lower() + + +def test_get_latest_inference_cohort_invalid_when_batch_name_not_found( + client: TestClient, +) -> None: + """GET latest-inference-cohort with batch_name returns invalid when batch not found.""" + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model&batch_name=nonexistent_batch" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "batch" in (data["reason"] or "").lower() + + +def test_get_latest_inference_cohort_loops_back_when_no_course_data_for_most_recent( + client: TestClient, +) -> None: + """When most recent cohort term has no course data, try next cohort term until one has course data.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + }, + } + # Fall 2025: s1, s2. Fall 2024: s3, s4. Course data only for s3, s4 (Fall 2024). + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2", "s3", "s4"], + "cohort": ["2025-26", "2025-26", "2024-25", "2024-25"], + "cohort_term": ["FALL", "FALL", "FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME", "FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame({"study_id": ["s3", "s4"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "valid" + assert ( + data["cohort_label"] == "Fall 2024" + ) # Fall 2025 had no course data; used Fall 2024 + assert data["valid_student_count"] == 2 # s3, s4 + + +def test_get_latest_inference_cohort_invalid_when_no_cohort_has_course_data( + client: TestClient, +) -> None: + """When no cohort term in the batch has course data, return invalid.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": {"enrollment_type": "FIRST-TIME", "cohort_term": ["FALL"]}, + } + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2"], + "cohort": ["2025-26", "2025-26"], + "cohort_term": ["FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME"], + } + ) + # Course data for different IDs (no overlap with s1, s2) + df_course = pd.DataFrame({"study_id": ["other1", "other2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "course data" in (data["reason"] or "").lower() + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_missing_student_id_column( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 invalid when student file has no student id column (study_id/student_id).""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + }, + } + df_student = pd.DataFrame( + { + "cohort": ["2024-25", "2024-25"], + "cohort_term": ["FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame({"study_id": ["s1", "s2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "student id column" in (data["reason"] or "").lower() + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_valid_student_count_zero_when_no_students_meet_criteria( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 valid with valid_student_count=0 when cohort has course data but no students meet criteria.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "cohort_term": ["FALL", "SPRING"], + }, + } + # All students in Fall 2024 are Transfer; config requires FIRST-TIME + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2"], + "cohort": ["2024-25", "2024-25"], + "cohort_term": ["FALL", "FALL"], + "enrollment_type": ["Transfer", "Transfer"], + } + ) + df_course = pd.DataFrame({"study_id": ["s1", "s2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "valid" + assert data["cohort_label"] == "Fall 2024" + assert data["valid_student_count"] == 0 + assert data.get("batch_name") == "batch_foo" + + +def test_get_latest_inference_cohort_invalid_when_student_criteria_references_missing_column( + client: TestClient, +) -> None: + """GET latest-inference-cohort returns 200 invalid when student_criteria references a column not in student file.""" + import pandas as pd + + MOCK_DATABRICKS = mock.Mock(spec=DatabricksControl) + MOCK_DATABRICKS.fetch_model_version.return_value = mock.Mock( + run_id="0b2e206732ce48f6b644149090c9614a" + ) + MOCK_DATABRICKS.read_volume_training_config.return_value = { + "student_criteria": { + "enrollment_type": "FIRST-TIME", + "nonexistent_col": "X", + "cohort_term": ["FALL", "SPRING"], + }, + } + df_student = pd.DataFrame( + { + "study_id": ["s1", "s2"], + "cohort": ["2024-25", "2024-25"], + "cohort_term": ["FALL", "FALL"], + "enrollment_type": ["FIRST-TIME", "FIRST-TIME"], + } + ) + df_course = pd.DataFrame({"study_id": ["s1", "s2"], "x": [1, 2]}) + + def storage_read(bucket: str, path: str) -> Any: + if "input_one" in path: + return df_student + if "input_two" in path: + return df_course + return df_student + + MOCK_STORAGE.read_csv_as_dataframe.side_effect = storage_read + app.dependency_overrides[DatabricksControl] = lambda: MOCK_DATABRICKS + try: + response = client.get( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/latest-inference-cohort?model_name=test_model" + ) + finally: + app.dependency_overrides.pop(DatabricksControl, None) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "invalid" + assert "reason" in data and data["reason"] + assert "Missing columns" in (data["reason"] or "") or "nonexistent_col" in ( + data["reason"] or "" + ) + assert data.get("batch_name") == "batch_foo" diff --git a/src/webapp/routers/front_end_tables.py b/src/webapp/routers/front_end_tables.py index 94522eae..6b628fdd 100644 --- a/src/webapp/routers/front_end_tables.py +++ b/src/webapp/routers/front_end_tables.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from sqlalchemy.future import select import logging -from ..config import databricks_vars, env_vars, gcs_vars +from ..config import ENV_TO_VOLUME_SCHEMA, databricks_vars, env_vars, gcs_vars import tempfile import pathlib @@ -476,12 +476,11 @@ def get_model_cards( try: env = str(env_vars["ENV"]).strip().upper() - SCHEMAS = {"DEV": "dev_sst_02", "STAGING": "staging_sst_01"} - if env not in SCHEMAS: + if env not in ENV_TO_VOLUME_SCHEMA: raise ValueError( f"Unsupported ENV {env_vars.get('ENV')!r}; expected DEV or STAGING" ) - env_schema = SCHEMAS[env] + env_schema = ENV_TO_VOLUME_SCHEMA[env] volume_path = f"/Volumes/{env_schema}/{databricksify_inst_name(query_result[0][0].name)}_gold/gold_volume/model_cards/{run_id}/model-card-{model_name}.pdf" LOGGER.info(f"Attempting to download from {volume_path}") diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 02be74ae..31f11b46 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -138,6 +138,8 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False + # Optional term filter (e.g. ["fall 2024-25"]). Used for cohort/graduation models. Omit for pipeline default. + term_filter: list[str] | None = None # Model related operations. Or model specific data. @@ -502,6 +504,8 @@ def trigger_inference_run( """Returns top-level info around all executions of a given model. Only visible to users of that institution or Datakinder access types. + Optional request field term_filter: list of labels (e.g. ["fall 2024-25"]); used for cohort/graduation models. + When omitted, the pipeline uses its config default. """ model_name = decode_url_piece(model_name) has_access_to_inst_or_err(inst_id, current_user) @@ -566,7 +570,7 @@ def trigger_inference_run( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Unexpected number of batches found: Expected 1, got " - + str(len(inst_result)), + + str(len(batch_result)), ) # inst_file_schemas = [x.schemas for x in batch_result[0][0].files] inst_file_schemas = [list({s for f in batch_result[0][0].files for s in f.schemas})] @@ -588,6 +592,26 @@ def trigger_inference_run( status_code=status.HTTP_400_BAD_REQUEST, detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}", ) + if req.term_filter is not None: + # When term_filter is provided, require at least one non-empty label. + if len(req.term_filter) == 0: + logging.warning( + "run-inference term_filter validation failed: empty list for inst_id=%s", + inst_id, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one label is required when term_filter is provided for run-inference.", + ) + if any(not label or not str(label).strip() for label in req.term_filter): + logging.warning( + "run-inference term_filter validation failed: empty or whitespace label for inst_id=%s", + inst_id, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Labels must be non-empty strings for run-inference.", + ) # Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines. db_req = DatabricksInferenceRunRequest( inst_name=inst_result[0][0].name, @@ -596,9 +620,10 @@ def trigger_inference_run( gcp_external_bucket_name=get_external_bucket_name(inst_id), # The institution email to which pipeline success/failure notifications will get sent. email=cast(str, current_user.email), + term_filter=req.term_filter, ) try: - res = databricks_control.run_pdp_inference(db_req) + inference_run_response = databricks_control.run_pdp_inference(db_req) except Exception as e: tb = traceback.format_exc() logging.error(f"Databricks run failure:\n{tb}") @@ -606,6 +631,12 @@ def trigger_inference_run( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Databricks run_pdp_inference error. Error = {str(e)}", ) from e + logging.info( + "run-inference: user=%s term_filter=%s job_run_id=%s", + current_user.email, + req.term_filter, + inference_run_response.job_run_id, + ) triggered_timestamp = datetime.now() latest_model_version = databricks_control.fetch_model_version( catalog_name=str(env_vars["CATALOG_NAME"]), @@ -613,7 +644,7 @@ def trigger_inference_run( model_name=model_name, ) job = JobTable( - id=res.job_run_id, + id=inference_run_response.job_run_id, triggered_at=triggered_timestamp, created_by=str_to_uuid(current_user.user_id), batch_name=req.batch_name, @@ -626,7 +657,7 @@ def trigger_inference_run( return { "inst_id": inst_id, "m_name": model_name, - "run_id": res.job_run_id, + "run_id": inference_run_response.job_run_id, "created_by": current_user.user_id, "triggered_at": triggered_timestamp, "batch_name": req.batch_name, diff --git a/src/webapp/routers/models_test.py b/src/webapp/routers/models_test.py index cfa57671..47082147 100644 --- a/src/webapp/routers/models_test.py +++ b/src/webapp/routers/models_test.py @@ -385,6 +385,111 @@ def test_trigger_inference_run(client: TestClient) -> None: assert response.json()["created_by"] == uuid_to_str(USER_UUID) assert response.json()["triggered_at"] is not None assert response.json()["batch_name"] == "batch_foo" + # Backward compatible: no term_filter in request; mock was called with term_filter=None + call_args = MOCK_DATABRICKS.run_pdp_inference.call_args + assert call_args is not None + assert call_args[0][0].term_filter is None + + +def test_trigger_inference_run_with_term_filter(client: TestClient) -> None: + """Run-inference with term_filter passes list to Databricks and returns 200.""" + MOCK_DATABRICKS.run_pdp_inference.return_value = DatabricksInferenceRunResponse( + job_run_id=456 + ) + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/models/sample_model_for_school_1/run-inference", + json={ + "batch_name": "batch_foo", + "is_pdp": True, + "term_filter": ["fall 2024-25"], + }, + ) + assert response.status_code == 200 + assert response.json()["run_id"] == 456 + call_args = MOCK_DATABRICKS.run_pdp_inference.call_args + assert call_args is not None + assert call_args[0][0].term_filter == ["fall 2024-25"] + + +def test_trigger_inference_run_term_filter_empty_list_rejected( + client: TestClient, +) -> None: + """Run-inference with term_filter=[] returns 400.""" + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/models/sample_model_for_school_1/run-inference", + json={ + "batch_name": "batch_foo", + "is_pdp": True, + "term_filter": [], + }, + ) + assert response.status_code == 400 + assert "At least one label is required" in response.json()["detail"] + + +def test_trigger_inference_run_term_filter_empty_string_rejected( + client: TestClient, +) -> None: + """Run-inference with term_filter containing empty string returns 400.""" + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/models/sample_model_for_school_1/run-inference", + json={ + "batch_name": "batch_foo", + "is_pdp": True, + "term_filter": ["fall 2024-25", ""], + }, + ) + assert response.status_code == 400 + assert "Labels must be non-empty" in response.json()["detail"] + + +def test_trigger_inference_run_term_filter_whitespace_only_rejected( + client: TestClient, +) -> None: + """Run-inference with term_filter containing only whitespace returns 400.""" + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/models/sample_model_for_school_1/run-inference", + json={ + "batch_name": "batch_foo", + "is_pdp": True, + "term_filter": ["fall 2024-25", " \t "], + }, + ) + assert response.status_code == 400 + assert "Labels must be non-empty" in response.json()["detail"] + + +def test_trigger_inference_run_with_multiple_term_filter_labels( + client: TestClient, +) -> None: + """Run-inference with multiple term_filter labels passes list to Databricks and returns 200.""" + MOCK_DATABRICKS.run_pdp_inference.return_value = DatabricksInferenceRunResponse( + job_run_id=789 + ) + labels = ["fall 2024-25", "spring 2024-25"] + response = client.post( + "/institutions/" + + uuid_to_str(USER_VALID_INST_UUID) + + "/models/sample_model_for_school_1/run-inference", + json={ + "batch_name": "batch_foo", + "is_pdp": True, + "term_filter": labels, + }, + ) + assert response.status_code == 200 + assert response.json()["run_id"] == 789 + call_args = MOCK_DATABRICKS.run_pdp_inference.call_args + assert call_args is not None + assert call_args[0][0].term_filter == labels def test_check_file_types_valid_schema_configs():