Skip to content

Commit 95015f1

Browse files
committed
feat: Add custom parameter support to model config
1 parent a2cc966 commit 95015f1

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

ldai/client.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ class ModelConfig:
2121
Configuration related to the model.
2222
"""
2323

24-
def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None):
24+
def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
2525
"""
2626
:param id: The ID of the model.
2727
:param parameters: Additional model-specific parameters.
28+
:param custom: Additional customer provided data.
2829
"""
2930
self._id = id
3031
self._parameters = parameters
32+
self._custom = custom
3133

3234
@property
3335
def id(self) -> str:
@@ -51,6 +53,15 @@ def get_parameter(self, key: str) -> Any:
5153

5254
return self._parameters.get(key)
5355

56+
def get_custom(self, key: str) -> Any:
57+
"""
58+
Retrieve customer provided data.
59+
"""
60+
if self._custom is None:
61+
return None
62+
63+
return self._custom.get(key)
64+
5465

5566
class ProviderConfig:
5667
"""
@@ -128,9 +139,11 @@ def config(
128139
model = None
129140
if 'model' in variation and isinstance(variation['model'], dict):
130141
parameters = variation['model'].get('parameters', None)
142+
custom = variation['model'].get('custom', None)
131143
model = ModelConfig(
132144
id=variation['model']['id'],
133-
parameters=parameters
145+
parameters=parameters,
146+
custom=custom
134147
)
135148

136149
enabled = variation.get('_ldMeta', {}).get('enabled', False)

ldai/testing/test_model_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def td() -> TestData:
1313
td.flag('model-config')
1414
.variations(
1515
{
16-
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}},
16+
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}},
1717
'provider': {'id': 'fakeProvider'},
1818
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
1919
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
@@ -117,6 +117,14 @@ def test_model_config_delegates_to_properties():
117117
assert model.id == model.get_parameter('id')
118118

119119

120+
def test_model_config_handles_custom():
121+
model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'})
122+
assert model.id == 'fakeModel'
123+
assert model.get_parameter('extra-attribute') is None
124+
assert model.get_custom('non-existent') is None
125+
assert model.get_custom('id') is None
126+
127+
120128
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
121129
context = Context.create('user-key')
122130
default_value = AIConfig(

0 commit comments

Comments
 (0)