diff --git a/Makefile b/Makefile index 49b3a494b..46ca2f881 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ clean: docs-clean .venv/bin/python: pip install hatch hatch env create + hatch run pip install ".[llm,pii]" dev: .venv/bin/python @hatch run which python diff --git a/pyproject.toml b/pyproject.toml index b24a127f7..f40dc64ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ pii = [ # This may be required for the larger models due to Databricks connect memory limitations. # The models cannot be delcared as dependency here buecase PyPI does not support URL-based dependencies which would prevent releases. ] +llm = ["dspy~=3.0.3"] [project.entry-points.databricks] runtime = "databricks.labs.dqx.workflows_runner:main" @@ -77,7 +78,7 @@ include = ["src"] path = "src/databricks/labs/dqx/__about__.py" [tool.hatch.envs.default] -features = ["pii"] +features = ["pii", "llm"] dependencies = [ "black~=24.8.0", "chispa~=0.10.1", diff --git a/src/databricks/labs/dqx/llm/__init__.py b/src/databricks/labs/dqx/llm/__init__.py index e69de29bb..2e2379a9d 100644 --- a/src/databricks/labs/dqx/llm/__init__.py +++ b/src/databricks/labs/dqx/llm/__init__.py @@ -0,0 +1,11 @@ +from importlib.util import find_spec + +required_specs = [ + "dspy", +] + +# Check if required llm packages are installed +if not all(find_spec(spec) for spec in required_specs): + raise ImportError( + "llm extras not installed. Install additional dependencies by running `pip install databricks-labs-dqx[llm]`." + ) diff --git a/src/databricks/labs/dqx/llm/llm_core.py b/src/databricks/labs/dqx/llm/llm_core.py new file mode 100644 index 000000000..e1abf3e9b --- /dev/null +++ b/src/databricks/labs/dqx/llm/llm_core.py @@ -0,0 +1,277 @@ +import json +import logging +import dspy # type: ignore +from databricks.labs.dqx.llm.llm_utils import create_optimizer_training_set +from databricks.labs.dqx.engine import DQEngineCore + +logger = logging.getLogger(__name__) + + +class SchemaGuesserSignature(dspy.Signature): + """ + Guess a table schema based on business description. + + This class defines the schema for inferring complete table structure from + natural language descriptions. + """ + + business_description: str = dspy.InputField( + desc=( + "Natural language summary of the dataset and its use. " + "Including some column hints (e.g., id, amount, status, email, dates)." + ) + ) + guessed_schema_json: str = dspy.OutputField( + desc=( + "Strict JSON with shape: " + '{"columns":[{"name":"","type":"","example":""}]}. ' + "Prefer: ids:string, money:decimal(18,2), timestamps:timestamp, dates:date. " + "Return one line JSON with no extra text." + ) + ) + assumptions_bullets: str = dspy.OutputField( + desc=( + "Concise bullet list (1-6 lines) of assumptions made about columns, types, " + "and examples. Keep each bullet short." + ) + ) + + +class SchemaGuesser(dspy.Module): + """ + Guess table schema from business description. + + This class provides functionality to infer a complete table schema based on + natural language descriptions of the dataset and its intended use. + """ + + def __init__(self): + super().__init__() + self.guess = dspy.ChainOfThought(SchemaGuesserSignature) + + def forward(self, business_description: str) -> dspy.primitives.prediction.Prediction: + """ + Guess schema based on business description. + + Args: + business_description (str): Natural language description of the dataset and its use case. + + Returns: + dspy.primitives.prediction.Prediction: A Prediction object containing the guessed schema + and assumptions made during the inference process. + """ + return self.guess(business_description=business_description) + + +class RuleSignature(dspy.Signature): + """ + Generate data quality rules with improved output format. + + This class defines the schema for generating data quality rules based on + schema information, business descriptions, and available functions. + """ + + schema_info: str = dspy.InputField(desc="JSON string of table schema with column names, types, and sample data") + business_description: str = dspy.InputField(desc="Natural language description of data quality requirements") + available_functions: str = dspy.InputField(desc="JSON string of available DQX check functions") + quality_rules: str = dspy.OutputField( + desc=( + "Return a valid JSON array of data quality rules. Use double quotes only. " + "Criticality can be error or warn. " + "Check function name and doc to select the appropriate check function. " + "Format: [{\"criticality\":\"error\",\"check\":{\"function\":\"name\",\"arguments\":{\"column\":\"col\"}}}] " + "Example: [{\"criticality\":\"error\",\"check\":{\"function\":\"is_not_null\",\"arguments\":{\"column\":\"customer_id\"}}}]" + ) + ) + reasoning: str = dspy.OutputField(desc="Explanation of why these rules were chosen") + + +class DQRuleGeneration(dspy.Module): + """ + Generate data quality rules with improved JSON output reliability. + + This class provides functionality to generate data quality rules based on schema information, + business descriptions, and available functions. It can optionally infer the schema from the + business description if schema_info is not provided or is empty. + """ + + def __init__(self, schema_guesser: SchemaGuesser = SchemaGuesser()): + """ + Initialize the DQ rule generation module. + + Args: + schema_guesser (SchemaGuesser): Schema guesser instance for inferring schema when needed. + """ + super().__init__() + # Use Predict for reliable output + self.generator = dspy.Predict(RuleSignature) + self.schema_guesser = schema_guesser + + def forward( + self, schema_info: str, business_description: str, available_functions: str + ) -> dspy.primitives.prediction.Prediction: + """ + Generate data quality rules based on schema information, business descriptions, and available functions. + + If schema_info is empty and enable_schema_inference is True, it will first use SchemaGuesser + to infer the schema from the business description. + + Args: + schema_info (str): JSON string containing table schema with column names, types, and sample data. + If empty and enable_schema_inference=True, schema will be inferred. + business_description (str): Natural language description of data quality requirements. + available_functions (str): JSON string of available DQX check functions. + + Returns: + dspy.primitives.prediction.Prediction: A Prediction object containing the generated data quality rules, + reasoning, and optionally guessed_schema_json and assumptions_bullets if schema was inferred. + """ + # Step 1: Infer schema if needed + guessed_schema_json = None + assumptions_bullets = None + + if not schema_info or not schema_info.strip(): + logger.info("Inferring schema from business description...") + schema_result = self.schema_guesser(business_description=business_description) + schema_info = schema_result.guessed_schema_json + guessed_schema_json = schema_result.guessed_schema_json + assumptions_bullets = schema_result.assumptions_bullets + logger.info(f"Inferred schema: {schema_info}") + else: + logger.debug(f"Using provided schema: {schema_info}") + + # Step 2: Generate rules using the schema (provided or inferred) + result = self.generator( + schema_info=schema_info, business_description=business_description, available_functions=available_functions + ) + + # Validate and clean the JSON output + if result.quality_rules: + try: + # Try to parse the JSON to ensure it's valid + json.loads(result.quality_rules) + except json.JSONDecodeError as e: + logger.warning(f"Generated invalid JSON: {e}. Raw output: {result.quality_rules}") + # Return a fallback empty array if JSON is invalid + result.quality_rules = "[]" + + # Add schema inference results to the prediction if they exist + if guessed_schema_json: + result.guessed_schema_json = guessed_schema_json + result.assumptions_bullets = assumptions_bullets + result.schema_info = schema_info + + # Enhance reasoning to show that schema was inferred + original_reasoning = result.reasoning if hasattr(result, 'reasoning') else "" + result.reasoning = ( + f"[Schema Inference] The schema was automatically inferred from the question:\n" + f"{guessed_schema_json}\n\n" + f"Assumptions made:\n{assumptions_bullets}\n\n" + f"[Rule Generation] {original_reasoning}" + ) + + return result + + +def _configure_dspy_model(model: str, api_key: str = "", api_base: str = ""): + """ + Configure the Dspy language model. + + Args: + model (str): The model to use for the Dspy language model. + api_key (str): The API key for the model. Not required by Databricks foundational models. + api_base (str): The API base URL for the model. Not required by Databricks foundational models. + """ + language_model = dspy.LM( + model=model, + model_type="chat", + api_key=api_key, + api_base=api_base, + max_retries=3, + ) + dspy.configure(lm=language_model) + + +def validate_generated_rules(actual: str) -> float: + """ + Validate generated rules with granular scoring for better optimizer feedback. + + Scoring breakdown: + - JSON parsing (40%): Checks if the actual output can be parsed as valid JSON. + - Rules validation (60%): Ensures the rules pass DQX validation checks. + + Args: + actual (str): JSON string of the actual generated rules. + + Returns: + float: A score between 0.0 and 1.0 representing the quality of the generated rules. + """ + total_score = 0.0 + + # Score weights + json_weight = 0.2 + rule_weight = 0.8 + + # Json parsing score (40%) + try: + actual_rules = json.loads(actual) + total_score += json_weight + logger.debug(f"✓ JSON parsing successful (+{json_weight:.1f})") + except json.JSONDecodeError as e: + logger.warning(f"✗ JSON parsing failed: {e}") + logger.debug(f" Raw output: {repr(actual[:200])}") + # Early return if we can't parse JSON at all + return total_score + + # Rules validation score (60%) + validation_status = DQEngineCore.validate_checks(actual_rules) + if not validation_status.has_errors: + total_score += rule_weight + logger.debug(f"✓ Rules validation passed (+{rule_weight:.1f})") + else: + logger.warning(f"✗ Rules validation errors: {validation_status.errors}") + + logger.debug(f"Final score: {total_score:.2f}") + return total_score + + +def get_dspy_compiler( + model: str = "databricks/databricks-meta-llama-3-3-70b-instruct", + api_key: str = "", + api_base: str = "", +) -> dspy.Module: + """ + Get the Dspy compiler configured with an optimizer. + + This function initializes and configures the Dspy compiler with a training set and an optimizer + to validate and optimize the generated data quality rules. + + Args: + model (str): The model to use for the Dspy language model. + api_key (str): The API key for the model. Not required by Databricks foundational models. + api_base (str): The API base URL for the model. Not required by Databricks foundational models. + + Returns: + dspy.Module: An optimized Dspy module for generating data quality rules. + """ + _configure_dspy_model(api_key=api_key, api_base=api_base, model=model) + + # Use standard DSPy approach with improved prompting + dq_model = DQRuleGeneration() + train_set = create_optimizer_training_set() + + # Standard metric for JSON output validation + def json_metric(_example, pred, _trace=None): + if hasattr(pred, 'quality_rules'): + return validate_generated_rules(pred.quality_rules) + return 0.0 + + optimizer = dspy.BootstrapFewShot( + metric=json_metric, + max_bootstrapped_demos=3, + max_labeled_demos=5, + teacher_settings={}, + ) + + optimized_model = optimizer.compile(dq_model, trainset=train_set) + return optimized_model diff --git a/src/databricks/labs/dqx/llm/llm_engine.py b/src/databricks/labs/dqx/llm/llm_engine.py new file mode 100644 index 000000000..b12fe771a --- /dev/null +++ b/src/databricks/labs/dqx/llm/llm_engine.py @@ -0,0 +1,35 @@ +import json +import dspy # type: ignore +from databricks.labs.dqx.llm.llm_core import DQRuleGeneration +from databricks.labs.dqx.llm.llm_utils import _get_required_check_function_info + + +def get_business_rules_with_llm( + user_input: str, dspy_compiler: DQRuleGeneration, schema_info: str = "" +) -> dspy.primitives.prediction.Prediction: + """ + Get DQX rules based on natural language request with optional schema information. + + If schema_info is empty (default) and the dspy_compiler has schema inference enabled, + it will automatically infer the schema from the user_input before generating rules. + + Args: + user_input: Natural language description of data quality requirements. + dspy_compiler: The compiled DQRuleGeneration model. + schema_info: Optional JSON string containing table schema with column names and types. + If empty (default), triggers schema inference if enabled. + If provided, uses that schema directly. + + Returns: + Prediction: A Prediction object containing: + - quality_rules: The generated DQ rules + - reasoning: Explanation of the rules (includes schema inference info if schema was inferred) + - guessed_schema_json: The inferred schema (if schema was inferred) + - assumptions_bullets: Assumptions made about schema (if schema was inferred) + - schema_info: The final schema used (if schema was inferred) + """ + return dspy_compiler( + schema_info=schema_info, + business_description=user_input, + available_functions=json.dumps(_get_required_check_function_info()), + ) diff --git a/src/databricks/labs/dqx/llm/llm_utils.py b/src/databricks/labs/dqx/llm/llm_utils.py new file mode 100644 index 000000000..14a6f37f3 --- /dev/null +++ b/src/databricks/labs/dqx/llm/llm_utils.py @@ -0,0 +1,128 @@ +import logging +import inspect +from collections.abc import Callable +from importlib.resources import files +from pathlib import Path +from typing import Any +import json +import yaml +import dspy # type: ignore +from pyspark.sql import SparkSession +from databricks.labs.dqx.checks_resolver import resolve_check_function +from databricks.labs.dqx.rule import CHECK_FUNC_REGISTRY + +logger = logging.getLogger(__name__) + + +def get_check_function_definition(custom_check_functions: dict[str, Callable] | None = None) -> list[dict[str, str]]: + """ + A utility function to get the definition of all check functions. + This function is primarily used to generate a prompt for the LLM to generate check functions. + + If provided, the function will use the custom check functions to resolve the check function. + If not provided, the function will use only the built-in check functions. + + Args: + custom_check_functions: A dictionary of custom check functions. + + Returns: + list[dict]: A list of dictionaries, each containing the definition of a check function. + """ + function_docs: list[dict[str, str]] = [] + for name, func_type in CHECK_FUNC_REGISTRY.items(): + func = resolve_check_function(name, custom_check_functions, fail_on_missing=False) + if func is None: + logger.warning(f"Check function {name} not found in the registry") + continue + sig = inspect.signature(func) + doc = inspect.getdoc(func) + function_docs.append( + { + "name": name, + "type": func_type, + "doc": doc or "", + "signature": str(sig), + "parameters": str(sig.parameters), + "implementation": inspect.getsource(func), + } + ) + return function_docs + + +def get_column_metadata(table_name: str, spark: SparkSession) -> str: + """ + Get the column metadata for a given table. + + Args: + table_name (str): The name of the table to retrieve metadata for. + spark (SparkSession): The Spark session used to access the table. + + Returns: + str: A JSON string containing the column metadata with columns wrapped in a "columns" key. + """ + df = spark.table(table_name) + columns = [{"name": field.name, "type": field.dataType.simpleString()} for field in df.schema.fields] + schema_info = {"columns": columns} + return json.dumps(schema_info) + + +def load_training_examples() -> list[dict[str, Any]]: + """A function to Load the training_examples.yml file from the llm/resources folder. + + Returns: + list[dict[str, Any]]: Training examples as a list of dictionaries. + """ + resource = Path(str(files("databricks.labs.dqx.llm.resources") / "training_examples.yml")) + + training_examples_as_text = resource.read_text(encoding="utf-8") + training_examples = yaml.safe_load(training_examples_as_text) + if not isinstance(training_examples, list): + raise ValueError("YAML file must contain a list at the root level.") + + return training_examples + + +def _get_required_check_function_info() -> list[dict[str, str]]: + """ + Extract only required function information (name and doc). + + Returns: + list[dict[str, str]]: A list of dictionaries containing the name, doc, type, signature, and parameters of each function. + """ + required_function_docs: list[dict[str, str]] = [] + for func in get_check_function_definition(): + required_func_info = { + "check_function_name": func.get("name", ""), + "parameters": func.get("parameters", ""), + } + required_function_docs.append(required_func_info) + return required_function_docs + + +def create_optimizer_training_set() -> list[dspy.Example]: + """ + Get examples for the dspy optimizer. + + Returns: + list[dspy.Example]: A list of dspy.Example objects created from training examples. + """ + training_examples = load_training_examples() + + examples = [] + available_functions = json.dumps(_get_required_check_function_info()) + + for example_data in training_examples: + # Convert schema_info to JSON string format expected by dspy.Example + schema_info_json = json.dumps(example_data["schema_info"]) + + example = dspy.Example( + schema_info=schema_info_json, + business_description=example_data["business_description"], + available_functions=available_functions, + quality_rules=example_data["quality_rules"], + reasoning=example_data["reasoning"], + ).with_inputs("schema_info", "business_description", "available_functions") + + examples.append(example) + + return examples diff --git a/src/databricks/labs/dqx/llm/resources/training_examples.yml b/src/databricks/labs/dqx/llm/resources/training_examples.yml new file mode 100644 index 000000000..ce602d6b2 --- /dev/null +++ b/src/databricks/labs/dqx/llm/resources/training_examples.yml @@ -0,0 +1,93 @@ +- name: "product_code_not_null_or_empty" + schema_info: + columns: + - name: "product_code" + type: "string" + business_description: "Product code must always be present and not empty" + quality_rules: '[{"criticality":"error","check":{"function":"is_not_null_and_not_empty","arguments":{"column":"product_code"}}}]' + reasoning: "Product code is a key identifier and cannot be missing or blank" + +- name: "status_in_allowed_values" + schema_info: + columns: + - name: "status" + type: "integer" + business_description: "Status must be one of the allowed values: 1=Active, 2=Inactive, 3=Pending" + quality_rules: '[{"criticality":"error","check":{"function":"is_in_list","arguments":{"column":"status","allowed":[1,2,3]}}}]' + reasoning: "Restrict status to predefined set of values for data consistency" + +- name: "order_quantity_in_range" + schema_info: + columns: + - name: "quantity" + type: "integer" + business_description: "Order quantity must be between 1 and 1000" + quality_rules: '[{"criticality":"error","check":{"function":"is_in_range","arguments":{"column":"quantity","min_limit":1,"max_limit":1000}}}]' + reasoning: "Quantities outside this range are invalid business values" + +- name: "order_date_valid_range" + schema_info: + columns: + - name: "order_date" + type: "date" + business_description: "Order date must be between January 1, 2025 and December 31, 2025" + quality_rules: '[{"criticality":"error","check":{"function":"is_in_range","arguments":{"column":"order_date","min_limit":"2025-01-01","max_limit":"2025-12-31"}}}]' + reasoning: "Orders can only be created within the current business year" + +- name: "unique_customer_email" + schema_info: + columns: + - name: "email" + type: "string" + business_description: "Customer email must be unique across all records" + quality_rules: '[{"criticality":"error","check":{"function":"is_unique","arguments":{"columns":["email"]}}}]' + reasoning: "Emails uniquely identify customers and duplicates must not exist" + +- name: "unique_customer_order" + schema_info: + columns: + - name: "customer_id" + type: "string" + - name: "order_id" + type: "string" + business_description: "Each customer can only have a unique order id" + quality_rules: '[{"criticality":"error","check":{"function":"is_unique","arguments":{"columns":["customer_id","order_id"]}}}]' + reasoning: "The combination of customer and order id should not be duplicated" + +- name: "description_no_pii" + schema_info: + columns: + - name: "description" + type: "string" + business_description: "Free-text description must not contain personally identifiable information (PII)" + quality_rules: '[{"criticality":"error","check":{"function":"does_not_contain_pii","arguments":{"column":"description"}}}]' + reasoning: "To comply with privacy regulations, no PII should be stored in descriptions" + +- name: "customer_id_must_exist_in_reference" + schema_info: + columns: + - name: "customer_id" + type: "string" + business_description: "Each customer_id in orders must exist in the customers table" + quality_rules: '[{"criticality":"error","check":{"function":"foreign_key","arguments":{"columns":["customer_id"],"ref_columns":["id"],"ref_table":"sales.customers"}}}]' + reasoning: "Foreign key relationship ensures data integrity between orders and customers" + +- name: "discount_not_greater_than_price" + schema_info: + columns: + - name: "price" + type: "decimal" + - name: "discount" + type: "decimal" + business_description: "Discount value must never exceed price" + quality_rules: '[{"criticality":"error","check":{"function":"sql_expression","arguments":{"expression":"discount <= price","msg":"Discount is greater than price"}}}]' + reasoning: "A discount higher than price is invalid business logic" + +- name: "event_timestamp_recent" + schema_info: + columns: + - name: "event_time" + type: "timestamp" + business_description: "Event timestamp must be within the last 24 hours" + quality_rules: '[{"criticality":"error","check":{"function":"is_data_fresh","arguments":{"column":"event_time","max_age_minutes":1440}}}]' + reasoning: "Events must reflect recent activity, stale records are invalid" diff --git a/src/databricks/labs/dqx/llm/resources/yaml_checks_examples.yml b/src/databricks/labs/dqx/llm/resources/yaml_checks_examples.yml deleted file mode 100644 index f127b75f2..000000000 --- a/src/databricks/labs/dqx/llm/resources/yaml_checks_examples.yml +++ /dev/null @@ -1,838 +0,0 @@ -- criticality: error - check: - function: is_not_null - arguments: - column: col1 -- criticality: error - check: - function: is_not_empty - arguments: - column: col1 -- criticality: error - check: - function: is_not_null_and_not_empty - arguments: - column: col1 - trim_strings: true -- criticality: error - check: - function: is_in_list - arguments: - column: col2 - allowed: - - 1 - - 2 - - 3 -- criticality: error - check: - function: is_not_null_and_is_in_list - arguments: - column: col2 - allowed: - - 1 - - 2 - - 3 -- criticality: error - check: - function: is_not_null_and_not_empty_array - arguments: - column: col4 -- criticality: error - check: - function: is_in_range - arguments: - column: col2 - min_limit: 1 - max_limit: 10 -- criticality: error - check: - function: is_in_range - arguments: - column: col5 - min_limit: 2025-01-01 - max_limit: 2025-02-24 -- criticality: error - check: - function: is_in_range - arguments: - column: col6 - min_limit: 2025-01-01 00:00:00 - max_limit: 2025-02-24 01:00:00 -- criticality: error - check: - function: is_in_range - arguments: - column: col3 - min_limit: col2 - max_limit: col2 * 2 -- criticality: error - check: - function: is_not_in_range - arguments: - column: col2 - min_limit: 11 - max_limit: 20 -- criticality: error - check: - function: is_not_in_range - arguments: - column: col5 - min_limit: 2025-02-25 - max_limit: 2025-02-26 -- criticality: error - check: - function: is_not_in_range - arguments: - column: col6 - min_limit: 2025-02-25 00:00:00 - max_limit: 2025-02-26 01:00:00 -- criticality: error - check: - function: is_not_in_range - arguments: - column: col3 - min_limit: col2 + 10 - max_limit: col2 * 10 -- criticality: error - check: - function: is_equal_to - arguments: - column: col10 - value: 2 -- criticality: error - check: - function: is_equal_to - arguments: - column: col3 - value: col2 -- criticality: error - check: - function: is_not_equal_to - arguments: - column: col1 - value: '''unknown''' -- criticality: error - check: - function: is_not_equal_to - arguments: - column: col5 - value: 2025-02-24 -- criticality: error - check: - function: is_not_equal_to - arguments: - column: col6 - value: 2025-02-24 01:00:00 -- criticality: error - check: - function: is_not_equal_to - arguments: - column: col3 - value: col2 + 5 -- criticality: error - check: - function: is_not_less_than - arguments: - column: col2 - limit: 0 -- criticality: error - check: - function: is_not_less_than - arguments: - column: col5 - limit: 2025-01-01 -- criticality: error - check: - function: is_not_less_than - arguments: - column: col6 - limit: 2025-01-01 01:00:00 -- criticality: error - check: - function: is_not_less_than - arguments: - column: col3 - limit: col2 - 10 -- criticality: error - check: - function: is_not_greater_than - arguments: - column: col2 - limit: 10 -- criticality: error - check: - function: is_not_greater_than - arguments: - column: col5 - limit: 2025-03-01 -- criticality: error - check: - function: is_not_greater_than - arguments: - column: col6 - limit: 2025-03-24 01:00:00 -- criticality: error - check: - function: is_not_greater_than - arguments: - column: col3 - limit: col2 + 10 -- criticality: error - check: - function: is_valid_date - arguments: - column: col5 -- criticality: error - name: col5_is_not_valid_date2 - check: - function: is_valid_date - arguments: - column: col5 - date_format: yyyy-MM-dd -- criticality: error - check: - function: is_valid_timestamp - arguments: - column: col6 - timestamp_format: yyyy-MM-dd HH:mm:ss -- criticality: error - name: col6_is_not_valid_timestamp2 - check: - function: is_valid_timestamp - arguments: - column: col6 -- criticality: error - check: - function: is_not_in_future - arguments: - column: col6 - offset: 86400 -- criticality: error - check: - function: is_not_in_near_future - arguments: - column: col6 - offset: 36400 -- criticality: error - check: - function: is_older_than_n_days - arguments: - column: col5 - days: 10000 -- criticality: error - check: - function: is_older_than_col2_for_n_days - arguments: - column1: col5 - column2: col6 - days: 2 -- criticality: error - check: - function: regex_match - arguments: - column: col2 - regex: '[0-9]+' - negate: false -- criticality: error - check: - function: is_valid_ipv4_address - arguments: - column: col2 -- criticality: error - check: - function: is_ipv4_address_in_cidr - arguments: - column: col2 - cidr_block: 192.168.1.0/24 -- criticality: error - check: - function: is_valid_ipv6_address - arguments: - column: col_ipv6 -- criticality: error - check: - function: is_ipv6_address_in_cidr - arguments: - column: col_ipv6 - cidr_block: 2001:0db8:85a3:08d3:0000:0000:0000:0000/64 -- criticality: error - check: - function: sql_expression - arguments: - expression: col3 >= col2 and col3 <= 10 - msg: col3 is less than col2 and col3 is greater than 10 - name: custom_output_name - negate: false -- criticality: error - check: - function: sql_expression - arguments: - expression: col3 >= col2 and col3 <= 10 - msg: col3 is less than col2 and col3 is greater than 10 - columns: - - col2 - - col3 -- criticality: error - check: - function: does_not_contain_pii - arguments: - column: col1 - threshold: 0.7 - language: en -- criticality: error - check: - function: does_not_contain_pii - arguments: - column: col1 - threshold: 0.8 - entities: - - PERSON - - EMAIL_ADDRESS - - PHONE_NUMBER - nlp_engine_config: - nlp_engine_name: spacy - models: - - lang_code: en - model_name: en_core_web_md -- criticality: error - check: - function: is_data_fresh - arguments: - column: col5 - max_age_minutes: 15 - base_timestamp: col6 -- criticality: error - check: - function: is_data_fresh - arguments: - column: col6 - max_age_minutes: 1440 - base_timestamp: 2025-01-02 10:00:00 -- criticality: error - check: - function: is_data_fresh - arguments: - column: col5 - max_age_minutes: 15 -- criticality: error - check: - function: is_not_null - for_each_column: - - col3 - - col5 -- criticality: error - check: - function: is_latitude - arguments: - column: col2 -- criticality: error - check: - function: is_longitude - arguments: - column: col2 -- criticality: error - check: - function: is_geometry - arguments: - column: point_geom -- criticality: error - check: - function: is_geography - arguments: - column: point_geom -- criticality: error - check: - function: is_point - arguments: - column: point_geom -- criticality: error - check: - function: is_linestring - arguments: - column: linestring_geom -- criticality: error - check: - function: is_polygon - arguments: - column: polygon_geom -- criticality: error - check: - function: is_multipoint - arguments: - column: multipoint_geom -- criticality: error - check: - function: is_multilinestring - arguments: - column: multilinestring_geom -- criticality: error - check: - function: is_multipolygon - arguments: - column: multipolygon_geom -- criticality: error - check: - function: is_geometrycollection - arguments: - column: geometrycollection_geom -- criticality: error - check: - function: is_ogc_valid - arguments: - column: point_geom -- criticality: error - check: - function: is_non_empty_geometry - arguments: - column: point_geom -- criticality: error - check: - function: has_dimension - arguments: - column: polygon_geom - dimension: 2 -- criticality: error - check: - function: has_x_coordinate_between - arguments: - column: polygon_geom - min_value: 0.0 - max_value: 10.0 -- criticality: error - check: - function: has_y_coordinate_between - arguments: - column: polygon_geom - min_value: 0.0 - max_value: 10.0 -- criticality: error - check: - function: is_not_null - arguments: - column: col8.field1 -- criticality: error - check: - function: is_not_null - arguments: - column: try_element_at(col7, 'key1') -- criticality: error - check: - function: is_not_null - arguments: - column: try_element_at(col4, 1) -- criticality: error - check: - function: is_equal_to - arguments: - column: col8.field1 - value: 1 -- criticality: error - check: - function: is_not_equal_to - arguments: - column: try_element_at(col7, 'key1') - value: col10 -- criticality: error - check: - function: is_not_less_than - arguments: - column: array_min(col4) - limit: 1 -- criticality: error - check: - function: is_not_greater_than - arguments: - column: array_max(col4) - limit: 10 -- criticality: error - check: - function: sql_expression - arguments: - expression: try_element_at(col7, 'key1') >= 10 - msg: col7 element 'key1' is less than 10 - name: col7_element_key1_less_than_10 - negate: false -- criticality: error - check: - function: sql_expression - arguments: - expression: not exists(col4, x -> x > 10) - msg: array col4 has an element greater than 10 - name: col4_all_elements_less_than_10 - negate: false -- criticality: error - check: - function: is_not_null - for_each_column: - - col1 - - col8.field1 - - try_element_at(col7, 'key1') - - try_element_at(col4, 1) -- criticality: error - check: - function: is_unique - arguments: - columns: - - col1 -- criticality: error - name: composite_key_col1_and_col2_is_not_unique - check: - function: is_unique - arguments: - columns: - - col1 - - col2 -- criticality: error - name: composite_key_col1_and_col2_is_not_unique_not_nulls_distinct - check: - function: is_unique - arguments: - columns: - - col1 - - col2 - nulls_distinct: false -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: '*' - aggr_type: count - limit: 10 -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: col2 - aggr_type: count - limit: 10 -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: col2 - aggr_type: count - group_by: - - col3 - limit: 10 -- criticality: error - check: - function: is_aggr_not_less_than - arguments: - column: '*' - aggr_type: count - limit: 1 -- criticality: error - check: - function: is_aggr_not_less_than - arguments: - column: col2 - aggr_type: count - limit: 1 -- criticality: error - check: - function: is_aggr_not_less_than - arguments: - column: col2 - aggr_type: count - group_by: - - col3 - limit: 1 -- criticality: error - check: - function: is_aggr_equal - arguments: - column: '*' - aggr_type: count - limit: 3 -- criticality: error - check: - function: is_aggr_equal - arguments: - column: col2 - aggr_type: avg - limit: 10.5 -- criticality: error - check: - function: is_aggr_equal - arguments: - column: col2 - aggr_type: sum - group_by: - - col3 - limit: 100 -- criticality: error - check: - function: is_aggr_not_equal - arguments: - column: '*' - aggr_type: count - limit: 5 -- criticality: error - check: - function: is_aggr_not_equal - arguments: - column: col2 - aggr_type: avg - limit: 15.2 -- criticality: error - check: - function: is_aggr_not_equal - arguments: - column: col2 - aggr_type: sum - group_by: - - col3 - limit: 200 -- criticality: error - check: - function: foreign_key - arguments: - columns: - - col1 - ref_columns: - - ref_col1 - ref_df_name: ref_df_key -- criticality: error - check: - function: foreign_key - arguments: - columns: - - col1 - ref_columns: - - ref_col1 - ref_table: catalog1.schema1.ref_table -- criticality: error - check: - function: foreign_key - arguments: - columns: - - col1 - - col2 - ref_columns: - - ref_col1 - - ref_col2 - ref_df_name: ref_df_key -- criticality: error - check: - function: foreign_key - arguments: - columns: - - col1 - ref_columns: - - ref_col1 - ref_df_name: ref_df_key - negate: true -- criticality: error - check: - function: sql_query - arguments: - query: SELECT col1, col2, SUM(col3) = 0 AS condition FROM {{ input_view }} GROUP - BY col1, col2 - input_placeholder: input_view - merge_columns: - - col1 - - col2 - condition_column: condition - msg: sql query check failed - name: sql_query_violation - negate: false -- criticality: error - check: - function: compare_datasets - arguments: - columns: - - col1 - - col2 - ref_columns: - - ref_col1 - - ref_col2 - ref_df_name: ref_df_key -- criticality: error - check: - function: compare_datasets - arguments: - columns: - - col1 - - col2 - ref_columns: - - ref_col1 - - ref_col2 - ref_table: catalog1.schema1.ref_table - exclude_columns: - - col7 - check_missing_records: true - null_safe_row_matching: true - null_safe_column_value_matching: true -- criticality: error - check: - function: is_data_fresh_per_time_window - arguments: - column: col6 - window_minutes: 1 - min_records_per_window: 1 - lookback_windows: 3 -- criticality: error - check: - function: has_valid_schema - arguments: - expected_schema: id INT, name STRING, age INT -- criticality: error - check: - function: has_valid_schema - arguments: - expected_schema: 'id INT, name STRING, age INT, contact_info STRUCT' - strict: true -- criticality: warn - check: - function: has_valid_schema - arguments: - expected_schema: 'id INT, name STRING, age INT, contact_info STRUCT' - columns: - - id - - name -- criticality: error - check: - function: is_unique - for_each_column: - - - col3 - - col5 - - - col1 -- criticality: error - check: - function: sql_expression - arguments: - expression: col1 NOT LIKE '%foo' - msg: col1 ends with 'foo' -- criticality: error - check: - function: sql_expression - arguments: - expression: col1 <= col2 - msg: col1 is greater than col2 -- criticality: error - name: sensor_reading_exceeded - check: - function: sql_expression - arguments: - expression: MAX(reading_value) OVER (PARTITION BY sensor_id) > 100 - msg: one of the sensor reading is greater than 100 - negate: true -- criticality: error - filter: col2 > 0 - check: - function: is_not_null - arguments: - column: col1 -- criticality: error - filter: col2 > 0 - check: - function: is_unique - arguments: - columns: - - col1 -- criticality: error - check: - function: does_not_contain_pii - arguments: - column: description -- criticality: error - check: - function: does_not_contain_pii - arguments: - column: description - threshold: 0.8 - entities: - - PERSON - - EMAIL_ADDRESS -- criticality: warn - check: - function: is_not_null_and_not_empty - arguments: - column: col3 -- criticality: error - check: - function: is_not_null - for_each_column: - - col1 - - col2 -- criticality: warn - filter: col1 < 3 - check: - function: is_not_null_and_not_empty - arguments: - column: col4 -- criticality: warn - check: - function: is_not_null_and_not_empty - arguments: - column: col5 - user_metadata: - check_category: completeness - responsible_data_steward: someone@email.com -- criticality: warn - check: - function: is_in_list - arguments: - column: col1 - allowed: - - 1 - - 2 -- check: - function: is_not_null - arguments: - column: col7.field1 -- criticality: error - check: - function: is_not_null - arguments: - column: try_element_at(col5, 'key1') -- criticality: error - check: - function: is_not_null - arguments: - column: try_element_at(col6, 1) -- criticality: error - check: - function: is_unique - arguments: - columns: - - col1 - - col2 -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: col1 - aggr_type: count - group_by: - - col2 - limit: 10 -- criticality: error - check: - function: is_aggr_not_less_than - arguments: - column: col1 - aggr_type: avg - group_by: - - col2 - limit: 1.2 -- criticality: error - check: - function: is_aggr_equal - arguments: - column: col1 - aggr_type: count - group_by: - - col2 - limit: 5 -- criticality: error - check: - function: is_aggr_not_equal - arguments: - column: col1 - aggr_type: avg - group_by: - - col2 - limit: 10.5 diff --git a/src/databricks/labs/dqx/llm/utils.py b/src/databricks/labs/dqx/llm/utils.py deleted file mode 100644 index ea82224ef..000000000 --- a/src/databricks/labs/dqx/llm/utils.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -import inspect -from collections.abc import Callable -from importlib.resources import files -from pathlib import Path - -import yaml - -from databricks.labs.dqx.checks_resolver import resolve_check_function -from databricks.labs.dqx.errors import InvalidParameterError -from databricks.labs.dqx.rule import CHECK_FUNC_REGISTRY - -logger = logging.getLogger(__name__) - - -def get_check_function_definition(custom_check_functions: dict[str, Callable] | None = None) -> list[dict[str, str]]: - """ - A utility function to get the definition of all check functions. - This function is primarily used to generate a prompt for the LLM to generate check functions. - - If provided, the function will use the custom check functions to resolve the check function. - If not provided, the function will use only the built-in check functions. - - Args: - custom_check_functions: A dictionary of custom check functions. - - Returns: - list[dict]: A list of dictionaries, each containing the definition of a check function. - """ - function_docs: list[dict[str, str]] = [] - for name, func_type in CHECK_FUNC_REGISTRY.items(): - func = resolve_check_function(name, custom_check_functions, fail_on_missing=False) - if func is None: - logger.warning(f"Check function {name} not found in the registry") - continue - sig = inspect.signature(func) - doc = inspect.getdoc(func) - function_docs.append( - { - "name": name, - "type": func_type, - "doc": doc or "", - "signature": str(sig), - "parameters": str(sig.parameters), - "implementation": inspect.getsource(func), - } - ) - return function_docs - - -def load_yaml_checks_examples() -> str: - """ - Load yaml_checks_examples.yml file from the llm/resources folder. - - Returns: - checks examples as yaml string. - """ - resource = Path(str(files("databricks.labs.dqx.llm.resources") / "yaml_checks_examples.yml")) - - yaml_checks_as_text = resource.read_text(encoding="utf-8") - parsed = yaml.safe_load(yaml_checks_as_text) - if not isinstance(parsed, list): - raise InvalidParameterError("YAML file must contain a list at the root level.") - - return yaml_checks_as_text diff --git a/src/databricks/labs/dqx/pii/__init__.py b/src/databricks/labs/dqx/pii/__init__.py index 6fb0376ec..ccd8c97e4 100644 --- a/src/databricks/labs/dqx/pii/__init__.py +++ b/src/databricks/labs/dqx/pii/__init__.py @@ -9,6 +9,6 @@ # Check that PII detection modules are installed if not all(find_spec(spec) for spec in required_specs): raise ImportError( - "PII detection extras not installed; Install additional " - "dependencies by running `pip install databricks-labs-dqx[pii]`" + "PII detection extras not installed." + "Install additional dependencies by running `pip install databricks-labs-dqx[pii]`." ) diff --git a/src/databricks/labs/dqx/profiler/generator.py b/src/databricks/labs/dqx/profiler/generator.py index f0849b60b..b3b7af95d 100644 --- a/src/databricks/labs/dqx/profiler/generator.py +++ b/src/databricks/labs/dqx/profiler/generator.py @@ -1,15 +1,52 @@ import logging +import json +from pyspark.sql import SparkSession +from databricks.sdk import WorkspaceClient from databricks.labs.dqx.base import DQEngineBase from databricks.labs.dqx.engine import DQEngine from databricks.labs.dqx.profiler.common import val_maybe_to_str from databricks.labs.dqx.profiler.profiler import DQProfile from databricks.labs.dqx.telemetry import telemetry_logger +from databricks.labs.dqx.errors import MissingParameterError + +# Conditional imports for LLM functionality +try: + from databricks.labs.dqx.llm.llm_core import get_dspy_compiler + from databricks.labs.dqx.llm.llm_engine import get_business_rules_with_llm + from databricks.labs.dqx.llm.llm_utils import get_column_metadata + + LLM_ENABLED = True +except ImportError: + LLM_ENABLED = False logger = logging.getLogger(__name__) class DQGenerator(DQEngineBase): + def __init__( + self, + workspace_client: WorkspaceClient, + spark: SparkSession | None = None, + model: str = "databricks/databricks-meta-llama-3-3-70b-instruct", + api_key: str = "", + api_base: str = "", + ): + super().__init__(workspace_client=workspace_client) + self.spark = SparkSession.builder.getOrCreate() if spark is None else spark + + # Initialize DSPy compiler during init + if not LLM_ENABLED: + logger.warning("LLM dependencies not installed. DSPy compiler not available.") + self.dspy_compiler = None + else: + try: + self.dspy_compiler = get_dspy_compiler(model=model, api_key=api_key, api_base=api_base) + logger.info(f"DSPy compiler initialized with model: {model}") + except Exception as e: + logger.error(f"Failed to initialize DSPy compiler: {e}") + self.dspy_compiler = None + @telemetry_logger("generator", "generate_dq_rules") def generate_dq_rules(self, profiles: list[DQProfile] | None = None, level: str = "error") -> list[dict]: """ @@ -45,6 +82,49 @@ def generate_dq_rules(self, profiles: list[DQProfile] | None = None, level: str return dq_rules + @telemetry_logger("generator", "generate_dq_rules_with_llm") + def generate_dq_rules_with_llm(self, user_input: str, table_name: str = "") -> list[dict]: + """ + Generates data quality rules using LLM based on natural language input. + + Args: + user_input: Natural language description of data quality requirements. + table_name: Optional fully qualified table name. + If not provided, LLM will be used to guess the table schema. + + Returns: + A list of dictionaries representing the generated data quality rules. + + Raises: + MissingParameterError: If DSPy compiler is not available. + """ + # Check if DSPy compiler is available + if self.dspy_compiler is None: + raise MissingParameterError( + "DSPy compiler not available. Make sure LLM dependencies are installed: " + "pip install 'databricks-labs-dqx[llm]'" + ) + + schema_info = get_column_metadata(table_name, self.spark) if table_name else "" + + logger.info(f"Generating DQ rules with LLM for input: '{user_input}'") + + # Generate rules using pre-initialized LLM compiler + prediction = get_business_rules_with_llm( + user_input=user_input, dspy_compiler=self.dspy_compiler, schema_info=schema_info + ) + + # Validate the generated rules using DQEngine + dq_rules = json.loads(prediction.quality_rules) + status = DQEngine.validate_checks(dq_rules) + if status.has_errors: + logger.warning(f"Generated rules have validation errors: {status.errors}") + else: + logger.info(f"Generated {len(dq_rules)} rules with LLM:{dq_rules}") + logger.info(f"LLM reasoning: {prediction.reasoning}") + + return dq_rules + @staticmethod def dq_generate_is_in(column: str, level: str = "error", **params: dict): """ diff --git a/tests/unit/test_llm_utils.py b/tests/unit/test_llm_utils.py index d9c700fe0..8193b1da2 100644 --- a/tests/unit/test_llm_utils.py +++ b/tests/unit/test_llm_utils.py @@ -1,7 +1,15 @@ import inspect +from unittest.mock import Mock +import json +import dspy # type: ignore import pyspark.sql.functions as F +from pyspark.sql.types import StructField, StringType, IntegerType from databricks.labs.dqx.check_funcs import make_condition, register_rule -from databricks.labs.dqx.llm.utils import get_check_function_definition, load_yaml_checks_examples +from databricks.labs.dqx.llm.llm_utils import ( + get_check_function_definition, + get_column_metadata, + create_optimizer_training_set, +) @register_rule("row") @@ -57,6 +65,58 @@ def test_get_check_function_definition_with_custom_check_functions_missing_speci assert not result -def test_load_yaml_checks_examples(): - yaml_examples = load_yaml_checks_examples() - assert yaml_examples +def test_column_metadata(): + mock_spark = Mock() + mock_df = Mock() + mock_df.schema.fields = [ + StructField("customer_id", StringType(), True), + StructField("first_name", StringType(), True), + StructField("last_name", StringType(), True), + StructField("age", IntegerType(), True), + ] + mock_spark.table.return_value = mock_df + + result = get_column_metadata("test_table", mock_spark) + expected_result = { + "columns": [ + {"name": "customer_id", "type": "string"}, + {"name": "first_name", "type": "string"}, + {"name": "last_name", "type": "string"}, + {"name": "age", "type": "int"}, + ] + } + assert result == json.dumps(expected_result) + + +def test_get_training_examples(): + """Test that get_training_examples returns properly formatted dspy.Example objects.""" + + examples = create_optimizer_training_set() + + # Verify it returns a list + assert isinstance(examples, list) + + # Verify it has at least one example + assert len(examples) >= 1 + + # Verify all items are dspy.Example objects + for example in examples: + assert isinstance(example, dspy.Example) + + # Verify required attributes exist + assert hasattr(example, 'schema_info') + assert hasattr(example, 'business_description') + assert hasattr(example, 'available_functions') + assert hasattr(example, 'quality_rules') + assert hasattr(example, 'reasoning') + + schema_info = json.loads(example.schema_info) + assert isinstance(schema_info, dict) + assert "columns" in schema_info + + # Verify available_functions is valid JSON + available_functions = json.loads(example.available_functions) + assert isinstance(available_functions, list) + + # Verify quality_rules is a string containing YAML + assert isinstance(example.quality_rules, str)