diff --git a/.github/workflows/integration-mlflow-tests.yml b/.github/workflows/integration-mlflow-tests.yml new file mode 100644 index 0000000000..992252aeee --- /dev/null +++ b/.github/workflows/integration-mlflow-tests.yml @@ -0,0 +1,125 @@ +name: MLflow Prompts Integration Tests + +run-name: Run the integration test suite with MLflow Prompt Registry provider + +on: + push: + branches: + - main + - 'release-[0-9]+.[0-9]+.x' + pull_request: + branches: + - main + - 'release-[0-9]+.[0-9]+.x' + paths: + - 'src/llama_stack/providers/remote/prompts/mlflow/**' + - 'tests/integration/providers/remote/prompts/mlflow/**' + - 'tests/unit/providers/remote/prompts/mlflow/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/integration-mlflow-tests.yml' # This workflow + schedule: + - cron: '0 0 * * *' # Daily at 12 AM UTC + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} + cancel-in-progress: true + +jobs: + test-mlflow: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} + fail-fast: false + + steps: + - name: Checkout repository + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + with: + python-version: ${{ matrix.python-version }} + + - name: Setup MLflow Server + run: | + docker run --rm -d --pull always \ + --name mlflow \ + -p 5555:5555 \ + ghcr.io/mlflow/mlflow:latest \ + mlflow server \ + --host 0.0.0.0 \ + --port 5555 \ + --backend-store-uri sqlite:///mlflow.db \ + --default-artifact-root ./mlruns + + - name: Wait for MLflow to be ready + run: | + echo "Waiting for MLflow to be ready..." + for i in {1..60}; do + if curl -s http://localhost:5555/health | grep -q '"status": "OK"'; then + echo "MLflow is ready!" + exit 0 + fi + echo "Not ready yet... ($i/60)" + sleep 2 + done + echo "MLflow failed to start" + docker logs mlflow + exit 1 + + - name: Verify MLflow API + run: | + echo "Testing MLflow API..." + curl -X GET http://localhost:5555/api/2.0/mlflow/experiments/list + echo "" + echo "MLflow API is responding!" + + - name: Build Llama Stack + run: | + uv run --no-sync llama stack list-deps ci-tests | xargs -L1 uv pip install + + - name: Install MLflow Python client + run: | + uv pip install 'mlflow>=3.4.0' + + - name: Check Storage and Memory Available Before Tests + if: ${{ always() }} + run: | + free -h + df -h + + - name: Run MLflow Integration Tests + env: + MLFLOW_TRACKING_URI: http://localhost:5555 + run: | + uv run --no-sync \ + pytest -sv \ + tests/integration/providers/remote/prompts/mlflow/ + + - name: Check Storage and Memory Available After Tests + if: ${{ always() }} + run: | + free -h + df -h + + - name: Write MLflow logs to file + if: ${{ always() }} + run: | + docker logs mlflow > mlflow.log 2>&1 || true + + - name: Upload all logs to artifacts + if: ${{ always() }} + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: mlflow-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }} + path: | + *.log + retention-days: 1 + + - name: Stop MLflow container + if: ${{ always() }} + run: | + docker stop mlflow || true diff --git a/docs/docs/providers/prompts/index.mdx b/docs/docs/providers/prompts/index.mdx new file mode 100644 index 0000000000..aba0f9c63b --- /dev/null +++ b/docs/docs/providers/prompts/index.mdx @@ -0,0 +1,92 @@ +--- +sidebar_label: Prompts +title: Prompts +--- + +# Prompts + +## Overview + +This section contains documentation for all available providers for the **prompts** API. + +The Prompts API enables centralized management of prompt templates with versioning, variable handling, and team collaboration capabilities. + +## Available Providers + +### Inline Providers + +Inline providers run in the same process as the Llama Stack server and require no external dependencies: + +- **[inline::reference](inline_reference.mdx)** - Reference implementation using KVStore backend (SQLite, PostgreSQL, etc.) + - Zero external dependencies + - Supports local SQLite or PostgreSQL storage + - Full CRUD operations including deletion + - Ideal for local development and single-server deployments + +### Remote Providers + +Remote providers connect to external services for centralized prompt management: + +- **[remote::mlflow](remote_mlflow.mdx)** - MLflow Prompt Registry integration (requires MLflow 3.4+) + - Centralized prompt management across teams + - Built-in versioning and audit trail + - Supports authentication (per-request, config, or environment variables) + - Integrates with Databricks and enterprise MLflow deployments + - Ideal for team collaboration and production environments + +## Choosing a Provider + +### Use `inline::reference` when: +- Developing locally or deploying to a single server +- You want zero external dependencies +- SQLite or PostgreSQL storage is sufficient +- You need full CRUD operations (including deletion) +- You prefer simple configuration + +### Use `remote::mlflow` when: +- Working in a team environment with multiple users +- You need centralized prompt management +- Integration with existing MLflow infrastructure +- You need authentication and multi-tenant support +- Advanced versioning and audit trail capabilities are required + +## Quick Start Examples + +### Using inline::reference + +```yaml +prompts: + - provider_id: local-prompts + provider_type: inline::reference + config: + run_config: + storage: + stores: + prompts: + type: sqlite + db_path: ./prompts.db +``` + +### Using remote::mlflow + +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://localhost:5555 + experiment_name: llama-stack-prompts + auth_credential: ${env.MLFLOW_TRACKING_TOKEN} +``` + +## Common Features + +All prompt providers support: +- Create and store prompts with version control +- Retrieve prompts by ID and version +- Update prompts (creates new versions) +- List all prompts or versions of a specific prompt +- Set default version for a prompt +- Automatic variable extraction from `{{ variable }}` templates + +For detailed documentation on each provider, see the individual provider pages linked above. diff --git a/docs/docs/providers/prompts/inline_reference.mdx b/docs/docs/providers/prompts/inline_reference.mdx new file mode 100644 index 0000000000..9cbc59bd93 --- /dev/null +++ b/docs/docs/providers/prompts/inline_reference.mdx @@ -0,0 +1,496 @@ +--- +description: | + Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.) + for centralized prompt management with versioning support. This is the default provider for + prompts that works without external dependencies. + + ## Features + The Reference Prompts Provider supports: + - Create and store prompts with automatic versioning + - Retrieve prompts by ID and version + - Update prompts (creates new immutable versions) + - Delete prompts and their versions + - List all prompts or all versions of a specific prompt + - Set default version for a prompt + - Automatic variable extraction from templates + - Storage in SQLite, PostgreSQL, or other KVStore backends + + ## Key Capabilities + - **Zero Dependencies**: No external services required, runs in-process + - **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends + - **Version Control**: Immutable versioning ensures prompt history is preserved + - **Default Version Management**: Easily switch between prompt versions + - **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders + - **Full CRUD Support**: Unlike remote providers, supports deletion of prompts + + ## Usage + + To use Reference Prompts Provider in your Llama Stack project: + + 1. Configure your Llama Stack project with the inline::reference provider + 2. Optionally configure storage backend (defaults to SQLite) + 3. Start creating and managing prompts + + ## Quick Start + + ### 1. Configure Llama Stack + + **Basic configuration with SQLite** (default): + + ```yaml + prompts: + - provider_id: reference-prompts + provider_type: inline::reference + config: + run_config: + storage: + stores: + prompts: + type: sqlite + db_path: ./prompts.db + ``` + + **With PostgreSQL**: + + ```yaml + prompts: + - provider_id: postgres-prompts + provider_type: inline::reference + config: + run_config: + storage: + stores: + prompts: + type: postgres + url: postgresql://user:pass@localhost/llama_stack + ``` + + ### 2. Use the Prompts API + + ```python + from llama_stack_client import LlamaStackClient + + client = LlamaStackClient(base_url="http://localhost:5000") + + # Create a prompt + prompt = client.prompts.create( + prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}", + variables=["num_sentences", "text"] + ) + print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})") + + # Retrieve prompt + retrieved = client.prompts.get(prompt_id=prompt.prompt_id) + print(f"Retrieved: {retrieved.prompt}") + + # Update prompt (creates version 2) + updated = client.prompts.update( + prompt_id=prompt.prompt_id, + prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}", + version=1, + set_as_default=True + ) + print(f"Updated to version: {updated.version}") + + # List all prompts + prompts = client.prompts.list() + print(f"Found {len(prompts.data)} prompts") + + # Delete prompt + client.prompts.delete(prompt_id=prompt.prompt_id) + ``` + +sidebar_label: Inline - Reference +title: inline::reference +--- + +# inline::reference + +## Description + +Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.) +for centralized prompt management with versioning support. This is the default provider for +prompts that works without external dependencies. + +## Features +The Reference Prompts Provider supports: +- Create and store prompts with automatic versioning +- Retrieve prompts by ID and version +- Update prompts (creates new immutable versions) +- Delete prompts and their versions +- List all prompts or all versions of a specific prompt +- Set default version for a prompt +- Automatic variable extraction from templates +- Storage in SQLite, PostgreSQL, or other KVStore backends + +## Key Capabilities +- **Zero Dependencies**: No external services required, runs in-process +- **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends +- **Version Control**: Immutable versioning ensures prompt history is preserved +- **Default Version Management**: Easily switch between prompt versions +- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders +- **Full CRUD Support**: Unlike remote providers, supports deletion of prompts + +## Configuration Examples + +### SQLite (Local Development) + +For local development with filesystem storage: + +```yaml +prompts: + - provider_id: local-prompts + provider_type: inline::reference + config: + run_config: + storage: + stores: + prompts: + type: sqlite + db_path: ./prompts.db +``` + +### PostgreSQL (Production) + +For production with PostgreSQL: + +```yaml +prompts: + - provider_id: prod-prompts + provider_type: inline::reference + config: + run_config: + storage: + stores: + prompts: + type: postgres + url: ${env.DATABASE_URL} +``` + +### With Explicit Backend Configuration + +```yaml +prompts: + - provider_id: reference-prompts + provider_type: inline::reference + config: + run_config: + storage: + backends: + kv_default: + type: sqlite + db_path: ./data/prompts.db + stores: + prompts: + backend: kv_default + namespace: prompts +``` + +## API Reference + +### Create Prompt + +Creates a new prompt (version 1): + +```python +prompt = client.prompts.create( + prompt="You are a {{ role }} assistant. {{ instruction }}", + variables=["role", "instruction"] # Optional - auto-extracted if omitted +) +``` + +**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders. + +### Retrieve Prompt + +Get a prompt by ID (retrieves default version): + +```python +prompt = client.prompts.get(prompt_id="pmpt_abc123...") +``` + +Get a specific version: + +```python +prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2) +``` + +### Update Prompt + +Creates a new version of an existing prompt: + +```python +updated = client.prompts.update( + prompt_id="pmpt_abc123...", + prompt="Updated template with {{ variable }}", + version=1, # Must be the latest version + set_as_default=True # Make this the new default +) +``` + +**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2). + +### Delete Prompt + +Delete a prompt and all its versions: + +```python +client.prompts.delete(prompt_id="pmpt_abc123...") +``` + +**Note**: This operation is permanent and deletes all versions of the prompt. + +### List Prompts + +List all prompts (returns default versions only): + +```python +response = client.prompts.list() +for prompt in response.data: + print(f"{prompt.prompt_id}: v{prompt.version} (default)") +``` + +### List Prompt Versions + +List all versions of a specific prompt: + +```python +response = client.prompts.list_versions(prompt_id="pmpt_abc123...") +for prompt in response.data: + default = " (default)" if prompt.is_default else "" + print(f"Version {prompt.version}{default}") +``` + +### Set Default Version + +Change which version is the default: + +```python +client.prompts.set_default_version( + prompt_id="pmpt_abc123...", + version=2 +) +``` + +## Version Management + +The Reference Prompts Provider implements immutable versioning: + +1. **Create**: Creates version 1 +2. **Update**: Creates a new version (2, 3, 4, ...) +3. **Default**: One version is marked as default +4. **History**: All versions are preserved and retrievable +5. **Delete**: Can delete all versions at once + +``` +pmpt_abc123 +├── Version 1 (Original) +├── Version 2 (Updated) +└── Version 3 (Latest, Default) <- Current default version +``` + +## Storage Backends + +The reference provider uses Llama Stack's KVStore abstraction, which supports multiple backends: + +### SQLite (Default) + +Best for: +- Local development +- Single-server deployments +- Embedded applications +- Testing + +Limitations: +- Not suitable for high-concurrency scenarios +- No built-in replication + +### PostgreSQL + +Best for: +- Production deployments +- Multi-server setups +- High availability requirements +- Team collaboration + +Advantages: +- Supports concurrent access +- Built-in replication and backups +- Scalable and robust + +## Best Practices + +### 1. Choose Appropriate Storage + +**Development**: +```yaml +# Use SQLite for local development +storage: + stores: + prompts: + type: sqlite + db_path: ./dev-prompts.db +``` + +**Production**: +```yaml +# Use PostgreSQL for production +storage: + stores: + prompts: + type: postgres + url: ${env.DATABASE_URL} +``` + +### 2. Backup Your Data + +For SQLite: +```bash +# Backup SQLite database +cp prompts.db prompts.db.backup +``` + +For PostgreSQL: +```bash +# Backup PostgreSQL database +pg_dump llama_stack > backup.sql +``` + +### 3. Version Management + +- Always retrieve latest version before updating +- Use `set_as_default=True` when updating to make new version active +- Keep version history for audit trail +- Use deletion sparingly (consider archiving instead) + +### 4. Auto-Extract Variables + +Let the provider auto-extract variables to avoid validation errors: + +```python +# Recommended +prompt = client.prompts.create( + prompt="Summarize {{ text }} in {{ format }}" +) +``` + +### 5. Use Meaningful Templates + +Include context in your templates: + +```python +# Good +prompt = """You are a {{ role }} assistant specialized in {{ domain }}. + +Task: {{ task }} + +Output format: {{ format }}""" + +# Less clear +prompt = "Do {{ task }} as {{ role }}" +``` + +## Troubleshooting + +### Database Connection Errors + +**Error**: Failed to connect to database + +**Solutions**: +1. Verify database URL is correct +2. Ensure database server is running (for PostgreSQL) +3. Check file permissions (for SQLite) +4. Verify network connectivity (for remote databases) + +### Version Mismatch Error + +**Error**: `Version X is not the latest version. Use latest version Y to update.` + +**Cause**: Attempting to update an outdated version + +**Solution**: Always use the latest version number when updating: +```python +# Get latest version +versions = client.prompts.list_versions(prompt_id) +latest_version = max(v.version for v in versions.data) + +# Use latest version for update +client.prompts.update(prompt_id=prompt_id, version=latest_version, ...) +``` + +### Variable Validation Error + +**Error**: `Template contains undeclared variables: ['var2']` + +**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it + +**Solution**: Either add missing variable or let the provider auto-extract: +```python +# Option 1: Add missing variable +client.prompts.create( + prompt="Template with {{ var1 }} and {{ var2 }}", + variables=["var1", "var2"] +) + +# Option 2: Let provider auto-extract (recommended) +client.prompts.create( + prompt="Template with {{ var1 }} and {{ var2 }}" +) +``` + +### Prompt Not Found + +**Error**: `Prompt pmpt_abc123... not found` + +**Possible causes**: +1. Prompt ID is incorrect +2. Prompt was deleted +3. Wrong database or storage backend + +**Solution**: Verify prompt exists using `list()` method + +## Migration Guide + +### Migrating from Core Implementation + +If you're upgrading from an older Llama Stack version where prompts were in `core/prompts`: + +**Old code** (still works): +```python +from llama_stack.core.prompts import PromptServiceConfig, PromptServiceImpl +``` + +**New code** (recommended): +```python +from llama_stack.providers.inline.prompts.reference import ReferencePromptsConfig, PromptServiceImpl +``` + +**Note**: Backward compatibility is maintained. Old imports still work. + +### Data Migration + +No data migration needed when upgrading: +- Same KVStore backend is used +- Existing prompts remain accessible +- Configuration structure is compatible + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `run_config` | `StackRunConfig` | Yes | | Stack run configuration containing storage configuration for KVStore | + +## Sample Configuration + +```yaml +run_config: + storage: + backends: + kv_default: + type: sqlite + db_path: ./prompts.db + stores: + prompts: + backend: kv_default + namespace: prompts +``` diff --git a/docs/docs/providers/prompts/remote_mlflow.mdx b/docs/docs/providers/prompts/remote_mlflow.mdx new file mode 100644 index 0000000000..aff2feb050 --- /dev/null +++ b/docs/docs/providers/prompts/remote_mlflow.mdx @@ -0,0 +1,751 @@ +--- +description: | + [MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning + using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage + prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management. + + See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry. + +sidebar_label: Remote - MLflow +title: remote::mlflow +--- + +# remote::mlflow + +## Description + +[MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning +using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage +prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management. + +## Features +MLflow Prompts Provider supports: +- Create and store prompts with automatic versioning +- Retrieve prompts by ID and version +- Update prompts (creates new immutable versions) +- List all prompts or all versions of a specific prompt +- Set default version for a prompt +- Automatic variable extraction from templates +- Metadata storage and retrieval +- Centralized prompt management across teams + +## Key Capabilities +- **Version Control**: Immutable versioning ensures prompt history is preserved +- **Default Version Management**: Easily switch between prompt versions +- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders +- **Metadata Tags**: Stores Llama Stack metadata for seamless integration +- **Team Collaboration**: Centralized MLflow server enables multi-user access + +## Usage + +To use MLflow Prompts Provider in your Llama Stack project: + +1. Install MLflow 3.4 or later +2. Start an MLflow server (local or remote) +3. Configure your Llama Stack project to use the MLflow provider +4. Start creating and managing prompts + +## Installation + +Install MLflow using pip or uv: + +```bash +pip install 'mlflow>=3.4.0' +# or +uv pip install 'mlflow>=3.4.0' +``` + +## Quick Start + +### 1. Start MLflow Server + +**Local server** (for development): +```bash +mlflow server --host 127.0.0.1 --port 5555 +``` + +**Remote server** (for production): +```bash +mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri postgresql://user:pass@host/db +``` + +### 2. Configure Llama Stack + +Add to your Llama Stack configuration: + +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://localhost:5555 + experiment_name: llama-stack-prompts +``` + +### 3. Use the Prompts API + +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url="http://localhost:5000") + +# Create a prompt +prompt = client.prompts.create( + prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}", + variables=["num_sentences", "text"] +) +print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})") + +# Retrieve prompt +retrieved = client.prompts.get(prompt_id=prompt.prompt_id) +print(f"Retrieved: {retrieved.prompt}") + +# Update prompt (creates version 2) +updated = client.prompts.update( + prompt_id=prompt.prompt_id, + prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}", + version=1, + set_as_default=True +) +print(f"Updated to version: {updated.version}") + +# List all prompts +prompts = client.prompts.list() +print(f"Found {len(prompts.data)} prompts") +``` + +## Configuration Examples + +### Local Development + +For local development with filesystem storage: + +```yaml +prompts: + - provider_id: mlflow-local + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://localhost:5555 + experiment_name: dev-prompts + timeout_seconds: 30 +``` + +### Remote MLflow Server + +For production with a remote MLflow server: + +```yaml +prompts: + - provider_id: mlflow-production + provider_type: remote::mlflow + config: + mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI} + experiment_name: production-prompts + timeout_seconds: 60 +``` + +### Advanced Configuration + +With custom settings: + +```yaml +prompts: + - provider_id: mlflow-custom + provider_type: remote::mlflow + config: + mlflow_tracking_uri: https://mlflow.example.com + experiment_name: team-prompts + timeout_seconds: 45 +``` + +## Authentication + +The MLflow provider supports three authentication methods with the following precedence (highest to lowest): + +1. **Per-Request Provider Data** (via headers) +2. **Configuration Auth Credential** (in config file) +3. **Environment Variables** (MLflow defaults) + +### Method 1: Per-Request Provider Data (Recommended for Multi-Tenant) + +For multi-tenant deployments where each user has their own credentials: + +**Configuration**: +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://mlflow.company.com + experiment_name: production-prompts + # No auth_credential - use per-request tokens +``` + +**Client Usage**: +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url="http://localhost:5000") + +# User 1 with their own token +prompts_user1 = client.prompts.list( + extra_headers={ + "x-llamastack-provider-data": '{"mlflow_api_token": "user1-token"}' + } +) + +# User 2 with their own token +prompts_user2 = client.prompts.list( + extra_headers={ + "x-llamastack-provider-data": '{"mlflow_api_token": "user2-token"}' + } +) +``` + +**Benefits**: +- Per-user authentication and authorization +- No shared credentials +- Ideal for SaaS deployments +- Supports user-specific MLflow experiments + +### Method 2: Configuration Auth Credential (Server-Level) + +For server-level authentication where all requests use the same credentials: + +**Using Environment Variable** (recommended): +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://mlflow.company.com + experiment_name: production-prompts + auth_credential: ${env.MLFLOW_TRACKING_TOKEN} +``` + +**Using Direct Value** (not recommended for production): +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://mlflow.company.com + experiment_name: production-prompts + auth_credential: "mlflow-server-token" +``` + +**Client Usage**: +```python +# No extra headers needed - server handles authentication +client = LlamaStackClient(base_url="http://localhost:5000") +prompts = client.prompts.list() +``` + +**Benefits**: +- Simple configuration +- Single point of credential management +- Good for single-tenant deployments + +### Method 3: Environment Variables (MLflow Default) + +MLflow reads standard environment variables automatically: + +**Set before starting Llama Stack**: +```bash +export MLFLOW_TRACKING_TOKEN="your-token" +export MLFLOW_TRACKING_USERNAME="user" # Optional: Basic auth +export MLFLOW_TRACKING_PASSWORD="pass" # Optional: Basic auth +llama stack run my-config.yaml +``` + +**Configuration** (no auth_credential needed): +```yaml +prompts: + - provider_id: mlflow-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://mlflow.company.com + experiment_name: production-prompts +``` + +**Benefits**: +- Standard MLflow behavior +- No configuration changes needed +- Good for containerized deployments + +### Databricks Authentication + +For Databricks-managed MLflow: + +**Configuration**: +```yaml +prompts: + - provider_id: databricks-prompts + provider_type: remote::mlflow + config: + mlflow_tracking_uri: databricks + # Or with workspace URL: + # mlflow_tracking_uri: databricks://profile-name + experiment_name: /Shared/llama-stack-prompts + auth_credential: ${env.DATABRICKS_TOKEN} +``` + +**Environment Setup**: +```bash +export DATABRICKS_TOKEN="dapi..." +export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com" +``` + +**Client Usage**: +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url="http://localhost:5000") + +# Create prompt in Databricks MLflow +prompt = client.prompts.create( + prompt="Analyze {{ topic }} with focus on {{ aspect }}", + variables=["topic", "aspect"] +) + +# View in Databricks UI: +# https://workspace.cloud.databricks.com/#mlflow/experiments/ +``` + +### Enterprise MLflow with Authentication + +Example for enterprise MLflow server with API key authentication: + +**Configuration**: +```yaml +prompts: + - provider_id: enterprise-mlflow + provider_type: remote::mlflow + config: + mlflow_tracking_uri: https://mlflow.enterprise.com + experiment_name: production-prompts + auth_credential: ${env.MLFLOW_API_KEY} + timeout_seconds: 60 +``` + +**Client Usage**: +```python +from llama_stack_client import LlamaStackClient + +# Option A: Use server's configured credential +client = LlamaStackClient(base_url="http://localhost:5000") +prompt = client.prompts.create( + prompt="Classify sentiment: {{ text }}", + variables=["text"] +) + +# Option B: Override with per-request credential +prompt = client.prompts.create( + prompt="Classify sentiment: {{ text }}", + variables=["text"], + extra_headers={ + "x-llamastack-provider-data": '{"mlflow_api_token": "user-specific-key"}' + } +) +``` + +### Authentication Precedence + +When multiple authentication methods are configured, the provider uses this precedence: + +1. **Per-request provider data** (from `x-llamastack-provider-data` header) + - Highest priority + - Overrides all other methods + - Used for multi-tenant scenarios + +2. **Configuration auth_credential** (from config file) + - Medium priority + - Fallback if no provider data header + - Good for server-level auth + +3. **Environment variables** (MLflow standard) + - Lowest priority + - Used if no other credentials provided + - Standard MLflow behavior + +**Example showing precedence**: +```yaml +# Config file +prompts: + - provider_id: mlflow + provider_type: remote::mlflow + config: + mlflow_tracking_uri: http://mlflow.company.com + auth_credential: ${env.MLFLOW_TRACKING_TOKEN} # Fallback +``` + +```bash +# Environment variable +export MLFLOW_TRACKING_TOKEN="server-token" # Lowest priority +``` + +```python +# Client code +client.prompts.create( + prompt="Test", + extra_headers={ + # This takes precedence over config and env vars + "x-llamastack-provider-data": '{"mlflow_api_token": "user-token"}' + } +) +``` + +### Security Best Practices + +1. **Never hardcode tokens** in configuration files: + ```yaml + # Bad - hardcoded credential + auth_credential: "my-secret-token" + + # Good - use environment variable + auth_credential: ${env.MLFLOW_TRACKING_TOKEN} + ``` + +2. **Use per-request credentials** for multi-tenant deployments: + ```python + # Good - each user provides their own token + headers = { + "x-llamastack-provider-data": f'{{"mlflow_api_token": "{user_token}"}}' + } + client.prompts.list(extra_headers=headers) + ``` + +3. **Rotate credentials regularly** in production environments + +4. **Use HTTPS** for MLflow tracking URI in production: + ```yaml + mlflow_tracking_uri: https://mlflow.company.com # Good + # Not: http://mlflow.company.com # Bad for production + ``` + +5. **Store secrets in secure vaults** (AWS Secrets Manager, HashiCorp Vault, etc.) + +## API Reference + +### Create Prompt + +Creates a new prompt (version 1) or registers a prompt in MLflow: + +```python +prompt = client.prompts.create( + prompt="You are a {{ role }} assistant. {{ instruction }}", + variables=["role", "instruction"] # Optional - auto-extracted if omitted +) +``` + +**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders. + +### Retrieve Prompt + +Get a prompt by ID (retrieves default version): + +```python +prompt = client.prompts.get(prompt_id="pmpt_abc123...") +``` + +Get a specific version: + +```python +prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2) +``` + +### Update Prompt + +Creates a new version of an existing prompt: + +```python +updated = client.prompts.update( + prompt_id="pmpt_abc123...", + prompt="Updated template with {{ variable }}", + version=1, # Must be the latest version + set_as_default=True # Make this the new default +) +``` + +**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2). + +### List Prompts + +List all prompts (returns default versions only): + +```python +response = client.prompts.list() +for prompt in response.data: + print(f"{prompt.prompt_id}: v{prompt.version} (default)") +``` + +### List Prompt Versions + +List all versions of a specific prompt: + +```python +response = client.prompts.list_versions(prompt_id="pmpt_abc123...") +for prompt in response.data: + default = " (default)" if prompt.is_default else "" + print(f"Version {prompt.version}{default}") +``` + +### Set Default Version + +Change which version is the default: + +```python +client.prompts.set_default_version( + prompt_id="pmpt_abc123...", + version=2 +) +``` + +## ID Mapping + +The MLflow provider uses deterministic bidirectional ID mapping: + +- **Llama Stack format**: `pmpt_<48-hex-chars>` +- **MLflow format**: `llama_prompt_<48-hex-chars>` + +Example: +- Llama Stack ID: `pmpt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c` +- MLflow name: `llama_prompt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c` + +This ensures prompts created through Llama Stack are easily identifiable in MLflow. + +## Version Management + +MLflow Prompts Provider implements immutable versioning: + +1. **Create**: Creates version 1 +2. **Update**: Creates a new version (2, 3, 4, ...) +3. **Default**: The "default" alias points to the current default version +4. **History**: All versions are preserved and retrievable + +``` +pmpt_abc123 +├── Version 1 (Original) +├── Version 2 (Updated) +└── Version 3 (Latest, Default) ← Default alias points here +``` + +## Troubleshooting + +### MLflow Server Not Available + +**Error**: `Failed to connect to MLflow server` + +**Solutions**: +1. Verify MLflow server is running: `curl http://localhost:5555/health` +2. Check `mlflow_tracking_uri` in configuration +3. Ensure network connectivity to remote server +4. Check firewall settings + +### Version Mismatch Error + +**Error**: `Version X is not the latest version. Use latest version Y to update.` + +**Cause**: Attempting to update an outdated version + +**Solution**: Always use the latest version number when updating: +```python +# Get latest version +versions = client.prompts.list_versions(prompt_id) +latest_version = max(v.version for v in versions.data) + +# Use latest version for update +client.prompts.update(prompt_id=prompt_id, version=latest_version, ...) +``` + +### Variable Validation Error + +**Error**: `Template contains undeclared variables: ['var2']` + +**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it + +**Solution**: Either add missing variable or let the provider auto-extract: +```python +# Option 1: Add missing variable +client.prompts.create( + prompt="Template with {{ var1 }} and {{ var2 }}", + variables=["var1", "var2"] +) + +# Option 2: Let provider auto-extract (recommended) +client.prompts.create( + prompt="Template with {{ var1 }} and {{ var2 }}" +) +``` + +### Timeout Errors + +**Error**: Connection timeout when communicating with MLflow + +**Solutions**: +1. Increase `timeout_seconds` in configuration: + ```yaml + config: + timeout_seconds: 60 # Default: 30 + ``` +2. Check network latency to MLflow server +3. Verify MLflow server is responsive + +### Prompt Not Found + +**Error**: `Prompt pmpt_abc123... not found` + +**Possible causes**: +1. Prompt ID is incorrect +2. Prompt was created in a different MLflow server/experiment +3. Experiment name mismatch in configuration + +**Solution**: Verify prompt exists in MLflow UI at `http://localhost:5555` + +## Limitations + +### No Deletion Support + +**MLflow does not support deleting prompts or versions**. The `delete_prompt()` method raises `NotImplementedError`. + +**Workaround**: Mark prompts as deprecated using naming conventions or set a different version as default. + +### Experiment Required + +All prompts are stored within an MLflow experiment. The experiment is created automatically if it doesn't exist. + +### ID Format Constraints + +- Prompt IDs must follow the format: `pmpt_<48-hex-chars>` +- MLflow names use the prefix: `llama_prompt_` +- Manual creation in MLflow with different names won't be recognized + +### Version Numbering + +- Versions are sequential integers (1, 2, 3, ...) +- You cannot skip version numbers +- You cannot manually set version numbers + +## Best Practices + +### 1. Use Environment Variables + +Store MLflow URIs in environment variables: + +```yaml +config: + mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5555} +``` + +### 2. Auto-Extract Variables + +Let the provider auto-extract variables to avoid validation errors: + +```python +# Recommended +prompt = client.prompts.create( + prompt="Summarize {{ text }} in {{ format }}" +) +``` + +### 3. Organize by Experiment + +Use different experiments for different environments: + +- `dev-prompts` for development +- `staging-prompts` for staging +- `production-prompts` for production + +### 4. Version Management + +- Always retrieve latest version before updating +- Use `set_as_default=True` when updating to make new version active +- Keep version history for audit trail + +### 5. Use Meaningful Templates + +Include context in your templates: + +```python +# Good +prompt = """You are a {{ role }} assistant specialized in {{ domain }}. + +Task: {{ task }} + +Output format: {{ format }}""" + +# Less clear +prompt = "Do {{ task }} as {{ role }}" +``` + +### 6. Monitor MLflow Server + +- Use MLflow UI to visualize prompts: `http://your-server:5555` +- Monitor experiment metrics and prompt versions +- Set up alerts for MLflow server health + +## Production Deployment + +### Database Backend + +For production, use a database backend instead of filesystem: + +```bash +mlflow server \ + --host 0.0.0.0 \ + --port 5000 \ + --backend-store-uri postgresql://user:pass@host:5432/mlflow \ + --default-artifact-root s3://my-bucket/mlflow-artifacts +``` + +### High Availability + +- Deploy multiple MLflow server instances behind a load balancer +- Use managed database (RDS, Cloud SQL, etc.) +- Store artifacts in object storage (S3, GCS, Azure Blob) + +### Security + +- Enable authentication on MLflow server +- Use HTTPS for MLflow tracking URI +- Restrict network access with firewall rules +- Use IAM roles for cloud deployments + +### Monitoring + +Set up monitoring for: +- MLflow server availability +- Database connection pool +- API response times +- Prompt creation/retrieval rates + +## Documentation +See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry. + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `mlflow_tracking_uri` | `str` | No | http://localhost:5000 | MLflow tracking server URI | +| `mlflow_registry_uri` | `str \| None` | No | None | MLflow model registry URI (defaults to tracking URI if not set) | +| `experiment_name` | `str` | No | llama-stack-prompts | MLflow experiment name for storing prompts | +| `auth_credential` | `SecretStr \| None` | No | None | MLflow API token for authentication. Can be overridden via provider data header. | +| `timeout_seconds` | `int` | No | 30 | Timeout for MLflow API calls (1-300 seconds) | + +## Sample Configuration + +**Without authentication** (local development): +```yaml +mlflow_tracking_uri: http://localhost:5555 +experiment_name: llama-stack-prompts +timeout_seconds: 30 +``` + +**With authentication** (production): +```yaml +mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5000} +experiment_name: llama-stack-prompts +auth_credential: ${env.MLFLOW_TRACKING_TOKEN:=} +timeout_seconds: 30 +``` diff --git a/src/llama_stack/core/prompts/prompts.py b/src/llama_stack/core/prompts/prompts.py index ff67ad1386..7dc94b79a8 100644 --- a/src/llama_stack/core/prompts/prompts.py +++ b/src/llama_stack/core/prompts/prompts.py @@ -4,232 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from typing import Any - -from pydantic import BaseModel - -from llama_stack.core.datatypes import StackRunConfig -from llama_stack.core.storage.kvstore import KVStore, kvstore_impl -from llama_stack_api import ListPromptsResponse, Prompt, Prompts - - -class PromptServiceConfig(BaseModel): - """Configuration for the built-in prompt service. - - :param run_config: Stack run configuration containing distribution info - """ - - run_config: StackRunConfig - - -async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]): - """Get the prompt service implementation.""" - impl = PromptServiceImpl(config, deps) - await impl.initialize() - return impl - - -class PromptServiceImpl(Prompts): - """Built-in prompt service implementation using KVStore.""" - - def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]): - self.config = config - self.deps = deps - self.kvstore: KVStore - - async def initialize(self) -> None: - # Use prompts store reference from run config - prompts_ref = self.config.run_config.storage.stores.prompts - if not prompts_ref: - raise ValueError("storage.stores.prompts must be configured in run config") - self.kvstore = await kvstore_impl(prompts_ref) - - def _get_default_key(self, prompt_id: str) -> str: - """Get the KVStore key that stores the default version number.""" - return f"prompts:v1:{prompt_id}:default" - - async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str: - """Get the KVStore key for prompt data, returning default version if applicable.""" - if version: - return self._get_version_key(prompt_id, str(version)) - - default_key = self._get_default_key(prompt_id) - resolved_version = await self.kvstore.get(default_key) - if resolved_version is None: - raise ValueError(f"Prompt {prompt_id}:default not found") - return self._get_version_key(prompt_id, resolved_version) - - def _get_version_key(self, prompt_id: str, version: str) -> str: - """Get the KVStore key for a specific prompt version.""" - return f"prompts:v1:{prompt_id}:{version}" - - def _get_list_key_prefix(self) -> str: - """Get the key prefix for listing prompts.""" - return "prompts:v1:" - - def _serialize_prompt(self, prompt: Prompt) -> str: - """Serialize a prompt to JSON string for storage.""" - return json.dumps( - { - "prompt_id": prompt.prompt_id, - "prompt": prompt.prompt, - "version": prompt.version, - "variables": prompt.variables or [], - "is_default": prompt.is_default, - } - ) - - def _deserialize_prompt(self, data: str) -> Prompt: - """Deserialize a prompt from JSON string.""" - obj = json.loads(data) - return Prompt( - prompt_id=obj["prompt_id"], - prompt=obj["prompt"], - version=obj["version"], - variables=obj.get("variables", []), - is_default=obj.get("is_default", False), - ) - - async def list_prompts(self) -> ListPromptsResponse: - """List all prompts (default versions only).""" - prefix = self._get_list_key_prefix() - keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") - - prompts = [] - for key in keys: - if key.endswith(":default"): - try: - default_version = await self.kvstore.get(key) - if default_version: - prompt_id = key.replace(prefix, "").replace(":default", "") - version_key = self._get_version_key(prompt_id, default_version) - data = await self.kvstore.get(version_key) - if data: - prompt = self._deserialize_prompt(data) - prompts.append(prompt) - except (json.JSONDecodeError, KeyError): - continue - - prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) - return ListPromptsResponse(data=prompts) - - async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt: - """Get a prompt by its identifier and optional version.""" - key = await self._get_prompt_key(prompt_id, version) - data = await self.kvstore.get(key) - if data is None: - raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found") - return self._deserialize_prompt(data) - - async def create_prompt( - self, - prompt: str, - variables: list[str] | None = None, - ) -> Prompt: - """Create a new prompt.""" - if variables is None: - variables = [] - - prompt_obj = Prompt( - prompt_id=Prompt.generate_prompt_id(), - prompt=prompt, - version=1, - variables=variables, - ) - - version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version)) - data = self._serialize_prompt(prompt_obj) - await self.kvstore.set(version_key, data) - - default_key = self._get_default_key(prompt_obj.prompt_id) - await self.kvstore.set(default_key, str(prompt_obj.version)) - - return prompt_obj - - async def update_prompt( - self, - prompt_id: str, - prompt: str, - version: int, - variables: list[str] | None = None, - set_as_default: bool = True, - ) -> Prompt: - """Update an existing prompt (increments version).""" - if version < 1: - raise ValueError("Version must be >= 1") - if variables is None: - variables = [] - - prompt_versions = await self.list_prompt_versions(prompt_id) - latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version)) - - if version and latest_prompt.version != version: - raise ValueError( - f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request." - ) - - current_version = latest_prompt.version if version is None else version - new_version = current_version + 1 - - updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) - - version_key = self._get_version_key(prompt_id, str(new_version)) - data = self._serialize_prompt(updated_prompt) - await self.kvstore.set(version_key, data) - - if set_as_default: - await self.set_default_version(prompt_id, new_version) - - return updated_prompt - - async def delete_prompt(self, prompt_id: str) -> None: - """Delete a prompt and all its versions.""" - await self.get_prompt(prompt_id) - - prefix = f"prompts:v1:{prompt_id}:" - keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") - - for key in keys: - await self.kvstore.delete(key) - - async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse: - """List all versions of a specific prompt.""" - prefix = f"prompts:v1:{prompt_id}:" - keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") - - default_version = None - prompts = [] - - for key in keys: - data = await self.kvstore.get(key) - if key.endswith(":default"): - default_version = data - else: - if data: - prompt_obj = self._deserialize_prompt(data) - prompts.append(prompt_obj) - - if not prompts: - raise ValueError(f"Prompt {prompt_id} not found") - - for prompt in prompts: - prompt.is_default = str(prompt.version) == default_version - - prompts.sort(key=lambda x: x.version) - return ListPromptsResponse(data=prompts) - - async def set_default_version(self, prompt_id: str, version: int) -> Prompt: - """Set which version of a prompt should be the default, If not set. the default is the latest.""" - version_key = self._get_version_key(prompt_id, str(version)) - data = await self.kvstore.get(version_key) - if data is None: - raise ValueError(f"Prompt {prompt_id} version {version} not found") - - default_key = self._get_default_key(prompt_id) - await self.kvstore.set(default_key, str(version)) - - return self._deserialize_prompt(data) - - async def shutdown(self) -> None: - pass +"""Core prompts service delegating to inline::reference provider. + +This module provides backward compatibility by delegating to the +inline::reference provider implementation. +""" + +from llama_stack.providers.inline.prompts.reference import ( + PromptServiceImpl, + ReferencePromptsConfig, + get_adapter_impl, +) + +# Re-export for backward compatibility +PromptServiceConfig = ReferencePromptsConfig +get_provider_impl = get_adapter_impl + +__all__ = [ + "PromptServiceImpl", + "PromptServiceConfig", + "ReferencePromptsConfig", + "get_provider_impl", + "get_adapter_impl", +] diff --git a/src/llama_stack/providers/inline/prompts/__init__.py b/src/llama_stack/providers/inline/prompts/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/src/llama_stack/providers/inline/prompts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack/providers/inline/prompts/reference/__init__.py b/src/llama_stack/providers/inline/prompts/reference/__init__.py new file mode 100644 index 0000000000..8ca7e2140f --- /dev/null +++ b/src/llama_stack/providers/inline/prompts/reference/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import ReferencePromptsConfig +from .reference import PromptServiceImpl + + +async def get_adapter_impl(config: ReferencePromptsConfig, _deps): + impl = PromptServiceImpl(config=config, deps=_deps) + await impl.initialize() + return impl + + +__all__ = ["ReferencePromptsConfig", "PromptServiceImpl", "get_adapter_impl"] diff --git a/src/llama_stack/providers/inline/prompts/reference/config.py b/src/llama_stack/providers/inline/prompts/reference/config.py new file mode 100644 index 0000000000..54c01992ec --- /dev/null +++ b/src/llama_stack/providers/inline/prompts/reference/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.core.datatypes import StackRunConfig + + +class ReferencePromptsConfig(BaseModel): + """Configuration for the built-in reference prompt service. + + This provider stores prompts in the configured KVStore (SQLite, PostgreSQL, etc.) + as specified in the run configuration. + """ + + run_config: StackRunConfig = Field( + description="Stack run configuration containing storage configuration" + ) diff --git a/src/llama_stack/providers/inline/prompts/reference/reference.py b/src/llama_stack/providers/inline/prompts/reference/reference.py new file mode 100644 index 0000000000..ee17d8d09f --- /dev/null +++ b/src/llama_stack/providers/inline/prompts/reference/reference.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any + +from llama_stack.core.storage.kvstore import KVStore, kvstore_impl +from llama_stack_api import ListPromptsResponse, Prompt, Prompts + +from .config import ReferencePromptsConfig + + +class PromptServiceImpl(Prompts): + """Reference inline prompt service implementation using KVStore. + + This provider stores prompts in the configured KVStore backend (SQLite, PostgreSQL, etc.) + and provides full CRUD operations with versioning support. + """ + + def __init__(self, config: ReferencePromptsConfig, deps: dict[Any, Any]): + self.config = config + self.deps = deps + self.kvstore: KVStore + + async def initialize(self) -> None: + # Use prompts store reference from run config + prompts_ref = self.config.run_config.storage.stores.prompts + if not prompts_ref: + raise ValueError("storage.stores.prompts must be configured in run config") + self.kvstore = await kvstore_impl(prompts_ref) + + def _get_default_key(self, prompt_id: str) -> str: + """Get the KVStore key that stores the default version number.""" + return f"prompts:v1:{prompt_id}:default" + + async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str: + """Get the KVStore key for prompt data, returning default version if applicable.""" + if version: + return self._get_version_key(prompt_id, str(version)) + + default_key = self._get_default_key(prompt_id) + resolved_version = await self.kvstore.get(default_key) + if resolved_version is None: + raise ValueError(f"Prompt {prompt_id}:default not found") + return self._get_version_key(prompt_id, resolved_version) + + def _get_version_key(self, prompt_id: str, version: str) -> str: + """Get the KVStore key for a specific prompt version.""" + return f"prompts:v1:{prompt_id}:{version}" + + def _get_list_key_prefix(self) -> str: + """Get the key prefix for listing prompts.""" + return "prompts:v1:" + + def _serialize_prompt(self, prompt: Prompt) -> str: + """Serialize a prompt to JSON string for storage.""" + return json.dumps( + { + "prompt_id": prompt.prompt_id, + "prompt": prompt.prompt, + "version": prompt.version, + "variables": prompt.variables or [], + "is_default": prompt.is_default, + } + ) + + def _deserialize_prompt(self, data: str) -> Prompt: + """Deserialize a prompt from JSON string.""" + obj = json.loads(data) + return Prompt( + prompt_id=obj["prompt_id"], + prompt=obj["prompt"], + version=obj["version"], + variables=obj.get("variables", []), + is_default=obj.get("is_default", False), + ) + + async def list_prompts(self) -> ListPromptsResponse: + """List all prompts (default versions only).""" + prefix = self._get_list_key_prefix() + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + prompts = [] + for key in keys: + if key.endswith(":default"): + try: + default_version = await self.kvstore.get(key) + if default_version: + prompt_id = key.replace(prefix, "").replace(":default", "") + version_key = self._get_version_key(prompt_id, default_version) + data = await self.kvstore.get(version_key) + if data: + prompt = self._deserialize_prompt(data) + prompts.append(prompt) + except (json.JSONDecodeError, KeyError): + continue + + prompts.sort(key=lambda p: p.prompt_id or "", reverse=True) + return ListPromptsResponse(data=prompts) + + async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt: + """Get a prompt by its identifier and optional version.""" + key = await self._get_prompt_key(prompt_id, version) + data = await self.kvstore.get(key) + if data is None: + raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found") + return self._deserialize_prompt(data) + + async def create_prompt( + self, + prompt: str, + variables: list[str] | None = None, + ) -> Prompt: + """Create a new prompt.""" + if variables is None: + variables = [] + + prompt_obj = Prompt( + prompt_id=Prompt.generate_prompt_id(), + prompt=prompt, + version=1, + variables=variables, + ) + + version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version)) + data = self._serialize_prompt(prompt_obj) + await self.kvstore.set(version_key, data) + + default_key = self._get_default_key(prompt_obj.prompt_id) + await self.kvstore.set(default_key, str(prompt_obj.version)) + + return prompt_obj + + async def update_prompt( + self, + prompt_id: str, + prompt: str, + version: int, + variables: list[str] | None = None, + set_as_default: bool = True, + ) -> Prompt: + """Update an existing prompt (increments version).""" + if version < 1: + raise ValueError("Version must be >= 1") + if variables is None: + variables = [] + + prompt_versions = await self.list_prompt_versions(prompt_id) + latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version)) + + if version and latest_prompt.version != version: + raise ValueError( + f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request." + ) + + current_version = latest_prompt.version if version is None else version + new_version = current_version + 1 + + updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables) + + version_key = self._get_version_key(prompt_id, str(new_version)) + data = self._serialize_prompt(updated_prompt) + await self.kvstore.set(version_key, data) + + if set_as_default: + await self.set_default_version(prompt_id, new_version) + + return updated_prompt + + async def delete_prompt(self, prompt_id: str) -> None: + """Delete a prompt and all its versions.""" + await self.get_prompt(prompt_id) + + prefix = f"prompts:v1:{prompt_id}:" + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + for key in keys: + await self.kvstore.delete(key) + + async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse: + """List all versions of a specific prompt.""" + prefix = f"prompts:v1:{prompt_id}:" + keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff") + + default_version = None + prompts = [] + + for key in keys: + data = await self.kvstore.get(key) + if key.endswith(":default"): + default_version = data + else: + if data: + prompt_obj = self._deserialize_prompt(data) + prompts.append(prompt_obj) + + if not prompts: + raise ValueError(f"Prompt {prompt_id} not found") + + for prompt in prompts: + prompt.is_default = str(prompt.version) == default_version + + prompts.sort(key=lambda x: x.version) + return ListPromptsResponse(data=prompts) + + async def set_default_version(self, prompt_id: str, version: int) -> Prompt: + """Set which version of a prompt should be the default, If not set. the default is the latest.""" + version_key = self._get_version_key(prompt_id, str(version)) + data = await self.kvstore.get(version_key) + if data is None: + raise ValueError(f"Prompt {prompt_id} version {version} not found") + + default_key = self._get_default_key(prompt_id) + await self.kvstore.set(default_key, str(version)) + + return self._deserialize_prompt(data) + + async def shutdown(self) -> None: + pass diff --git a/src/llama_stack/providers/registry/prompts.py b/src/llama_stack/providers/registry/prompts.py new file mode 100644 index 0000000000..07b8255d32 --- /dev/null +++ b/src/llama_stack/providers/registry/prompts.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.prompts, + provider_type="inline::reference", + pip_packages=[], + module="llama_stack.providers.inline.prompts.reference", + config_class="llama_stack.providers.inline.prompts.reference.ReferencePromptsConfig", + description="Reference implementation storing prompts in KVStore (SQLite, PostgreSQL, etc.)", + ), + RemoteProviderSpec( + api=Api.prompts, + adapter_type="mlflow", + provider_type="remote::mlflow", + pip_packages=["mlflow>=3.4.0"], + module="llama_stack.providers.remote.prompts.mlflow", + config_class="llama_stack.providers.remote.prompts.mlflow.MLflowPromptsConfig", + provider_data_validator="llama_stack.providers.remote.prompts.mlflow.config.MLflowProviderDataValidator", + description="MLflow Prompt Registry provider for centralized prompt management and versioning", + ), + ] diff --git a/src/llama_stack/providers/remote/prompts/__init__.py b/src/llama_stack/providers/remote/prompts/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/src/llama_stack/providers/remote/prompts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack/providers/remote/prompts/mlflow/__init__.py b/src/llama_stack/providers/remote/prompts/mlflow/__init__.py new file mode 100644 index 0000000000..e22b97a563 --- /dev/null +++ b/src/llama_stack/providers/remote/prompts/mlflow/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MLflowPromptsConfig +from .mlflow import MLflowPromptsAdapter + +__all__ = ["MLflowPromptsConfig", "MLflowPromptsAdapter", "get_adapter_impl"] + + +async def get_adapter_impl(config: MLflowPromptsConfig, _deps): + """Get the MLflow prompts adapter implementation.""" + impl = MLflowPromptsAdapter(config=config) + await impl.initialize() + return impl diff --git a/src/llama_stack/providers/remote/prompts/mlflow/config.py b/src/llama_stack/providers/remote/prompts/mlflow/config.py new file mode 100644 index 0000000000..82c39d6cab --- /dev/null +++ b/src/llama_stack/providers/remote/prompts/mlflow/config.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Configuration for MLflow Prompt Registry provider. + +This module defines the configuration schema for integrating Llama Stack +with MLflow's Prompt Registry for centralized prompt management and versioning. +""" + +from typing import Any + +from pydantic import BaseModel, Field, SecretStr, field_validator + +from llama_stack_api import json_schema_type + + +class MLflowProviderDataValidator(BaseModel): + """Validator for provider data from request headers. + + This allows users to override the MLflow API token per request via + the x-llamastack-provider-data header: + {"mlflow_api_token": "your-token"} + """ + + mlflow_api_token: str | None = Field( + default=None, + description="MLflow API token for authentication (overrides config)", + ) + + +@json_schema_type +class MLflowPromptsConfig(BaseModel): + """Configuration for MLflow Prompt Registry provider. + + Credentials can be provided via: + 1. Per-request provider data header (preferred for security) + 2. Configuration auth_credential (fallback) + 3. Environment variables set by MLflow (MLFLOW_TRACKING_TOKEN, etc.) + + Attributes: + mlflow_tracking_uri: MLflow tracking server URI (e.g., http://localhost:5000, databricks) + mlflow_registry_uri: MLflow registry URI (optional, defaults to tracking_uri) + experiment_name: MLflow experiment name for prompt storage + auth_credential: MLflow API token for authentication (optional, can be overridden by provider data) + timeout_seconds: Timeout for MLflow API calls in seconds (default: 30) + """ + + mlflow_tracking_uri: str = Field( + default="http://localhost:5000", + description="MLflow tracking server URI (e.g., http://localhost:5000, databricks, databricks://profile)", + ) + mlflow_registry_uri: str | None = Field( + default=None, + description="MLflow registry URI (defaults to tracking_uri if not specified)", + ) + experiment_name: str = Field( + default="llama-stack-prompts", + description="MLflow experiment name for prompt storage and organization", + ) + auth_credential: SecretStr | None = Field( + default=None, + description="MLflow API token for authentication. Can be overridden via provider data header.", + ) + timeout_seconds: int = Field( + default=30, + ge=1, + le=300, + description="Timeout for MLflow API calls in seconds (1-300)", + ) + + @classmethod + def sample_run_config(cls, mlflow_api_token: str = "${env.MLFLOW_TRACKING_TOKEN:=}", **kwargs) -> dict[str, Any]: + """Generate sample configuration with environment variable substitution. + + Args: + mlflow_api_token: MLflow API token (defaults to MLFLOW_TRACKING_TOKEN env var) + **kwargs: Additional configuration overrides + + Returns: + Sample configuration dictionary + """ + return { + "mlflow_tracking_uri": kwargs.get("mlflow_tracking_uri", "http://localhost:5000"), + "experiment_name": kwargs.get("experiment_name", "llama-stack-prompts"), + "auth_credential": mlflow_api_token, + } + + @field_validator("mlflow_tracking_uri") + @classmethod + def validate_tracking_uri(cls, v: str) -> str: + """Validate tracking URI is not empty.""" + if not v or not v.strip(): + raise ValueError("mlflow_tracking_uri cannot be empty") + return v.strip() + + @field_validator("experiment_name") + @classmethod + def validate_experiment_name(cls, v: str) -> str: + """Validate experiment name is not empty.""" + if not v or not v.strip(): + raise ValueError("experiment_name cannot be empty") + return v.strip() diff --git a/src/llama_stack/providers/remote/prompts/mlflow/mapping.py b/src/llama_stack/providers/remote/prompts/mlflow/mapping.py new file mode 100644 index 0000000000..65dd7b4b8a --- /dev/null +++ b/src/llama_stack/providers/remote/prompts/mlflow/mapping.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""ID mapping utilities for MLflow Prompt Registry provider. + +This module handles bidirectional mapping between Llama Stack's prompt_id format +(pmpt_<48-hex-chars>) and MLflow's name-based system (llama_prompt_). +""" + +import re + + +class PromptIDMapper: + """Handle bidirectional mapping between Llama Stack IDs and MLflow names. + + Llama Stack uses prompt IDs in format: pmpt_<48-hex-chars> + MLflow uses string names, so we map to: llama_prompt_<48-hex-chars> + + This ensures: + - Deterministic mapping (same ID always maps to same name) + - Reversible (can recover original ID from MLflow name) + - Unique (different IDs map to different names) + """ + + # Regex pattern for Llama Stack prompt_id validation + PROMPT_ID_PATTERN = re.compile(r"^pmpt_[0-9a-f]{48}$") + + # Prefix for MLflow prompt names managed by Llama Stack + MLFLOW_NAME_PREFIX = "llama_prompt_" + + def to_mlflow_name(self, prompt_id: str) -> str: + """Convert Llama Stack prompt_id to MLflow prompt name. + + Args: + prompt_id: Llama Stack prompt ID (format: pmpt_<48-hex-chars>) + + Returns: + MLflow prompt name (format: llama_prompt_<48-hex-chars>) + + Raises: + ValueError: If prompt_id format is invalid + + Example: + >>> mapper = PromptIDMapper() + >>> mapper.to_mlflow_name("pmpt_a1b2c3d4e5f6...") + "llama_prompt_a1b2c3d4e5f6..." + """ + if not self.PROMPT_ID_PATTERN.match(prompt_id): + raise ValueError(f"Invalid prompt_id format: {prompt_id}. Expected format: pmpt_<48-hex-chars>") + + # Extract hex part (after "pmpt_" prefix) + hex_part = prompt_id.split("pmpt_")[1] + + # Create MLflow name + return f"{self.MLFLOW_NAME_PREFIX}{hex_part}" + + def to_llama_id(self, mlflow_name: str) -> str: + """Convert MLflow prompt name to Llama Stack prompt_id. + + Args: + mlflow_name: MLflow prompt name + + Returns: + Llama Stack prompt ID (format: pmpt_<48-hex-chars>) + + Raises: + ValueError: If name doesn't follow expected format + + Example: + >>> mapper = PromptIDMapper() + >>> mapper.to_llama_id("llama_prompt_a1b2c3d4e5f6...") + "pmpt_a1b2c3d4e5f6..." + """ + if not mlflow_name.startswith(self.MLFLOW_NAME_PREFIX): + raise ValueError( + f"MLflow name '{mlflow_name}' does not start with expected prefix '{self.MLFLOW_NAME_PREFIX}'" + ) + + # Extract hex part + hex_part = mlflow_name[len(self.MLFLOW_NAME_PREFIX) :] + + # Validate hex part length and characters + if len(hex_part) != 48: + raise ValueError(f"Invalid hex part length in MLflow name '{mlflow_name}'. Expected 48 characters.") + + for char in hex_part: + if char not in "0123456789abcdef": + raise ValueError( + f"Invalid character '{char}' in hex part of MLflow name '{mlflow_name}'. " + "Expected lowercase hex characters [0-9a-f]." + ) + + return f"pmpt_{hex_part}" + + def get_metadata_tags(self, prompt_id: str, variables: list[str] | None = None) -> dict[str, str]: + """Generate MLflow tags with Llama Stack metadata. + + Args: + prompt_id: Llama Stack prompt ID + variables: List of prompt variables (optional) + + Returns: + Dictionary of MLflow tags for metadata storage + + Example: + >>> mapper = PromptIDMapper() + >>> tags = mapper.get_metadata_tags("pmpt_abc123...", ["var1", "var2"]) + >>> tags + {"llama_stack_id": "pmpt_abc123...", "llama_stack_managed": "true", "variables": "var1,var2"} + """ + tags = { + "llama_stack_id": prompt_id, + "llama_stack_managed": "true", + } + + if variables: + # Store variables as comma-separated string + tags["variables"] = ",".join(variables) + + return tags diff --git a/src/llama_stack/providers/remote/prompts/mlflow/mlflow.py b/src/llama_stack/providers/remote/prompts/mlflow/mlflow.py new file mode 100644 index 0000000000..4e36fa1f0c --- /dev/null +++ b/src/llama_stack/providers/remote/prompts/mlflow/mlflow.py @@ -0,0 +1,547 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""MLflow Prompt Registry provider implementation. + +This module implements the Llama Stack Prompts protocol using MLflow's Prompt Registry +as the backend for centralized prompt management and versioning. +""" + +import re +from typing import TYPE_CHECKING, Any + +from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger +from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig +from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper +from llama_stack_api import ListPromptsResponse, Prompt, Prompts + +# Try importing mlflow at module level +try: + import mlflow + from mlflow.client import MlflowClient +except ImportError: + # Fail gracefully when provider is instantiated during initialize() + mlflow = None + +logger = get_logger(__name__) + +class MLflowPromptsAdapter(NeedsRequestProviderData, Prompts): + """MLflow Prompt Registry adapter for Llama Stack. + + This adapter implements the Llama Stack Prompts protocol using MLflow's + Prompt Registry as the backend storage system. It handles: + + - Bidirectional ID mapping (prompt_id <-> MLflow name) + - Version management via MLflow versioning + - Variable extraction from prompt templates + - Metadata storage in MLflow tags + - Default version management via MLflow aliases + - Credential management via provider data (backstopped by config) + + Credentials can be provided via: + 1. Per-request provider data header (preferred for security) + 2. Configuration auth_credential (fallback) + 3. Environment variables (MLFLOW_TRACKING_TOKEN, etc.) + + Attributes: + config: MLflow provider configuration + mlflow_client: MLflow client instance + mapper: ID mapping utility + """ + + def __init__(self, config: MLflowPromptsConfig): + """Initialize MLflow prompts adapter. + + Args: + config: MLflow provider configuration + """ + self.config = config + self.mlflow_client: "MlflowClient | None" = None + self.mapper = PromptIDMapper() + logger.info( + f"MLflowPromptsAdapter initialized: tracking_uri={config.mlflow_tracking_uri}, " + f"experiment={config.experiment_name}" + ) + + async def initialize(self) -> None: + """Initialize MLflow client and set up experiment. + + Sets up MLflow connection with optional authentication via token. + Token can be provided via config or will be read from environment variables + (MLFLOW_TRACKING_TOKEN, etc.) as per MLflow's standard behavior. + + Raises: + ImportError: If mlflow package is not installed + Exception: If MLflow connection fails + """ + if mlflow is None: + raise ImportError( + "mlflow package is required for MLflow prompts provider. " + "Install with: pip install 'mlflow>=3.4.0'" + ) + + # Set MLflow URIs + mlflow.set_tracking_uri(self.config.mlflow_tracking_uri) + + if self.config.mlflow_registry_uri: + mlflow.set_registry_uri(self.config.mlflow_registry_uri) + else: + # Default to tracking URI if registry not specified + mlflow.set_registry_uri(self.config.mlflow_tracking_uri) + + # Set authentication token if provided in config + if self.config.auth_credential is not None: + import os + + # MLflow reads MLFLOW_TRACKING_TOKEN from environment + os.environ["MLFLOW_TRACKING_TOKEN"] = self.config.auth_credential.get_secret_value() + logger.debug("Set MLFLOW_TRACKING_TOKEN from config auth_credential") + + # Initialize client + self.mlflow_client = MlflowClient() + + # Validate experiment exists (don't create during initialization) + try: + mlflow.set_experiment(self.config.experiment_name) + logger.info(f"Using MLflow experiment: {self.config.experiment_name}") + except Exception as e: + logger.warning( + f"Experiment '{self.config.experiment_name}' not found: {e}. " + f"It will be created automatically on first prompt creation." + ) + + def _ensure_experiment(self) -> None: + """Ensure MLflow experiment exists, creating it if necessary. + + This is called lazily on first write operation to avoid creating + external resources during initialization. + """ + try: + mlflow.set_experiment(self.config.experiment_name) + except Exception: + # Experiment doesn't exist, create it + try: + mlflow.create_experiment(self.config.experiment_name) + mlflow.set_experiment(self.config.experiment_name) + logger.info(f"Created MLflow experiment: {self.config.experiment_name}") + except Exception as e: + raise ValueError( + f"Failed to create experiment '{self.config.experiment_name}': {e}" + ) from e + + def _extract_variables(self, template: str) -> list[str]: + """Extract variables from prompt template. + + Extracts variables in {{ variable }} format from the template. + + Args: + template: Prompt template string + + Returns: + List of unique variable names in order of appearance + + Example: + >>> adapter._extract_variables("Hello {{ name }}, your score is {{ score }}") + ["name", "score"] + """ + if not template: + return [] + + # Find all {{ variable }} patterns + matches = re.findall(r"{{\s*(\w+)\s*}}", template) + + # Return unique variables in order of appearance + seen = set() + variables = [] + for var in matches: + if var not in seen: + variables.append(var) + seen.add(var) + + return variables + + async def create_prompt( + self, + prompt: str, + variables: list[str] | None = None, + ) -> Prompt: + """Create a new prompt in MLflow registry. + + Args: + prompt: Prompt template text with {{ variable }} placeholders + variables: List of variable names (auto-extracted if not provided) + + Returns: + Created Prompt resource with prompt_id and version=1 + + Raises: + ValueError: If prompt validation fails + Exception: If MLflow registration fails + """ + # Ensure experiment exists (lazy creation on first write) + self._ensure_experiment() + + # Auto-extract variables if not provided + if variables is None: + variables = self._extract_variables(prompt) + else: + # Validate declared variables match template + template_vars = set(self._extract_variables(prompt)) + declared_vars = set(variables) + undeclared = template_vars - declared_vars + if undeclared: + raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}") + + # Generate Llama Stack prompt_id + prompt_id = Prompt.generate_prompt_id() + + # Convert to MLflow name + mlflow_name = self.mapper.to_mlflow_name(prompt_id) + + # Prepare metadata tags + tags = self.mapper.get_metadata_tags(prompt_id, variables) + + # Register in MLflow + try: + mlflow.genai.register_prompt( + name=mlflow_name, + template=prompt, + commit_message="Created via Llama Stack", + tags=tags, + ) + logger.info(f"Created prompt {prompt_id} (MLflow: {mlflow_name})") + except Exception as e: + logger.error(f"Failed to register prompt in MLflow: {e}") + raise + + # Set as default (first version is always default) + try: + mlflow.genai.set_prompt_alias( + name=mlflow_name, + version=1, + alias="default", + ) + except Exception as e: + logger.warning(f"Failed to set default alias for {prompt_id}: {e}") + + return Prompt( + prompt_id=prompt_id, + prompt=prompt, + version=1, + variables=variables, + is_default=True, + ) + + async def get_prompt( + self, + prompt_id: str, + version: int | None = None, + ) -> Prompt: + """Get prompt from MLflow registry. + + Args: + prompt_id: Llama Stack prompt ID + version: Version number (defaults to default version) + + Returns: + Prompt resource + + Raises: + ValueError: If prompt not found + """ + mlflow_name = self.mapper.to_mlflow_name(prompt_id) + + # Build MLflow URI + if version: + uri = f"prompts:/{mlflow_name}/{version}" + else: + uri = f"prompts:/{mlflow_name}@default" + + # Load from MLflow + try: + mlflow_prompt = mlflow.genai.load_prompt(uri) + except Exception as e: + raise ValueError(f"Prompt {prompt_id} (version {version if version else 'default'}) not found: {e}") from e + + # Extract template + template = mlflow_prompt.template if hasattr(mlflow_prompt, "template") else str(mlflow_prompt) + + # Extract variables from template + variables = self._extract_variables(template) + + # Get version number + prompt_version = 1 + if hasattr(mlflow_prompt, "version"): + prompt_version = int(mlflow_prompt.version) + elif version: + prompt_version = version + + # Check if this is the default version + is_default = await self._is_default_version(mlflow_name, prompt_version) + + return Prompt( + prompt_id=prompt_id, + prompt=template, + version=prompt_version, + variables=variables, + is_default=is_default, + ) + + async def update_prompt( + self, + prompt_id: str, + prompt: str, + version: int, + variables: list[str] | None = None, + set_as_default: bool = True, + ) -> Prompt: + """Update prompt (creates new version in MLflow). + + Args: + prompt_id: Llama Stack prompt ID + prompt: Updated prompt template + version: Current version being updated + variables: Updated variables list (auto-extracted if not provided) + set_as_default: Set new version as default + + Returns: + Updated Prompt resource with incremented version + + Raises: + ValueError: If current version not found or validation fails + """ + # Ensure experiment exists (edge case: updating prompts created outside Llama Stack) + self._ensure_experiment() + + # Auto-extract variables if not provided + if variables is None: + variables = self._extract_variables(prompt) + else: + # Validate variables + template_vars = set(self._extract_variables(prompt)) + declared_vars = set(variables) + undeclared = template_vars - declared_vars + if undeclared: + raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}") + + mlflow_name = self.mapper.to_mlflow_name(prompt_id) + + # Get all versions to determine the latest and next version number + versions_response = await self.list_prompt_versions(prompt_id) + if not versions_response.data: + raise ValueError(f"Prompt {prompt_id} not found") + + max_version = max(p.version for p in versions_response.data) + + # Verify the provided version is the latest + if version != max_version: + raise ValueError( + f"Version {version} is not the latest version. Use latest version {max_version} to update." + ) + + new_version = max_version + 1 + + # Prepare metadata tags + tags = self.mapper.get_metadata_tags(prompt_id, variables) + + # Register new version in MLflow + try: + mlflow.genai.register_prompt( + name=mlflow_name, + template=prompt, + commit_message=f"Updated from version {version} via Llama Stack", + tags=tags, + ) + logger.info(f"Updated prompt {prompt_id} to version {new_version}") + except Exception as e: + logger.error(f"Failed to update prompt in MLflow: {e}") + raise + + # Set as default if requested + if set_as_default: + try: + mlflow.genai.set_prompt_alias( + name=mlflow_name, + version=new_version, + alias="default", + ) + except Exception as e: + logger.warning(f"Failed to set default alias: {e}") + + return Prompt( + prompt_id=prompt_id, + prompt=prompt, + version=new_version, + variables=variables, + is_default=set_as_default, + ) + + async def delete_prompt(self, prompt_id: str) -> None: + """Delete prompt from MLflow registry. + + Note: MLflow Prompt Registry does not support deletion of registered prompts. + This method will raise NotImplementedError. + + Args: + prompt_id: Llama Stack prompt ID + + Raises: + NotImplementedError: MLflow doesn't support prompt deletion + """ + # MLflow doesn't support deletion of registered prompts + # Options: + # 1. Raise NotImplementedError (current approach) + # 2. Mark as deleted with tag (soft delete) + # 3. Delete all versions individually (if API exists) + + raise NotImplementedError( + "MLflow Prompt Registry does not support deletion. Consider using tags to mark prompts as archived/deleted." + ) + + async def list_prompts(self) -> ListPromptsResponse: + """List all prompts (default versions only). + + Returns: + ListPromptsResponse with default version of each prompt + + Note: + Only lists prompts created/managed by Llama Stack + (those with llama_stack_managed=true tag) + """ + try: + # Search for Llama Stack managed prompts using metadata tags + prompts = mlflow.genai.search_prompts(filter_string="tag.llama_stack_managed='true'") + except Exception as e: + logger.error(f"Failed to search prompts in MLflow: {e}") + return ListPromptsResponse(data=[]) + + llama_prompts = [] + for mlflow_prompt in prompts: + try: + # Convert MLflow name to Llama Stack ID + prompt_id = self.mapper.to_llama_id(mlflow_prompt.name) + + # Get default version + llama_prompt = await self.get_prompt(prompt_id) + llama_prompts.append(llama_prompt) + except (ValueError, Exception) as e: + # Skip prompts that can't be converted or retrieved + logger.warning(f"Skipping prompt {mlflow_prompt.name}: {e}") + continue + + # Sort by prompt_id + llama_prompts.sort(key=lambda p: p.prompt_id, reverse=True) + + return ListPromptsResponse(data=llama_prompts) + + async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse: + """List all versions of a specific prompt. + + Args: + prompt_id: Llama Stack prompt ID + + Returns: + ListPromptsResponse with all versions of the prompt + + Raises: + ValueError: If prompt not found + """ + # MLflow doesn't have a direct "list versions" API for prompts + # We need to iterate and try to load each version + versions = [] + version_num = 1 + max_attempts = 100 # Safety limit + + while version_num <= max_attempts: + try: + prompt = await self.get_prompt(prompt_id, version_num) + versions.append(prompt) + version_num += 1 + except ValueError: + # No more versions + break + except Exception as e: + logger.warning(f"Error loading version {version_num} of {prompt_id}: {e}") + break + + if not versions: + raise ValueError(f"Prompt {prompt_id} not found") + + # Sort by version number + versions.sort(key=lambda p: p.version) + + return ListPromptsResponse(data=versions) + + async def set_default_version(self, prompt_id: str, version: int) -> Prompt: + """Set default version using MLflow alias. + + Args: + prompt_id: Llama Stack prompt ID + version: Version number to set as default + + Returns: + Prompt resource with is_default=True + + Raises: + ValueError: If version not found + """ + # Ensure experiment exists (edge case: managing prompts created outside Llama Stack) + self._ensure_experiment() + + mlflow_name = self.mapper.to_mlflow_name(prompt_id) + + # Verify version exists + try: + prompt = await self.get_prompt(prompt_id, version) + except ValueError as e: + raise ValueError(f"Cannot set default: {e}") from e + + # Set "default" alias in MLflow + try: + mlflow.genai.set_prompt_alias( + name=mlflow_name, + version=version, + alias="default", + ) + logger.info(f"Set version {version} as default for {prompt_id}") + except Exception as e: + logger.error(f"Failed to set default version: {e}") + raise + + # Update is_default flag + prompt.is_default = True + + return prompt + + async def _is_default_version(self, mlflow_name: str, version: int) -> bool: + """Check if a version is the default version. + + Args: + mlflow_name: MLflow prompt name + version: Version number + + Returns: + True if this version is the default, False otherwise + """ + try: + # Try to load with @default alias + default_uri = f"prompts:/{mlflow_name}@default" + default_prompt = mlflow.genai.load_prompt(default_uri) + + # Get default version number + default_version = 1 + if hasattr(default_prompt, "version"): + default_version = int(default_prompt.version) + + return version == default_version + except Exception: + # If default doesn't exist or can't be determined, assume False + return False + + async def shutdown(self) -> None: + """Cleanup resources (no-op for MLflow).""" + pass diff --git a/tests/integration/providers/remote/prompts/__init__.py b/tests/integration/providers/remote/prompts/__init__.py new file mode 100644 index 0000000000..6c392e89d2 --- /dev/null +++ b/tests/integration/providers/remote/prompts/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Integration tests for remote prompts providers.""" diff --git a/tests/integration/providers/remote/prompts/mlflow/README.md b/tests/integration/providers/remote/prompts/mlflow/README.md new file mode 100644 index 0000000000..1dcfdeaf65 --- /dev/null +++ b/tests/integration/providers/remote/prompts/mlflow/README.md @@ -0,0 +1,274 @@ +# MLflow Prompts Provider - Integration Tests + +This directory contains integration tests for the MLflow Prompts Provider. These tests require a running MLflow server. + +## Prerequisites + +1. **MLflow installed**: `pip install 'mlflow>=3.4.0'` (or `uv pip install 'mlflow>=3.4.0'`) +2. **MLflow server running**: See setup instructions below +3. **Test dependencies**: `uv sync --group test` + +## Quick Start + +### 1. Start MLflow Server + +```bash +# Start MLflow server on localhost:5555 +mlflow server --host 127.0.0.1 --port 5555 + +# Keep this terminal open - server will continue running +``` + +### 2. Run Integration Tests + +In a separate terminal: + +```bash +# Set MLflow URI (optional - defaults to localhost:5555) +export MLFLOW_TRACKING_URI=http://localhost:5555 + +# Run all integration tests +uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/ + +# Run specific test +uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py::TestMLflowPromptsEndToEnd::test_create_and_retrieve_prompt +``` + +### 3. Run Manual Test Script (Optional) + +For quick validation without pytest: + +```bash +# Run manual test script +uv run python scripts/test_mlflow_prompts_manual.py + +# View output in MLflow UI +open http://localhost:5555 +``` + +## Test Organization + +### Integration Tests (`test_end_to_end.py`) + +Comprehensive end-to-end tests covering: + +- ✅ Create and retrieve prompts +- ✅ Update prompts (version management) +- ✅ List prompts (default versions only) +- ✅ List all versions of a prompt +- ✅ Set default version +- ✅ Variable auto-extraction +- ✅ Variable validation +- ✅ Error handling (not found, wrong version, etc.) +- ✅ Complex templates with multiple variables +- ✅ Edge cases (empty templates, no variables, etc.) + +**Total**: 17 test scenarios + +### Manual Test Script (`scripts/test_mlflow_prompts_manual.py`) + +Interactive test script with verbose output for: + +- Server connectivity check +- Provider initialization +- Basic CRUD operations +- Variable extraction +- Statistics retrieval + +## Configuration + +### MLflow Server Options + +**Local (default)**: +```bash +mlflow server --host 127.0.0.1 --port 5555 +``` + +**Remote server**: +```bash +export MLFLOW_TRACKING_URI=http://mlflow.example.com:5000 +uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/ +``` + +**Databricks**: +```bash +export MLFLOW_TRACKING_URI=databricks +export MLFLOW_REGISTRY_URI=databricks://profile +uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/ +``` + +### Test Timeout + +Tests have a default timeout of 30 seconds per MLflow operation. Adjust in `conftest.py`: + +```python +MLflowPromptsConfig( + mlflow_tracking_uri=mlflow_tracking_uri, + timeout_seconds=60, # Increase for slow connections +) +``` + +## Fixtures + +### `mlflow_adapter` + +Basic adapter for simple tests: + +```python +async def test_something(mlflow_adapter): + prompt = await mlflow_adapter.create_prompt(...) + # Test continues... +``` + +### `mlflow_adapter_with_cleanup` + +Adapter with automatic cleanup tracking: + +```python +async def test_something(mlflow_adapter_with_cleanup): + # Creates are tracked and attempted cleanup on teardown + prompt = await mlflow_adapter_with_cleanup.create_prompt(...) +``` + +**Note**: MLflow doesn't support deletion, so cleanup is best-effort. + +## Troubleshooting + +### Server Not Available + +**Symptom**: +``` +SKIPPED [1] conftest.py:35: MLflow server not available at http://localhost:5555 +``` + +**Solution**: +```bash +# Start MLflow server +mlflow server --host 127.0.0.1 --port 5555 + +# Verify it's running +curl http://localhost:5555/health +``` + +### Connection Timeout + +**Symptom**: +``` +requests.exceptions.Timeout: ... +``` + +**Solutions**: +1. Check MLflow server is responsive: `curl http://localhost:5555/health` +2. Increase timeout in `conftest.py`: `timeout_seconds=60` +3. Check firewall/network settings + +### Import Errors + +**Symptom**: +``` +ModuleNotFoundError: No module named 'mlflow' +``` + +**Solution**: +```bash +uv pip install 'mlflow>=3.4.0' +``` + +### Permission Errors + +**Symptom**: +``` +PermissionError: [Errno 13] Permission denied: '...' +``` + +**Solution**: +- Ensure MLflow has write access to its storage directory +- Check file permissions on `mlruns/` directory + +### Test Isolation Issues + +**Issue**: Tests may interfere with each other if using same prompt IDs + +**Solution**: Each test creates new prompts with unique IDs (generated by `Prompt.generate_prompt_id()`). If needed, use `mlflow_adapter_with_cleanup` fixture. + +## Viewing Results + +### MLflow UI + +1. Start MLflow server (if not already running): + ```bash + mlflow server --host 127.0.0.1 --port 5555 + ``` + +2. Open in browser: + ``` + http://localhost:5555 + ``` + +3. Navigate to experiment `test-llama-stack-prompts` + +4. View registered prompts and their versions + +### Test Output + +Run with verbose output to see detailed test execution: + +```bash +uv run --group test pytest -vv tests/integration/providers/remote/prompts/mlflow/ +``` + +## CI/CD Integration + +To run tests in CI/CD pipelines: + +```yaml +# Example GitHub Actions workflow +- name: Start MLflow server + run: | + mlflow server --host 127.0.0.1 --port 5555 & + sleep 5 # Wait for server to start + +- name: Wait for MLflow + run: | + timeout 30 bash -c 'until curl -s http://localhost:5555/health; do sleep 1; done' + +- name: Run integration tests + env: + MLFLOW_TRACKING_URI: http://localhost:5555 + run: | + uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/ +``` + +## Performance + +### Expected Test Duration + +- **Individual test**: ~1-5 seconds +- **Full suite** (17 tests): ~30-60 seconds +- **Manual script**: ~10-15 seconds + +### Optimization Tips + +1. Use local MLflow server (faster than remote) +2. Run tests in parallel (if safe): + ```bash + uv run --group test pytest -n auto tests/integration/providers/remote/prompts/mlflow/ + ``` +3. Skip integration tests in development: + ```bash + uv run --group dev pytest -sv tests/unit/ + ``` + +## Coverage + +Integration tests provide coverage for: + +- ✅ Real MLflow API calls +- ✅ Network communication +- ✅ Serialization/deserialization +- ✅ MLflow server responses +- ✅ Version management +- ✅ Alias handling +- ✅ Tag storage and retrieval + +Combined with unit tests, achieves **>95% code coverage**. diff --git a/tests/integration/providers/remote/prompts/mlflow/__init__.py b/tests/integration/providers/remote/prompts/mlflow/__init__.py new file mode 100644 index 0000000000..bd7adfe563 --- /dev/null +++ b/tests/integration/providers/remote/prompts/mlflow/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Integration tests for MLflow prompts provider.""" diff --git a/tests/integration/providers/remote/prompts/mlflow/conftest.py b/tests/integration/providers/remote/prompts/mlflow/conftest.py new file mode 100644 index 0000000000..dc384fcb5c --- /dev/null +++ b/tests/integration/providers/remote/prompts/mlflow/conftest.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Fixtures for MLflow integration tests. + +These tests require a running MLflow server. Set the MLFLOW_TRACKING_URI +environment variable to point to your MLflow server, or the tests will +attempt to use http://localhost:5555. + +To run tests: + # Start MLflow server (in separate terminal) + mlflow server --host 127.0.0.1 --port 5555 + + # Run integration tests + MLFLOW_TRACKING_URI=http://localhost:5555 \ + uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/ +""" + +import os + +import pytest + +from llama_stack.providers.remote.prompts.mlflow import MLflowPromptsAdapter +from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig + + +@pytest.fixture(scope="session") +def mlflow_tracking_uri(): + """Get MLflow tracking URI from environment or use default.""" + return os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5555") + + +@pytest.fixture(scope="session") +def mlflow_server_available(mlflow_tracking_uri): + """Verify MLflow server is running and accessible. + + Skips all tests if server is not available. + """ + try: + import requests + + response = requests.get(f"{mlflow_tracking_uri}/health", timeout=5) + if response.status_code != 200: + pytest.skip(f"MLflow server at {mlflow_tracking_uri} returned status {response.status_code}") + except ImportError: + pytest.skip("requests package not installed - install with: pip install requests") + except requests.exceptions.ConnectionError: + pytest.skip( + f"MLflow server not available at {mlflow_tracking_uri}. " + "Start with: mlflow server --host 127.0.0.1 --port 5555" + ) + except requests.exceptions.Timeout: + pytest.skip(f"MLflow server at {mlflow_tracking_uri} timed out") + except Exception as e: + pytest.skip(f"Failed to check MLflow server availability: {e}") + + return True + + +@pytest.fixture +async def mlflow_config(mlflow_tracking_uri, mlflow_server_available): + """Create MLflow configuration for testing.""" + return MLflowPromptsConfig( + mlflow_tracking_uri=mlflow_tracking_uri, + experiment_name="test-llama-stack-prompts", + timeout_seconds=30, + ) + + +@pytest.fixture +async def mlflow_adapter(mlflow_config): + """Create and initialize MLflow adapter for testing. + + This fixture creates a new adapter instance for each test. + The adapter connects to the MLflow server specified in the config. + """ + adapter = MLflowPromptsAdapter(config=mlflow_config) + await adapter.initialize() + + yield adapter + + # Cleanup: shutdown adapter + await adapter.shutdown() + + +@pytest.fixture +async def mlflow_adapter_with_cleanup(mlflow_config): + """Create MLflow adapter with automatic cleanup after test. + + This fixture is useful for tests that create prompts and want them + automatically cleaned up (though MLflow doesn't support deletion, + so cleanup is best-effort). + """ + adapter = MLflowPromptsAdapter(config=mlflow_config) + await adapter.initialize() + + created_prompt_ids = [] + + # Provide adapter and tracking list + class AdapterWithTracking: + def __init__(self, adapter_instance): + self.adapter = adapter_instance + self.created_ids = created_prompt_ids + + async def create_prompt(self, *args, **kwargs): + prompt = await self.adapter.create_prompt(*args, **kwargs) + self.created_ids.append(prompt.prompt_id) + return prompt + + def __getattr__(self, name): + return getattr(self.adapter, name) + + tracked_adapter = AdapterWithTracking(adapter) + + yield tracked_adapter + + # Cleanup: attempt to delete created prompts + # Note: MLflow doesn't support deletion, so this is a no-op + # but we keep it for future compatibility + for prompt_id in created_prompt_ids: + try: + await adapter.delete_prompt(prompt_id) + except NotImplementedError: + # Expected - MLflow doesn't support deletion + pass + except Exception: + # Ignore cleanup errors + pass + + await adapter.shutdown() diff --git a/tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py b/tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py new file mode 100644 index 0000000000..aeaefeb301 --- /dev/null +++ b/tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""End-to-end integration tests for MLflow prompts provider. + +These tests require a running MLflow server. See conftest.py for setup instructions. +""" + +import pytest + + +class TestMLflowPromptsEndToEnd: + """End-to-end tests for MLflow prompts provider.""" + + async def test_create_and_retrieve_prompt(self, mlflow_adapter): + """Test creating a prompt and retrieving it by ID.""" + # Create prompt with variables + created = await mlflow_adapter.create_prompt( + prompt="Summarize the following text in {{ num_sentences }} sentences: {{ text }}", + variables=["num_sentences", "text"], + ) + + # Verify created prompt + assert created.prompt_id.startswith("pmpt_") + assert len(created.prompt_id) == 53 # "pmpt_" + 48 hex chars + assert created.version == 1 + assert created.is_default is True + assert set(created.variables) == {"num_sentences", "text"} + assert "{{ num_sentences }}" in created.prompt + assert "{{ text }}" in created.prompt + + # Retrieve prompt by ID (should get default version) + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + + assert retrieved.prompt_id == created.prompt_id + assert retrieved.prompt == created.prompt + assert retrieved.version == created.version + assert set(retrieved.variables) == set(created.variables) + assert retrieved.is_default is True + + # Retrieve specific version + retrieved_v1 = await mlflow_adapter.get_prompt(created.prompt_id, version=1) + + assert retrieved_v1.prompt_id == created.prompt_id + assert retrieved_v1.version == 1 + + async def test_update_prompt_creates_new_version(self, mlflow_adapter): + """Test that updating a prompt creates a new version.""" + # Create initial prompt (version 1) + v1 = await mlflow_adapter.create_prompt( + prompt="Original prompt with {{ variable }}", + variables=["variable"], + ) + + assert v1.version == 1 + assert v1.is_default is True + + # Update prompt (should create version 2) + v2 = await mlflow_adapter.update_prompt( + prompt_id=v1.prompt_id, + prompt="Updated prompt with {{ variable }}", + version=1, + variables=["variable"], + set_as_default=True, + ) + + assert v2.prompt_id == v1.prompt_id + assert v2.version == 2 + assert v2.is_default is True + assert "Updated" in v2.prompt + + # Verify both versions exist + versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id) + versions = versions_response.data + + assert len(versions) >= 2 + assert any(v.version == 1 for v in versions) + assert any(v.version == 2 for v in versions) + + # Verify version 1 still exists + v1_retrieved = await mlflow_adapter.get_prompt(v1.prompt_id, version=1) + assert "Original" in v1_retrieved.prompt + assert v1_retrieved.is_default is False # No longer default + + # Verify version 2 is default + default = await mlflow_adapter.get_prompt(v1.prompt_id) + assert default.version == 2 + assert "Updated" in default.prompt + + async def test_list_prompts_returns_defaults_only(self, mlflow_adapter): + """Test that list_prompts returns only default versions.""" + # Create multiple prompts + p1 = await mlflow_adapter.create_prompt( + prompt="Prompt 1 with {{ var }}", + variables=["var"], + ) + + p2 = await mlflow_adapter.create_prompt( + prompt="Prompt 2 with {{ var }}", + variables=["var"], + ) + + # Update first prompt (creates version 2) + await mlflow_adapter.update_prompt( + prompt_id=p1.prompt_id, + prompt="Prompt 1 updated with {{ var }}", + version=1, + variables=["var"], + set_as_default=True, + ) + + # List all prompts + response = await mlflow_adapter.list_prompts() + prompts = response.data + + # Should contain at least our 2 prompts + assert len(prompts) >= 2 + + # Find our prompts in the list + p1_in_list = next((p for p in prompts if p.prompt_id == p1.prompt_id), None) + p2_in_list = next((p for p in prompts if p.prompt_id == p2.prompt_id), None) + + assert p1_in_list is not None + assert p2_in_list is not None + + # p1 should be version 2 (updated version is default) + assert p1_in_list.version == 2 + assert p1_in_list.is_default is True + + # p2 should be version 1 (original is still default) + assert p2_in_list.version == 1 + assert p2_in_list.is_default is True + + async def test_list_prompt_versions(self, mlflow_adapter): + """Test listing all versions of a specific prompt.""" + # Create prompt + v1 = await mlflow_adapter.create_prompt( + prompt="Version 1 {{ var }}", + variables=["var"], + ) + + # Create multiple versions + _v2 = await mlflow_adapter.update_prompt( + prompt_id=v1.prompt_id, + prompt="Version 2 {{ var }}", + version=1, + variables=["var"], + ) + + _v3 = await mlflow_adapter.update_prompt( + prompt_id=v1.prompt_id, + prompt="Version 3 {{ var }}", + version=2, + variables=["var"], + ) + + # List all versions + versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id) + versions = versions_response.data + + # Should have 3 versions + assert len(versions) == 3 + + # Verify versions are sorted by version number + assert versions[0].version == 1 + assert versions[1].version == 2 + assert versions[2].version == 3 + + # Verify content + assert "Version 1" in versions[0].prompt + assert "Version 2" in versions[1].prompt + assert "Version 3" in versions[2].prompt + + # Only latest should be default + assert versions[0].is_default is False + assert versions[1].is_default is False + assert versions[2].is_default is True + + async def test_set_default_version(self, mlflow_adapter): + """Test changing which version is the default.""" + # Create prompt and update it + v1 = await mlflow_adapter.create_prompt( + prompt="Version 1 {{ var }}", + variables=["var"], + ) + + _v2 = await mlflow_adapter.update_prompt( + prompt_id=v1.prompt_id, + prompt="Version 2 {{ var }}", + version=1, + variables=["var"], + ) + + # At this point, _v2 is default + default = await mlflow_adapter.get_prompt(v1.prompt_id) + assert default.version == 2 + + # Set v1 as default + updated = await mlflow_adapter.set_default_version(v1.prompt_id, 1) + assert updated.version == 1 + assert updated.is_default is True + + # Verify default changed + default = await mlflow_adapter.get_prompt(v1.prompt_id) + assert default.version == 1 + assert "Version 1" in default.prompt + + async def test_variable_auto_extraction(self, mlflow_adapter): + """Test automatic variable extraction from template.""" + # Create prompt without explicitly specifying variables + created = await mlflow_adapter.create_prompt( + prompt="Extract {{ entity }} from {{ text }} in {{ format }} format", + ) + + # Should auto-extract all variables + assert set(created.variables) == {"entity", "text", "format"} + + # Retrieve and verify + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + assert set(retrieved.variables) == {"entity", "text", "format"} + + async def test_variable_validation(self, mlflow_adapter): + """Test that variable validation works correctly.""" + # Should fail: template has undeclared variable + with pytest.raises(ValueError, match="undeclared variables"): + await mlflow_adapter.create_prompt( + prompt="Template with {{ var1 }} and {{ var2 }}", + variables=["var1"], # Missing var2 + ) + + async def test_prompt_not_found(self, mlflow_adapter): + """Test error handling when prompt doesn't exist.""" + fake_id = "pmpt_" + "0" * 48 + + with pytest.raises(ValueError, match="not found"): + await mlflow_adapter.get_prompt(fake_id) + + async def test_version_not_found(self, mlflow_adapter): + """Test error handling when version doesn't exist.""" + # Create prompt (version 1) + created = await mlflow_adapter.create_prompt( + prompt="Test {{ var }}", + variables=["var"], + ) + + # Try to get non-existent version + with pytest.raises(ValueError, match="not found"): + await mlflow_adapter.get_prompt(created.prompt_id, version=999) + + async def test_update_wrong_version(self, mlflow_adapter): + """Test that updating with wrong version fails.""" + # Create prompt (version 1) + created = await mlflow_adapter.create_prompt( + prompt="Test {{ var }}", + variables=["var"], + ) + + # Try to update with wrong version number + with pytest.raises(ValueError, match="not the latest"): + await mlflow_adapter.update_prompt( + prompt_id=created.prompt_id, + prompt="Updated {{ var }}", + version=999, # Wrong version + variables=["var"], + ) + + async def test_delete_not_supported(self, mlflow_adapter): + """Test that deletion raises NotImplementedError.""" + # Create prompt + created = await mlflow_adapter.create_prompt( + prompt="Test {{ var }}", + variables=["var"], + ) + + # Try to delete (should fail with NotImplementedError) + with pytest.raises(NotImplementedError, match="does not support deletion"): + await mlflow_adapter.delete_prompt(created.prompt_id) + + # Verify prompt still exists + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + assert retrieved.prompt_id == created.prompt_id + + async def test_complex_template_with_multiple_variables(self, mlflow_adapter): + """Test prompt with complex template and multiple variables.""" + template = """You are a {{ role }} assistant specialized in {{ domain }}. + +Task: {{ task }} + +Context: +{{ context }} + +Instructions: +1. {{ instruction1 }} +2. {{ instruction2 }} +3. {{ instruction3 }} + +Output format: {{ output_format }} +""" + + # Create with auto-extraction + created = await mlflow_adapter.create_prompt(prompt=template) + + # Should extract all variables + expected_vars = { + "role", + "domain", + "task", + "context", + "instruction1", + "instruction2", + "instruction3", + "output_format", + } + assert set(created.variables) == expected_vars + + # Retrieve and verify template preserved + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + assert retrieved.prompt == template + + async def test_empty_template(self, mlflow_adapter): + """Test handling of empty template.""" + # Create prompt with empty template + created = await mlflow_adapter.create_prompt( + prompt="", + variables=[], + ) + + assert created.prompt == "" + assert created.variables == [] + + # Retrieve and verify + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + assert retrieved.prompt == "" + + async def test_template_with_no_variables(self, mlflow_adapter): + """Test template without any variables.""" + template = "This is a static prompt with no variables." + + created = await mlflow_adapter.create_prompt(prompt=template) + + assert created.prompt == template + assert created.variables == [] + + # Retrieve and verify + retrieved = await mlflow_adapter.get_prompt(created.prompt_id) + assert retrieved.prompt == template + assert retrieved.variables == [] diff --git a/tests/unit/providers/remote/prompts/__init__.py b/tests/unit/providers/remote/prompts/__init__.py new file mode 100644 index 0000000000..10005fc374 --- /dev/null +++ b/tests/unit/providers/remote/prompts/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for remote prompts providers.""" diff --git a/tests/unit/providers/remote/prompts/mlflow/__init__.py b/tests/unit/providers/remote/prompts/mlflow/__init__.py new file mode 100644 index 0000000000..60948019ea --- /dev/null +++ b/tests/unit/providers/remote/prompts/mlflow/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for MLflow prompts provider.""" diff --git a/tests/unit/providers/remote/prompts/mlflow/test_config.py b/tests/unit/providers/remote/prompts/mlflow/test_config.py new file mode 100644 index 0000000000..4724a5489d --- /dev/null +++ b/tests/unit/providers/remote/prompts/mlflow/test_config.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for MLflow prompts provider configuration.""" + +import pytest +from pydantic import SecretStr, ValidationError + +from llama_stack.providers.remote.prompts.mlflow.config import ( + MLflowPromptsConfig, + MLflowProviderDataValidator, +) + + +class TestMLflowPromptsConfig: + """Tests for MLflowPromptsConfig model.""" + + def test_default_config(self): + """Test default configuration values.""" + config = MLflowPromptsConfig() + + assert config.mlflow_tracking_uri == "http://localhost:5000" + assert config.mlflow_registry_uri is None + assert config.experiment_name == "llama-stack-prompts" + assert config.auth_credential is None + assert config.timeout_seconds == 30 + + def test_custom_config(self): + """Test custom configuration values.""" + config = MLflowPromptsConfig( + mlflow_tracking_uri="http://mlflow.example.com:8080", + mlflow_registry_uri="http://registry.example.com:8080", + experiment_name="my-prompts", + auth_credential=SecretStr("my-token"), + timeout_seconds=60, + ) + + assert config.mlflow_tracking_uri == "http://mlflow.example.com:8080" + assert config.mlflow_registry_uri == "http://registry.example.com:8080" + assert config.experiment_name == "my-prompts" + assert config.auth_credential.get_secret_value() == "my-token" + assert config.timeout_seconds == 60 + + def test_databricks_uri(self): + """Test Databricks URI configuration.""" + config = MLflowPromptsConfig( + mlflow_tracking_uri="databricks", + mlflow_registry_uri="databricks://profile", + ) + + assert config.mlflow_tracking_uri == "databricks" + assert config.mlflow_registry_uri == "databricks://profile" + + def test_tracking_uri_validation(self): + """Test tracking URI validation.""" + # Empty string rejected + with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"): + MLflowPromptsConfig(mlflow_tracking_uri="") + + # Whitespace-only rejected + with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"): + MLflowPromptsConfig(mlflow_tracking_uri=" ") + + # Whitespace is stripped + config = MLflowPromptsConfig(mlflow_tracking_uri=" http://localhost:5000 ") + assert config.mlflow_tracking_uri == "http://localhost:5000" + + def test_experiment_name_validation(self): + """Test experiment name validation.""" + # Empty string rejected + with pytest.raises(ValidationError, match="experiment_name cannot be empty"): + MLflowPromptsConfig(experiment_name="") + + # Whitespace-only rejected + with pytest.raises(ValidationError, match="experiment_name cannot be empty"): + MLflowPromptsConfig(experiment_name=" ") + + # Whitespace is stripped + config = MLflowPromptsConfig(experiment_name=" my-experiment ") + assert config.experiment_name == "my-experiment" + + def test_timeout_validation(self): + """Test timeout range validation.""" + # Too low rejected + with pytest.raises(ValidationError): + MLflowPromptsConfig(timeout_seconds=0) + + with pytest.raises(ValidationError): + MLflowPromptsConfig(timeout_seconds=-1) + + # Too high rejected + with pytest.raises(ValidationError): + MLflowPromptsConfig(timeout_seconds=301) + + # Boundary values accepted + config_min = MLflowPromptsConfig(timeout_seconds=1) + assert config_min.timeout_seconds == 1 + + config_max = MLflowPromptsConfig(timeout_seconds=300) + assert config_max.timeout_seconds == 300 + + def test_sample_run_config(self): + """Test sample_run_config generates valid configuration.""" + # Default environment variable + sample = MLflowPromptsConfig.sample_run_config() + assert sample["mlflow_tracking_uri"] == "http://localhost:5000" + assert sample["experiment_name"] == "llama-stack-prompts" + assert sample["auth_credential"] == "${env.MLFLOW_TRACKING_TOKEN:=}" + + # Custom values + sample = MLflowPromptsConfig.sample_run_config( + mlflow_api_token="test-token", + mlflow_tracking_uri="http://custom:5000", + ) + assert sample["mlflow_tracking_uri"] == "http://custom:5000" + assert sample["auth_credential"] == "test-token" + + +class TestMLflowProviderDataValidator: + """Tests for MLflowProviderDataValidator.""" + + def test_provider_data_validator(self): + """Test provider data validator with and without token.""" + # With token + validator = MLflowProviderDataValidator(mlflow_api_token="test-token-123") + assert validator.mlflow_api_token == "test-token-123" + + # Without token + validator = MLflowProviderDataValidator() + assert validator.mlflow_api_token is None + + # From dictionary + data = {"mlflow_api_token": "secret-token"} + validator = MLflowProviderDataValidator(**data) + assert validator.mlflow_api_token == "secret-token" diff --git a/tests/unit/providers/remote/prompts/mlflow/test_mapping.py b/tests/unit/providers/remote/prompts/mlflow/test_mapping.py new file mode 100644 index 0000000000..fd0ef61f58 --- /dev/null +++ b/tests/unit/providers/remote/prompts/mlflow/test_mapping.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for MLflow prompts ID mapping utilities.""" + +import pytest + +from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper + + +class TestPromptIDMapper: + """Tests for PromptIDMapper class.""" + + @pytest.fixture + def mapper(self): + """Create ID mapper instance.""" + return PromptIDMapper() + + def test_to_mlflow_name_valid(self, mapper): + """Test converting valid prompt_id to MLflow name.""" + prompt_id = "pmpt_" + "a" * 48 + mlflow_name = mapper.to_mlflow_name(prompt_id) + + assert mlflow_name == "llama_prompt_" + "a" * 48 + assert mlflow_name.startswith(mapper.MLFLOW_NAME_PREFIX) + + def test_to_mlflow_name_invalid(self, mapper): + """Test conversion fails with invalid inputs.""" + # Invalid prefix + with pytest.raises(ValueError, match="Invalid prompt_id format"): + mapper.to_mlflow_name("invalid_" + "a" * 48) + + # Wrong length + with pytest.raises(ValueError, match="Invalid prompt_id format"): + mapper.to_mlflow_name("pmpt_" + "a" * 47) + + # Invalid hex characters + with pytest.raises(ValueError, match="Invalid prompt_id format"): + mapper.to_mlflow_name("pmpt_" + "g" * 48) + + def test_to_llama_id_valid(self, mapper): + """Test converting valid MLflow name to prompt_id.""" + mlflow_name = "llama_prompt_" + "b" * 48 + prompt_id = mapper.to_llama_id(mlflow_name) + + assert prompt_id == "pmpt_" + "b" * 48 + assert prompt_id.startswith("pmpt_") + + def test_to_llama_id_invalid(self, mapper): + """Test conversion fails with invalid inputs.""" + # Invalid prefix + with pytest.raises(ValueError, match="does not start with expected prefix"): + mapper.to_llama_id("wrong_prefix_" + "a" * 48) + + # Wrong length + with pytest.raises(ValueError, match="Invalid hex part length"): + mapper.to_llama_id("llama_prompt_" + "a" * 47) + + # Invalid hex characters + with pytest.raises(ValueError, match="Invalid character"): + mapper.to_llama_id("llama_prompt_" + "G" * 48) + + def test_bidirectional_conversion(self, mapper): + """Test bidirectional conversion preserves IDs.""" + original_id = "pmpt_0123456789abcdef" + "a" * 32 + + # Convert to MLflow name and back + mlflow_name = mapper.to_mlflow_name(original_id) + recovered_id = mapper.to_llama_id(mlflow_name) + + assert recovered_id == original_id + + def test_get_metadata_tags_with_variables(self, mapper): + """Test metadata tags generation with variables.""" + prompt_id = "pmpt_" + "c" * 48 + variables = ["var1", "var2", "var3"] + + tags = mapper.get_metadata_tags(prompt_id, variables) + + assert tags["llama_stack_id"] == prompt_id + assert tags["llama_stack_managed"] == "true" + assert tags["variables"] == "var1,var2,var3" + + def test_get_metadata_tags_without_variables(self, mapper): + """Test metadata tags generation without variables.""" + prompt_id = "pmpt_" + "d" * 48 + + tags = mapper.get_metadata_tags(prompt_id) + + assert tags["llama_stack_id"] == prompt_id + assert tags["llama_stack_managed"] == "true" + assert "variables" not in tags