diff --git a/Makefile b/Makefile index 49b3a494b..46ca2f881 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ clean: docs-clean .venv/bin/python: pip install hatch hatch env create + hatch run pip install ".[llm,pii]" dev: .venv/bin/python @hatch run which python diff --git a/demos/dqx_llm_demo.py b/demos/dqx_llm_demo.py new file mode 100644 index 000000000..9dfe969ce --- /dev/null +++ b/demos/dqx_llm_demo.py @@ -0,0 +1,316 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Using DQX for LLM-based Primary Key Detection +# MAGIC DQX provides optional LLM-based primary key detection capabilities that can intelligently identify primary key columns from table schema and metadata. This feature uses Large Language Models to analyze table structures and suggest potential primary keys, enhancing the data profiling and quality rules generation process. +# MAGIC +# MAGIC The LLM-based primary key detection is completely optional and only activates when users explicitly request it. Regular DQX functionality works without any LLM dependencies. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Install DQX with LLM extras +# MAGIC +# MAGIC To enable LLM-based primary key detection, DQX has to be installed with `llm` extras: +# MAGIC +# MAGIC `%pip install databricks-labs-dqx[llm]` + +# COMMAND ---------- + +dbutils.widgets.text("test_library_ref", "", "Test Library Ref") + +if dbutils.widgets.get("test_library_ref") != "": + %pip install 'databricks-labs-dqx[llm] @ {dbutils.widgets.get("test_library_ref")}' +else: + %pip install databricks-labs-dqx[llm] + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.labs.dqx.config import ProfilerConfig, LLMConfig +from databricks.labs.dqx.profiler.profiler import DQProfiler + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Regular Profiling (No LLM Dependencies Required) +# MAGIC +# MAGIC By default, DQX works without any LLM dependencies. Regular profiling functionality is always available. + +# COMMAND ---------- + +# Default configuration - no LLM features +config = ProfilerConfig() +print(f"LLM PK Detection: {config.llm_config.enable_pk_detection}") # False by default + +# This works without any LLM dependencies! +print("✅ Regular profiling works out of the box!") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## LLM-Based Primary Key Detection +# MAGIC +# MAGIC When explicitly requested, DQX can use LLM-based analysis to detect potential primary keys in your tables. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Method 1: Configuration-based enablement + +# COMMAND ---------- + +# Enable LLM-based PK detection via configuration +config = ProfilerConfig(llm_config=LLMConfig(enable_pk_detection=True)) +print(f"LLM PK Detection: {config.llm_config.enable_pk_detection}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Method 2: Options-based enablement + +# COMMAND ---------- + +ws = WorkspaceClient() +profiler = DQProfiler(ws) + +# Enable via options parameter +summary_stats, dq_rules = profiler.profile_table( + "catalog.schema.table", + options={"llm": True} # Simple LLM enablement +) +print("✅ LLM-based profiling enabled!") + +# Check if primary key was detected +if "llm_primary_key_detection" in summary_stats: + pk_info = summary_stats["llm_primary_key_detection"] + print(f"Detected PK: {pk_info['detected_columns']}") + print(f"Confidence: {pk_info['confidence']}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Method 3: Direct detection method + +# COMMAND ---------- + +# Direct LLM-based primary key detection +result = profiler.detect_primary_keys_with_llm( + table="customers", + llm=True, # Explicit LLM enablement required + options={ + "llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct" + } +) + +if result and result.get("success", False): + print(f"✅ Detected PK: {result['primary_key_columns']}") + print(f"Confidence: {result['confidence']}") + print(f"Reasoning: {result['reasoning']}") +else: + print("❌ Primary key detection failed or returned no results") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Rules Generation with is_unique +# MAGIC +# MAGIC Once primary keys are detected via LLM, DQX can automatically generate `is_unique` data quality rules to validate the uniqueness of those columns. + +# COMMAND ---------- + +from databricks.labs.dqx.profiler.generator import DQGenerator + +# Example: Generate is_unique rule from LLM-detected primary key +detected_pk_columns = ["customer_id", "order_id"] # Example detected PK +confidence = "high" +reasoning = "LLM analysis indicates these columns form a composite primary key based on schema patterns" + +# Generate is_unique rule using the detected primary key +is_unique_rule = DQGenerator.dq_generate_is_unique( + column=",".join(detected_pk_columns), + level="error", + columns=detected_pk_columns, + confidence=confidence, + reasoning=reasoning, + llm_detected=True, + nulls_distinct=True # Default behavior: NULLs are treated as distinct +) + +print("Generated is_unique rule:") +print(f"Rule name: {is_unique_rule['name']}") +print(f"Function: {is_unique_rule['check']['function']}") +print(f"Columns: {is_unique_rule['check']['arguments']['columns']}") +print(f"Criticality: {is_unique_rule['criticality']}") +print(f"LLM detected: {is_unique_rule['user_metadata']['llm_based_detection']}") +print(f"Confidence: {is_unique_rule['user_metadata']['pk_detection_confidence']}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Integrated Workflow: LLM Detection + Rule Generation +# MAGIC +# MAGIC Here's how to combine LLM-based primary key detection with automatic rule generation: + +# COMMAND ---------- + +# Create sample data for demonstration +sample_data = [ + (1, "A001", "John", "Doe"), + (2, "A002", "Jane", "Smith"), + (3, "A001", "Bob", "Johnson"), # Duplicate customer_id - should fail uniqueness + (4, "A003", "Alice", "Brown") +] + +sample_df = spark.createDataFrame( + sample_data, + ["id", "customer_id", "first_name", "last_name"] +) + +# Display sample data +sample_df.show() + +# COMMAND ---------- + +# Simulate LLM detection result (in practice, this would come from the LLM) +llm_detection_result = { + "success": True, + "primary_key_columns": ["id"], # LLM detected 'id' as primary key + "confidence": "high", + "reasoning": "Column 'id' appears to be an auto-incrementing identifier based on naming patterns and data distribution" +} + +if llm_detection_result["success"]: + # Generate is_unique rule from LLM detection + pk_columns = llm_detection_result["primary_key_columns"] + + generated_rule = DQGenerator.dq_generate_is_unique( + column=",".join(pk_columns), + level="error", + columns=pk_columns, + confidence=llm_detection_result["confidence"], + reasoning=llm_detection_result["reasoning"], + llm_detected=True + ) + + print("✅ Generated is_unique rule from LLM detection:") + print(f" Rule: {generated_rule['name']}") + print(f" Columns: {generated_rule['check']['arguments']['columns']}") + print(f" Metadata: LLM-based detection with {generated_rule['user_metadata']['pk_detection_confidence']} confidence") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Applying the Generated is_unique Rule +# MAGIC +# MAGIC Now let's apply the generated rule to validate data quality: + +# COMMAND ---------- + +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQDatasetRule +from databricks.labs.dqx import check_funcs + +# Convert the generated rule to a DQDatasetRule for execution +pk_columns = generated_rule['check']['arguments']['columns'] +dq_rule = DQDatasetRule( + name=generated_rule['name'], + criticality=generated_rule['criticality'], + check_func=check_funcs.is_unique, + columns=pk_columns, + check_func_kwargs={ + "nulls_distinct": generated_rule['check']['arguments']['nulls_distinct'] + } +) + +# Apply the rule using DQEngine +dq_engine = DQEngine(workspace_client=ws) +result_df = dq_engine.apply_checks(sample_df, [dq_rule]) + +print("✅ Applied is_unique rule to sample data") +print("Result columns:", result_df.columns) + +# Show results - the rule should pass since 'id' column has unique values +result_df.select("id", "customer_id", f"dq_check_{generated_rule['name']}").show() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Composite Primary Key Example +# MAGIC +# MAGIC Let's demonstrate with a composite primary key detected by LLM: + +# COMMAND ---------- + +# Sample data with composite key scenario +composite_data = [ + ("store_1", "2024-01-01", 100.0), + ("store_1", "2024-01-02", 150.0), + ("store_2", "2024-01-01", 200.0), + ("store_2", "2024-01-02", 175.0), + ("store_1", "2024-01-01", 120.0), # Duplicate composite key - should fail +] + +composite_df = spark.createDataFrame( + composite_data, + ["store_id", "date", "sales_amount"] +) + +# Simulate LLM detecting composite primary key +composite_llm_result = { + "success": True, + "primary_key_columns": ["store_id", "date"], + "confidence": "medium", + "reasoning": "Combination of store_id and date appears to uniquely identify sales records based on business logic patterns" +} + +# Generate composite is_unique rule +composite_rule = DQGenerator.dq_generate_is_unique( + column=",".join(composite_llm_result["primary_key_columns"]), + level="warn", # Use warning level for this example + columns=composite_llm_result["primary_key_columns"], + confidence=composite_llm_result["confidence"], + reasoning=composite_llm_result["reasoning"], + llm_detected=True +) + +print("Generated composite is_unique rule:") +print(f"Rule name: {composite_rule['name']}") +print(f"Columns: {composite_rule['check']['arguments']['columns']}") +print(f"Criticality: {composite_rule['criticality']}") + +# COMMAND ---------- + +# Apply composite key validation +composite_dq_rule = DQDatasetRule( + name=composite_rule['name'], + criticality=composite_rule['criticality'], + check_func=check_funcs.is_unique, + columns=composite_rule['check']['arguments']['columns'], + check_func_kwargs={ + "nulls_distinct": composite_rule['check']['arguments']['nulls_distinct'] + } +) + +composite_result_df = dq_engine.apply_checks(composite_df, [composite_dq_rule]) + +print("✅ Applied composite is_unique rule") +print("Data with duplicates detected:") +composite_result_df.show() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Key Features +# MAGIC +# MAGIC - **🔧 Completely Optional**: Not activated by default - requires explicit enablement +# MAGIC - **🤖 Intelligent Detection**: Uses LLM analysis of table schema and metadata +# MAGIC - **✨ Multiple Activation Methods**: Various ways to enable when needed +# MAGIC - **🛡️ Graceful Fallback**: Clear messaging when dependencies unavailable +# MAGIC - **📊 Confidence Scoring**: Provides confidence levels and reasoning +# MAGIC - **🔄 Validation**: Optionally validates detected PKs for duplicates +# MAGIC - **⚡ Automatic Rule Generation**: Converts detected PKs into executable `is_unique` rules +# MAGIC - **🔗 End-to-End Workflow**: From LLM detection to data quality validation diff --git a/docs/dqx/docs/demos.mdx b/docs/dqx/docs/demos.mdx index 5434fe266..0adeb2693 100644 --- a/docs/dqx/docs/demos.mdx +++ b/docs/dqx/docs/demos.mdx @@ -23,6 +23,7 @@ Import the following notebooks in the Databricks workspace to try DQX out: ## Use Cases * [DQX for PII Detection Notebook](https://github.com/databrickslabs/dqx/blob/v0.9.3/demos/dqx_demo_pii_detection.py) - demonstrates how to use DQX to check data for Personally Identifiable Information (PII). * [DQX for Manufacturing Notebook](https://github.com/databrickslabs/dqx/blob/v0.9.3/demos/dqx_manufacturing_demo.py) - demonstrates how to use DQX to check data quality for Manufacturing Industry datasets. +* [DQX LLM-based Primary Key Detection Notebook](https://github.com/databrickslabs/dqx/blob/v0.9.3/demos/dqx_llm_demo.py) - demonstrates how to use DQX's LLM-based primary key detection capabilities and rules generation with is_unique checks.
diff --git a/docs/dqx/docs/guide/data_profiling.mdx b/docs/dqx/docs/guide/data_profiling.mdx index a410ecbcc..6dd08a59a 100644 --- a/docs/dqx/docs/guide/data_profiling.mdx +++ b/docs/dqx/docs/guide/data_profiling.mdx @@ -323,9 +323,182 @@ Example of the configuration file (relevant fields only): filter: "maintenance_type = 'preventive'" ``` +## LLM-Assisted Primary Key Detection + + +DQX provides **optional** LLM-based features that enhance data profiling capabilities. These features are only active when explicitly requested and require additional dependencies. + + +The LLM-based Primary Key Detection uses Large Language Models (via DSPy and Databricks Model Serving) to intelligently identify primary keys from table schema and metadata. This enhances the DQX profiling process by automatically detecting primary keys and generating appropriate uniqueness validation rules. + +### How LLM Primary Key Detection Works + +1. **Schema Analysis**: The detection process examines table structure, column names, data types, and constraints. +2. **LLM Processing**: Uses advanced language models to understand naming patterns and relationships. +3. **Confidence Scoring**: Provides confidence levels (high/medium/low) for detected primary keys. +4. **Duplicate Validation**: Optionally validates that detected columns actually contain unique values. +5. **Rule Generation**: Creates appropriate uniqueness quality checks for validating primary keys. + +### When to Use LLM Primary Key Detection + +- **Data Discovery**: When exploring new datasets without documented primary keys. +- **Data Migration**: When migrating data between systems with different constraint definitions. +- **Data Quality Assessment**: To validate existing primary key assumptions. +- **Automated Profiling**: For large-scale data profiling across multiple tables. +- **Dataset Comparison**: To improve accuracy of dataset comparison operations. + +### LLM Installation Requirements + +To use LLM-based features, install DQX with LLM dependencies: + +```bash +# Install DQX with LLM dependencies using extras +pip install databricks-labs-dqx[llm] +``` + +### LLM Usage Examples + +#### Configuration-Based (Profiler Workflows) + + + + ```yaml + # config.yml - Configuration for profiler workflow + run_configs: + - name: "default" + input_config: + location: "catalog.schema.table" + profiler_config: + # Enable LLM-based primary key detection + llm_config: + enable_pk_detection: true + pk_detection_endpoint: "databricks-meta-llama-3-1-8b-instruct" + ``` + + + +#### Programmatic Integration + + + + ```python + from databricks.labs.dqx.profiler.profiler import DQProfiler + from databricks.sdk import WorkspaceClient + + ws = WorkspaceClient() + profiler = DQProfiler(ws) + + # Enable via options parameter + summary_stats, dq_rules = profiler.profile_table( + "catalog.schema.table", + options={ + "llm": True, + "llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct" + } + ) + + # Or use the explicit flag + summary_stats, dq_rules = profiler.profile_table( + "catalog.schema.table", + options={"enable_llm_pk_detection": True} + ) + + # Direct LLM-based primary key detection + result, detector = profiler._run_llm_pk_detection( + table="main.sales.customers", + options={ + "llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct" + } + ) + + if result and result.get("success", False): + print(f"✅ Detected PK: {result['primary_key_columns']}") + print(f"Confidence: {result['confidence']}") + print(f"Reasoning: {result['reasoning']}") + else: + print("❌ Primary key detection failed or returned no results") + ``` + + + +### LLM Output & Metadata + +#### Summary Statistics with LLM Results + +When LLM primary key detection is enabled, additional metadata is added to the summary statistics: + +```python +summary_stats["llm_primary_key_detection"] = { + "detected_columns": ["customer_id", "order_id"], # Detected PK columns + "confidence": "high", # LLM confidence level + "has_duplicates": False, # Duplicate validation result + "validation_performed": True, # Whether validation was run + "method": "llm_based" # Detection method +} +``` + +#### Generated Quality Rules + +LLM-detected primary keys automatically generate uniqueness validation rules: + +```python +{ + "check": { + "function": "is_unique", + "arguments": { + "columns": ["customer_id", "order_id"], + "nulls_distinct": True + } + }, + "name": "primary_key_customer_id_order_id_validation", + "criticality": "error", + "user_metadata": { + "pk_detection_confidence": "high", + "pk_detection_reasoning": "LLM analysis of schema and metadata", + "detected_primary_key": True, + "llm_based_detection": True, + "detection_method": "llm_analysis", + "requires_llm_dependencies": True + } +} +``` + +### LLM Troubleshooting + +#### Common Issues + +1. **ImportError: No module named 'dspy'** + ```bash + pip install dspy-ai databricks_langchain + ``` + +2. **LLM Detection Not Running** + - Ensure `llm=True` or `enable_llm_pk_detection=True` + - Check that LLM dependencies are installed + +3. **Low Confidence Results** + - Review table schema and metadata quality + - Consider using different LLM endpoints + - Validate results manually + +4. **Performance Issues** + - Use sampling for large tables + - Adjust retry limits + - Consider caching results + +#### Debug Mode + +```python +import logging +logging.basicConfig(level=logging.DEBUG) + +# Enable detailed logging +profiler.profile_table(table, options={"llm": True}) +``` + ## Profiling options -The profiler supports extensive configuration options to customize the profiling behavior. +The profiler supports extensive configuration options to customize the profiling behavior, including LLM-based features. ### Profiling options for a single table @@ -362,9 +535,16 @@ The profiler supports extensive configuration options to customize the profiling # Value rounding "round": True, # Round min/max values for cleaner rules - + # Filter Options "filter": None, # No filter (SQL expression) + + # LLM-based Primary Key Detection options + "enable_llm_pk_detection": True, # Enable LLM-based primary key detection + "llm_pk_detection": True, # Another alternative to enable LLM PK detection + "llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct", # LLM endpoint + "llm_pk_validate_duplicates": True, # Validate detected primary keys for duplicates + "llm_pk_max_retries": 3, # Maximum retries for LLM prediction } ws = WorkspaceClient() @@ -388,7 +568,10 @@ The profiler supports extensive configuration options to customize the profiling - `sample_fraction`: fraction of data to sample for profiling. - `sample_seed`: seed for reproducible sampling. - `limit`: maximum number of records to analyze. - - `filter`: string SQL expression to filter the input data + - `filter`: string SQL expression to filter the input data. + - `llm_config`: LLM-based feature configuration containing: + - `enable_pk_detection`: enable LLM-based primary key detection (default: false) + - `pk_detection_endpoint`: LLM endpoint for primary key detection (default: "databricks-meta-llama-3-1-8b-instruct") You can open the configuration file to check available run configs and adjust the settings if needed: ```commandline @@ -403,6 +586,9 @@ The profiler supports extensive configuration options to customize the profiling sample_fraction: 0.3 # Fraction of data to sample (30%) limit: 1000 # Maximum number of records to analyze sample_seed: 42 # Seed for reproducible results + llm_config: # LLM-based features (optional) + enable_pk_detection: true + pk_detection_endpoint: "databricks-meta-llama-3-1-8b-instruct" ... ``` More detailed options such as 'num_sigmas' are not configurable when using profiler workflow. They are only available for programmatic integration. diff --git a/pyproject.toml b/pyproject.toml index b24a127f7..70c7aa33b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ pii = [ # This may be required for the larger models due to Databricks connect memory limitations. # The models cannot be delcared as dependency here buecase PyPI does not support URL-based dependencies which would prevent releases. ] +llm = [ # LLM assisted features + "dspy-ai~=3.0.3", + "databricks-langchain~=0.8.0", +] [project.entry-points.databricks] runtime = "databricks.labs.dqx.workflows_runner:main" @@ -77,7 +81,7 @@ include = ["src"] path = "src/databricks/labs/dqx/__about__.py" [tool.hatch.envs.default] -features = ["pii"] +features = ["pii", "llm"] dependencies = [ "black~=24.8.0", "chispa~=0.10.1", @@ -102,6 +106,8 @@ dependencies = [ "dbldatagen~=0.4.0", "pyparsing~=3.2.3", "jmespath~=1.0.1", + "dspy-ai~=3.0.3", + "databricks-langchain~=0.8.0", "psycopg2-binary~=2.9.10", ] @@ -168,10 +174,21 @@ exclude = ['venv', '.venv', 'demos/*', 'tests/e2e/*'] module = ["google", "google.*"] # Google libraries are not type annotated ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["dspy", "databricks_langchain"] # Optional LLM dependencies +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["databricks.labs.dqx.profiler.runner"] # Internal module without py.typed +ignore_missing_imports = true + [tool.pytest.ini_options] addopts = "--no-header" cache_dir = ".venv/pytest-cache" -filterwarnings = ["ignore::DeprecationWarning"] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore:PII detection uses pandas user-defined functions.*:UserWarning" +] [tool.black] target-version = ["py310"] diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 962d550a8..9ce306595 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1459,7 +1459,6 @@ def compare_datasets( 100 vs 101 → equal (diff = 1, tolerance = 1) 2.001 vs 2.0099 → equal - Returns: Tuple[Column, Callable]: - A Spark Column representing the condition for comparison violations. diff --git a/src/databricks/labs/dqx/config.py b/src/databricks/labs/dqx/config.py index d1df0b93f..808f58927 100644 --- a/src/databricks/labs/dqx/config.py +++ b/src/databricks/labs/dqx/config.py @@ -10,6 +10,7 @@ "InputConfig", "OutputConfig", "ExtraParams", + "LLMConfig", "ProfilerConfig", "BaseChecksStorageConfig", "FileChecksStorageConfig", @@ -43,6 +44,16 @@ class OutputConfig: trigger: dict[str, bool | str] = field(default_factory=dict) +@dataclass +class LLMConfig: + """Configuration class for LLM-assisted features.""" + + # Primary Key Detection Configuration + # Note: LLM-based PK detection requires: pip install databricks-labs-dqx[llm] + enable_pk_detection: bool = False + pk_detection_endpoint: str = "databricks-meta-llama-3-1-8b-instruct" + + @dataclass class ProfilerConfig: """Configuration class for profiler.""" @@ -53,6 +64,9 @@ class ProfilerConfig: limit: int = 1000 # limit the number of records to profile filter: str | None = None # filter to apply to the data before profiling + # LLM-assisted features configuration + llm_config: LLMConfig = field(default_factory=LLMConfig) + @dataclass class RunConfig: diff --git a/src/databricks/labs/dqx/installer/install.py b/src/databricks/labs/dqx/installer/install.py index b525aa0aa..c1fd81961 100644 --- a/src/databricks/labs/dqx/installer/install.py +++ b/src/databricks/labs/dqx/installer/install.py @@ -306,7 +306,8 @@ def config(self): """ Returns the configuration of the workspace installation. - :return: The WorkspaceConfig instance. + Returns: + The WorkspaceConfig instance. """ return self._config @@ -315,7 +316,8 @@ def install_folder(self): """ Returns the installation install_folder path. - :return: The installation install_folder path as a string. + Returns: + The installation install_folder path as a string. """ return self._installation.install_folder() @@ -323,7 +325,8 @@ def run(self) -> bool: """ Runs the workflow installation. - :return: True if the installation finished successfully, False otherwise. + Returns: + True if the installation finished successfully, False otherwise. """ logger.info(f"Installing DQX v{self._product_info.version()}") install_tasks = [self._workflow_installer.create_jobs, self._create_dashboard] diff --git a/src/databricks/labs/dqx/llm/__init__.py b/src/databricks/labs/dqx/llm/__init__.py index e69de29bb..9834285a4 100644 --- a/src/databricks/labs/dqx/llm/__init__.py +++ b/src/databricks/labs/dqx/llm/__init__.py @@ -0,0 +1,14 @@ +from importlib.util import find_spec + +# Check only core libraries at import-time. Model packages are loaded when LLM detection is invoked. +required_specs = [ + "dspy", + "databricks_langchain", +] + +# Check that LLM detection modules are installed +if not all(find_spec(spec) for spec in required_specs): + raise ImportError( + "LLM detection extras not installed; Install additional " + "dependencies by running `pip install databricks-labs-dqx[llm]`" + ) diff --git a/src/databricks/labs/dqx/llm/pk_identifier.py b/src/databricks/labs/dqx/llm/pk_identifier.py new file mode 100644 index 000000000..a9c8a4461 --- /dev/null +++ b/src/databricks/labs/dqx/llm/pk_identifier.py @@ -0,0 +1,702 @@ +"""Primary Key Detection using DSPy with Databricks Model Serving.""" + +import logging +from typing import Any + +import dspy +from databricks_langchain import ChatDatabricks +from langchain_core.messages import HumanMessage +from pyspark.sql import SparkSession + +from databricks.labs.dqx.utils import generate_table_definition_from_dataframe + +logger = logging.getLogger(__name__) + + +class DatabricksLM(dspy.LM): + """Custom DSPy LM adapter for Databricks Model Serving.""" + + def __init__(self, endpoint: str, temperature: float = 0.1, max_tokens: int = 1000, chat_databricks_cls=None): + self.endpoint = endpoint + self.temperature = temperature + self.max_tokens = max_tokens + + # Allow injection of ChatDatabricks class for testing + chat_cls = chat_databricks_cls or ChatDatabricks + self.llm = chat_cls(endpoint=endpoint, temperature=temperature) + super().__init__(model=endpoint) + + def __call__(self, prompt=None, messages=None, **kwargs): + """Call the Databricks model serving endpoint.""" + try: + if messages: + response = self.llm.invoke(messages) + else: + response = self.llm.invoke([HumanMessage(content=prompt)]) + + return [response.content] + + except (ConnectionError, TimeoutError, ValueError) as e: + print(f"Error calling Databricks model: {e}") + return [f"Error: {str(e)}"] + except (AttributeError, TypeError, RuntimeError) as e: + print(f"Unexpected error calling Databricks model: {e}") + return [f"Unexpected error: {str(e)}"] + + +class SparkManager: + """Manages Spark SQL operations for table schema and duplicate checking.""" + + def __init__(self, spark_session=None): + """Initialize with Spark session.""" + self.spark = spark_session + if not self.spark: + try: + self.spark = SparkSession.getActiveSession() + if not self.spark: + self.spark = SparkSession.builder.appName("PKDetection").getOrCreate() + except ImportError: + logger.warning("PySpark not available. Some features may not work.") + except RuntimeError as e: + # Handle case where only remote Spark sessions are supported (e.g., in unit tests) + if "Only remote Spark sessions using Databricks Connect are supported" in str(e): + logger.warning( + "Local Spark session not available. Using None - features requiring Spark will not work." + ) + self.spark = None + else: + raise + + def _get_table_columns(self, table: str) -> list[str]: + """Get table column definitions from DESCRIBE TABLE.""" + describe_query = f"DESCRIBE TABLE EXTENDED {table}" + describe_result = self.spark.sql(describe_query) + describe_df = describe_result.toPandas() + + definition_lines = [] + in_column_section = True + + for _, row in describe_df.iterrows(): + col_name = row['col_name'] + data_type = row['data_type'] + comment = row['comment'] if 'comment' in row else '' + + if col_name.startswith('#') or col_name.strip() == '': + in_column_section = False + continue + + if in_column_section and not col_name.startswith('#'): + nullable = "" if "not null" in str(comment).lower() else "" + definition_lines.append(f" {col_name} {data_type}{nullable}") + + return definition_lines + + def _get_existing_primary_key(self, table: str) -> str | None: + """Get existing primary key from table properties.""" + try: + pk_query = f"SHOW TBLPROPERTIES {table}" + pk_result = self.spark.sql(pk_query) + pk_df = pk_result.toPandas() + + for _, row in pk_df.iterrows(): + if 'primary' in str(row.get('key', '')).lower(): + return row.get('value', '') + except (ValueError, RuntimeError, KeyError): + # Silently continue if table properties are not accessible + pass + return None + + @staticmethod + def _build_table_definition_string(definition_lines: list[str], existing_pk: str | None) -> str: + """Build the final table definition string.""" + table_definition = "{\n" + ",\n".join(definition_lines) + "\n}" + if existing_pk: + table_definition += f"\n-- Existing Primary Key: {existing_pk}" + return table_definition + + def get_table_definition(self, table: str) -> str: + """Retrieve table definition using Spark SQL DESCRIBE commands.""" + if not self.spark: + raise ValueError("Spark session not available") + + try: + print(f"🔍 Retrieving schema for table: {table}") + + definition_lines = self._get_table_columns(table) + existing_pk = self._get_existing_primary_key(table) + table_definition = self._build_table_definition_string(definition_lines, existing_pk) + + print("✅ Table definition retrieved successfully") + return table_definition + + except (ValueError, RuntimeError) as e: + logger.error(f"Error retrieving table definition for {table}: {e}") + raise + except (AttributeError, TypeError, KeyError) as e: + logger.error(f"Unexpected error retrieving table definition for {table}: {e}") + raise RuntimeError(f"Failed to retrieve table definition: {e}") from e + + def _execute_duplicate_check_query(self, table: str, pk_columns: list[str]) -> tuple[bool, int, Any]: + """Execute the duplicate check query and return results.""" + pk_cols_str = ", ".join([f"`{col}`" for col in pk_columns]) + print(f"🔍 Checking for duplicates in {table} using columns: {pk_cols_str}") + + duplicate_query = f""" + SELECT {pk_cols_str}, COUNT(*) as duplicate_count + FROM {table} + GROUP BY {pk_cols_str} + HAVING COUNT(*) > 1 + """ + + duplicate_result = self.spark.sql(duplicate_query) + duplicates_df = duplicate_result.toPandas() + return len(duplicates_df) > 0, len(duplicates_df), duplicates_df + + @staticmethod + def _report_duplicate_results( + has_duplicates: bool, duplicate_count: int, pk_columns: list[str], duplicates_df=None + ): + """Report the results of duplicate checking.""" + if has_duplicates and duplicates_df is not None: + total_duplicate_records = duplicates_df['duplicate_count'].sum() + logger.warning( + f"Found {duplicate_count} duplicate key combinations affecting {total_duplicate_records} total records" + ) + print(f"⚠️ Found {duplicate_count} duplicate combinations for: {', '.join(pk_columns)}") + if len(duplicates_df) > 0: + print("Sample duplicates:") + print(duplicates_df.head().to_string(index=False)) + else: + logger.info(f"No duplicates found for predicted primary key: {', '.join(pk_columns)}") + print(f"✅ No duplicates found for: {', '.join(pk_columns)}") + + def check_duplicates( + self, + table: str, + pk_columns: list[str], + ) -> tuple[bool, int]: + """Check for duplicates using Spark SQL GROUP BY and HAVING.""" + if not self.spark: + raise ValueError("Spark session not available") + + try: + has_duplicates, duplicate_count, duplicates_df = self._execute_duplicate_check_query(table, pk_columns) + self._report_duplicate_results(has_duplicates, duplicate_count, pk_columns, duplicates_df) + return has_duplicates, duplicate_count + + except (ValueError, RuntimeError) as e: + logger.error(f"Error checking duplicates: {e}") + print(f"❌ Error checking duplicates: {e}") + return False, 0 + except (AttributeError, TypeError, KeyError) as e: + logger.error(f"Unexpected error checking duplicates: {e}") + print(f"❌ Unexpected error checking duplicates: {e}") + return False, 0 + + @staticmethod + def _extract_useful_properties(stats_df) -> list[str]: + """Extract useful properties from table properties DataFrame.""" + metadata_info = [] + for _, row in stats_df.iterrows(): + key = row.get('key', '') + value = row.get('value', '') + if any(keyword in key.lower() for keyword in ('numrows', 'rawdatasize', 'totalsize', 'primary', 'unique')): + metadata_info.append(f"{key}: {value}") + return metadata_info + + def _get_table_properties(self, table: str) -> list[str]: + """Get table properties metadata.""" + try: + stats_query = f"SHOW TBLPROPERTIES {table}" + stats_result = self.spark.sql(stats_query) + stats_df = stats_result.toPandas() + return self._extract_useful_properties(stats_df) + except (ValueError, RuntimeError, KeyError): + # Silently continue if table properties are not accessible + return [] + + @staticmethod + def _categorize_columns_by_type(col_df) -> tuple[list[str], list[str], list[str], list[str]]: + """Categorize columns by their data types.""" + numeric_cols = [] + string_cols = [] + date_cols = [] + timestamp_cols = [] + + for _, row in col_df.iterrows(): + col_name = row.get('col_name', '') + data_type = str(row.get('data_type', '')).lower() + + if col_name.startswith('#') or col_name.strip() == '': + break + + if any(t in data_type for t in ('int', 'long', 'bigint', 'decimal', 'double', 'float')): + numeric_cols.append(col_name) + elif any(t in data_type for t in ('string', 'varchar', 'char')): + string_cols.append(col_name) + elif 'date' in data_type: + date_cols.append(col_name) + elif 'timestamp' in data_type: + timestamp_cols.append(col_name) + + return numeric_cols, string_cols, date_cols, timestamp_cols + + @staticmethod + def _format_column_distribution( + numeric_cols: list[str], string_cols: list[str], date_cols: list[str], timestamp_cols: list[str] + ) -> list[str]: + """Format column type distribution information.""" + metadata_info = [ + "Column type distribution:", + f" Numeric columns ({len(numeric_cols)}): {', '.join(numeric_cols[:5])}", + f" String columns ({len(string_cols)}): {', '.join(string_cols[:5])}", + f" Date columns ({len(date_cols)}): {', '.join(date_cols)}", + f" Timestamp columns ({len(timestamp_cols)}): {', '.join(timestamp_cols)}", + ] + return metadata_info + + def _get_column_statistics(self, table: str) -> list[str]: + """Get column statistics and type distribution.""" + try: + col_stats_query = f"DESCRIBE TABLE EXTENDED {table}" + col_result = self.spark.sql(col_stats_query) + col_df = col_result.toPandas() + + numeric_cols, string_cols, date_cols, timestamp_cols = self._categorize_columns_by_type(col_df) + return self._format_column_distribution(numeric_cols, string_cols, date_cols, timestamp_cols) + except (ValueError, RuntimeError, KeyError): + # Silently continue if table properties are not accessible + return [] + + def get_table_metadata_info(self, table: str) -> str: + """Get additional metadata information to help with primary key detection.""" + if not self.spark: + return "No metadata available (Spark session not found)" + + try: + metadata_info = [] + + # Get table properties + metadata_info.extend(self._get_table_properties(table)) + + # Get column statistics + metadata_info.extend(self._get_column_statistics(table)) + + return ( + "Metadata information:\n" + "\n".join(metadata_info) if metadata_info else "Limited metadata available" + ) + + except (ValueError, RuntimeError) as e: + return f"Could not retrieve metadata: {e}" + except (AttributeError, TypeError, KeyError) as e: + logger.warning(f"Unexpected error retrieving metadata: {e}") + return f"Could not retrieve metadata due to unexpected error: {e}" + + +def configure_databricks_llm(endpoint: str = "", temperature: float = 0.1, chat_databricks_cls=None): + """Configure DSPy to use Databricks model serving.""" + language_model = DatabricksLM(endpoint=endpoint, temperature=temperature, chat_databricks_cls=chat_databricks_cls) + dspy.configure(lm=language_model) + return language_model + + +def configure_with_tracing(): + """Enable DSPy tracing to see live reasoning.""" + dspy.settings.configure(trace=[]) + return True + + +class PrimaryKeyDetection(dspy.Signature): + """Analyze table schema and metadata step-by-step to identify the most likely primary key columns.""" + + table: str = dspy.InputField(desc="Name of the database table") + table_definition: str = dspy.InputField(desc="Complete table schema definition") + context: str = dspy.InputField(desc="Context about similar tables or patterns") + previous_attempts: str = dspy.InputField(desc="Previous failed attempts and why they failed") + metadata_info: str = dspy.InputField(desc="Table metadata and column statistics to aid analysis") + + primary_key_columns: str = dspy.OutputField(desc="Comma-separated list of primary key columns") + confidence: str = dspy.OutputField(desc="Confidence level: high, medium, or low") + reasoning: str = dspy.OutputField(desc="Step-by-step reasoning for the selection based on metadata analysis") + + +class DatabricksPrimaryKeyDetector: + """Primary Key Detector optimized for Databricks Spark environment.""" + + def __init__( + self, + table: str, + *, + endpoint: str = "databricks-meta-llama-3-1-8b-instruct", + context: str = "", + validate_duplicates: bool = True, + fail_on_duplicates: bool = True, + spark_session=None, + show_live_reasoning: bool = True, + max_retries: int = 3, + chat_databricks_cls=None, + ): + self.table = table + self.context = context + self.llm_provider = "databricks" # Fixed to databricks provider + self.endpoint = endpoint + self.validate_duplicates = validate_duplicates + self.fail_on_duplicates = fail_on_duplicates + self.spark = spark_session + self.detector = dspy.ChainOfThought(PrimaryKeyDetection) + self.spark_manager = SparkManager(spark_session) + self.show_live_reasoning = show_live_reasoning + self.max_retries = max_retries + + configure_databricks_llm(endpoint=endpoint, temperature=0.1, chat_databricks_cls=chat_databricks_cls) + configure_with_tracing() + + def detect_primary_keys(self) -> dict[str, Any]: + """ + Detect primary keys for tables and views. + """ + logger.info(f"Starting primary key detection for table: {self.table}") + return self._detect_primary_keys_from_table() + + def _detect_primary_keys_from_table(self) -> dict[str, Any]: + """Detect primary keys from a registered table or view.""" + try: + table_definition = self.spark_manager.get_table_definition(self.table) + metadata_info = self.spark_manager.get_table_metadata_info(self.table) + except (ValueError, RuntimeError, OSError) as e: + return { + 'table': self.table, + 'success': False, + 'error': f"Failed to retrieve table metadata: {str(e)}", + 'retries_attempted': 0, + } + except (AttributeError, TypeError, KeyError) as e: + logger.error(f"Unexpected error during table metadata retrieval: {e}") + return { + 'table': self.table, + 'success': False, + 'error': f"Unexpected error retrieving table metadata: {str(e)}", + 'retries_attempted': 0, + } + + return self._predict_with_retry_logic( + self.table, + table_definition, + self.context, + metadata_info, + self.validate_duplicates, + ) + + def _generate_table_definition_from_dataframe(self, df: Any) -> str: + """Generate a CREATE TABLE statement from a DataFrame schema.""" + return generate_table_definition_from_dataframe(df, self.table) + + def _check_duplicates_and_update_result(self, table: str, pk_columns: list[str], result: dict) -> tuple[bool, int]: + """Check for duplicates and update result with validation info.""" + has_duplicates, duplicate_count = self.spark_manager.check_duplicates(table, pk_columns) + + result['has_duplicates'] = has_duplicates + result['duplicate_count'] = duplicate_count + result['validation_performed'] = True + + return has_duplicates, duplicate_count + + @staticmethod + def _handle_successful_validation( + result: dict, attempt: int, all_attempts: list, previous_attempts: str + ) -> tuple[dict, str, bool]: + """Handle successful validation (no duplicates found).""" + logger.info("No duplicates found - Primary key prediction validated!") + result['retries_attempted'] = attempt + result['all_attempts'] = all_attempts + result['final_status'] = 'success' + + return result, previous_attempts, True # Success, stop retrying + + def _handle_duplicates_found( + self, + pk_columns: list[str], + duplicate_count: int, + attempt: int, + result: dict, + all_attempts: list, + previous_attempts: str, + ) -> tuple[dict, str, bool]: + """Handle case where duplicates are found.""" + logger.info(f"Found {duplicate_count} duplicate groups - Retrying with enhanced context") + + if attempt < self.max_retries: + failed_pk = ", ".join(pk_columns) + previous_attempts += ( + f"\nAttempt {attempt + 1}: Tried [{failed_pk}] but found {duplicate_count} " + f"duplicate key combinations. " + ) + previous_attempts += ( + "This indicates the combination is not unique enough. Need to find additional columns " + "or a different combination that ensures complete uniqueness. " + ) + previous_attempts += ( + "Consider adding timestamp fields, sequence numbers, or other differentiating columns " + "that would make each row unique." + ) + return result, previous_attempts, False # Continue retrying + + logger.info(f"Maximum retries ({self.max_retries}) reached. Returning best attempt with duplicates noted.") + + # Check if we should fail when duplicates are found + if hasattr(self, 'fail_on_duplicates') and self.fail_on_duplicates: + result['success'] = False # Mark as failed since duplicates were found + result['error'] = ( + f"Primary key validation failed: Found {duplicate_count} duplicate combinations " + f"in suggested columns {pk_columns}" + ) + else: + # Return best attempt with warning but don't fail + result['success'] = True + result['warning'] = ( + f"Primary key has duplicates: Found {duplicate_count} duplicate combinations " + f"in suggested columns {pk_columns}" + ) + + result['retries_attempted'] = attempt + result['all_attempts'] = all_attempts + result['final_status'] = 'max_retries_reached_with_duplicates' + return result, previous_attempts, True # Stop retrying, max attempts reached + + @staticmethod + def _handle_validation_error( + error: Exception, result: dict, attempt: int, all_attempts: list, previous_attempts: str + ) -> tuple[dict, str, bool]: + """Handle validation errors.""" + logger.error(f"Error during duplicate validation: {error}") + result['validation_error'] = str(error) + result['retries_attempted'] = attempt + result['all_attempts'] = all_attempts + result['final_status'] = 'validation_error' + return result, previous_attempts, True # Stop retrying due to error + + def _validate_pk_duplicates( + self, + table: str, + pk_columns: list[str], + result: dict, + attempt: int, + all_attempts: list, + previous_attempts: str, + ) -> tuple[dict, str, bool]: + """Validate primary key for duplicates and handle retry logic.""" + try: + has_duplicates, duplicate_count = self._check_duplicates_and_update_result(table, pk_columns, result) + + if not has_duplicates: + return self._handle_successful_validation(result, attempt, all_attempts, previous_attempts) + + return self._handle_duplicates_found( + pk_columns, duplicate_count, attempt, result, all_attempts, previous_attempts + ) + + except (ValueError, RuntimeError) as e: + return self._handle_validation_error(e, result, attempt, all_attempts, previous_attempts) + + def _execute_single_prediction( + self, table: str, table_definition: str, context: str, previous_attempts: str, metadata_info: str + ) -> dict[str, Any]: + """Execute a single prediction with the LLM.""" + if self.show_live_reasoning: + with dspy.context(show_guidelines=True): + logger.info("AI is analyzing metadata step by step...") + result = self.detector( + table=table, + table_definition=table_definition, + context=context, + previous_attempts=previous_attempts, + metadata_info=metadata_info, + ) + else: + result = self.detector( + table=table, + table_definition=table_definition, + context=context, + previous_attempts=previous_attempts, + metadata_info=metadata_info, + ) + + pk_columns = [col.strip() for col in result.primary_key_columns.split(',')] + + final_result = { + 'table': table, + 'primary_key_columns': pk_columns, + 'confidence': result.confidence, + 'reasoning': result.reasoning, + 'success': True, + } + + logger.info(f"Primary Key: {', '.join(pk_columns)}") + logger.info(f"Confidence: {result.confidence}") + + return final_result + + def _predict_with_retry_logic( + self, + table: str, + table_definition: str, + context: str, + metadata_info: str, + validate_duplicates: bool, + ) -> dict[str, Any]: + """Handle prediction with retry logic for duplicate validation.""" + + previous_attempts = "" + all_attempts = [] + + for attempt in range(self.max_retries + 1): + logger.info(f"Prediction attempt {attempt + 1}/{self.max_retries + 1}") + + result = self._single_prediction(table, table_definition, context, previous_attempts, metadata_info) + + if not result['success']: + return result + + all_attempts.append(result.copy()) + pk_columns = result['primary_key_columns'] + + if not validate_duplicates: + result['validation_performed'] = False + result['retries_attempted'] = attempt + return result + + logger.info("Validating primary key prediction...") + result, previous_attempts, should_stop = self._validate_pk_duplicates( + table, pk_columns, result, attempt, all_attempts, previous_attempts + ) + + if should_stop: + return result + + # This shouldn't be reached, but just in case + return all_attempts[-1] if all_attempts else {'success': False, 'error': 'No attempts made'} + + def _single_prediction( + self, table: str, table_definition: str, context: str, previous_attempts: str, metadata_info: str + ) -> dict[str, Any]: + """Make a single primary key prediction using metadata.""" + + logger.info("Analyzing table schema and metadata patterns...") + + try: + final_result = self._execute_single_prediction( + table, table_definition, context, previous_attempts, metadata_info + ) + + # Print reasoning if available + if 'reasoning' in final_result: + self._print_reasoning_formatted(final_result['reasoning']) + + self._print_trace_if_available() + + return final_result + + except (ValueError, RuntimeError, AttributeError) as e: + error_msg = f"Error during prediction: {str(e)}" + logger.error(error_msg) + return {'table': table, 'success': False, 'error': error_msg} + except (TypeError, KeyError, ImportError) as e: + error_msg = f"Unexpected error during prediction: {str(e)}" + logger.error(error_msg) + logger.debug("Full traceback:", exc_info=True) + return {'table': table, 'success': False, 'error': error_msg} + + @staticmethod + def _print_reasoning_formatted(reasoning): + """Format and print reasoning step by step.""" + if not reasoning: + print("No reasoning provided") + return + + lines = reasoning.split('\n') + step_counter = 1 + + for line in lines: + line = line.strip() + if not line: + continue + + if line.lower().startswith('step'): + print(f"📝 {line}") + elif line.startswith('-') or line.startswith('•'): + print(f" {line}") + elif len(line) > 10 and any( + word in line.lower() for word in ('analyze', 'consider', 'look', 'notice', 'think') + ): + print(f"📝 Step {step_counter}: {line}") + step_counter += 1 + else: + print(f" {line}") + + @staticmethod + def _print_trace_if_available(): + """Print DSPy trace if available.""" + try: + if hasattr(dspy.settings, 'trace') and dspy.settings.trace: + logger.debug("\n🔬 TRACE INFORMATION:") + logger.debug("-" * 60) + for i, trace_item in enumerate(dspy.settings.trace[-3:]): + logger.debug(f"Trace {i+1}: {str(trace_item)[:200]}...") + except (AttributeError, IndexError): + # Silently continue if trace information is not available + pass + + @staticmethod + def print_pk_detection_summary(result): + """Print summary based on result dictionary.""" + + logger.info("=" * 60) + logger.info("🎯 PRIMARY KEY DETECTION SUMMARY") + logger.info("=" * 60) + logger.info(f"Table: {result['table']}") + logger.info(f"Status: {'✅ SUCCESS' if result['success'] else '❌ FAILED'}") + logger.info(f"Attempts: {result.get('retries_attempted', 0) + 1}") + if result.get('retries_attempted', 0) > 0: + logger.info(f"Retries needed: {result['retries_attempted']}") + logger.info("") + + logger.info("📋 FINAL PRIMARY KEY:") + for col in result['primary_key_columns']: + logger.info(f" • {col}") + logger.info("") + + logger.info(f"🎯 Confidence: {result['confidence'].upper()}") + + if result.get('validation_performed', False): + validation_msg = ( + "No duplicates found" + if not result.get('has_duplicates', True) + else f"Found {result.get('duplicate_count', 0)} duplicates" + ) + logger.info(f"🔍 Validation: {validation_msg}") + else: + logger.info("🔍 Validation: Not performed") + logger.info("") + + if result.get('all_attempts') and len(result['all_attempts']) > 1: + logger.info("📝 ATTEMPT HISTORY:") + for i, attempt in enumerate(result['all_attempts']): + cols_str = ', '.join(attempt['primary_key_columns']) + if i == 0: + logger.info(f" 1st attempt: {cols_str} → Found duplicates") + else: + attempt_num = i + 1 + suffix = "nd" if attempt_num == 2 else "rd" if attempt_num == 3 else "th" + status_msg = "Success!" if i == len(result['all_attempts']) - 1 else "Still had duplicates" + logger.info(f" {attempt_num}{suffix} attempt: {cols_str} → {status_msg}") + logger.info("") + + status = result.get('final_status', 'unknown') + if status == 'success': + logger.info("✅ RECOMMENDATION: Use the detected composite key") + elif status == 'max_retries_reached_with_duplicates': + logger.info("⚠️ RECOMMENDATION: Manual review needed - duplicates persist") + else: + logger.info(f"ℹ️ STATUS: {status}") + + logger.info("=" * 60) diff --git a/src/databricks/labs/dqx/profiler/generator.py b/src/databricks/labs/dqx/profiler/generator.py index f0849b60b..ef17ac995 100644 --- a/src/databricks/labs/dqx/profiler/generator.py +++ b/src/databricks/labs/dqx/profiler/generator.py @@ -163,9 +163,57 @@ def dq_generate_is_not_null_or_empty(column: str, level: str = "error", **params "criticality": level, } + @staticmethod + def dq_generate_is_unique(column: str, level: str = "error", **params: dict): + """Generates a data quality rule to check if specified columns are unique. + + Uses is_unique with nulls_distinct=True for uniqueness validation. + + Args: + column: Comma-separated list of column names that form the primary key. Uses all columns if not provided. + level: The criticality level of the rule (default is "error"). + params: Additional parameters including columns list, confidence, reasoning, etc. + + Returns: + A dictionary representing the data quality rule. + """ + columns = params.get("columns", column.split(",")) + + # Clean up column names (remove whitespace) + columns = [col.strip() for col in columns] + + confidence = params.get("confidence", "unknown") + reasoning = params.get("reasoning", "") + nulls_distinct = params.get("nulls_distinct", True) + llm_detected = params.get("llm_detected", False) + + # Create base metadata + user_metadata = { + "pk_detection_confidence": confidence, + "pk_detection_reasoning": reasoning, + "detected_primary_key": True, + } + + # Add LLM-specific metadata if this was LLM-detected + if llm_detected: + user_metadata.update( + {"llm_based_detection": True, "detection_method": "llm_analysis", "requires_llm_dependencies": True} + ) + + return { + "check": { + "function": "is_unique", + "arguments": {"columns": columns, "nulls_distinct": nulls_distinct}, + }, + "name": f"primary_key_{'_'.join(columns)}_validation", + "criticality": level, + "user_metadata": user_metadata, + } + _checks_mapping = { "is_not_null": dq_generate_is_not_null, "is_in": dq_generate_is_in, "min_max": dq_generate_min_max, "is_not_null_or_empty": dq_generate_is_not_null_or_empty, + "is_unique": dq_generate_is_unique, } diff --git a/src/databricks/labs/dqx/profiler/profiler.py b/src/databricks/labs/dqx/profiler/profiler.py index f3ddc407a..d09b76e17 100644 --- a/src/databricks/labs/dqx/profiler/profiler.py +++ b/src/databricks/labs/dqx/profiler/profiler.py @@ -18,11 +18,22 @@ from databricks.labs.dqx.base import DQEngineBase from databricks.labs.dqx.config import InputConfig -from databricks.labs.dqx.io import read_input_data -from databricks.labs.dqx.utils import list_tables + +from databricks.labs.dqx.utils import ( + list_tables, +) +from databricks.labs.dqx.io import read_input_data, STORAGE_PATH_PATTERN from databricks.labs.dqx.telemetry import telemetry_logger from databricks.labs.dqx.errors import InvalidParameterError +# Optional LLM imports +try: + from databricks.labs.dqx.llm.pk_identifier import DatabricksPrimaryKeyDetector + + HAS_LLM_DETECTOR = True +except ImportError: + HAS_LLM_DETECTOR = False + logger = logging.getLogger(__name__) @@ -116,6 +127,58 @@ def profile( return summary_stats, dq_rules + def detect_primary_keys_with_llm( + self, + table: str, + options: dict[str, Any] | None = None, + llm: bool = False, + ) -> dict[str, Any] | None: + """Detect primary key for a table using LLM-based analysis. + + This method requires LLM dependencies and will only work when explicitly requested. + + Args: + table: Fully qualified table name (e.g., 'catalog.schema.table') + options: Optional dictionary of options for PK detection + llm: Enable LLM-based detection + + Returns: + Dictionary with PK detection results or None if disabled/failed + """ + if not HAS_LLM_DETECTOR: + raise ImportError("LLM detector not available") + + if options is None: + options = {} + + # Check if LLM-based PK detection is explicitly requested + llm_enabled = llm or options.get("enable_llm_pk_detection", False) or options.get("llm", False) + + if not llm_enabled: + logger.debug("LLM-based PK detection not requested. Use llm=True to enable.") + return None + + try: + result, detector = self._run_llm_pk_detection(table, options) + + if result.get("success", False): + logger.info(f"✅ LLM-based primary key detected for {table}: {result['primary_key_columns']}") + detector.print_pk_detection_summary(result) + return result + + logger.warning( + f"❌ LLM-based primary key detection failed for {table}: {result.get('error', 'Unknown error')}" + ) + return None + + except (ValueError, RuntimeError, OSError) as e: + logger.error(f"Error during LLM-based primary key detection for {table}: {e}") + return None + except (AttributeError, TypeError, ImportError) as e: + logger.error(f"Unexpected error during LLM-based primary key detection for {table}: {e}") + logger.debug("Full traceback:", exc_info=True) + return None + @telemetry_logger("profiler", "profile_table") def profile_table( self, @@ -136,7 +199,182 @@ def profile_table( """ logger.info(f"Profiling {table} with options: {options}") df = read_input_data(spark=self.spark, input_config=InputConfig(location=table)) - return self.profile(df=df, columns=columns, options=options) + summary_stats, dq_rules = self.profile(df=df, columns=columns, options=options) + + # Add LLM-based primary key detection if explicitly requested + self._add_llm_primary_key_detection(table, options, summary_stats) + + return summary_stats, dq_rules + + def _add_llm_primary_key_detection( + self, table: str, options: dict[str, Any] | None, summary_stats: dict[str, Any] + ) -> None: + """ + Adds LLM-based primary key detection results to summary statistics if enabled. + + Args: + table: The fully-qualified table name (*catalog.schema.table*) + options: Optional dictionary of options for profiling + summary_stats: Summary statistics dictionary to update with PK detection results + """ + llm_enabled = options and ( + options.get("enable_llm_pk_detection", False) + or options.get("llm", False) + or options.get("llm_pk_detection", False) + ) + + if not llm_enabled: + return + + # Parse table name to extract catalog, schema, table (or use full path for files) + # No need to parse table components since we pass the full table name + + pk_result = self.detect_primary_keys_with_llm(table, options, llm=True) + + if pk_result and pk_result.get("success", False): + pk_columns = pk_result.get("primary_key_columns", []) + if pk_columns and pk_columns != ["none"]: + # Add to summary stats (but don't automatically generate rules) + summary_stats["llm_primary_key_detection"] = { + "detected_columns": pk_columns, + "confidence": pk_result.get("confidence", "unknown"), + "has_duplicates": pk_result.get("has_duplicates", False), + "validation_performed": pk_result.get("validation_performed", False), + "method": "llm_based", + } + + def _parse_table_name(self, table: str) -> tuple[str | None, str | None, str]: + """ + Parses a fully-qualified table name into its components. + + Args: + table: The fully-qualified table name (catalog.schema.table or schema.table or table) + + Returns: + A tuple of (catalog, schema, table_name) where catalog and schema can be None + """ + table_parts = table.split(".") + if len(table_parts) == 3: + return table_parts[0], table_parts[1], table_parts[2] + if len(table_parts) == 2: + return None, table_parts[0], table_parts[1] + return None, None, table_parts[0] + + def _is_file_path(self, name: str) -> bool: + """ + Determine if the given name is a file path rather than a table name. + + Args: + name: The name to check + + Returns: + True if it looks like a file path, False if it looks like a table name + """ + return bool(STORAGE_PATH_PATTERN.match(name)) + + def _run_llm_pk_detection(self, table: str, options: dict[str, Any] | None): + """Run LLM-based primary key detection for a table.""" + logger.info(f"🤖 Starting LLM-based primary key detection for {table}") + + if options and options.get("llm_pk_detection_endpoint"): + detector = DatabricksPrimaryKeyDetector( + table=table, + endpoint=options.get("llm_pk_detection_endpoint", ""), + validate_duplicates=options.get("llm_pk_validate_duplicates", True) if options else True, + spark_session=self.spark, + max_retries=options.get("llm_pk_max_retries", 3) if options else 3, + ) + else: # use default endpoint + detector = DatabricksPrimaryKeyDetector( + table=table, + validate_duplicates=options.get("llm_pk_validate_duplicates", True) if options else True, + spark_session=self.spark, + max_retries=options.get("llm_pk_max_retries", 3) if options else 3, + ) + + # Use generic detection method that works for both tables and paths + result = detector.detect_primary_keys() + return result, detector + + def _run_llm_pk_detection_for_dataframe( + self, df: DataFrame, options: dict[str, Any] | None, summary_stats: dict[str, Any] + ) -> None: + """Run LLM-based primary key detection for DataFrame.""" + if not HAS_LLM_DETECTOR: + raise ImportError("LLM detector not available") + + # Create a temporary view from the DataFrame for LLM analysis + temp_view_name = f"temp_dataframe_analysis_{id(df)}" + try: + df.createOrReplaceTempView(temp_view_name) + logger.info("🤖 Starting LLM-based primary key detection for DataFrame") + + detector = DatabricksPrimaryKeyDetector( + table=temp_view_name, + endpoint=(options.get("llm_pk_detection_endpoint") if options else None) + or "databricks-meta-llama-3-1-8b-instruct", + validate_duplicates=options.get("llm_pk_validate_duplicates", True) if options else True, + spark_session=self.spark, + max_retries=options.get("llm_pk_max_retries", 3) if options else 3, + show_live_reasoning=False, + ) + + # Use the public method for primary key detection + pk_result = detector.detect_primary_keys() + finally: + # Clean up the temporary view using SQL (Unity Catalog compatible) + try: + self.spark.sql(f"DROP VIEW IF EXISTS {temp_view_name}") + except Exception: + # Ignore cleanup errors + pass + + if pk_result and pk_result.get("success", False): + pk_columns = pk_result.get("primary_key_columns", []) + if pk_columns and pk_columns != ["none"]: + # Validate that detected columns actually exist in the DataFrame + valid_columns = [col for col in pk_columns if col in df.columns] + if valid_columns: + # Add to summary stats (but don't automatically generate rules) + summary_stats["llm_primary_key_detection"] = { + "detected_columns": valid_columns, + "confidence": pk_result.get("confidence", "unknown"), + "has_duplicates": False, # Not validated for DataFrames by default + "validation_performed": False, # DataFrame validation would require additional logic + "method": "llm_based_dataframe", + } + logger.info(f"✅ LLM-based primary key detected for DataFrame: {valid_columns}") + + def _add_llm_primary_key_detection_for_dataframe( + self, df: DataFrame, options: dict[str, Any] | None, summary_stats: dict[str, Any] + ) -> None: + """ + Adds LLM-based primary key detection results for DataFrames to summary statistics if enabled. + + Args: + df: The DataFrame to analyze + options: Optional dictionary of options for profiling + summary_stats: Summary statistics dictionary to update with PK detection results + """ + llm_enabled = options and ( + options.get("enable_llm_pk_detection", False) + or options.get("llm", False) + or options.get("llm_pk_detection", False) + ) + + if not llm_enabled: + return + + try: + self._run_llm_pk_detection_for_dataframe(df, options, summary_stats) + + except ImportError as e: + logger.warning(str(e)) + except (ValueError, RuntimeError, OSError) as e: + logger.error(f"Error during LLM-based primary key detection for DataFrame: {e}") + except (AttributeError, TypeError, KeyError) as e: + logger.error(f"Unexpected error during LLM-based primary key detection for DataFrame: {e}") + logger.debug("Full traceback:", exc_info=True) @telemetry_logger("profiler", "profile_tables_for_patterns") def profile_tables_for_patterns( @@ -314,6 +552,9 @@ def _profile( self._calculate_metrics(df, dq_rules, field_name, metrics, opts, total_count, typ) + # Add LLM-based primary key detection for DataFrames if enabled + self._add_llm_primary_key_detection_for_dataframe(df, opts, summary_stats) + def _calculate_metrics( self, df: DataFrame, diff --git a/src/databricks/labs/dqx/profiler/profiler_runner.py b/src/databricks/labs/dqx/profiler/profiler_runner.py index 180c08e77..285a7c880 100644 --- a/src/databricks/labs/dqx/profiler/profiler_runner.py +++ b/src/databricks/labs/dqx/profiler/profiler_runner.py @@ -63,16 +63,22 @@ def run( Returns: A tuple containing the generated checks and profile summary statistics. """ + # Build common profiling options + options = { + "sample_fraction": profiler_config.sample_fraction, + "sample_seed": profiler_config.sample_seed, + "limit": profiler_config.limit, + "filter": profiler_config.filter, + # LLM-based Primary Key Detection Options (available for both tables and DataFrames) + "enable_llm_pk_detection": profiler_config.llm_config.enable_pk_detection, + "llm_pk_detection_endpoint": profiler_config.llm_config.pk_detection_endpoint, + "llm_pk_validate_duplicates": True, # Always validate for duplicates + "llm_pk_max_retries": 3, # Fixed to 3 retries for optimal performance + } + df = read_input_data(self.spark, input_config) - summary_stats, profiles = self.profiler.profile( - df, - options={ - "sample_fraction": profiler_config.sample_fraction, - "sample_seed": profiler_config.sample_seed, - "limit": profiler_config.limit, - "filter": profiler_config.filter, - }, - ) + summary_stats, profiles = self.profiler.profile(df, options=options) + checks = self.generator.generate_dq_rules(profiles) # use default criticality level "error" logger.info(f"Using options: \n{profiler_config}") logger.info(f"Generated checks: \n{checks}") diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index 62dcba6ab..192a6be03 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -5,8 +5,8 @@ import re from typing import Any from fnmatch import fnmatch - from pyspark.sql import Column +from pyspark.sql import types as T # Import spark connect column if spark session is created using spark connect try: @@ -417,19 +417,72 @@ def _filter_tables_by_patterns(tables: list[str], patterns: list[str], exclude_m list[str]: A filtered list of table names based on the matching criteria. """ if exclude_matched: - return [table for table in tables if not _match_table_patterns(table, patterns)] - return [table for table in tables if _match_table_patterns(table, patterns)] + return [table for table in tables if not any(fnmatch(table, pattern) for pattern in patterns)] + return [table for table in tables if any(fnmatch(table, pattern) for pattern in patterns)] + +def spark_type_to_sql_type(spark_type: Any) -> str: + """ + Converts Spark data types to SQL-like string representations. + + Args: + spark_type (Any): The Spark DataType to convert -def _match_table_patterns(table: str, patterns: list[str]) -> bool: + Returns: + str: A string representation of the SQL type + """ + # Type mapping for better maintainability + type_mapping = { + T.StringType: "STRING", + T.IntegerType: "INT", + T.LongType: "BIGINT", + T.DoubleType: "DOUBLE", + T.FloatType: "FLOAT", + T.BooleanType: "BOOLEAN", + T.DateType: "DATE", + T.TimestampType: "TIMESTAMP", + } + + # Check simple type mappings first + for spark_type_class, sql_type in type_mapping.items(): + if isinstance(spark_type, spark_type_class): + return sql_type + + # Handle complex types + if isinstance(spark_type, T.DecimalType): + return f"DECIMAL({spark_type.precision},{spark_type.scale})" + if isinstance(spark_type, T.ArrayType): + return f"ARRAY<{spark_type_to_sql_type(spark_type.elementType)}>" + if isinstance(spark_type, T.MapType): + return f"MAP<{spark_type_to_sql_type(spark_type.keyType)},{spark_type_to_sql_type(spark_type.valueType)}>" + if isinstance(spark_type, T.StructType): + return "STRUCT<...>" # Simplified for LLM analysis + + # Default case + return str(spark_type).upper() + + +def generate_table_definition_from_dataframe(df, table: str = "dataframe_analysis") -> str: """ - Checks if a table name matches any of the provided wildcard patterns. + Generate a CREATE TABLE statement from a DataFrame schema. Args: - table (str): The table name to check. - patterns (list[str]): A list of wildcard patterns (e.g., 'catalog.schema.*') to match against the table name. + df (Any): The DataFrame to generate a table definition for + table (str): Name to use in the CREATE TABLE statement Returns: - bool: True if the table name matches any of the patterns, False otherwise. + A string representing a CREATE TABLE statement """ - return any(fnmatch(table, pattern) for pattern in patterns) + table_definition = f"CREATE TABLE {table} (\n" + + column_definitions = [] + for field in df.schema.fields: + # Convert Spark data types to SQL-like representation + sql_type = spark_type_to_sql_type(field.dataType) + nullable = "" if field.nullable else " NOT NULL" + column_definitions.append(f" {field.name} {sql_type}{nullable}") + + table_definition += ",\n".join(column_definitions) + table_definition += "\n)" + + return table_definition diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index 5c572800c..35ca4ebdb 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -23,6 +23,15 @@ from databricks.labs.dqx.schema import dq_result_schema from databricks.labs.dqx import check_funcs import databricks.labs.dqx.geo.check_funcs as geo_check_funcs + +# Import for LLM tests (conditional import handled in test) +try: + from databricks.labs.dqx.profiler.profiler import DQProfiler + + HAS_PROFILER = True +except ImportError: + HAS_PROFILER = False + from tests.integration.conftest import REPORTING_COLUMNS, RUN_TIME, EXTRA_PARAMS @@ -7409,6 +7418,129 @@ def test_compare_datasets_check(ws, spark, set_utc_timezone): assert_df_equality(checked.sort(pk_columns), expected.sort(pk_columns), ignore_nullable=True) +def _run_llm_pk_detection_test(ws, src_table): + """Helper function to run LLM primary key detection test logic.""" + if not HAS_PROFILER: + pytest.skip("DQProfiler not available") + + profiler = DQProfiler(ws) + + # Detect primary keys using actual LLM functionality + pk_detection_result = profiler.detect_primary_keys_with_llm( + src_table, + options={ + "enable_llm_pk_detection": True, + "llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct", + "llm_pk_validate_duplicates": False, # Skip duplicate validation for test speed + }, + llm=True, + ) + + # Skip test if LLM detection failed (e.g., endpoint not available) + if pk_detection_result is None or not pk_detection_result.get("success", False): + pytest.skip("LLM-based primary key detection not available or failed") + + return pk_detection_result + + +def _create_llm_dataset_check(pk_detection_result, detected_pk_columns, ref_table): + """Helper function to create dataset check with LLM-detected primary key.""" + return DQDatasetRule( + name="llm_detected_pk_compare_datasets", + criticality="error", + check_func=check_funcs.compare_datasets, + columns=detected_pk_columns, + filter="customer_id != 1002", # Filter out the middle record + check_func_kwargs={"ref_columns": detected_pk_columns, "ref_table": ref_table}, + user_metadata={ + "test_tag": "llm_integration", + "llm_detected_pk": True, + "pk_detection_confidence": pk_detection_result["confidence"], + "pk_detection_reasoning": pk_detection_result["reasoning"], + }, + ) + + +def _verify_llm_test_results(checked, detected_pk_columns): + """Helper function to verify LLM test results.""" + # Verify that the check was applied and used the LLM-detected primary key + assert checked is not None + assert "dq_issues" in checked.columns + assert "dq_validations" in checked.columns + + # Check that the metadata includes LLM detection information + issues_rows = checked.filter(checked.dq_issues.isNotNull()).collect() + if len(issues_rows) > 0: + issue_data = issues_rows[0]["dq_issues"][0] + assert issue_data["name"] == "llm_detected_pk_compare_datasets" + assert issue_data["user_metadata"]["llm_detected_pk"] is True + assert "pk_detection_confidence" in issue_data["user_metadata"] + assert "pk_detection_reasoning" in issue_data["user_metadata"] + # Verify the columns used match what LLM detected + assert issue_data["columns"] == detected_pk_columns + + +def test_compare_datasets_check_with_llm_pk_detection(ws, spark, set_utc_timezone): + """Test compare_datasets check using LLM-based primary key detection.""" + pytest.importorskip("dspy", reason="dspy not available") + pytest.importorskip("databricks_langchain", reason="databricks_langchain not available") + + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + + schema = "customer_id long, order_id long, product_name string, order_date date, created_at timestamp, amount float, quantity bigint, is_active boolean" + + src_df = spark.createDataFrame( + [ + [1001, 2001, "Laptop", datetime(2023, 1, 15), datetime(2023, 1, 15, 10, 30, 0), 1299.99, 1, True], + [1002, 2002, "Mouse", datetime(2023, 1, 16), datetime(2023, 1, 16, 14, 45, 0), 29.99, 2, True], + [1003, 2003, "Keyboard", datetime(2023, 1, 17), datetime(2023, 1, 17, 9, 15, 0), 89.99, 1, False], + ], + schema, + ) + + ref_df = spark.createDataFrame( + [ + # diff in amount + [1001, 2001, "Laptop", datetime(2023, 1, 15), datetime(2023, 1, 15, 10, 30, 0), 1399.99, 1, True], + # no diff + [1003, 2003, "Keyboard", datetime(2023, 1, 17), datetime(2023, 1, 17, 9, 15, 0), 89.99, 1, False], + # missing record in src + [1004, 2004, "Monitor", datetime(2023, 1, 18), datetime(2023, 1, 18, 11, 0, 0), 299.99, 1, True], + ], + schema, + ) + + # Create temporary tables for LLM analysis + src_table = "test_orders_src" + ref_table = "test_orders_ref" + + src_df.createOrReplaceTempView(src_table) + ref_df.createOrReplaceTempView(ref_table) + + try: + # Use the profiler to detect primary keys with real LLM + pk_detection_result = _run_llm_pk_detection_test(ws, src_table) + detected_pk_columns = pk_detection_result["primary_key_columns"] + + # Use the detected primary key columns in the compare_datasets check + checks = [_create_llm_dataset_check(pk_detection_result, detected_pk_columns, ref_table)] + + checked = dq_engine.apply_checks(src_df, checks) + _verify_llm_test_results(checked, detected_pk_columns) + + except ImportError as e: + pytest.skip(f"LLM dependencies not available: {e}") + except Exception as e: + pytest.skip(f"LLM-based detection failed (possibly endpoint unavailable): {e}") + finally: + # Clean up temporary views + try: + spark.sql(f"DROP VIEW IF EXISTS {src_table}") + spark.sql(f"DROP VIEW IF EXISTS {ref_table}") + except Exception: + pass + + def test_compare_datasets_check_missing_records(ws, spark, set_utc_timezone): dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) diff --git a/tests/unit/test_llm_based_pk_identifier.py b/tests/unit/test_llm_based_pk_identifier.py new file mode 100644 index 000000000..73188492a --- /dev/null +++ b/tests/unit/test_llm_based_pk_identifier.py @@ -0,0 +1,189 @@ +""" +LLM-based primary key identifier. +""" + +from unittest.mock import Mock +import pytest + +# Check LLM dependencies availability +try: + from databricks.labs.dqx.llm.pk_identifier import DatabricksPrimaryKeyDetector + + HAS_LLM_DEPS = True +except ImportError: + DatabricksPrimaryKeyDetector = type(None) # type: ignore + HAS_LLM_DEPS = False + + +# Test helper classes +class MockSparkManager: + """Test double for SparkManager.""" + + def __init__(self, table_definition="", metadata_info="", should_raise=False): + self.table_definition = table_definition + self.metadata_info = metadata_info + self.should_raise = should_raise + + def get_table_definition(self, _table_name, _catalog=None, _schema=None): + if self.should_raise: + raise ValueError("Table not found") + return self.table_definition + + def get_table_metadata_info(self, _table_name, _catalog=None, _schema=None): + if self.should_raise: + raise ValueError("Metadata not available") + return self.metadata_info + + def check_duplicates(self, _table_name, _pk_columns, _catalog=None, _schema=None): + # Default: no duplicates found + return False, 0 + + +class MockDetector: + """Test double for DSPy detector.""" + + def __init__(self, primary_key_columns="", confidence="high", reasoning=""): + self.primary_key_columns = primary_key_columns + self.confidence = confidence + self.reasoning = reasoning + + def __call__(self, **kwargs): + result = Mock() + result.primary_key_columns = self.primary_key_columns + result.confidence = self.confidence + result.reasoning = self.reasoning + return result + + +@pytest.mark.skipif(not HAS_LLM_DEPS, reason="LLM dependencies (dspy, databricks_langchain) not installed") +def test_detect_primary_key_simple(): + """Test simple primary key detection with mocked table schema and LLM.""" + + # Mock table definition and metadata + mock_table_definition = """ + CREATE TABLE customers ( + customer_id BIGINT NOT NULL, + first_name STRING, + last_name STRING, + email STRING, + created_at TIMESTAMP + ) + """ + mock_metadata = "Table: customers, Columns: 5, Primary constraints: None" + + # Create mock ChatDatabricks class + mock_chat_databricks = Mock() + + # Create detector instance with injected mock + detector = DatabricksPrimaryKeyDetector( + table="customers", + endpoint="mock-endpoint", + validate_duplicates=False, + show_live_reasoning=False, + spark_session=None, # Explicitly pass None to avoid Spark session creation + chat_databricks_cls=mock_chat_databricks, + ) + + # Inject test doubles + detector.spark_manager = MockSparkManager(mock_table_definition, mock_metadata) + detector.detector = MockDetector( + primary_key_columns="customer_id", + confidence="high", + reasoning="customer_id is a unique identifier, non-nullable bigint typically used as primary key", + ) + + # Test primary key detection + result = detector.detect_primary_keys() + + # Assertions + assert result["success"] is True + assert result["primary_key_columns"] == ["customer_id"] + assert "customer_id" in result["reasoning"] + + +@pytest.mark.skipif(not HAS_LLM_DEPS, reason="LLM dependencies (dspy, databricks_langchain) not installed") +def test_detect_primary_key_composite(): + """Test detection of composite primary key.""" + + # Mock table definition and metadata + mock_table_definition = """ + CREATE TABLE order_items ( + order_id BIGINT NOT NULL, + product_id BIGINT NOT NULL, + quantity INT, + price DECIMAL + ) + """ + mock_metadata = "Table: order_items, Columns: 4, Primary constraints: None" + + # Create mock ChatDatabricks class + mock_chat_databricks = Mock() + + # Create detector instance with injected mock + detector = DatabricksPrimaryKeyDetector( + table="order_items", + endpoint="mock-endpoint", + validate_duplicates=False, + show_live_reasoning=False, + spark_session=None, # Explicitly pass None to avoid Spark session creation + chat_databricks_cls=mock_chat_databricks, + ) + + # Inject test doubles + detector.spark_manager = MockSparkManager(mock_table_definition, mock_metadata) + detector.detector = MockDetector( + primary_key_columns="order_id, product_id", + confidence="high", + reasoning="Combination of order_id and product_id forms composite primary key for order items", + ) + + # Test primary key detection + result = detector.detect_primary_keys() + + # Assertions + assert result["success"] is True + assert result["primary_key_columns"] == ["order_id", "product_id"] + + +@pytest.mark.skipif(not HAS_LLM_DEPS, reason="LLM dependencies (dspy, databricks_langchain) not installed") +def test_detect_primary_key_no_clear_key(): + """Test when LLM cannot identify a clear primary key.""" + + # Mock table definition and metadata + mock_table_definition = """ + CREATE TABLE application_logs ( + timestamp TIMESTAMP, + level STRING, + message STRING, + source STRING + ) + """ + mock_metadata = "Table: application_logs, Columns: 4, Primary constraints: None" + + # Create mock ChatDatabricks class + mock_chat_databricks = Mock() + + # Create detector instance with injected mock + detector = DatabricksPrimaryKeyDetector( + table="application_logs", + endpoint="mock-endpoint", + validate_duplicates=False, + show_live_reasoning=False, + spark_session=None, # Explicitly pass None to avoid Spark session creation + chat_databricks_cls=mock_chat_databricks, + ) + + # Inject test doubles + detector.spark_manager = MockSparkManager(mock_table_definition, mock_metadata) + detector.detector = MockDetector( + primary_key_columns="none", + confidence="low", + reasoning="No clear primary key identified - all columns are nullable and none appear to be unique identifiers", + ) + + # Test primary key detection + result = detector.detect_primary_keys() + + # Assertions + assert result["success"] is True + assert result["primary_key_columns"] == ["none"]