Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/providers/inference/remote_watsonx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
| `project_id` | `str \| None` | No | | The Project ID key |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |

## Sample Configuration
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
completion_text += "".join(choice_data["content_parts"])

# Add metrics to the chunk
if self.telemetry and chunk.usage:
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks unrelated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a bug around this. I have seen a few PRs which introduced the same logic: #3392, #3422

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, my tests for watsonx.ai wouldn't succeed without this fix (as noted in the PR description). I am fine with letting some other PR put the fix in, but I think this PR should probably wait until that one is in if that's the plan.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does watsonx return usage information? we need to fix/adapt in the adapter, not in the core. putting this in the core obscures the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was already handling the case where there was no usage information by checking chunk.usage. So I would argue this change is just doing what the line was already doing but in a more robust way. FWIW, the watsonx.ai API does return a usage block, e.g.:

	"usage": {
		"completion_tokens": 54,
		"prompt_tokens": 79,
		"total_tokens": 133
	},

Notably though, it doesn't put the usage information on every chunk. It only puts it on the final chunk in the stream. I can see that when I call the streaming REST API directly. However, I don't see it when I call LiteLLM directly with streaming=True. I do see it when I call LiteLLM without streaming, FWIW. So I think LiteLLM might be dropping the usage information from the last chunk. Here is how I tested this in a notebook:

import litellm, asyncio
import nest_asyncio
nest_asyncio.apply()

async def get_litellm_response():
    return await litellm.acompletion(
        model="watsonx/meta-llama/llama-3-3-70b-instruct",
        messages=[{"role": "user", "content": "What is the capital of France?"}],
        stream=True
    )

async def print_litellm_response():
    response = await get_litellm_response()
    async for chunk in response:
        print(chunk)

asyncio.run(print_litellm_response())

With all that said, even if LiteLLM was correctly including this on the last chunk, we'd still have the issue that it is missing from all of the other chunks (unless LiteLLM put in an explicit None for this field for each other chunk). So I still think we should adopt this change here and let the line handle both missing AND explicit None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fair enough 😄

metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distributions/watsonx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .watsonx import get_distribution_template # noqa: F401
41 changes: 15 additions & 26 deletions llama_stack/distributions/watsonx/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
- provider_type: remote::watsonx
- provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_type: inline::faiss
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_type: inline::llama-guard
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- aiosqlite
- sqlalchemy[asyncio]
103 changes: 3 additions & 100 deletions llama_stack/distributions/watsonx/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
- files
providers:
inference:
- provider_id: watsonx
Expand All @@ -19,8 +19,6 @@ providers:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
Expand Down Expand Up @@ -48,7 +46,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
Expand Down Expand Up @@ -109,102 +107,7 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-2-13b-chat
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-2-13b
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-8b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-11b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
models: []
shields: []
vector_dbs: []
datasets: []
Expand Down
36 changes: 5 additions & 31 deletions llama_stack/distributions/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pathlib import Path

from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES


def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
Expand Down Expand Up @@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(),
)

embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)

available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
Expand All @@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
),
]

embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)

files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"inference": [inference_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_models=[],
default_tool_groups=default_tool_groups,
),
},
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
Expand Down
11 changes: 2 additions & 9 deletions llama_stack/providers/remote/inference/watsonx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.inference import Inference

from .config import WatsonXConfig


async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter

if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter


__all__ = ["get_adapter_impl", "WatsonXConfig"]
22 changes: 14 additions & 8 deletions llama_stack/providers/remote/inference/watsonx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import os
from typing import Any

from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, ConfigDict, Field, SecretStr

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type


class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
model_config = ConfigDict(
from_attributes=True,
extra="forbid",
)
watsonx_api_key: str | None


@json_schema_type
Expand All @@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key",
default=None,
description="The watsonx.ai API key",
)
# As above, this is optional here too for consistency.
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key",
default=None,
description="The watsonx.ai project ID",
)
timeout: int = Field(
default=60,
Expand Down
Loading
Loading