Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path
import ssl
import time
from decimal import Decimal
from io import BytesIO
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -543,8 +544,16 @@ async def get_program_price(self, item_hash: str) -> PriceResponse:
try:
resp.raise_for_status()
response_json = await resp.json()
token_value = response_json.get(
"cost", response_json.get("required_tokens")
)
if isinstance(token_value, str):
required_tokens = Decimal(token_value)
else:
required_tokens = Decimal(str(token_value))

return PriceResponse(
required_tokens=response_json["required_tokens"],
required_tokens=required_tokens,
payment_type=response_json["payment_type"],
)
except aiohttp.ClientResponseError as e:
Expand Down
3 changes: 1 addition & 2 deletions src/aleph/sdk/client/services/crn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ async def get_crns_list(self, only_active: bool = True) -> CrnList:
dict
The parsed JSON response from /crns.json.
"""
# We want filter_inactive = (not only_active)
# Convert bool to string for the query parameter
filter_inactive_str = str(not only_active).lower()
filter_inactive_str = str(only_active).lower()
params = {"filter_inactive": filter_inactive_str}

# Create a new session for external domain requests
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/sdk/query/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MessagesResponse(PaginationResponse):
class PriceResponse(BaseModel):
"""Response from an aleph.im node API on the path /api/v0/price/{item_hash}"""

required_tokens: float
required_tokens: Decimal
payment_type: str


Expand Down
2 changes: 1 addition & 1 deletion src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class Ports(BaseModel):
ports: Dict[int, PortFlags]


AllForwarders = RootModel[Dict[ItemHash, Ports]]
AllForwarders = RootModel[Dict[ItemHash, Optional[Ports]]]


class DictLikeModel(BaseModel):
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_price.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from decimal import Decimal

import pytest

from aleph.sdk.exceptions import InvalidHashError
Expand All @@ -21,6 +23,51 @@ async def test_get_program_price_valid():
assert response == expected


@pytest.mark.asyncio
async def test_get_program_price_cost_and_required_token():
"""
Test that the get_program_price method returns the correct PriceResponse
when
1 ) cost & required_token is here (priority to cost) who is a string that convert to decimal
2 ) When only required_token is here who is a float that now would be to be convert to decimal
"""
# Case 1
expected = {
"required_tokens": 0.001527777777777778,
"cost": "0.001527777777777777",
"payment_type": "credit",
}

# Case 2
expected_old = {
"required_tokens": 0.001527777777777778,
"payment_type": "credit",
}

# Expected model using the cost field as the source of truth
expected_model = PriceResponse(
required_tokens=Decimal("0.001527777777777777"),
payment_type=expected["payment_type"],
)

# Expected model for the old format
expected_model_old = PriceResponse(
required_tokens=Decimal(str(expected_old["required_tokens"])),
payment_type=expected_old["payment_type"],
)

mock_session = make_mock_get_session(expected)
mock_session_old = make_mock_get_session(expected_old)

async with mock_session:
response = await mock_session.get_program_price("cacacacacacaca")
assert response == expected_model

async with mock_session_old:
response = await mock_session_old.get_program_price("cacacacacacaca")
assert response == expected_model_old


@pytest.mark.asyncio
async def test_get_program_price_invalid():
"""
Expand Down
Loading