diff --git a/api-usage.ipynb b/api-usage.ipynb index 07db467..074f800 100644 --- a/api-usage.ipynb +++ b/api-usage.ipynb @@ -71,9 +71,9 @@ "outputs": [], "source": [ "r = requests.post(\n", - " f\"{gas_station_api}/users\",\n", + " f\"{gas_station_api}/sponsors\",\n", " json = {\n", - " \"address\": alice.public_key_hash(),\n", + " \"tezos_address\": alice.public_key_hash(),\n", " \"name\": \"alice\"\n", " }\n", ")" @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "r = requests.get(f\"{gas_station_api}/users/{alice.public_key_hash()}\")" + "r = requests.get(f\"{gas_station_api}/sponsors/{alice.public_key_hash()}\")" ] }, { @@ -108,8 +108,8 @@ { "data": { "text/plain": [ - "{'address': 'tz1VSUr8wwNhLAzempoch5d6hLRiTh8Cjcjb',\n", - " 'id': '06d44229-4b75-4df5-9bac-df3b53285859',\n", + "{'tezos_address': 'tz1VSUr8wwNhLAzempoch5d6hLRiTh8Cjcjb',\n", + " 'id': '2a5e9326-e725-4e60-a8cb-816bad6a4f7b',\n", " 'name': 'alice',\n", " 'withdraw_counter': 0}" ] @@ -133,9 +133,9 @@ { "data": { "text/plain": [ - "{'id': '958caaa8-ed25-4c26-a062-42bd78182399',\n", + "{'id': '0cfb5c6d-9655-46c2-bddc-42b772b05367',\n", " 'amount': 0,\n", - " 'owner_id': '06d44229-4b75-4df5-9bac-df3b53285859'}" + " 'owner_id': '2a5e9326-e725-4e60-a8cb-816bad6a4f7b'}" ] }, "execution_count": 9, @@ -144,7 +144,7 @@ } ], "source": [ - "r = requests.get(f\"{gas_station_api}/credits/{alice_user['address']}\")\n", + "r = requests.get(f\"{gas_station_api}/credits/{alice_user['tezos_address']}\")\n", "credits = r.json()[0]\n", "assert r.status_code == 200\n", "credits" @@ -248,20 +248,9 @@ "execution_count": 16, "id": "bcba5181", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "200" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "r.status_code" + "assert r.status_code == 200" ] }, { @@ -273,9 +262,9 @@ { "data": { "text/plain": [ - "[{'id': '958caaa8-ed25-4c26-a062-42bd78182399',\n", + "[{'id': '0cfb5c6d-9655-46c2-bddc-42b772b05367',\n", " 'amount': 1000000,\n", - " 'owner_id': '06d44229-4b75-4df5-9bac-df3b53285859'}]" + " 'owner_id': '2a5e9326-e725-4e60-a8cb-816bad6a4f7b'}]" ] }, "execution_count": 17, @@ -310,7 +299,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "14\n" + "25\n" ] } ], @@ -374,7 +363,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "id": "8ff9a023", "metadata": {}, "outputs": [ @@ -382,7 +371,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "15\n" + "25\n" ] } ], @@ -395,7 +384,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 23, "id": "5508e9bc", "metadata": {}, "outputs": [], @@ -408,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 24, "id": "0cbecd7f", "metadata": {}, "outputs": [], @@ -427,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 25, "id": "33cbdddf", "metadata": {}, "outputs": [ @@ -437,7 +426,7 @@ "" ] }, - "execution_count": 27, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -456,7 +445,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 26, "id": "202834f6", "metadata": {}, "outputs": [ @@ -466,7 +455,7 @@ "200" ] }, - "execution_count": 29, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -493,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 27, "id": "b34ffa8d", "metadata": {}, "outputs": [ @@ -503,7 +492,7 @@ "400" ] }, - "execution_count": 30, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -530,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 28, "id": "166eb0a0", "metadata": {}, "outputs": [ @@ -540,7 +529,7 @@ "200" ] }, - "execution_count": 31, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -559,7 +548,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 29, "id": "9daae29e", "metadata": {}, "outputs": [ @@ -567,17 +556,17 @@ "data": { "text/plain": [ "[{'type': 'MAX_CALLS_PER_SPONSEE',\n", - " 'vault_id': '958caaa8-ed25-4c26-a062-42bd78182399',\n", - " 'created_at': '2024-03-12T18:00:11.362107+00:00',\n", - " 'id': '928a842a-68da-48c6-b9fa-cccd71ccacb1',\n", - " 'contract_id': '4b0a9fdf-d36a-4ce8-af4c-1facc8ae1371',\n", + " 'vault_id': '0cfb5c6d-9655-46c2-bddc-42b772b05367',\n", + " 'created_at': '2024-03-15T10:55:20.172575+00:00',\n", + " 'is_active': True,\n", + " 'id': 'c86187e1-f184-4a2d-8dda-094cc581820f',\n", + " 'contract_id': 'a7f3f5ee-a790-4784-b57b-b9e486d88838',\n", " 'entrypoint_id': None,\n", " 'max': 1,\n", - " 'current': 2,\n", - " 'is_active': True}]" + " 'current': 2}]" ] }, - "execution_count": 32, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -587,14 +576,6 @@ " f\"{gas_station_api}/condition/{credits['id']}\"\n", ").json()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "01e26527", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/main.py b/examples/main.py new file mode 100644 index 0000000..dd9ea08 --- /dev/null +++ b/examples/main.py @@ -0,0 +1,113 @@ +import asyncio +from typing import Any +import os +import requests + +from fastapi import FastAPI, APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel + +from Crypto.PublicKey import RSA +import jwt + +api_url = os.getenv("GAS_STATION_URL") +gs_user_address = "tz1VSUr8wwNhLAzempoch5d6hLRiTh8Cjcjb" +router = APIRouter() + + +class Operation(BaseModel): + """Data sent when posting an operation. The sender is mandatory.""" + + sender_address: str + operations: list[dict[str, Any]] + + +# TODO: the signature should include the action decided by the sponsor API +class Receipt(BaseModel): + """Signature of an operation to be posted on-chain.""" + gas_station_action: str + signature: str + + +try: + skey = "".join(open("./private.pem").readlines()) + pkey = "".join(open("./public.pem").readlines()) +except FileNotFoundError: + mykey = RSA.generate(1024) + skey = mykey.export_key() + pkey = mykey.public_key().export_key() + with open("./private.pem", "w") as f: + f.write(skey) + with open("./public.pem", "w") as f: + f.write(pkey) + + +# We will assume this API runs in only one thread, and ignore race conditions +# for this example. +# Shared database of senders, to limit the number of sponsored operations +# per user. This could be implemented as a condition directly in the gas +# station as well. +seen_senders = dict() + + +def validate(sender): + seen = seen_senders.get(sender, 0) + print("SEEN", seen, "TIMES") + if seen > 1: + return False + else: + seen_senders[sender] = seen + 1 + return True + + +@router.post("/operation", response_model=Receipt) +async def sign_operation(request: Request): + raw_operation = await request.json() + try: + operation = Operation.parse_obj(raw_operation) + except: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Could not parse operation", + ) + + if validate(operation.sender_address): + signature = jwt.encode( + raw_operation, key=skey, algorithm="RS256" + ) + return Receipt( + gas_station_action="post_operation", + signature=signature + ) + else: + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=jsonable_encoder( + { + "detail": "Invalid call", + "body": "FIXME", + "custom msg": "" + } + ) + ) + +# TODO +# - implement "operation_posted" + +# Register to the API +r = requests.get(f"{api_url}/sponsors/{gs_user_address}") +gs_user = r.json() +requests.put( + f"{api_url}/sponsor_api", + json = { + "sponsor_id": gs_user["id"], + "api_url": "http://localhost:8005", # This API + "public_key": pkey + } +) + +app = FastAPI() +app.include_router(router) + +loop = asyncio.get_event_loop() diff --git a/requirements.txt b/requirements.txt index 2c62e10..0ebf1ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,10 @@ pytezos python-multipart uvicorn[standard] fastapi -python-jose[cryptography] +pyjwt[cryptography] python-dotenv pydantic sqlalchemy psycopg2 -alembic \ No newline at end of file +alembic +aiohttp diff --git a/sql/init_db.sql b/sql/init_db.sql index 1a93b5c..0707832 100644 --- a/sql/init_db.sql +++ b/sql/init_db.sql @@ -1,4 +1,4 @@ -INSERT INTO users("id", "name", "address", "withdraw_counter") +INSERT INTO sponsors("id", "name", "address", "withdraw_counter") VALUES ('164b18e2-205b-47fa-8fa5-e9961b3a8437', 'Alfred', 'tz1VLKbNYhmfyQSZzsdLWrbtVbyjsRf9qEjN', 0), ('b8c23360-9a81-4450-93d8-ea32a2d7467e', 'Quentin', 'tz1YdFws2E182i25ezpHvEvcn4vh74XcMDFi', 0); @@ -22,4 +22,4 @@ INSERT INTO entrypoints("id", "name", "is_enabled", "contract_id") ('dd225743-65b8-465d-849b-be5f795b0e3e', 'permit', true, 'c8b3f63a-9453-4e9f-98b3-855a0de682aa'), ('18e7ee0d-6e16-4392-9cc7-1609d6f84c0c', 'stake', true, 'f08660dc-34a8-4575-b53c-19d362296ead'), ('33532a9c-51e7-4f88-b60f-67530122c349', 'unstake', true, 'f08660dc-34a8-4575-b53c-19d362296ead'), - ('764d5857-1201-4a69-bbe8-137e0326a830', 'dummy', false, '4dfbd6f2-ca41-48d0-adc5-9c0bef8127d1'); \ No newline at end of file + ('764d5857-1201-4a69-bbe8-137e0326a830', 'dummy', false, '4dfbd6f2-ca41-48d0-adc5-9c0bef8127d1'); diff --git a/src/crud.py b/src/crud.py index af6dc62..be1609d 100644 --- a/src/crud.py +++ b/src/crud.py @@ -10,51 +10,56 @@ ContractNotFound, CreditNotFound, EntrypointNotFound, - UserNotFound, + SponsorNotFound, ) from . import models, schemas from sqlalchemy.exc import NoResultFound -def get_user(db: Session, uuid: UUID4): +def get_sponsor(db: Session, uuid: UUID4): """ - Return a models.User or raise UserNotFound exception + Return a models.Sponsor or raise SponsorNotFound exception """ - db_user: Optional[models.User] = db.query(models.User).get(uuid) - if db_user is None: - raise UserNotFound() - return db_user + db_sponsor: Optional[models.Sponsor] = db.query(models.Sponsor).get(uuid) + if db_sponsor is None: + raise SponsorNotFound() + return db_sponsor -def get_user_by_address(db: Session, address: str): +def get_sponsor_by_address(db: Session, tezos_address: str): """ - Return a models.User or raise UserNotFound exception + Return a models.Sponsor or raise SponsorNotFound exception """ try: - return db.query(models.User).filter(models.User.address == address).one() + db_sponsor = ( + db.query(models.Sponsor) + .filter(models.Sponsor.tezos_address == tezos_address) + .one() + ) + return db_sponsor except NoResultFound as e: - raise UserNotFound() from e + raise SponsorNotFound() from e -def create_user(db: Session, user: schemas.UserCreation): - db_user = models.User(**user.model_dump()) - db.add(db_user) +def create_sponsor(db: Session, sponsor: schemas.SponsorCreation): + db_sponsor = models.Sponsor(**sponsor.model_dump()) + db.add(db_sponsor) db.commit() - db.refresh(db_user) - return db_user + db.refresh(db_sponsor) + return db_sponsor -def get_contracts_by_user(db: Session, user_address: str): +def get_contracts_by_sponsor(db: Session, sponsor_address: str): """ - Return a list of models.Contracts or raise UserNotFound exception + Return a list of models.Contracts or raise SponsorNotFound exception """ - user = get_user_by_address(db, user_address) - return user.contracts + sponsor = get_sponsor_by_address(db, sponsor_address) + return sponsor.contracts def get_contracts_by_credit(db: Session, credit_id: str): """ - Return a list of models.Contracts or raise UserNotFound exception + Return a list of models.Contracts or raise SponsorNotFound exception """ return ( db.query(models.Contract).filter(models.Contract.credit_id == credit_id).all() @@ -127,6 +132,25 @@ def create_contract(db: Session, contract: schemas.ContractCreation): return db_contract +def update_sponsor_api(db: Session, api_update: schemas.SponsorAPIUpdate): + db_sponsor_api = models.SponsorAPI(**{ + "url": api_update.api_url, + "public_key": api_update.public_key + }) + db.add(db_sponsor_api) + db.commit() + sponsor = ( + db + .query(models.Sponsor) + .filter(models.Sponsor.id == api_update.sponsor_id) + .update({ + "api_id": db_sponsor_api.id + }) + ) + db.commit() + return sponsor + + def update_entrypoints(db: Session, entrypoints: list[schemas.EntrypointUpdate]): for e in entrypoints: db.query(models.Entrypoint).filter(models.Entrypoint.id == e.id).update( @@ -141,28 +165,28 @@ def update_entrypoints(db: Session, entrypoints: list[schemas.EntrypointUpdate]) ) -def get_user_credits(db: Session, user_id: str): +def get_sponsor_credits(db: Session, sponsor_id: str): """ - Get credits from a user. + Get credits from a sponsor. """ - db_credits = db.query(models.Credit).filter(models.Credit.owner_id == user_id).all() + db_credits = db.query(models.Credit).filter(models.Credit.owner_id == sponsor_id).all() return db_credits -def update_user_withdraw_counter(db: Session, user_id: str, withdraw_counter: int): +def update_sponsor_withdraw_counter(db: Session, sponsor_id: str, withdraw_counter: int): try: - db_user: Optional[models.User] = db.query(models.User).get(user_id) - if db_user is None: - raise UserNotFound() + db_sponsor: Optional[models.Sponsor] = db.query(models.Sponsor).get(sponsor_id) + if db_sponsor is None: + raise SponsorNotFound() - db.query(models.User).filter(models.User.id == user_id).update( + db.query(models.Sponsor).filter(models.Sponsor.id == sponsor_id).update( {"withdraw_counter": withdraw_counter} ) db.commit() - return db_user.withdraw_counter + return db_sponsor.withdraw_counter except NoResultFound as e: - raise UserNotFound from e + raise SponsorNotFound from e def create_credits(db: Session, credit: schemas.CreditCreation): @@ -170,15 +194,15 @@ def create_credits(db: Session, credit: schemas.CreditCreation): Creates credits for a given owner and returns a models.Credit. """ try: - # Check if the user exists - _ = db.query(models.User).get(credit.owner_id) + # Check if the sponsor exists + _ = db.query(models.Sponsor).get(credit.owner_id) credit = models.Credit(**credit.model_dump()) db.add(credit) db.commit() # db.refresh(credit) return credit except NoResultFound as e: - raise UserNotFound() from e + raise SponsorNotFound() from e def update_credits(db: Session, credit_update: schemas.CreditUpdate): @@ -222,7 +246,7 @@ def update_credits_from_contract_address(db: Session, amount: int, address: str) def get_credits(db: Session, uuid: UUID4): """ - Return a models.Credit or raise UserNotFound exception + Return a models.Credit or raise SponsorNotFound exception """ db_credit = db.query(models.Credit).get(uuid) if db_credit is None: diff --git a/src/models.py b/src/models.py index 42edec7..f4512e4 100644 --- a/src/models.py +++ b/src/models.py @@ -9,6 +9,7 @@ String, ) from sqlalchemy.orm import relationship +from sqlalchemy.sql import func from sqlalchemy.dialects.postgresql import UUID import uuid @@ -17,25 +18,48 @@ import datetime -# ------- USER ------- # -class User(Base): - __tablename__ = "users" +# ------- SPONSOR CLASSES ------- # +# Sponsors +# - must have a tezos address, which they may use to provide credits +# - may have an API, to which user operations are transmitted, and +# which may post these operations themselves, or just return a signed +# receipt to the gas station (which then posts the operations itself) +# If the sponsor API returns the operation to the GS, then the sponsor must +# have deposited credits. +class SponsorAPI(Base): + __tablename__ = "sponsor_apis" def __repr__(self): - return "User(id='{}', name='{}', address='{}', counter='{}')".format( - self.id, self.name, self.address, self.withdraw_counter + return "API(id='{}', url='{}')".format( + self.id, self.url + ) + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + url = Column(String, unique=True) + public_key = Column(String, nullable=False) + + +class Sponsor(Base): + __tablename__ = "sponsors" + + def __repr__(self): + return "Sponsor(id='{}', name='{}', address='{}', counter='{}')".format( + self.id, self.name, self.tezos_address, self.withdraw_counter ) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String) - address = Column(String, unique=True) + tezos_address = Column(String, unique=True) withdraw_counter = Column(Integer, default=0) - + api_id = Column(UUID(as_uuid=True), ForeignKey("sponsor_apis.id")) contracts = relationship("Contract", back_populates="owner") credits = relationship("Credit", back_populates="owner") + sponsor_api = relationship("SponsorAPI") # ------- CONTRACT ------- # +# TODO: contract sponsored by several owners +# Do not require contracts to be tied to a specific credit class Contract(Base): __tablename__ = "contracts" @@ -47,13 +71,13 @@ def __repr__(self): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String) address = Column(String, unique=True) - owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + owner_id = Column(UUID(as_uuid=True), ForeignKey("sponsors.id")) credit_id = Column(UUID(as_uuid=True), ForeignKey("credits.id")) max_calls_per_month = Column( Integer, default=-1 ) # TODO must be > 0 ; -1 means disabled - owner = relationship("User", back_populates="contracts") + owner = relationship("Sponsor", back_populates="contracts") entrypoints = relationship("Entrypoint", back_populates="contract") credit = relationship("Credit", back_populates="contracts") operations = relationship("Operation", back_populates="contract") @@ -97,9 +121,9 @@ def __repr__(self): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) amount = Column(Integer, default=0) - owner_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + owner_id = Column(UUID(as_uuid=True), ForeignKey("sponsors.id")) - owner = relationship("User", back_populates="credits") + owner = relationship("Sponsor", back_populates="credits") contracts = relationship("Contract", back_populates="credit") conditions = relationship("Condition", back_populates="vault") @@ -117,7 +141,7 @@ class Operation(Base): entrypoint_id = Column(UUID(as_uuid=True), ForeignKey("entrypoints.id")) hash = Column(String) status = Column(String) # TODO Enum - created_at = Column(DateTime(timezone=True), default=datetime.datetime.utcnow()) + created_at = Column(DateTime(timezone=True), server_default=func.now()) contract = relationship("Contract", back_populates="operations") entrypoint = relationship("Entrypoint", back_populates="operations") @@ -155,7 +179,7 @@ class Condition(Base): max = Column(Integer, nullable=False) current = Column(Integer, nullable=False) created_at = Column( - DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False + DateTime(timezone=True), server_default=func.now(), nullable=False ) is_active = Column(Boolean, nullable=False) contract = relationship("Contract", back_populates="conditions") diff --git a/src/routes.py b/src/routes.py index e0038ea..efea74d 100644 --- a/src/routes.py +++ b/src/routes.py @@ -1,5 +1,7 @@ from fastapi import APIRouter, HTTPException, status, Depends import asyncio +import aiohttp +import jwt from sqlalchemy.orm import Session from . import tezos, crud, schemas, database @@ -15,7 +17,7 @@ EntrypointNotFound, TooManyCallsForThisMonth, NotEnoughFunds, - UserNotFound, + SponsorNotFound, OperationNotFound, ) from .config import logging @@ -34,13 +36,13 @@ async def root(): # POST endpoints -@router.post("/users", response_model=schemas.User) -async def create_user( - user: schemas.UserCreation, db: Session = Depends(database.get_db) +@router.post("/sponsors", response_model=schemas.Sponsor) +async def create_sponsor( + sponsor: schemas.SponsorCreation, db: Session = Depends(database.get_db) ): - user = crud.create_user(db, user) - crud.create_credits(db, schemas.CreditCreation(owner_id=user.id)) - return user + sponsor = crud.create_sponsor(db, sponsor) + crud.create_credits(db, schemas.CreditCreation(owner_id=sponsor.id)) + return sponsor @router.post("/contracts", response_model=schemas.Contract) @@ -58,9 +60,20 @@ async def create_contract( # PUT endpoints +# FIXME: we obviously need to protect this in some way, but we'll rework +# security in a later upgrade +@router.put("/sponsor_api", response_model=bool) +async def update_sponsor_api( + api_update: schemas.SponsorAPIUpdate, + db: Session = Depends(database.get_db) +): + return crud.update_sponsor_api(db, api_update) + + @router.put("/entrypoints", response_model=list[schemas.Entrypoint]) async def update_entrypoints( - entrypoints: list[schemas.EntrypointUpdate], db: Session = Depends(database.get_db) + entrypoints: list[schemas.EntrypointUpdate], + db: Session = Depends(database.get_db) ): return crud.update_entrypoints(db, entrypoints) @@ -70,7 +83,7 @@ async def update_credits( credits: schemas.CreditUpdate, db: Session = Depends(database.get_db) ): try: - payer_address = crud.get_credits(db, credits.id).owner.address + payer_address = crud.get_credits(db, credits.id).owner.tezos_address op_hash = credits.operation_hash amount = credits.amount is_confirmed = await tezos.confirm_deposit(op_hash, payer_address, amount) @@ -127,8 +140,8 @@ async def withdraw_credits( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad withdraw counter." ) - owner_address = credits.owner.address - user = crud.get_user_by_address(db, owner_address) + owner_address = credits.owner.tezos_address + sponsor = crud.get_sponsor_by_address(db, owner_address) public_key = tezos.get_public_key(owner_address) is_valid = tezos.check_signature( withdraw.to_micheline_pair(), withdraw.micheline_signature, public_key @@ -140,8 +153,8 @@ async def withdraw_credits( ) # We increment the counter even if the withdraw fails to prevent # the counter from being used again immediately. - counter = crud.update_user_withdraw_counter( - db, str(user.id), withdraw.withdraw_counter + 1 + counter = crud.update_sponsor_withdraw_counter( + db, str(sponsor.id), withdraw.withdraw_counter + 1 ) result = await tezos.withdraw(tezos.tezos_manager, owner_address, withdraw.amount) if result["result"] == "ok": @@ -150,55 +163,55 @@ async def withdraw_credits( # has been confirmed asyncio.create_task( tezos.confirm_withdraw( - result["transaction_hash"], db, str(user.id), withdraw + result["transaction_hash"], db, str(sponsor.id), withdraw ) ) return {**result, "counter": counter} -# Users and credits getters -@router.get("/users/{address_or_id}", response_model=schemas.User) -async def get_user(address_or_id: str, db: Session = Depends(database.get_db)): +# Sponsors and credits getters +@router.get("/sponsors/{address_or_id}", response_model=schemas.Sponsor) +async def get_sponsor(address_or_id: str, db: Session = Depends(database.get_db)): try: if is_address(address_or_id) and address_or_id.startswith("tz"): - return crud.get_user_by_address(db, address_or_id) + return crud.get_sponsor_by_address(db, address_or_id) else: - return crud.get_user(db, address_or_id) - except UserNotFound: - logging.warning(f"User {address_or_id} not found") + return crud.get_sponsor(db, address_or_id) + except SponsorNotFound: + logging.warning(f"Sponsor {address_or_id} not found") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"User not found.", + detail=f"Sponsor not found.", ) -@router.get("/credits/{user_address_or_id}", response_model=list[schemas.Credit]) -async def credits_for_user( - user_address_or_id: str, db: Session = Depends(database.get_db) +@router.get("/credits/{sponsor_address_or_id}", response_model=list[schemas.Credit]) +async def credits_for_sponsor( + sponsor_address_or_id: str, db: Session = Depends(database.get_db) ): try: - if is_address(user_address_or_id) and user_address_or_id.startswith("tz"): - return crud.get_user_by_address(db, user_address_or_id).credits + if is_address(sponsor_address_or_id) and sponsor_address_or_id.startswith("tz"): + return crud.get_sponsor_by_address(db, sponsor_address_or_id).credits else: - return crud.get_user(db, user_address_or_id).credits - except UserNotFound: - logging.warning(f"User {user_address_or_id} not found") + return crud.get_sponsor(db, sponsor_address_or_id).credits + except SponsorNotFound: + logging.warning(f"Sponsor {sponsor_address_or_id} not found") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"User not found.", + detail=f"Sponsor not found.", ) # Contracts -@router.get("/contracts/user/{user_address}", response_model=list[schemas.Contract]) -async def get_user_contracts(user_address: str, db: Session = Depends(database.get_db)): +@router.get("/contracts/sponsor/{sponsor_address}", response_model=list[schemas.Contract]) +async def get_sponsor_contracts(sponsor_address: str, db: Session = Depends(database.get_db)): try: - return crud.get_contracts_by_user(db, user_address) - except UserNotFound: - logging.warning(f"User {user_address} not found.") + return crud.get_contracts_by_sponsor(db, sponsor_address) + except SponsorNotFound: + logging.warning(f"Sponsor {sponsor_address} not found.") raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"User not found." + status_code=status.HTTP_404_NOT_FOUND, detail=f"Sponsor not found." ) @@ -265,6 +278,78 @@ async def get_entrypoint( ) +def _check_contract(operation, db): + """Checks that a contract used in the operation is registered and active + in the database. Returns the contract object.""" + contract_address = str(operation["destination"]) + + # Transfers to implicit accounts are always refused + if not contract_address.startswith("KT"): + logging.warning(f"Target {contract_address} is not allowed") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Target {contract_address} is not allowed", + ) + try: + contract = crud.get_contract_by_address(db, contract_address) + except ContractNotFound: + logging.warning(f"{contract_address} is not found") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"{contract_address} is not found", + ) + + return contract + + +def _check_conditions(sender, contract_id, entrypoint_id, credit_id, db): + if not crud.check_conditions( + db, + schemas.CheckConditions( + sponsee_address=sender, + contract_id=contract_id, + entrypoint_id=entrypoint_id, + vault_id=credit_id, + ), + ): + logging.warning(f"A condition exceed the maximum defined.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"A condition exceed the maximum defined.", + ) + + +def _check_entrypoint(operation, contract, db): + """Checks that the target entrypoint is registered and active.""" + entrypoint_name = operation["parameters"]["entrypoint"] + try: + entrypoint = crud.get_entrypoint( + db, + str(contract.address), + entrypoint_name + ) + if not entrypoint.is_enabled: + logging.warning(f"Entrypoint {entrypoint_name} is disabled.") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Entrypoint {entrypoint_name} is disabled.", + ) + return entrypoint + except EntrypointNotFound: + logging.warning(f"Entrypoint {entrypoint_name} is not found") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Entrypoint {entrypoint_name} is not found", + ) + + +def _check_receipt(pkey, receipt, json_data): + decoded = jwt.decode( + receipt["signature"], key=pkey, algorithms=["RS256"] + ) + return True + + # Operations @router.post("/operation") async def post_operation( @@ -276,67 +361,48 @@ async def post_operation( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Empty operations list", ) - # TODO: check that amount=0? + # Build a set of all the sponsors who have an API that we should query + # and make checks on the contracts, entrypoints + sponsors_apis = dict() for operation in call_data.operations: - contract_address = str(operation["destination"]) - - # Transfers to implicit accounts are always refused - if not contract_address.startswith("KT"): - logging.warning(f"Target {contract_address} is not allowed") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Target {contract_address} is not allowed", - ) - try: - contract = crud.get_contract_by_address(db, contract_address) - except ContractNotFound: - logging.warning(f"{contract_address} is not found") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"{contract_address} is not found", - ) - - entrypoint_name = operation["parameters"]["entrypoint"] - + contract = _check_contract(operation, db) + entrypoint = _check_entrypoint(operation, contract, db) + _check_conditions( + call_data.sender_address, + contract.id, + entrypoint.id, + contract.credit_id, + db + ) + sponsor_api = contract.owner.sponsor_api + if sponsor_api is not None: + sponsors_apis[sponsor_api.url] = sponsor_api.public_key + + for sponsor_url in sponsors_apis: + # FIXME concurrent requests + async with aiohttp.ClientSession() as session: + async with session.post( + sponsor_url + "/operation", + json = call_data.model_dump() + ) as response: + receipt = await response.json() + pkey = sponsors_apis[sponsor_url] try: - entrypoint = crud.get_entrypoint(db, str(contract.address), entrypoint_name) - if not entrypoint.is_enabled: - raise EntrypointDisabled() - - if not crud.check_conditions( - db, - schemas.CheckConditions( - sponsee_address=call_data.sender_address, - contract_id=contract.id, - entrypoint_id=entrypoint.id, - vault_id=contract.credit_id, - ), - ): - raise ConditionExceed() - except EntrypointNotFound: - logging.warning(f"Entrypoint {entrypoint_name} is not found") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Entrypoint {entrypoint_name} is not found", - ) - except EntrypointDisabled: - logging.warning(f"Entrypoint {entrypoint_name} is disabled.") + parsed_receipt = schemas.Receipt(**receipt) + _check_receipt(pkey, receipt, call_data.model_dump()) + if parsed_receipt.gas_station_action.lower() == "refuse": + raise ValueError() + elif parsed_receipt.gas_station_action.lower() == "accepted": + return receipt + except ValueError as e: + print(e) raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Entrypoint {entrypoint_name} is disabled.", - ) - except ConditionExceed: - logging.warning(f"A condition exceed the maximum defined.") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"A condition exceed the maximum defined.", + status_code=status.HTTP_404_NOT_FOUND, + detail="A sponsor refused the operation" ) - try: # Simulate the operation alone without sending it - # TODO: log the result op = tezos.simulate_transaction(call_data.operations) - logging.debug(f"Result of operation simulation : {op}") op_estimated_fees = [(int(x["fee"]), x["destination"]) for x in op.contents] @@ -347,12 +413,14 @@ async def post_operation( if not tezos.check_credits(db, estimated_fees): logging.warning(f"Not enough funds to pay estimated fees.") raise NotEnoughFunds( - f"Estimated fees : {estimated_fees[str(contract.address)]} mutez" + f"Estimated fees: {estimated_fees[str(contract.address)]} mutez" ) if not crud.check_calls_per_month(db, contract.id): # type: ignore logging.warning(f"Too many calls made for this contract this month.") raise TooManyCallsForThisMonth() + # Adds the operation to a queue and updates the credits in the + # database. result = await tezos.tezos_manager.queue_operation(call_data.sender_address, op) crud.create_operation( @@ -396,7 +464,7 @@ async def post_operation( async def signed_operation( call_data: schemas.SignedCall, db: Session = Depends(database.get_db) ): - # In order for the user to sign Micheline, we need to + # In order to check the signed Micheline # FIXME: this is a serious issue, we should sign the contract address too. signed_data = [x["parameters"]["value"] for x in call_data.operations] if not tezos.check_signature( diff --git a/src/schemas.py b/src/schemas.py index 8143cfa..44de6a8 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -4,7 +4,7 @@ from typing import List, Any, Optional -# -- UTILITY TYPES -- +# Utility types class ConditionType(enum.Enum): # Max number of calls to a given entrypoint, for all sponsee MAX_CALLS_PER_ENTRYPOINT = "MAX_CALLS_PER_ENTRYPOINT" @@ -12,18 +12,32 @@ class ConditionType(enum.Enum): MAX_CALLS_PER_SPONSEE = "MAX_CALLS_PER_SPONSEE" -# Users -class UserBase(BaseModel): - address: str +class Receipt(BaseModel): + """Receipts returned by the sponsor APIs""" + gas_station_action: str + signature: str + + +# Sonsor APIs +class SponsorAPIUpdate(BaseModel): + sponsor_id: UUID4 + api_url: str + public_key: str + + +# Sponsors +class SponsorBase(BaseModel): + tezos_address: str -class User(UserBase): +class Sponsor(SponsorBase): id: UUID4 name: str withdraw_counter: int + api_id: UUID4 | None = None -class UserCreation(UserBase): +class SponsorCreation(SponsorBase): name: str diff --git a/src/tezos.py b/src/tezos.py index 1a7f557..c25071a 100644 --- a/src/tezos.py +++ b/src/tezos.py @@ -99,13 +99,14 @@ async def confirm_deposit(tx_hash, payer, amount: Union[int, str]): return False -async def confirm_withdraw(tx_hash, db, user_id, withdraw): - """Ensure withdraw transaction is successful to update credits user. \n - Can raise an OperationNotFound exception if transaction is not found. +async def confirm_withdraw(tx_hash, db, sponsor_id, withdraw): + """Ensures the withdraw transaction is successful and updates the sponsor's + credit. Raises an OperationNotFound exception if transaction is not found. """ await find_transaction(tx_hash) credit_update = schemas.CreditUpdate( - id=withdraw.id, amount=-withdraw.amount, owner_id=user_id, operation_hash="" + id=withdraw.id, amount=-withdraw.amount, owner_id=sponsor_id, + operation_hash="" ) crud.update_credits(db, credit_update) diff --git a/src/utils.py b/src/utils.py index 4192b48..3ce3333 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,10 +1,4 @@ -# -- EXCEPTIONS -- - - -import enum - - -class UserNotFound(Exception): +class SponsorNotFound(Exception): pass