diff --git a/kalshi/types.py b/kalshi/types.py index 0f87c6c..281815e 100644 --- a/kalshi/types.py +++ b/kalshi/types.py @@ -21,6 +21,11 @@ def _coerce_decimal(value: Any) -> Decimal: """ if isinstance(value, Decimal): return value + if isinstance(value, bool): + raise TypeError( + "Cannot convert bool to Decimal — bool is an int subclass, " + "so this is almost always a typo (did you mean count=1?)." + ) if isinstance(value, (int, float)): return Decimal(str(value)) if isinstance(value, str): @@ -30,13 +35,13 @@ def _coerce_decimal(value: Any) -> Decimal: def _decimal_to_str(value: Decimal) -> str: """Serialize Decimal back to string for API requests.""" - return str(value) + return f"{value:f}" DollarDecimal = Annotated[ Decimal, BeforeValidator(_coerce_decimal), - PlainSerializer(_decimal_to_str, return_type=str), + PlainSerializer(_decimal_to_str, return_type=str, when_used="json"), ] """A Decimal field that handles bidirectional conversion for Kalshi dollar values. @@ -52,7 +57,7 @@ def _decimal_to_str(value: Decimal) -> str: FixedPointCount = Annotated[ Decimal, BeforeValidator(_coerce_decimal), - PlainSerializer(_decimal_to_str, return_type=str), + PlainSerializer(_decimal_to_str, return_type=str, when_used="json"), ] """A Decimal field that handles bidirectional conversion for Kalshi count/volume values. @@ -67,6 +72,11 @@ def to_decimal(value: int | float | str | Decimal) -> Decimal: Always goes through str() to avoid float representation errors. e.g., to_decimal(0.65) returns Decimal("0.65"), not Decimal(0.65). """ + if isinstance(value, bool): + raise TypeError( + "Cannot convert bool to Decimal — bool is an int subclass, " + "so this is almost always a typo (did you mean count=1?)." + ) if isinstance(value, Decimal): return value return Decimal(str(value)) diff --git a/tests/test_models.py b/tests/test_models.py index 989378a..b9e805b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -79,8 +79,9 @@ def test_market_parses_int_price(self) -> None: def test_market_model_dump_serializes(self) -> None: m = Market.model_validate(market_dict(ticker="T", yes_ask_dollars="0.72")) - data = m.model_dump() - assert data["yes_ask"] == "0.72" + # mode='python' preserves Decimal; mode='json' produces wire string. + assert m.model_dump()["yes_ask"] == Decimal("0.72") + assert m.model_dump(mode="json")["yes_ask"] == "0.72" def test_order_decimal_fields(self) -> None: o = Order.model_validate( @@ -222,7 +223,7 @@ def test_create_order_serializes_with_dollars_alias(self) -> None: action="buy", yes_price=Decimal("0.65"), ) - data = req.model_dump(exclude_none=True, by_alias=True) + data = req.model_dump(mode="json", exclude_none=True, by_alias=True) assert "yes_price_dollars" in data assert data["yes_price_dollars"] == "0.65" assert "yes_price" not in data diff --git a/tests/test_page_dataframe.py b/tests/test_page_dataframe.py index 2b551aa..c7ab96f 100644 --- a/tests/test_page_dataframe.py +++ b/tests/test_page_dataframe.py @@ -255,3 +255,69 @@ def test_top_level_decimal_preserved_alongside_nested(self) -> None: assert isinstance(df["price"].dtype, pl.Decimal) assert df["price"].to_list() == [Decimal("0.55"), Decimal("0.42")] + + +# --------------------------------------------------------------------------- +# Issue #190: DollarDecimal-typed columns must land in the DataFrame as live +# Decimal objects, not strings. Before when_used="json" was set, the +# PlainSerializer fired on mode="python" too and produced str — silently +# breaking df["price"].sum() (string concat) and any numeric reduction. +# --------------------------------------------------------------------------- + +from kalshi.types import DollarDecimal # noqa: E402 + + +class _DollarRow(BaseModel): + price: DollarDecimal + + +def _dollar_items(*values: str) -> list[_DollarRow]: + return [_DollarRow(price=Decimal(v)) for v in values] # type: ignore[arg-type] + + +class TestPageDataframeDollarDecimal: + def test_page_to_dataframe_preserves_decimal_objects(self) -> None: + pytest.importorskip("pandas") + page: Page[_DollarRow] = Page(items=_dollar_items("0.5600"), cursor=None) + + df = page.to_dataframe() + + assert isinstance(df["price"].iloc[0], Decimal) + assert df["price"].iloc[0] == Decimal("0.5600") + + def test_page_to_dataframe_decimal_sum_is_decimal_not_concatenated_strings(self) -> None: + pytest.importorskip("pandas") + page: Page[_DollarRow] = Page( + items=_dollar_items("0.5600", "0.5600"), cursor=None + ) + + df = page.to_dataframe() + total = df["price"].sum() + + assert total == Decimal("1.1200") + assert not isinstance(total, str) + + +class TestPageToPolarsDollarDecimal: + def test_page_to_polars_preserves_decimal_objects(self) -> None: + pl = pytest.importorskip("polars") + page: Page[_DollarRow] = Page(items=_dollar_items("0.5600"), cursor=None) + + df = page.to_polars() + + # polars infers a Decimal column from the Python Decimal objects, not Utf8. + assert isinstance(df["price"].dtype, pl.Decimal) + assert df["price"].to_list()[0] == Decimal("0.5600") + + def test_page_to_polars_decimal_sum_is_decimal(self) -> None: + pytest.importorskip("polars") + page: Page[_DollarRow] = Page( + items=_dollar_items("0.5600", "0.5600"), cursor=None + ) + + df = page.to_polars() + total = df["price"].sum() + + # polars Decimal.sum() returns a Python Decimal — never a string concat. + assert isinstance(total, Decimal) + assert total == Decimal("1.1200") diff --git a/tests/test_types.py b/tests/test_types.py index 8d234f3..c59024e 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,10 +1,12 @@ """Tests for DollarDecimal / FixedPointCount type-fallback branches in kalshi.types.""" from __future__ import annotations +from decimal import Decimal + import pytest from pydantic import BaseModel -from kalshi.types import DollarDecimal, FixedPointCount +from kalshi.types import DollarDecimal, FixedPointCount, _decimal_to_str, to_decimal class _DollarModel(BaseModel): @@ -33,3 +35,58 @@ def test_list_input_raises_type_error_with_named_type(self) -> None: def test_dict_input_raises_type_error_with_named_type(self) -> None: with pytest.raises(TypeError, match="Cannot convert dict to Decimal"): _CountModel.model_validate({"x": {"nested": "value"}}) + +class TestDecimalToStrPositional: + def test_decimal_to_str_positional_for_large_exp(self) -> None: + assert _decimal_to_str(Decimal("1E+10")) == "10000000000" + + def test_decimal_to_str_positional_for_small_exp(self) -> None: + assert _decimal_to_str(Decimal("1E-7")) == "0.0000001" + + def test_decimal_to_str_preserves_trailing_zero(self) -> None: + assert _decimal_to_str(Decimal("0.5600")) == "0.5600" + + +class TestCoerceDecimalRejectsBool: + def test_coerce_decimal_rejects_bool_true(self) -> None: + with pytest.raises(TypeError, match="bool"): + _DollarModel.model_validate({"x": True}) + + def test_coerce_decimal_rejects_bool_false(self) -> None: + with pytest.raises(TypeError, match="bool"): + _DollarModel.model_validate({"x": False}) + + def test_to_decimal_rejects_bool(self) -> None: + with pytest.raises(TypeError, match="bool"): + to_decimal(True) # type: ignore[arg-type] + + +class TestDollarDecimalDumpMode: + def test_dollar_decimal_model_dump_python_returns_decimal_not_str(self) -> None: + class M(BaseModel): + price: DollarDecimal + + m = M(price=Decimal("0.5600")) # type: ignore[arg-type] + result = m.model_dump(mode="python") + assert isinstance(result["price"], Decimal) + assert result["price"] == Decimal("0.5600") + + def test_dollar_decimal_model_dump_json_returns_str(self) -> None: + class M(BaseModel): + price: DollarDecimal + + m = M(price=Decimal("1E+10")) # type: ignore[arg-type] + result = m.model_dump(mode="json") + assert isinstance(result["price"], str) + assert result["price"] == "10000000000" + + +class TestFixedPointCountDumpMode: + def test_fixed_point_count_model_dump_python_returns_decimal_not_str(self) -> None: + class M(BaseModel): + count: FixedPointCount + + m = M(count=Decimal("42")) # type: ignore[arg-type] + result = m.model_dump(mode="python") + assert isinstance(result["count"], Decimal) + assert result["count"] == Decimal("42")