Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aedf6ae
LLM assisted rules generation code
souravg-db2 Sep 15, 2025
a081731
added change for tests
souravg-db2 Sep 15, 2025
985d80c
Addressed fmt issues
souravg-db2 Sep 15, 2025
40d0b6c
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Sep 16, 2025
353222c
refactor
mwojtyczka Sep 16, 2025
c2ee51e
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Sep 18, 2025
e7b94ce
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Sep 19, 2025
4fddaef
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Sep 30, 2025
0e37f34
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Oct 2, 2025
015bff3
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Oct 3, 2025
82cfcff
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Oct 7, 2025
cdf231c
added few changes
souravg-db2 Oct 8, 2025
032a18c
Removed removed which was not needed
souravg-db2 Oct 9, 2025
34d9782
Merge branch 'main' into feature_370_llm_prompt_prep_3
souravg-db2 Oct 10, 2025
ed42bc9
Added changes to call llm rules generation from generator
souravg-db2 Oct 10, 2025
4993453
Add schema inference for no-schema provided DQ rule generation
vb-dbrks Oct 10, 2025
a4d73e6
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Oct 16, 2025
0602e16
Merge branch 'main' into feature_370_llm_prompt_prep_3
mwojtyczka Oct 16, 2025
ab89982
Addressed some review comments
souravg-db2 Oct 16, 2025
cea51d7
Changes per review comments
souravg-db2 Oct 17, 2025
17482be
Changes per review comments
souravg-db2 Oct 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ clean: docs-clean
.venv/bin/python:
pip install hatch
hatch env create
hatch run pip install ".[llm,pii]"

dev: .venv/bin/python
@hatch run which python
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pii = [
# This may be required for the larger models due to Databricks connect memory limitations.
# The models cannot be delcared as dependency here buecase PyPI does not support URL-based dependencies which would prevent releases.
]
llm = ["dspy~=3.0.3"]

[project.entry-points.databricks]
runtime = "databricks.labs.dqx.workflows_runner:main"
Expand All @@ -77,7 +78,7 @@ include = ["src"]
path = "src/databricks/labs/dqx/__about__.py"

[tool.hatch.envs.default]
features = ["pii"]
features = ["pii", "llm"]
dependencies = [
"black~=24.8.0",
"chispa~=0.10.1",
Expand Down
11 changes: 11 additions & 0 deletions src/databricks/labs/dqx/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from importlib.util import find_spec

required_specs = [
"dspy",
]

# Check if required llm packages are installed
if not all(find_spec(spec) for spec in required_specs):
raise ImportError(
"llm extras not installed. Install additional dependencies by running `pip install databricks-labs-dqx[llm]`."
)
277 changes: 277 additions & 0 deletions src/databricks/labs/dqx/llm/llm_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import json
import logging
import dspy # type: ignore
from databricks.labs.dqx.llm.llm_utils import create_optimizer_training_set
from databricks.labs.dqx.engine import DQEngineCore

logger = logging.getLogger(__name__)


class SchemaGuesserSignature(dspy.Signature):
"""
Guess a table schema based on business description.

This class defines the schema for inferring complete table structure from
natural language descriptions.
"""

business_description: str = dspy.InputField(
desc=(
"Natural language summary of the dataset and its use. "
"Including some column hints (e.g., id, amount, status, email, dates)."
)
)
guessed_schema_json: str = dspy.OutputField(
desc=(
"Strict JSON with shape: "
'{"columns":[{"name":"<col>","type":"<spark_type>","example":"<opt>"}]}. '
"Prefer: ids:string, money:decimal(18,2), timestamps:timestamp, dates:date. "
"Return one line JSON with no extra text."
)
)
assumptions_bullets: str = dspy.OutputField(
desc=(
"Concise bullet list (1-6 lines) of assumptions made about columns, types, "
"and examples. Keep each bullet short."
)
)


class SchemaGuesser(dspy.Module):
"""
Guess table schema from business description.

This class provides functionality to infer a complete table schema based on
natural language descriptions of the dataset and its intended use.
"""

def __init__(self):
super().__init__()
self.guess = dspy.ChainOfThought(SchemaGuesserSignature)

def forward(self, business_description: str) -> dspy.primitives.prediction.Prediction:
"""
Guess schema based on business description.

Args:
business_description (str): Natural language description of the dataset and its use case.

Returns:
dspy.primitives.prediction.Prediction: A Prediction object containing the guessed schema
and assumptions made during the inference process.
"""
return self.guess(business_description=business_description)


class RuleSignature(dspy.Signature):
"""
Generate data quality rules with improved output format.

This class defines the schema for generating data quality rules based on
schema information, business descriptions, and available functions.
"""

schema_info: str = dspy.InputField(desc="JSON string of table schema with column names, types, and sample data")
business_description: str = dspy.InputField(desc="Natural language description of data quality requirements")
available_functions: str = dspy.InputField(desc="JSON string of available DQX check functions")
quality_rules: str = dspy.OutputField(
desc=(
"Return a valid JSON array of data quality rules. Use double quotes only. "
"Criticality can be error or warn. "
"Check function name and doc to select the appropriate check function. "
"Format: [{\"criticality\":\"error\",\"check\":{\"function\":\"name\",\"arguments\":{\"column\":\"col\"}}}] "
"Example: [{\"criticality\":\"error\",\"check\":{\"function\":\"is_not_null\",\"arguments\":{\"column\":\"customer_id\"}}}]"
)
)
reasoning: str = dspy.OutputField(desc="Explanation of why these rules were chosen")


class DQRuleGeneration(dspy.Module):
"""
Generate data quality rules with improved JSON output reliability.

This class provides functionality to generate data quality rules based on schema information,
business descriptions, and available functions. It can optionally infer the schema from the
business description if schema_info is not provided or is empty.
"""

def __init__(self, schema_guesser: SchemaGuesser = SchemaGuesser()):
"""
Initialize the DQ rule generation module.

Args:
schema_guesser (SchemaGuesser): Schema guesser instance for inferring schema when needed.
"""
super().__init__()
# Use Predict for reliable output
self.generator = dspy.Predict(RuleSignature)
self.schema_guesser = schema_guesser

def forward(
self, schema_info: str, business_description: str, available_functions: str
) -> dspy.primitives.prediction.Prediction:
"""
Generate data quality rules based on schema information, business descriptions, and available functions.

If schema_info is empty and enable_schema_inference is True, it will first use SchemaGuesser
to infer the schema from the business description.

Args:
schema_info (str): JSON string containing table schema with column names, types, and sample data.
If empty and enable_schema_inference=True, schema will be inferred.
business_description (str): Natural language description of data quality requirements.
available_functions (str): JSON string of available DQX check functions.

Returns:
dspy.primitives.prediction.Prediction: A Prediction object containing the generated data quality rules,
reasoning, and optionally guessed_schema_json and assumptions_bullets if schema was inferred.
"""
# Step 1: Infer schema if needed
guessed_schema_json = None
assumptions_bullets = None

if not schema_info or not schema_info.strip():
logger.info("Inferring schema from business description...")
schema_result = self.schema_guesser(business_description=business_description)
schema_info = schema_result.guessed_schema_json
guessed_schema_json = schema_result.guessed_schema_json
assumptions_bullets = schema_result.assumptions_bullets
logger.info(f"Inferred schema: {schema_info}")
else:
logger.debug(f"Using provided schema: {schema_info}")

# Step 2: Generate rules using the schema (provided or inferred)
result = self.generator(
schema_info=schema_info, business_description=business_description, available_functions=available_functions
)

# Validate and clean the JSON output
if result.quality_rules:
try:
# Try to parse the JSON to ensure it's valid
json.loads(result.quality_rules)
except json.JSONDecodeError as e:
logger.warning(f"Generated invalid JSON: {e}. Raw output: {result.quality_rules}")
# Return a fallback empty array if JSON is invalid
result.quality_rules = "[]"

# Add schema inference results to the prediction if they exist
if guessed_schema_json:
result.guessed_schema_json = guessed_schema_json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need need this field? we assigned guessed_schema_json to schema_info before so they are the same

result.assumptions_bullets = assumptions_bullets
result.schema_info = schema_info

# Enhance reasoning to show that schema was inferred
original_reasoning = result.reasoning if hasattr(result, 'reasoning') else ""
result.reasoning = (
f"[Schema Inference] The schema was automatically inferred from the question:\n"
f"{guessed_schema_json}\n\n"
f"Assumptions made:\n{assumptions_bullets}\n\n"
f"[Rule Generation] {original_reasoning}"
)

return result


def _configure_dspy_model(model: str, api_key: str = "", api_base: str = ""):
"""
Configure the Dspy language model.

Args:
model (str): The model to use for the Dspy language model.
api_key (str): The API key for the model. Not required by Databricks foundational models.
api_base (str): The API base URL for the model. Not required by Databricks foundational models.
"""
language_model = dspy.LM(
model=model,
model_type="chat",
api_key=api_key,
api_base=api_base,
max_retries=3,
)
dspy.configure(lm=language_model)


def validate_generated_rules(actual: str) -> float:
"""
Validate generated rules with granular scoring for better optimizer feedback.

Scoring breakdown:
- JSON parsing (40%): Checks if the actual output can be parsed as valid JSON.
- Rules validation (60%): Ensures the rules pass DQX validation checks.

Args:
actual (str): JSON string of the actual generated rules.

Returns:
float: A score between 0.0 and 1.0 representing the quality of the generated rules.
"""
total_score = 0.0

# Score weights
json_weight = 0.2
rule_weight = 0.8

# Json parsing score (40%)
try:
actual_rules = json.loads(actual)
total_score += json_weight
logger.debug(f"✓ JSON parsing successful (+{json_weight:.1f})")
except json.JSONDecodeError as e:
logger.warning(f"✗ JSON parsing failed: {e}")
logger.debug(f" Raw output: {repr(actual[:200])}")
# Early return if we can't parse JSON at all
return total_score

# Rules validation score (60%)
validation_status = DQEngineCore.validate_checks(actual_rules)
if not validation_status.has_errors:
total_score += rule_weight
logger.debug(f"✓ Rules validation passed (+{rule_weight:.1f})")
else:
logger.warning(f"✗ Rules validation errors: {validation_status.errors}")

logger.debug(f"Final score: {total_score:.2f}")
return total_score


def get_dspy_compiler(
model: str = "databricks/databricks-meta-llama-3-3-70b-instruct",
api_key: str = "",
api_base: str = "",
) -> dspy.Module:
"""
Get the Dspy compiler configured with an optimizer.

This function initializes and configures the Dspy compiler with a training set and an optimizer
to validate and optimize the generated data quality rules.

Args:
model (str): The model to use for the Dspy language model.
api_key (str): The API key for the model. Not required by Databricks foundational models.
api_base (str): The API base URL for the model. Not required by Databricks foundational models.

Returns:
dspy.Module: An optimized Dspy module for generating data quality rules.
"""
_configure_dspy_model(api_key=api_key, api_base=api_base, model=model)

# Use standard DSPy approach with improved prompting
dq_model = DQRuleGeneration()
train_set = create_optimizer_training_set()

# Standard metric for JSON output validation
def json_metric(_example, pred, _trace=None):
if hasattr(pred, 'quality_rules'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(pred, 'quality_rules'):
if pred.quality_rules:

don't use reflection if you can access the field directly

return validate_generated_rules(pred.quality_rules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be great to have a few passes if the score is 0. A simple loop where we try to generate rules with the validation errors as input would be enough to start with, without going into teacher-student patterns.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DqEngine validates the list[dict] checks. We need to score both json and valid rule while during compile time. I have changed the score though to give higher weightage to valid rule

return 0.0

optimizer = dspy.BootstrapFewShot(
metric=json_metric,
max_bootstrapped_demos=3,
max_labeled_demos=5,
teacher_settings={},
)

optimized_model = optimizer.compile(dq_model, trainset=train_set)
return optimized_model
35 changes: 35 additions & 0 deletions src/databricks/labs/dqx/llm/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
import dspy # type: ignore
from databricks.labs.dqx.llm.llm_core import DQRuleGeneration
from databricks.labs.dqx.llm.llm_utils import _get_required_check_function_info


def get_business_rules_with_llm(
user_input: str, dspy_compiler: DQRuleGeneration, schema_info: str = ""
) -> dspy.primitives.prediction.Prediction:
"""
Get DQX rules based on natural language request with optional schema information.

If schema_info is empty (default) and the dspy_compiler has schema inference enabled,
it will automatically infer the schema from the user_input before generating rules.

Args:
user_input: Natural language description of data quality requirements.
dspy_compiler: The compiled DQRuleGeneration model.
schema_info: Optional JSON string containing table schema with column names and types.
If empty (default), triggers schema inference if enabled.
If provided, uses that schema directly.

Returns:
Prediction: A Prediction object containing:
- quality_rules: The generated DQ rules
- reasoning: Explanation of the rules (includes schema inference info if schema was inferred)
- guessed_schema_json: The inferred schema (if schema was inferred)
- assumptions_bullets: Assumptions made about schema (if schema was inferred)
- schema_info: The final schema used (if schema was inferred)
"""
return dspy_compiler(
schema_info=schema_info,
business_description=user_input,
available_functions=json.dumps(_get_required_check_function_info()),
)
Loading
Loading