Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions kalshi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def _coerce_decimal(value: Any) -> Decimal:
"""
if isinstance(value, Decimal):
return value
if isinstance(value, bool):
raise TypeError(
"Cannot convert bool to Decimal — bool is an int subclass, "
"so this is almost always a typo (did you mean count=1?)."
)
if isinstance(value, (int, float)):
return Decimal(str(value))
if isinstance(value, str):
Expand All @@ -30,13 +35,13 @@ def _coerce_decimal(value: Any) -> Decimal:

def _decimal_to_str(value: Decimal) -> str:
"""Serialize Decimal back to string for API requests."""
return str(value)
return f"{value:f}"


DollarDecimal = Annotated[
Decimal,
BeforeValidator(_coerce_decimal),
PlainSerializer(_decimal_to_str, return_type=str),
PlainSerializer(_decimal_to_str, return_type=str, when_used="json"),
]
"""A Decimal field that handles bidirectional conversion for Kalshi dollar values.

Expand All @@ -52,7 +57,7 @@ def _decimal_to_str(value: Decimal) -> str:
FixedPointCount = Annotated[
Decimal,
BeforeValidator(_coerce_decimal),
PlainSerializer(_decimal_to_str, return_type=str),
PlainSerializer(_decimal_to_str, return_type=str, when_used="json"),
]
"""A Decimal field that handles bidirectional conversion for Kalshi count/volume values.

Expand All @@ -67,6 +72,11 @@ def to_decimal(value: int | float | str | Decimal) -> Decimal:
Always goes through str() to avoid float representation errors.
e.g., to_decimal(0.65) returns Decimal("0.65"), not Decimal(0.65).
"""
if isinstance(value, bool):
raise TypeError(
"Cannot convert bool to Decimal — bool is an int subclass, "
"so this is almost always a typo (did you mean count=1?)."
)
if isinstance(value, Decimal):
return value
return Decimal(str(value))
Expand Down
7 changes: 4 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def test_market_parses_int_price(self) -> None:

def test_market_model_dump_serializes(self) -> None:
m = Market.model_validate(market_dict(ticker="T", yes_ask_dollars="0.72"))
data = m.model_dump()
assert data["yes_ask"] == "0.72"
# mode='python' preserves Decimal; mode='json' produces wire string.
assert m.model_dump()["yes_ask"] == Decimal("0.72")
assert m.model_dump(mode="json")["yes_ask"] == "0.72"

def test_order_decimal_fields(self) -> None:
o = Order.model_validate(
Expand Down Expand Up @@ -222,7 +223,7 @@ def test_create_order_serializes_with_dollars_alias(self) -> None:
action="buy",
yes_price=Decimal("0.65"),
)
data = req.model_dump(exclude_none=True, by_alias=True)
data = req.model_dump(mode="json", exclude_none=True, by_alias=True)
assert "yes_price_dollars" in data
assert data["yes_price_dollars"] == "0.65"
assert "yes_price" not in data
Expand Down
66 changes: 66 additions & 0 deletions tests/test_page_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,69 @@ def test_top_level_decimal_preserved_alongside_nested(self) -> None:

assert isinstance(df["price"].dtype, pl.Decimal)
assert df["price"].to_list() == [Decimal("0.55"), Decimal("0.42")]


# ---------------------------------------------------------------------------
# Issue #190: DollarDecimal-typed columns must land in the DataFrame as live
# Decimal objects, not strings. Before when_used="json" was set, the
# PlainSerializer fired on mode="python" too and produced str — silently
# breaking df["price"].sum() (string concat) and any numeric reduction.
# ---------------------------------------------------------------------------

from kalshi.types import DollarDecimal # noqa: E402


class _DollarRow(BaseModel):
price: DollarDecimal


def _dollar_items(*values: str) -> list[_DollarRow]:
return [_DollarRow(price=Decimal(v)) for v in values] # type: ignore[arg-type]


class TestPageDataframeDollarDecimal:
def test_page_to_dataframe_preserves_decimal_objects(self) -> None:
pytest.importorskip("pandas")
page: Page[_DollarRow] = Page(items=_dollar_items("0.5600"), cursor=None)

df = page.to_dataframe()

assert isinstance(df["price"].iloc[0], Decimal)
assert df["price"].iloc[0] == Decimal("0.5600")

def test_page_to_dataframe_decimal_sum_is_decimal_not_concatenated_strings(self) -> None:
pytest.importorskip("pandas")
page: Page[_DollarRow] = Page(
items=_dollar_items("0.5600", "0.5600"), cursor=None
)

df = page.to_dataframe()
total = df["price"].sum()

assert total == Decimal("1.1200")
assert not isinstance(total, str)


class TestPageToPolarsDollarDecimal:
def test_page_to_polars_preserves_decimal_objects(self) -> None:
pl = pytest.importorskip("polars")
page: Page[_DollarRow] = Page(items=_dollar_items("0.5600"), cursor=None)

df = page.to_polars()

# polars infers a Decimal column from the Python Decimal objects, not Utf8.
assert isinstance(df["price"].dtype, pl.Decimal)
assert df["price"].to_list()[0] == Decimal("0.5600")

def test_page_to_polars_decimal_sum_is_decimal(self) -> None:
pytest.importorskip("polars")
page: Page[_DollarRow] = Page(
items=_dollar_items("0.5600", "0.5600"), cursor=None
)

df = page.to_polars()
total = df["price"].sum()

# polars Decimal.sum() returns a Python Decimal — never a string concat.
assert isinstance(total, Decimal)
assert total == Decimal("1.1200")
59 changes: 58 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for DollarDecimal / FixedPointCount type-fallback branches in kalshi.types."""
from __future__ import annotations

from decimal import Decimal

import pytest
from pydantic import BaseModel

from kalshi.types import DollarDecimal, FixedPointCount
from kalshi.types import DollarDecimal, FixedPointCount, _decimal_to_str, to_decimal


class _DollarModel(BaseModel):
Expand Down Expand Up @@ -33,3 +35,58 @@ def test_list_input_raises_type_error_with_named_type(self) -> None:
def test_dict_input_raises_type_error_with_named_type(self) -> None:
with pytest.raises(TypeError, match="Cannot convert dict to Decimal"):
_CountModel.model_validate({"x": {"nested": "value"}})

class TestDecimalToStrPositional:
def test_decimal_to_str_positional_for_large_exp(self) -> None:
assert _decimal_to_str(Decimal("1E+10")) == "10000000000"

def test_decimal_to_str_positional_for_small_exp(self) -> None:
assert _decimal_to_str(Decimal("1E-7")) == "0.0000001"

def test_decimal_to_str_preserves_trailing_zero(self) -> None:
assert _decimal_to_str(Decimal("0.5600")) == "0.5600"


class TestCoerceDecimalRejectsBool:
def test_coerce_decimal_rejects_bool_true(self) -> None:
with pytest.raises(TypeError, match="bool"):
_DollarModel.model_validate({"x": True})

def test_coerce_decimal_rejects_bool_false(self) -> None:
with pytest.raises(TypeError, match="bool"):
_DollarModel.model_validate({"x": False})

def test_to_decimal_rejects_bool(self) -> None:
with pytest.raises(TypeError, match="bool"):
to_decimal(True) # type: ignore[arg-type]


class TestDollarDecimalDumpMode:
def test_dollar_decimal_model_dump_python_returns_decimal_not_str(self) -> None:
class M(BaseModel):
price: DollarDecimal

m = M(price=Decimal("0.5600")) # type: ignore[arg-type]
result = m.model_dump(mode="python")
assert isinstance(result["price"], Decimal)
assert result["price"] == Decimal("0.5600")

def test_dollar_decimal_model_dump_json_returns_str(self) -> None:
class M(BaseModel):
price: DollarDecimal

m = M(price=Decimal("1E+10")) # type: ignore[arg-type]
result = m.model_dump(mode="json")
assert isinstance(result["price"], str)
assert result["price"] == "10000000000"


class TestFixedPointCountDumpMode:
def test_fixed_point_count_model_dump_python_returns_decimal_not_str(self) -> None:
class M(BaseModel):
count: FixedPointCount

m = M(count=Decimal("42")) # type: ignore[arg-type]
result = m.model_dump(mode="python")
assert isinstance(result["count"], Decimal)
assert result["count"] == Decimal("42")
Loading