Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/marketdata/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,13 @@ class InvalidStatusDataError(BaseMarketdataException):
pass


class MinMaxDateValidationError(BaseMarketdataException):
class MinMaxValidationError(BaseMarketdataException):
pass


class MinMaxValueValidationError(MinMaxValidationError):
pass


class MinMaxDateValidationError(MinMaxValidationError):
pass
13 changes: 12 additions & 1 deletion src/marketdata/input_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, ConfigDict, Field, field_validator

from marketdata.exceptions import MinMaxDateValidationError
from marketdata.exceptions import MinMaxDateValidationError, MinMaxValueValidationError
from marketdata.utils import check_is_date

BaseModelConfig = ConfigDict(populate_by_name=True, frozen=False)
Expand All @@ -26,6 +26,17 @@ def _validate_min_max_dates(
f"{min_param} must be less than {max_param}"
)

def _validate_min_max_value(
self, min_param: str | None, max_param: str | None
) -> None:
min_value = getattr(self, min_param)
max_value = getattr(self, max_param)

if min_value is not None and max_value is not None and min_value > max_value:
raise MinMaxValueValidationError(
f"{min_param} must be less than or equal to {max_param}"
)


class OutputFormat(str, Enum):
DATAFRAME = "dataframe"
Expand Down
6 changes: 3 additions & 3 deletions src/marketdata/input_types/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def validate_expiration(

@model_validator(mode="after")
def validate_input(self) -> "OptionsChainInput":
params_typles = [
params_tuples = [
("min_bid", "max_bid"),
("min_ask", "max_ask"),
]
for min_param, max_param in params_typles:
self._validate_min_max_dates(min_param, max_param)
for min_param, max_param in params_tuples:
self._validate_min_max_value(min_param, max_param)
return self


Expand Down
63 changes: 62 additions & 1 deletion src/tests/test_input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from pydantic import Field, model_validator

import marketdata.input_types as input_types_pkg
from marketdata.exceptions import MinMaxDateValidationError
from marketdata.exceptions import (
MinMaxDateValidationError,
MinMaxValidationError,
MinMaxValueValidationError,
)
from marketdata.input_types.base import (
BaseInputType,
OutputFormat,
UserUniversalAPIParams,
)
from marketdata.input_types.options import OptionsChainInput
from marketdata.internal_settings import GLOBAL_EXCLUDED_PARAMS


Expand All @@ -26,6 +31,16 @@ def validate_input(self) -> "DummyInput":
return self


class DummyNumericInput(BaseInputType):
min_param: float | None = Field(default=None)
max_param: float | None = Field(default=None)

@model_validator(mode="after")
def validate_input(self) -> "DummyNumericInput":
self._validate_min_max_value("min_param", "max_param")
return self


def _all_input_models() -> list[type[BaseInputType]]:
"""Return every concrete BaseInputType subclass defined in the SDK.

Expand Down Expand Up @@ -95,6 +110,52 @@ def test_base_input_type_min_max_validation():
DummyInput(min_param="2025-01-01", max_param="2024-01-01")


def test_base_input_type_min_max_value_validation():
with pytest.raises(MinMaxValueValidationError):
DummyNumericInput(min_param=5.0, max_param=1.0)


def test_base_input_type_min_max_value_valid_range():
instance = DummyNumericInput(min_param=1.0, max_param=5.0)
assert instance.min_param == 1.0
assert instance.max_param == 5.0


def test_base_input_type_min_max_value_allows_none():
# Either bound missing -> no comparison, no error.
assert DummyNumericInput(min_param=5.0).max_param is None
assert DummyNumericInput(max_param=1.0).min_param is None
assert DummyNumericInput().min_param is None


def test_min_max_errors_share_common_base():
# Both specialized errors must be catchable as the common MinMaxValidationError.
assert issubclass(MinMaxDateValidationError, MinMaxValidationError)
assert issubclass(MinMaxValueValidationError, MinMaxValidationError)


@pytest.mark.parametrize(
"kwargs",
[
{"min_bid": 5.0, "max_bid": 1.0},
{"min_ask": 10.0, "max_ask": 2.0},
],
)
def test_options_chain_input_invalid_price_range(kwargs: dict):
with pytest.raises(MinMaxValueValidationError):
OptionsChainInput(symbol="AAPL", **kwargs)


def test_options_chain_input_valid_price_range():
instance = OptionsChainInput(
symbol="AAPL", min_bid=1.0, max_bid=5.0, min_ask=2.0, max_ask=10.0
)
assert instance.min_bid == 1.0
assert instance.max_bid == 5.0
assert instance.min_ask == 2.0
assert instance.max_ask == 10.0


def test_universal_api_params_api_format():
params = UserUniversalAPIParams(output_format=OutputFormat.DATAFRAME)
assert params.api_format == OutputFormat.JSON
Expand Down
Loading