Skip to content

Commit 4bb0c2c

Browse files
committed
bot: Start implementing new FastAPI-based server
1 parent 664221f commit 4bb0c2c

File tree

24 files changed

+1132
-53
lines changed

24 files changed

+1132
-53
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"--disable=too-many-return-statements",
1313
"--disable=too-many-branches"
1414
],
15-
"editor.formatOnSave": true
15+
"editor.formatOnSave": true,
16+
"editor.defaultFormatter": "charliermarsh.ruff"
1617
}

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ install:
1010
dev-frontend: config.yml blanco.db
1111
poetry run python -m bot.dev_server
1212

13+
dev-backend: config.yml blanco.db
14+
poetry run python -m bot.api.main
15+
1316
dev: config.yml blanco.db
1417
poetry run python -m bot.main
1518

bot/api/depends/database.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import TYPE_CHECKING
2+
3+
from fastapi import HTTPException, Request
4+
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
5+
6+
if TYPE_CHECKING:
7+
from bot.database import Database
8+
9+
10+
def database_dependency(request: Request) -> 'Database':
11+
"""
12+
FastAPI dependency to get the database object.
13+
14+
Args:
15+
request (web.Request): The request.
16+
17+
Returns:
18+
Database: The database object.
19+
"""
20+
21+
state = request.app.state
22+
if not hasattr(state, 'database'):
23+
raise HTTPException(
24+
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
25+
)
26+
27+
database: 'Database' = state.database
28+
if database is None:
29+
raise HTTPException(
30+
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
31+
)
32+
33+
return database

bot/api/depends/session.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import TYPE_CHECKING, Optional
2+
3+
from fastapi import Depends, HTTPException, Request
4+
from starlette.status import HTTP_401_UNAUTHORIZED
5+
6+
from .database import database_dependency
7+
8+
if TYPE_CHECKING:
9+
from bot.api.utils.session import SessionManager
10+
from bot.database import Database
11+
from bot.models.oauth import OAuth
12+
13+
14+
EXPECTED_AUTH_SCHEME = 'Bearer'
15+
EXPECTED_AUTH_PARTS = 2
16+
17+
18+
def session_dependency(
19+
request: Request, db: 'Database' = Depends(database_dependency)
20+
) -> 'OAuth':
21+
"""
22+
FastAPI dependency to get the requesting user's info.
23+
24+
Args:
25+
request (web.Request): The request.
26+
27+
Returns:
28+
OAuth: The info for the current Discord user.
29+
"""
30+
31+
authorization = request.headers.get('Authorization')
32+
if authorization is None:
33+
raise HTTPException(
34+
status_code=HTTP_401_UNAUTHORIZED, detail='No authorization header'
35+
)
36+
37+
parts = authorization.split()
38+
if len(parts) != EXPECTED_AUTH_PARTS:
39+
raise HTTPException(
40+
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization header'
41+
)
42+
43+
scheme, token = parts
44+
if scheme != EXPECTED_AUTH_SCHEME:
45+
raise HTTPException(
46+
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization scheme'
47+
)
48+
49+
session_manager: 'SessionManager' = request.app.state.session_manager
50+
session = session_manager.decode_session(token)
51+
if session is None:
52+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='Invalid session')
53+
54+
user: Optional['OAuth'] = db.get_oauth('discord', session.user_id)
55+
if user is None:
56+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='User not found')
57+
58+
return user

bot/api/extension.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Nextcord extension that runs the API server for the bot
3+
"""
4+
5+
from typing import TYPE_CHECKING
6+
7+
from .main import run_app
8+
9+
if TYPE_CHECKING:
10+
from bot.utils.blanco import BlancoBot
11+
12+
13+
def setup(bot: 'BlancoBot'):
14+
"""
15+
Run the API server within the bot's existing event loop.
16+
"""
17+
run_app(bot.loop, bot.database)

bot/api/main.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Main module for the API server.
3+
"""
4+
5+
from asyncio import set_event_loop
6+
from contextlib import asynccontextmanager
7+
from logging import INFO
8+
from typing import TYPE_CHECKING, Any, Optional
9+
10+
from fastapi import FastAPI
11+
from uvicorn import Config, Server, run
12+
from uvicorn.config import LOGGING_CONFIG
13+
14+
from bot.database import Database
15+
from bot.utils.config import config as bot_config
16+
from bot.utils.logger import DATE_FMT_STR, LOG_FMT_COLOR, create_logger
17+
18+
from .routes.account import account_router
19+
from .routes.oauth import oauth_router
20+
from .utils.session import SessionManager
21+
22+
if TYPE_CHECKING:
23+
from asyncio import AbstractEventLoop
24+
25+
26+
_database: Optional[Database] = None
27+
28+
29+
@asynccontextmanager
30+
async def lifespan(app: FastAPI):
31+
logger = create_logger('api.lifespan')
32+
33+
if _database is None:
34+
logger.warn('Manually creating database connection')
35+
database = Database(bot_config.db_file)
36+
else:
37+
logger.info('Connecting to database from FastAPI')
38+
database = _database
39+
40+
app.state.database = database
41+
app.state.session_manager = SessionManager(database)
42+
yield
43+
44+
45+
app = FastAPI(lifespan=lifespan)
46+
app.include_router(account_router)
47+
app.include_router(oauth_router)
48+
49+
50+
@app.get('/')
51+
async def health_check():
52+
return {'status': 'ok'}
53+
54+
55+
def _get_log_config() -> dict[str, Any]:
56+
log_config = LOGGING_CONFIG
57+
log_config['formatters']['default']['fmt'] = LOG_FMT_COLOR[INFO]
58+
log_config['formatters']['default']['datefmt'] = DATE_FMT_STR
59+
log_config['formatters']['access']['fmt'] = LOG_FMT_COLOR[INFO]
60+
61+
return log_config
62+
63+
64+
def run_app(loop: 'AbstractEventLoop', db: Database):
65+
"""
66+
Run the API server in the bot's event loop.
67+
"""
68+
global _database # noqa: PLW0603
69+
_database = db
70+
71+
set_event_loop(loop)
72+
73+
config = Config(
74+
app=app,
75+
loop=loop, # type: ignore
76+
host='0.0.0.0',
77+
port=bot_config.server_port,
78+
log_config=_get_log_config(),
79+
)
80+
server = Server(config)
81+
82+
loop.create_task(server.serve())
83+
84+
85+
if __name__ == '__main__':
86+
run(
87+
app='bot.api.main:app',
88+
host='127.0.0.1',
89+
port=bot_config.server_port,
90+
reload=True,
91+
reload_dirs=['bot/api'],
92+
log_config=_get_log_config(),
93+
)

bot/api/models/account.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class AccountResponse(BaseModel):
7+
username: str = Field(description="The user's username.")
8+
spotify_logged_in: bool = Field(
9+
description='Whether the user is logged in to Spotify.'
10+
)
11+
spotify_username: Optional[str] = Field(
12+
default=None, description="The user's Spotify username, if logged in."
13+
)
14+
lastfm_logged_in: bool = Field(
15+
description='Whether the user is logged in to Last.fm.'
16+
)
17+
lastfm_username: Optional[str] = Field(
18+
default=None, description="The user's Last.fm username, if logged in."
19+
)

bot/api/models/oauth.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class OAuthResponse(BaseModel):
7+
session_id: str = Field(description='The session ID for the user.')
8+
jwt: str = Field(description='The JSON Web Token for the user.')
9+
10+
11+
class DiscordUser(BaseModel):
12+
id: int = Field(description='The user ID.')
13+
username: str = Field(description='The username.')
14+
discriminator: str = Field(description='The discriminator.')
15+
avatar: Optional[str] = Field(
16+
default=None, description='The avatar hash, if the user has one.'
17+
)

bot/api/models/session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import BaseModel
2+
3+
4+
class Session(BaseModel):
5+
user_id: int
6+
session_id: str
7+
expiration_time: int

bot/api/routes/account/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from fastapi import APIRouter
2+
3+
from .login import get_login_url as route_login
4+
from .me import get_logged_in_user as route_me
5+
6+
account_router = APIRouter(prefix='/account', tags=['account'])
7+
account_router.add_api_route('/login', route_login, methods=['GET'])
8+
account_router.add_api_route('/me', route_me, methods=['GET'])

0 commit comments

Comments
 (0)