|
5 | 5 | # the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import json
|
| 8 | +from collections.abc import Iterable |
8 | 9 | from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
9 | 10 |
|
10 | 11 | import pytest
|
@@ -498,6 +499,129 @@ def get_base_url(self):
|
498 | 499 | return "default-base-url"
|
499 | 500 |
|
500 | 501 |
|
| 502 | +class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): |
| 503 | + """Test implementation with custom list_provider_model_ids override""" |
| 504 | + |
| 505 | + def __init__(self, custom_model_ids): |
| 506 | + self._custom_model_ids = custom_model_ids |
| 507 | + |
| 508 | + async def list_provider_model_ids(self) -> Iterable[str]: |
| 509 | + """Return custom model IDs list""" |
| 510 | + return self._custom_model_ids |
| 511 | + |
| 512 | + |
| 513 | +class TestOpenAIMixinCustomListProviderModelIds: |
| 514 | + """Test cases for custom list_provider_model_ids() implementation functionality""" |
| 515 | + |
| 516 | + @pytest.fixture |
| 517 | + def custom_model_ids_list(self): |
| 518 | + """Create a list of custom model ID strings""" |
| 519 | + return ["custom-model-1", "custom-model-2", "custom-embedding"] |
| 520 | + |
| 521 | + @pytest.fixture |
| 522 | + def adapter(self, custom_model_ids_list): |
| 523 | + """Create mixin instance with custom list_provider_model_ids implementation""" |
| 524 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list) |
| 525 | + mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} |
| 526 | + return mixin |
| 527 | + |
| 528 | + async def test_is_used(self, adapter, custom_model_ids_list): |
| 529 | + """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" |
| 530 | + result = await adapter.list_models() |
| 531 | + |
| 532 | + assert result is not None |
| 533 | + assert len(result) == 3 |
| 534 | + |
| 535 | + assert set(custom_model_ids_list) == {m.identifier for m in result} |
| 536 | + |
| 537 | + async def test_populates_cache(self, adapter, custom_model_ids_list): |
| 538 | + """Test that custom list_provider_model_ids() results are cached""" |
| 539 | + assert len(adapter._model_cache) == 0 |
| 540 | + |
| 541 | + await adapter.list_models() |
| 542 | + |
| 543 | + assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) |
| 544 | + |
| 545 | + async def test_respects_allowed_models(self): |
| 546 | + """Test that custom list_provider_model_ids() respects allowed_models filtering""" |
| 547 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"]) |
| 548 | + mixin.allowed_models = ["model-1"] |
| 549 | + |
| 550 | + result = await mixin.list_models() |
| 551 | + |
| 552 | + assert result is not None |
| 553 | + assert len(result) == 1 |
| 554 | + assert result[0].identifier == "model-1" |
| 555 | + |
| 556 | + async def test_with_empty_list(self): |
| 557 | + """Test that custom list_provider_model_ids() handles empty list correctly""" |
| 558 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[]) |
| 559 | + |
| 560 | + result = await mixin.list_models() |
| 561 | + |
| 562 | + assert result is not None |
| 563 | + assert len(result) == 0 |
| 564 | + assert len(mixin._model_cache) == 0 |
| 565 | + |
| 566 | + async def test_wrong_type_raises_error(self): |
| 567 | + """Test that list_provider_model_ids() returning unhashable items results in an error""" |
| 568 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}]) |
| 569 | + |
| 570 | + with pytest.raises(TypeError, match="unhashable type"): |
| 571 | + await mixin.list_models() |
| 572 | + |
| 573 | + async def test_non_iterable_raises_error(self): |
| 574 | + """Test that list_provider_model_ids() returning non-iterable type raises error""" |
| 575 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42) |
| 576 | + |
| 577 | + with pytest.raises( |
| 578 | + TypeError, |
| 579 | + match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", |
| 580 | + ): |
| 581 | + await mixin.list_models() |
| 582 | + |
| 583 | + async def test_with_none_items_raises_error(self): |
| 584 | + """Test that list_provider_model_ids() returning list with None items causes error""" |
| 585 | + mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None]) |
| 586 | + |
| 587 | + with pytest.raises(Exception, match="Input should be a valid string"): |
| 588 | + await mixin.list_models() |
| 589 | + |
| 590 | + async def test_accepts_various_iterables(self): |
| 591 | + """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" |
| 592 | + |
| 593 | + class TupleAdapter(OpenAIMixinImpl): |
| 594 | + async def list_provider_model_ids(self) -> Iterable[str] | None: |
| 595 | + return ("model-1", "model-2", "model-3") |
| 596 | + |
| 597 | + mixin = TupleAdapter() |
| 598 | + result = await mixin.list_models() |
| 599 | + assert result is not None |
| 600 | + assert len(result) == 3 |
| 601 | + |
| 602 | + class GeneratorAdapter(OpenAIMixinImpl): |
| 603 | + async def list_provider_model_ids(self) -> Iterable[str] | None: |
| 604 | + def gen(): |
| 605 | + yield "gen-model-1" |
| 606 | + yield "gen-model-2" |
| 607 | + |
| 608 | + return gen() |
| 609 | + |
| 610 | + mixin = GeneratorAdapter() |
| 611 | + result = await mixin.list_models() |
| 612 | + assert result is not None |
| 613 | + assert len(result) == 2 |
| 614 | + |
| 615 | + class SetAdapter(OpenAIMixinImpl): |
| 616 | + async def list_provider_model_ids(self) -> Iterable[str] | None: |
| 617 | + return {"set-model-1", "set-model-2"} |
| 618 | + |
| 619 | + mixin = SetAdapter() |
| 620 | + result = await mixin.list_models() |
| 621 | + assert result is not None |
| 622 | + assert len(result) == 2 |
| 623 | + |
| 624 | + |
501 | 625 | class TestOpenAIMixinProviderDataApiKey:
|
502 | 626 | """Test cases for provider_data_api_key_field functionality"""
|
503 | 627 |
|
|
0 commit comments