diff --git a/.flake8 b/.flake8 index 9c7d08f6..011ecfa8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 160 -ignore = E203, E402, W503 +ignore = E203, E402, W503, B008 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20901bdc..b57fbfe2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - id: autoflake args: - --in-place - - --imports=sqlalchemy,pydantic + - --imports=trapdata,sqlalchemy,pydantic,fastapi files: . types: [file, python] diff --git a/poetry.lock b/poetry.lock index d0939916..5894bb32 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "alembic" @@ -20,6 +20,27 @@ typing-extensions = ">=4" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "anyio" +version = "3.6.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"}, + {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] +trio = ["trio (>=0.16,<0.22)"] + [[package]] name = "appnope" version = "0.1.3" @@ -406,6 +427,28 @@ files = [ [package.extras] tests = ["asttokens", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.95.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "fastapi-0.95.0-py3-none-any.whl", hash = "sha256:daf73bbe844180200be7966f68e8ec9fd8be57079dff1bacb366db32729e6eb5"}, + {file = "fastapi-0.95.0.tar.gz", hash = "sha256:99d4fdb10e9dd9a24027ac1d0bd4b56702652056ca17a6c8721eec4ad2f14e18"}, +] + +[package.dependencies] +pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +starlette = ">=0.26.1,<0.27.0" + +[package.extras] +all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"] +test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] + [[package]] name = "filelock" version = "3.10.7" @@ -513,6 +556,18 @@ files = [ docs = ["Sphinx", "docutils (<0.18)"] test = ["objgraph", "psutil"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "huggingface-hub" version = "0.13.3" @@ -1980,6 +2035,18 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "sqlalchemy" version = "2.0.8" @@ -2107,6 +2174,25 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.26.1" +description = "The little ASGI library that shines." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "starlette-0.26.1-py3-none-any.whl", hash = "sha256:e87fce5d7cbdde34b76f0ac69013fd9d190d581d80681493016666e6f96c6d5e"}, + {file = "starlette-0.26.1.tar.gz", hash = "sha256:41da799057ea8620e4667a3e69a5b1923ebd32b1819c8fa75634bbe8d8bea9bd"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] + [[package]] name = "structlog" version = "22.3.0" @@ -2390,6 +2476,25 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.21.1" +description = "The lightning-fast ASGI server." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "uvicorn-0.21.1-py3-none-any.whl", hash = "sha256:e47cac98a6da10cd41e6fd036d472c6f58ede6c5dbee3dbee3ef7a100ed97742"}, + {file = "uvicorn-0.21.1.tar.gz", hash = "sha256:0fac9cb342ba099e0d582966005f3fdba5b0290579fed4a6266dc702ca7bb032"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "wcwidth" version = "0.2.6" @@ -2420,4 +2525,4 @@ test = ["pytest (>=6.0.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "be789d9689e5a8a9feeb55857c26f300acf7b99a95ab555fffe45239a14a0086" +content-hash = "453accf6c7a42a5ccebff574cedeef2a4a4797d04a8b155375023c901b40fe37" diff --git a/pyproject.toml b/pyproject.toml index 1ddde4ce..1409d445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ ipython = "^8.11.0" pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" pytest = "*" +fastapi = "^0.95.0" +uvicorn = "^0.21.1" [tool.pytest.ini_options] diff --git a/trapdata/api/ami.service b/trapdata/api/ami.service new file mode 100644 index 00000000..8d861f8a --- /dev/null +++ b/trapdata/api/ami.service @@ -0,0 +1,21 @@ +[Unit] + +Description=AMI Data Manager API + +After=network.target + + +[Service] + +User=debian + +Group=www-data + +WorkingDirectory=/home/debian/ami-data-manager + +ExecStart=/home/debian/miniconda3/bin/gunicorn trapdata.api.main:app --bind 0.0.0.0:8000 --worker-class "uvicorn.workers.UvicornWorker" --log-syslog + + +[Install] + +WantedBy=multi-user.target diff --git a/trapdata/api/config.py b/trapdata/api/config.py new file mode 100644 index 00000000..d4810984 --- /dev/null +++ b/trapdata/api/config.py @@ -0,0 +1,39 @@ +import pathlib +from typing import Any, Dict, List, Optional + +from pydantic import BaseSettings, HttpUrl, PostgresDsn, validator +from pydantic.networks import AnyHttpUrl + +from trapdata.cli import read_settings +from trapdata.settings import Settings as BaseSettings + + +class Settings(BaseSettings): + PROJECT_NAME: str = "AMI Data Manager" + + SENTRY_DSN: Optional[HttpUrl] = None + + API_PATH: str = "/api/v1" + + ACCESS_TOKEN_EXPIRE_MINUTES: int = 7 * 24 * 60 # 7 days + + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] + + # The following variables need to be defined in environment + + TEST_DATABASE_URL: Optional[PostgresDsn] + + SECRET_KEY: str + # END: required environment variables + + # STATIC_ROOT: str = "static" + + # @validator("STATIC_ROOT") + # def validate_static_root(cls, v): + # path = cls.user_data_path / v + # path.mkdir(parents=True, exist_ok=True) + # return path + + +# settings = read_settings(SettingsClass=Settings, SECRET_KEY="secret") +settings = Settings(SECRET_KEY="secret") diff --git a/trapdata/api/deps/db.py b/trapdata/api/deps/db.py new file mode 100644 index 00000000..c800a29a --- /dev/null +++ b/trapdata/api/deps/db.py @@ -0,0 +1,15 @@ +from typing import Generator + +from sqlalchemy import orm + +from trapdata.cli import read_settings +from trapdata.db.base import get_session_class + +settings = read_settings() + + +def get_session() -> Generator[orm.Session, None, None]: + Session = get_session_class(db_path=settings.database_url) + with Session() as session: + yield session + session.close() diff --git a/trapdata/api/deps/request_params.py b/trapdata/api/deps/request_params.py new file mode 100644 index 00000000..113fdbbf --- /dev/null +++ b/trapdata/api/deps/request_params.py @@ -0,0 +1,46 @@ +import json +from typing import Callable, Optional, Type + +from fastapi import HTTPException, Query +from sqlalchemy import UnaryExpression, asc, desc + +from trapdata.api.request_params import RequestParams +from trapdata.db import Base + + +def parse_react_admin_params(model: Type[Base]) -> Callable: + """Parses sort and range parameters coming from a react-admin request""" + + def inner( + sort_: Optional[str] = Query( + None, + alias="sort", + description='Format: `["field_name", "direction"]`', + example='["id", "ASC"]', + ), + range_: Optional[str] = Query( + None, + alias="range", + description="Format: `[start, end]`", + example="[0, 10]", + ), + ) -> RequestParams: + skip, limit = 0, 10 + if range_: + start, end = json.loads(range_) + skip, limit = start, (end - start + 1) + + order_by: UnaryExpression = desc(model.id) + if sort_: + sort_column, sort_order = json.loads(sort_) + if sort_order.lower() == "asc": + direction = asc + elif sort_order.lower() == "desc": + direction = desc + else: + raise HTTPException(400, f"Invalid sort direction {sort_order}") + order_by = direction(model.__table__.c[sort_column]) + + return RequestParams(skip=skip, limit=limit, order_by=order_by) + + return inner diff --git a/trapdata/api/factory.py b/trapdata/api/factory.py new file mode 100644 index 00000000..6f190f39 --- /dev/null +++ b/trapdata/api/factory.py @@ -0,0 +1,87 @@ +from fastapi import FastAPI +from fastapi.routing import APIRoute +from fastapi.staticfiles import StaticFiles +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse, RedirectResponse + +from trapdata.api.config import settings +from trapdata.api.views import api_router + + +def create_app(): + description = f"{settings.PROJECT_NAME} API" + app = FastAPI( + title=settings.PROJECT_NAME, + openapi_url=f"{settings.API_PATH}/openapi.json", + docs_url="/docs/", + description=description, + redoc_url="/redoc/", + ) + setup_routers(app) + setup_cors_middleware(app) + serve_static_app(app) + return app + + +def setup_routers(app: FastAPI) -> None: + app.include_router(api_router, prefix=settings.API_PATH) + # The following operation needs to be at the end of this function + use_route_names_as_operation_ids(app) + + +def serve_static_app(app): + app.mount( + "/static/crops", + StaticFiles(directory=settings.user_data_path / "crops"), + name="crops", + ) + app.mount( + "/static/captures", + StaticFiles(directory=settings.image_base_path), + name="captures", + ) + app.mount( + "/", + StaticFiles(directory="trapdata/webui/public"), + name="static", + ) + + @app.middleware("http") + async def _add_404_middleware(request: Request, call_next): + """Serves static assets on 404""" + response = await call_next(request) + path = request["path"] + if path.startswith(settings.API_PATH) or path.startswith("/docs"): + return response + if response.status_code == 404: + return FileResponse("trapdata/webui/public/index.html") + return response + + +def setup_cors_middleware(app): + if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + expose_headers=["Content-Range", "Range"], + allow_headers=["Authorization", "Range", "Content-Range"], + ) + + +def use_route_names_as_operation_ids(app: FastAPI) -> None: + """ + Simplify operation IDs so that generated API clients have simpler function + names. + + Should be called only after all routes have been added. + """ + route_names = set() + for route in app.routes: + if isinstance(route, APIRoute): + if route.name in route_names: + raise Exception("Route function names should be unique") + route.operation_id = route.name + route_names.add(route.name) diff --git a/trapdata/api/gunicorn_conf.py b/trapdata/api/gunicorn_conf.py new file mode 100644 index 00000000..8cabec15 --- /dev/null +++ b/trapdata/api/gunicorn_conf.py @@ -0,0 +1,14 @@ +# gunicorn_conf.py +from multiprocessing import cpu_count + +bind = "0.0.0.0:8000" + +# Worker Options +workers = cpu_count() + 1 +worker_class = 'uvicorn.workers.UvicornWorker' +timeout = 120 + +# Logging Options +loglevel = 'debug' +accesslog = '/home/debian/logs/access_log' +errorlog = '/home/debian/logs/error_log' diff --git a/trapdata/api/main.py b/trapdata/api/main.py new file mode 100644 index 00000000..1d1cd547 --- /dev/null +++ b/trapdata/api/main.py @@ -0,0 +1,20 @@ +from trapdata import logger +from trapdata.api.factory import create_app + +app = create_app() + + +def run(): + import uvicorn + + logger.info("Starting uvicorn in reload mode") + uvicorn.run( + "main:app", + host="0.0.0.0", + reload=True, + port=int("8000"), + ) + + +if __name__ == "__main__": + run() diff --git a/trapdata/api/request_params.py b/trapdata/api/request_params.py new file mode 100644 index 00000000..43d5af5b --- /dev/null +++ b/trapdata/api/request_params.py @@ -0,0 +1,9 @@ +from typing import Any + +from pydantic.main import BaseModel + + +class RequestParams(BaseModel): + skip: int + limit: int + order_by: Any diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py new file mode 100644 index 00000000..fa891a78 --- /dev/null +++ b/trapdata/api/views/__init__.py @@ -0,0 +1,21 @@ +from fastapi import APIRouter + +from trapdata.api.views import ( + deployments, + events, + occurrences, + settings, + species, + stats, + status, +) + +api_router = APIRouter() + +api_router.include_router(stats.router, tags=["stats"]) +api_router.include_router(status.router, tags=["status"]) +api_router.include_router(deployments.router, tags=["deployments"]) +api_router.include_router(events.router, tags=["events"]) +api_router.include_router(occurrences.router, tags=["occurrences"]) +api_router.include_router(species.router, tags=["species"]) +api_router.include_router(settings.router, tags=["settings"]) diff --git a/trapdata/api/views/deployments.py b/trapdata/api/views/deployments.py new file mode 100644 index 00000000..c212cec2 --- /dev/null +++ b/trapdata/api/views/deployments.py @@ -0,0 +1,41 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments +from trapdata.db.models.events import update_all_aggregates + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + update_all_aggregates(session, settings.image_base_path) + deployments = list_deployments(session) + return deployments + + +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/events.py b/trapdata/api/views/events.py new file mode 100644 index 00000000..9f491eb0 --- /dev/null +++ b/trapdata/api/views/events.py @@ -0,0 +1,58 @@ +from typing import Any, List + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import orm +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.db.models.events import ( + MonitoringSessionDetail, + MonitoringSessionListItem, + get_monitoring_session_by_id, + list_monitoring_sessions, +) + +router = APIRouter(prefix="/events") + + +@router.get("", response_model=List[MonitoringSessionListItem]) +async def get_monitoring_sessions( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), + limit: int = 100, + offset: int = 100, +) -> Any: + items = list_monitoring_sessions( + session, settings.image_base_path, media_url_base="/static/" + ) + return items + + +@router.get("/{event_id}", response_model=MonitoringSessionDetail) +async def get_monitoring_session( + event_id: int, + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + event = get_monitoring_session_by_id(session, event_id, media_url_base="/static/") + if not event: + raise HTTPException(404) + return event + + +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/items.py b/trapdata/api/views/items.py new file mode 100644 index 00000000..e679e857 --- /dev/null +++ b/trapdata/api/views/items.py @@ -0,0 +1,103 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.item import Item +from app.models.user import User +from app.schemas.item import Item as ItemSchema +from app.schemas.item import ItemCreate, ItemUpdate +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/items") + + +@router.get("", response_model=List[ItemSchema]) +async def get_items( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Item)), + user: User = Depends(current_user), +) -> Any: + total = await session.scalar( + select(func.count(Item.id).filter(Item.user_id == user.id)) + ) + items = ( + ( + await session.execute( + select(Item) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Item.user_id == user.id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.post("", response_model=ItemSchema, status_code=201) +async def create_item( + item_in: ItemCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item = Item(**item_in.dict()) + item.user_id = user.id + session.add(item) + await session.commit() + return item + + +@router.put("/{item_id}", response_model=ItemSchema) +async def update_item( + item_id: int, + item_in: ItemUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + update_data = item_in.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(item, field, value) + session.add(item) + await session.commit() + return item + + +@router.get("/{item_id}", response_model=ItemSchema) +async def get_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + return item + + +@router.delete("/{item_id}") +async def delete_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + await session.delete(item) + await session.commit() + return {"success": True} diff --git a/trapdata/api/views/occurrences.bak.py b/trapdata/api/views/occurrences.bak.py new file mode 100644 index 00000000..d91a252d --- /dev/null +++ b/trapdata/api/views/occurrences.bak.py @@ -0,0 +1,47 @@ +import random +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user + +# from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import test_data +from app.schemas.occurrence import Occurrence +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[Occurrence]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + occurrences = test_data[ + request_params.skip : request_params.skip + request_params.limit + ] + total = len(occurrences) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(occurrences)}/{total}" + return occurrences + + +@router.get("/{occurrence_id}", response_model=Occurrence) +async def get_occurrence( + occurrence_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + occurrence: Optional[Occurrence] = await session.get(Occurrence, occurrence_id) + if not occurrence or occurrence.user_id != user.id: + raise HTTPException(404) + return occurrence diff --git a/trapdata/api/views/occurrences.bak2.py b/trapdata/api/views/occurrences.bak2.py new file mode 100644 index 00000000..466c9ef2 --- /dev/null +++ b/trapdata/api/views/occurrences.bak2.py @@ -0,0 +1,58 @@ +from typing import Any, List, Optional + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import Occurrence as OccurrenceSchema +from app.schemas.occurrence import OccurrenceCreate, OccurrenceUpdate +from app.schemas.request_params import RequestParams +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[OccurrenceSchema]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + total = await session.scalar( + select( + func.count(Occurrence.id).filter( + Occurrence.deployment_id == request_params.deployment_id + ) + ) + ) + items = ( + ( + await session.execute( + select(Occurrence) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Occurrence.deployment_id == request_params.deployment_id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.get("/{item_id}", response_model=OccurrenceSchema) +async def get_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + return item diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py new file mode 100644 index 00000000..ce3698b8 --- /dev/null +++ b/trapdata/api/views/occurrences.py @@ -0,0 +1,67 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.occurrences import OccurrenceListItem, list_occurrences + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[OccurrenceListItem]) +async def get_occurrences( + response: Response, + limit: int = 20, + offset: int = 0, + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + occurrences = list_occurrences( + settings.database_url, + settings.image_base_path, + classification_threshold=settings.classification_threshold, + media_url_base="/static/", + limit=limit, + offset=offset, + ) + return occurrences + + +@router.get("/{item_id}", response_model=List[OccurrenceListItem]) +async def get_occurrence( + item_id: int, + response: Response, + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + """ + @TODO placeholder! replace this with an actual get single occurrence method. + """ + occurrences = list_occurrences( + settings.database_url, + settings.image_base_path, + classification_threshold=settings.classification_threshold, + media_url_base="/static/", + limit=1, + offset=item_id, + ) + return occurrences + + +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/settings.py b/trapdata/api/views/settings.py new file mode 100644 index 00000000..8f02ec86 --- /dev/null +++ b/trapdata/api/views/settings.py @@ -0,0 +1,22 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments +from trapdata.settings import UserSettings + +router = APIRouter(prefix="/settings") + + +@router.get("", response_model=UserSettings) +async def get_settings( + response: Response, +) -> Any: + return settings diff --git a/trapdata/api/views/species.py b/trapdata/api/views/species.py new file mode 100644 index 00000000..ac163807 --- /dev/null +++ b/trapdata/api/views/species.py @@ -0,0 +1,33 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.detections import TaxonListItem, list_species + +router = APIRouter(prefix="/species") + + +@router.get("", response_model=List[TaxonListItem]) +async def get_species( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), + limit: int = 20, + offset: int = 0, +) -> Any: + species = list_species( + session=session, + image_base_path=settings.image_base_path, + classification_threshold=settings.classification_threshold, + media_url_base="/static/", + limit=limit, + offset=offset, + ) + return species diff --git a/trapdata/api/views/stats.py b/trapdata/api/views/stats.py new file mode 100644 index 00000000..c3e4d43e --- /dev/null +++ b/trapdata/api/views/stats.py @@ -0,0 +1,22 @@ +from typing import Any + +from fastapi import APIRouter + +router = APIRouter(prefix="/stats") + + +from pydantic import BaseModel + + +class Msg(BaseModel): + msg: str + + +@router.get( + "/", + response_model=Msg, + status_code=200, + include_in_schema=False, +) +def test_hello_world() -> Any: + return {"msg": "Hello world!"} diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py new file mode 100644 index 00000000..d88dab1c --- /dev/null +++ b/trapdata/api/views/status.py @@ -0,0 +1,60 @@ +import datetime +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends +from pydantic import BaseModel +from sqlalchemy import orm +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.db.models.deployments import list_deployments +from trapdata.db.models.detections import num_species_for_deployment +from trapdata.db.models.events import list_monitoring_sessions +from trapdata.db.models.queue import QueueListItem, list_queues + +router = APIRouter(prefix="/status") + + +@router.get("/queues", response_model=List[QueueListItem]) +async def get_queues( + response: Response, +) -> Any: + queues = list_queues(settings.database_url, settings.image_base_path) + return queues + + +class SummaryCounts(BaseModel): + num_deployments: Optional[int] = 0 + num_captures: Optional[int] = 0 + num_sessions: Optional[int] = 0 + num_detections: Optional[int] = 0 + num_occurrences: Optional[int] = 0 + num_species: Optional[int] = 0 + last_updated: Optional[datetime.datetime] = None + + +@router.get("/summary", response_model=SummaryCounts) +async def get_summary_counts( + response: Response, + session: orm.Session = Depends(get_session), +) -> Any: + deployments = list_deployments(session) + # events = [] + # for deployment in deployments: + # events += list_monitoring_sessions(session, deployment.image_base_path) + events = list_monitoring_sessions(session, settings.image_base_path) + + summary = SummaryCounts( + num_deployments=len(deployments), + num_sessions=len(events), + num_captures=sum(e.num_captures for e in events), + num_detections=sum(e.num_detections for e in events), + num_occurrences=sum(e.num_occurrences for e in events), + num_species=num_species_for_deployment( + session, image_base_path=settings.image_base_path + ), + last_updated=datetime.datetime.now(), + ) + + return summary diff --git a/trapdata/api/views/users.py b/trapdata/api/views/users.py new file mode 100644 index 00000000..06ca854e --- /dev/null +++ b/trapdata/api/views/users.py @@ -0,0 +1,30 @@ +from typing import Any, List + +from fastapi.params import Depends +from fastapi.routing import APIRouter +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.users import current_superuser +from app.models.user import User +from app.schemas.user import UserRead + +router = APIRouter() + + +@router.get("/users", response_model=List[UserRead]) +async def get_users( + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_superuser), + skip: int = 0, + limit: int = 100, +) -> Any: + total = await session.scalar(select(func.count(User.id))) + users = ( + (await session.execute(select(User).offset(skip).limit(limit))).scalars().all() + ) + response.headers["Content-Range"] = f"{skip}-{skip + len(users)}/{total}" + return users diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index b124b86f..771f5139 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -30,6 +30,16 @@ def gui(): run() +@cli.command() +def api(): + """ + Launch API server + """ + from trapdata.api.main import run as start_api + + start_api() + + @cli.command("import") def import_data(image_base_path: Optional[pathlib.Path] = None, queue: bool = True): """ diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index fdc07322..789b2f1b 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -92,6 +92,7 @@ def occurrences( for event in events: occurrences += list_occurrences( settings.database_url, + settings.image_base_path, monitoring_session=event, classification_threshold=settings.classification_threshold, num_examples=num_examples, @@ -188,7 +189,7 @@ def sessions( @cli.command() def captures( - date: datetime.datetime, + date: Optional[datetime.datetime] = None, format: ExportFormat = ExportFormat.json, outfile: Optional[pathlib.Path] = None, ) -> Optional[str]: @@ -199,16 +200,28 @@ def captures( """ Session = get_session_class(settings.database_url) session = Session() + if date is not None: + event_dates = [date.date()] + else: + event_dates = [ + event.day + for event in get_monitoring_sessions_from_db( + db_path=settings.database_url, base_directory=settings.image_base_path + ) + ] events = get_monitoring_session_by_date( db_path=settings.database_url, base_directory=settings.image_base_path, - event_dates=[str(date.date())], + event_dates=event_dates, ) - if not len(events): + if date and not len(events): raise Exception(f"No Monitoring Event with date: {date.date()}") - event = events[0] - captures = get_monitoring_session_images(settings.database_url, event, limit=100) + captures = [] + for event in events: + captures += get_monitoring_session_images( + settings.database_url, event, limit=100 + ) [session.add(img) for img in captures] df = pd.DataFrame([img.report_detail().dict() for img in captures]) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index 3d69efea..947a6be1 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -15,14 +15,9 @@ from trapdata.db import models from trapdata.db.base import get_session_class from trapdata.db.models.deployments import list_deployments -from trapdata.db.models.detections import ( - get_detected_objects, - num_occurrences_for_event, - num_species_for_event, -) +from trapdata.db.models.detections import get_detected_objects from trapdata.db.models.events import ( get_monitoring_session_by_date, - get_monitoring_sessions_from_db, update_all_aggregates, ) from trapdata.db.models.occurrences import list_occurrences, list_species @@ -79,6 +74,8 @@ def deployments(): update_all_aggregates(session, settings.image_base_path) deployments = list_deployments(session) table = Table( + "ID", + "Name", "Image Base Path", "Sessions", "Images", @@ -134,39 +131,26 @@ def sessions(): """ Show all monitoring events that have been interpreted from image timestamps. """ + from trapdata.db.models.events import list_monitoring_sessions + Session = get_session_class(settings.database_url) session = Session() # image_base_path = str(settings.image_base_path.resolve()) - update_all_aggregates(session, settings.image_base_path) - logger.info(f"Show monitoring events for images in {settings.image_base_path}") - events = ( - session.execute( - select(models.MonitoringSession).where( - models.MonitoringSession.base_directory == str(settings.image_base_path) - ) - ) - .unique() - .scalars() - .all() - ) + events = list_monitoring_sessions(session, settings.image_base_path) - table = Table("ID", "Day", "Images", "Detections", "Occurrences", "Species") + table = Table( + "ID", "Day", "Duration", "Captures", "Detections", "Occurrences", "Species" + ) for event in events: - event.update_aggregates(session) - num_occurrences = num_occurrences_for_event( - db_path=settings.database_url, monitoring_session=event - ) - num_species = num_species_for_event( - db_path=settings.database_url, monitoring_session=event - ) row_values = [ event.id, event.day, - event.num_images, - event.num_detected_objects, - num_occurrences, - num_species, + event.duration_label, + event.num_captures, + event.num_detections, + event.num_occurrences, + event.num_species, ] table.add_row(*[str(val) for val in row_values]) console.print(table) @@ -216,26 +200,37 @@ def detections( @cli.command() -def occurrences(limit: Optional[int] = 100, offset: int = 0): - events = get_monitoring_sessions_from_db( - db_path=settings.database_url, base_directory=settings.image_base_path - ) - occurrences: list[models.occurrences.Occurrence] = [] - for event in events: - occurrences += list_occurrences( - settings.database_url, - event, - classification_threshold=settings.classification_threshold, - limit=limit, - offset=offset, +def occurrences( + session_day: Optional[datetime.datetime] = None, + limit: Optional[int] = 100, + offset: int = 0, +): + event = None + if session_day: + events = get_monitoring_session_by_date( + db_path=settings.database_url, event_dates=[session_day] ) + if not events: + logger.info(f"No events found for {session_day}") + return [] + else: + event = events[0] + + occurrences = list_occurrences( + settings.database_url, + settings.image_base_path, + event, + classification_threshold=settings.classification_threshold, + limit=limit, + offset=offset, + ) table = Table( "Session", "ID", "Label", "Detections", "Score", "Appearance", "Duration" ) for occurrence in occurrences: table.add_row( - occurrence.event, + str(occurrence.event.day), occurrence.id, occurrence.label, str(occurrence.num_frames), diff --git a/trapdata/cli/test.py b/trapdata/cli/test.py index d63f2841..2e5c96b7 100644 --- a/trapdata/cli/test.py +++ b/trapdata/cli/test.py @@ -5,6 +5,7 @@ import typer from rich import print from sqlalchemy import select + from trapdata.cli import settings from trapdata.db.base import check_db, get_session_class from trapdata.db.models import MonitoringSession @@ -54,6 +55,7 @@ def species_by_track(event_day: datetime.datetime): print(f"Matched of event: {event}") get_unique_species_by_track( settings.database_url, + image_base_path=event.base_directory, monitoring_session=event, classification_threshold=0.1, ) diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index dd81224f..26a6a017 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -20,6 +20,7 @@ from . import constants from .logs import logger +from .schemas import FilePath APP_NAME_SLUG = "AMI" EXIF_DATETIME_STR_FORMAT = "%Y:%m:%d %H:%M:%S" @@ -72,7 +73,7 @@ def find_timestamped_folders(path): def _preprocess(name): return name.replace("_", "-") - dirs = sorted(list(pathlib.Path(path).iterdir())) + dirs = sorted(pathlib.Path(path).iterdir()) for d in dirs: # @TODO use yield? try: @@ -286,7 +287,7 @@ def find_images( [f.lstrip(".") for f in constants.SUPPORTED_IMAGE_EXTENSIONS] ) pattern = rf"\.({extensions_list})$" - for walk_path, dirs, files in os.walk(base_directory): + for walk_path, _dirs, files in os.walk(base_directory): for name in files: if re.search(pattern, name, re.IGNORECASE): relative_path = pathlib.Path(walk_path) / name @@ -485,7 +486,7 @@ def get_app_dir(app_name: Optional[str] = None) -> pathlib.Path: data_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")) data_dir = data_dir.expanduser().resolve() / app_name if not data_dir.exists(): - data_dir.mkdir() + data_dir.mkdir(parents=True) return data_dir @@ -504,3 +505,16 @@ def initial_directory_choice(): than "." which is the directory of this python package. """ return pathlib.Path("~/") + + +def media_url( + local_path: FilePath, delim: str, media_url_base: Optional[str] = None +) -> str: + """ + Given a local path to a file, return a URL to that file. @TODO rework this and handle slashes better. + """ + relative_path = f"{delim}{local_path.split(delim)[-1]}" + if media_url_base: + return os.path.join(media_url_base, relative_path) + else: + return relative_path diff --git a/trapdata/common/logs.py b/trapdata/common/logs.py index e0c2f7cb..cf4b4e7e 100644 --- a/trapdata/common/logs.py +++ b/trapdata/common/logs.py @@ -3,7 +3,7 @@ import structlog structlog.configure( - wrapper_class=structlog.make_filtering_bound_logger(logging.INFO), + wrapper_class=structlog.make_filtering_bound_logger(logging.DEBUG), ) diff --git a/trapdata/db/__init__.py b/trapdata/db/__init__.py index a5493f58..5c390e50 100644 --- a/trapdata/db/__init__.py +++ b/trapdata/db/__init__.py @@ -1,3 +1,5 @@ +from typing import Any + import sqlalchemy as sa from sqlalchemy import orm @@ -16,4 +18,4 @@ class Base(orm.DeclarativeBase): - pass + id: Any diff --git a/trapdata/db/base.py b/trapdata/db/base.py index 58a3c8a6..f0fa9dc1 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -13,12 +13,14 @@ from trapdata import logger from trapdata.common.schemas import DatabaseURL +DATABASE_SCHEMA_NAMESPACE = "trapdata" + DIALECT_CONNECTION_ARGS = { "sqlite": { "timeout": 10, # A longer timeout is necessary for SQLite and multiple PyTorch workers "check_same_thread": False, }, - "postgresql": {}, + "postgresql": {"options": f"-csearch_path={DATABASE_SCHEMA_NAMESPACE}"}, } SUPPORTED_DIALECTS = list(DIALECT_CONNECTION_ARGS.keys()) @@ -72,6 +74,11 @@ def create_db(db_path: DatabaseURL) -> None: from . import Base + if db.dialect.name != "sqlite": + with db.connect() as con: + if not db.dialect.has_schema(con, DATABASE_SCHEMA_NAMESPACE): + con.execute(sqlalchemy.schema.CreateSchema(DATABASE_SCHEMA_NAMESPACE)) + Base.metadata.schema = DATABASE_SCHEMA_NAMESPACE Base.metadata.create_all(db, checkfirst=True) alembic_cfg = get_alembic_config(db_path) alembic.stamp(alembic_cfg, "head") @@ -232,3 +239,19 @@ def get_or_create(session, model, defaults=None, **kwargs): return instance, False else: return instance, True + + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + + +def get_async_session_class(db_path: str) -> async_sessionmaker[AsyncSession]: + async_engine = create_async_engine(db_path, pool_pre_ping=True) + + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, + ) + return async_session_maker diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index add888f6..b394cd55 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -6,6 +6,7 @@ is used as the deployment name. """ import pathlib +from typing import Optional import sqlalchemy as sa from pydantic import BaseModel @@ -16,8 +17,9 @@ class DeploymentListItem(BaseModel): - # id: int + id: Optional[int] = None name: str + image_base_path: FilePath num_events: int num_source_images: int num_detections: int @@ -25,6 +27,10 @@ class DeploymentListItem(BaseModel): # num_species: int +class DeploymentDetail(DeploymentListItem): + pass + + def deployment_name(image_base_path: FilePath) -> str: """ Use the directory name of an absolute file path as the deployment name. @@ -38,17 +44,29 @@ def list_deployments(session: orm.Session) -> list[DeploymentListItem]: A proxy for "registered trap deployments". """ stmt = sa.select( - models.MonitoringSession.base_directory.label("name"), - sa.func.count(models.MonitoringSession.id).label("num_events"), + models.MonitoringSession.base_directory.label("image_base_path"), sa.func.sum(models.MonitoringSession.num_images).label("num_source_images"), sa.func.sum(models.MonitoringSession.num_detected_objects).label( "num_detections" ), ).group_by(models.MonitoringSession.base_directory) - deployments = [ - DeploymentListItem(**d._mapping) for d in session.execute(stmt).all() - ] - for deployment in deployments: - deployment.name = deployment_name(deployment.name) + deployments = [] + for deployment in session.execute(stmt).all(): + num_events = ( + session.scalar( + sa.select(sa.func.count(models.MonitoringSession.id)).where( + models.MonitoringSession.base_directory + == str(deployment.image_base_path) + ) + ) + or 0 + ) + deployments.append( + DeploymentListItem( + **deployment._mapping, + num_events=num_events, + name=deployment_name(deployment.image_base_path), + ) + ) return deployments diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 9ed55330..be44feb2 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -6,9 +6,15 @@ import sqlalchemy as sa from pydantic import BaseModel from sqlalchemy import orm +from sqlalchemy.ext.hybrid import hybrid_property from trapdata import constants, db -from trapdata.common.filemanagement import absolute_path, construct_exif, save_image +from trapdata.common.filemanagement import ( + absolute_path, + construct_exif, + media_url, + save_image, +) from trapdata.common.logs import logger from trapdata.common.schemas import FilePath from trapdata.common.utils import bbox_area, bbox_center, export_report @@ -17,16 +23,18 @@ class DetectionListItem(BaseModel): - id: int - cropped_image_path: Optional[pathlib.Path] - bbox: Optional[tuple[float, float, float, float]] - area_pixels: Optional[float] - last_detected: Optional[datetime.datetime] - label: Optional[str] - score: Optional[int] - model_name: Optional[str] - in_queue: bool - notes: Optional[str] + id: Optional[int] = None + cropped_image_path: Optional[FilePath] = None + bbox: Optional[tuple[float, float, float, float]] = None + area_pixels: Optional[float] = None + width: Optional[int] = None + height: Optional[int] = None + last_detected: Optional[datetime.datetime] = None + label: Optional[str] = None + score: Optional[int] = None + model_name: Optional[str] = None + in_queue: bool = False + notes: Optional[str] = "" class DetectionDetail(DetectionListItem): @@ -153,11 +161,15 @@ def save_cropped_image_data( self.path = str(fpath) return fpath - def width(self): - pass # Use bbox + @hybrid_property + def width(self) -> int: + x1, y1, x2, y2 = self.bbox + return x2 - x1 - def height(self): - pass # Use bbox + @hybrid_property + def height(self) -> int: + x1, y1, x2, y2 = self.bbox + return y2 - y1 def previous_frame_detections( self, session: orm.Session @@ -507,9 +519,7 @@ def get_species_for_image(db_path, image_id): def num_species_for_event( db_path, monitoring_session, classification_threshold: float = 0.6 ) -> int: - query = sa.select( - sa.func.count(DetectedObject.specific_label.distinct()), - ).where( + query = sa.select(sa.func.count(DetectedObject.specific_label.distinct()),).where( (DetectedObject.specific_label_score >= classification_threshold) & (DetectedObject.monitoring_session == monitoring_session) ) @@ -521,9 +531,7 @@ def num_species_for_event( def num_occurrences_for_event( db_path, monitoring_session, classification_threshold: float = 0.6 ) -> int: - query = sa.select( - sa.func.count(DetectedObject.sequence_id.distinct()), - ).where( + query = sa.select(sa.func.count(DetectedObject.sequence_id.distinct()),).where( (DetectedObject.specific_label_score >= classification_threshold) & (DetectedObject.monitoring_session == monitoring_session) ) @@ -532,6 +540,146 @@ def num_occurrences_for_event( return sesh.execute(query).scalar_one() +class TaxonOccurrenceListItem(BaseModel): + id: str + cropped_image_path: Optional[FilePath] = None + + +class TaxonListItem(BaseModel): + name: str + genus: Optional[str] = None + family: Optional[str] = None + num_occurrences: Optional[int] = None + num_detections: Optional[int] = None + examples: list[TaxonOccurrenceListItem] = [] + score_stats: Optional[dict[str, float]] = None + training_examples: Optional[int] = None + + +def list_species( + session: orm.Session, + image_base_path: FilePath, + classification_threshold: int = 0, + num_examples: int = 10, + media_url_base: Optional[str] = None, + limit: Optional[int] = None, + offset: int = 0, +) -> list[TaxonListItem]: + """ + Return a list of unique species and example detections. + + @TODO compare this with list_species in occurrences.py + @TODO prefetch related and speed this up + """ + species = session.execute( + sa.select( + DetectedObject.specific_label.label("name"), + sa.func.min(DetectedObject.sequence_id).label("sequence_id"), + sa.func.count(DetectedObject.id).label("num_detections"), + sa.func.count(DetectedObject.sequence_id.distinct()).label( + "num_occurrences" + ), # @TODO handle sequences with None + sa.func.max(DetectedObject.specific_label_score).label("score_max"), + sa.func.min(DetectedObject.specific_label_score).label("score_min"), + sa.func.avg(DetectedObject.specific_label_score).label("score_mean"), + ) + .where( + (models.TrapImage.base_path == str(image_base_path)) + & (models.DetectedObject.specific_label_score >= classification_threshold) + ) + .join(models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id) + .group_by(DetectedObject.specific_label) + .limit(limit) + .offset(offset) + ).all() + + # examples = ( + # session.execute( + # sa.select(DetectedObject) + # .where(DetectedObject.specific_label.in_(sp.name for sp in species])) # @TODO not working! + # .where(models.TrapImage.base_path == str(image_base_path)) + # .where(DetectedObject.specific_label_score >= classification_threshold) + # .join( + # models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id + # ) + # .limit(num_examples) + # .order_by(DetectedObject.specific_label_score.desc()) + # ) + # .unique() + # .scalars() + # .all() + # ) + + metadata_by_name = {} + examples_by_name = {} + for sp in species: + metadata_by_name[sp.name] = sp + # matching_examples = [ex for ex in examples if ex.specific_label == sp.name] + matching_example_ids = session.execute( + sa.select( + DetectedObject.sequence_id, sa.func.min(DetectedObject.id).label("id") + ) + .where(DetectedObject.specific_label == sp.name) + .where(models.TrapImage.base_path == str(image_base_path)) + .where(DetectedObject.specific_label_score >= classification_threshold) + .group_by(DetectedObject.sequence_id) + .join( + models.TrapImage, + models.DetectedObject.image_id == models.TrapImage.id, + ) + ).all() + matching_example_ids = [row.id for row in matching_example_ids] + # .group_by(DetectedObject.sequence_id, DetectedObject.path, DetectedObject.specific_label_score, models.TrapImage.base_path) + matching_examples = session.execute( + sa.select(DetectedObject.sequence_id, DetectedObject.path) + .where(DetectedObject.id.in_(matching_example_ids)) + .limit(num_examples) + ).all() + examples_by_name[sp.name] = matching_examples + print(sp.name, len(matching_examples)) + + taxa = [ + TaxonListItem( + name=name, + num_occurrences=metadata_by_name[name].num_occurrences, + num_detections=metadata_by_name[name].num_detections, + score_stats={ + "max": metadata_by_name[name].score_max, + "min": metadata_by_name[name].score_min, + "mean": metadata_by_name[name].score_mean, + }, + examples=[ + TaxonOccurrenceListItem( + id=detection.sequence_id, + cropped_image_path=media_url( + detection.path, + "crops", + media_url_base=media_url_base, + ), + ) + for detection in examples + ], + ) + for name, examples in examples_by_name.items() + ] + return taxa + + +def num_species_for_deployment(session: orm.Session, image_base_path: FilePath) -> int: + return ( + session.execute( + sa.select(sa.func.count(models.DetectedObject.specific_label.distinct())) + .join( + models.MonitoringSession, + models.MonitoringSession.id + == models.DetectedObject.monitoring_session_id, + ) + .where(models.MonitoringSession.base_directory == str(image_base_path)) + ).scalar() + or 0 + ) + + def get_unique_species( db_path, monitoring_session=None, classification_threshold: float = -1 ): diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index e146cc53..06845ff5 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -5,20 +5,51 @@ import sqlalchemy as sa from pydantic import BaseModel from sqlalchemy import orm +from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy_utils import aggregated -from trapdata.common.filemanagement import find_images, group_images_by_day +from trapdata.common.filemanagement import find_images, group_images_by_day, media_url from trapdata.common.logs import logger from trapdata.common.schemas import FilePath from trapdata.common.utils import export_report from trapdata.db import Base, get_session, models - +from trapdata.db.models.deployments import deployment_name # @TODO Rename to TrapEvent? CapturePeriod? less confusing with other types of Sessions. CaptureSession? Or SurveyEvent or Survey? -class Event(BaseModel): + + +class MonitoringSessionNestedCapture(BaseModel): id: str - frames: list[dict] - example_frames: list[dict] + path: pathlib.Path + timestamp: datetime.datetime + # example["source_image_path"] = media_url( + # example["source_image_path"], "captures/", media_url_base + # ) + + +class MonitoringSessionListItem(BaseModel): + id: str + day: datetime.date + # image_base_path: str + deployment: str + num_captures: int + num_detections: int + num_occurrences: int + num_species: int + example_captures: list[MonitoringSessionNestedCapture] + start_time: datetime.datetime + end_time: datetime.datetime + duration: datetime.timedelta + duration_label: str + + +class MonitoringSessionDetail(MonitoringSessionListItem): + notes: Optional[str] + captures: list[ + MonitoringSessionNestedCapture + ] # Too many! @TODO include summary data to generate the timeline instead + # @TODO add more info about the session, like the number of images, the number of detected objects, etc + # @TODO add the number of species detected in this session class MonitoringSession(Base): @@ -41,6 +72,26 @@ def num_images(self): def num_detected_objects(self): return sa.func.count("1") + def num_occurrences(self, session: orm.Session) -> int: + return ( + session.execute( + sa.select( + sa.func.count(models.DetectedObject.sequence_id.distinct()) + ).where(models.DetectedObject.monitoring_session_id == self.id) + ).scalar() + or 0 + ) + + def num_species(self, session: orm.Session) -> int: + return ( + session.execute( + sa.select( + sa.func.count(models.DetectedObject.specific_label.distinct()) + ).where(models.DetectedObject.monitoring_session_id == self.id) + ).scalar() + or 0 + ) + # This runs an expensive/slow query every time an image is updated # @observes("images") # def image_observer(self, images): @@ -101,11 +152,12 @@ def update_aggregates(self, session: orm.Session, commit=True): if commit: session.commit() - def duration(self) -> Optional[datetime.timedelta]: + @hybrid_method + def duration(self) -> datetime.timedelta: if self.start_time and self.end_time: return self.end_time - self.start_time else: - return None + return datetime.timedelta(0) @property def duration_label(self): @@ -344,3 +396,139 @@ def export_monitoring_sessions( ): records = [item.report_data() for item in items] return export_report(records, report_name, directory) + + +def event_response( + session: orm.Session, + event: MonitoringSession, +) -> MonitoringSessionListItem: + """ + Reusable method to create a MonitoringSession Schema from a MonitoringSession model. + + @TODO decide if this is helpful or not to reuse in get_monitoring_sessions and get_monitoring_session_by_id + """ + + event.update_aggregates(session) + event_response = MonitoringSessionListItem( + id=event.id, + day=event.day, + deployment=deployment_name(str(event.base_directory)), + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=[], + duration=event.duration(), + duration_label=event.duration_label, + ) + + return event_response + + +def get_monitoring_session_by_id( + session: orm.Session, + event_id: int, + media_url_base: str, +) -> Optional[MonitoringSessionDetail]: + event: Optional[MonitoringSession] = session.get(MonitoringSession, event_id) + if event: + event.update_aggregates(session) + captures = session.execute( + sa.select( + models.TrapImage.id, models.TrapImage.path, models.TrapImage.timestamp + ) + .where(models.TrapImage.monitoring_session_id == event.id) + .order_by(models.TrapImage.timestamp) + ).all() + nested_captures = [ + MonitoringSessionNestedCapture( + id=row.id, + path=media_url(row.path, "captures/", media_url_base), + timestamp=row.timestamp, + ) + for row in captures + ] + event_detail = MonitoringSessionDetail( + id=event.id, + day=event.day, + deployment=deployment_name(str(event.base_directory)), + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=[], + duration=event.duration(), + duration_label=event.duration_label, + notes=event.notes, + captures=nested_captures, + ) + return event_detail + else: + return None + + +def list_monitoring_sessions( + session: orm.Session, + image_base_path: FilePath, + limit: Optional[int] = None, + offset: int = 0, + num_examples: int = 5, + media_url_base: Optional[str] = None, +) -> list[MonitoringSessionListItem]: + """ """ + + update_all_aggregates(session, image_base_path) + logger.info(f"Fetching monitoring events for images in {image_base_path}") + events = ( + session.execute( + sa.select(models.MonitoringSession) + .where(models.MonitoringSession.base_directory == str(image_base_path)) + .order_by(models.MonitoringSession.day) + .limit(limit) + .offset(offset) + ) + .unique() + .scalars() + .all() + ) + + list_items = [] + for event in events: + event.update_aggregates(session) + rows = session.execute( + sa.select( + models.TrapImage.id, models.TrapImage.path, models.TrapImage.timestamp + ) + .where(models.TrapImage.monitoring_session_id == event.id) + .order_by(models.TrapImage.filesize.desc()) + .limit(num_examples) + ).all() + example_captures = [ + MonitoringSessionNestedCapture( + id=row.id, + path=media_url(row.path, "captures/", media_url_base), + timestamp=row.timestamp, + ) + for row in rows + ] + list_items.append( + MonitoringSessionListItem( + id=event.id, + day=event.day, + deployment=deployment_name(str(event.base_directory)), + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=example_captures, + duration=event.duration(), + duration_label=event.duration_label, + ) + ) + return list_items diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index 87eb1ac9..457f0bca 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -19,21 +19,26 @@ class CaptureListItem(BaseModel): id: int timestamp: datetime.datetime - source_image: str - last_read: Optional[datetime.datetime] - last_processed: Optional[datetime.datetime] + path: pathlib.Path num_detections: Optional[int] in_queue: bool + url: Optional[str] = None class CaptureDetail(CaptureListItem): - id: int event: object + url: Optional[str] = None + event: object + deployment: str notes: Optional[str] detections: list filesize: int width: int height: int + last_read: Optional[datetime.datetime] + last_processed: Optional[datetime.datetime] + next_capture: Optional[CaptureListItem] + prev_capture: Optional[CaptureListItem] class TrapImage(Base): @@ -121,17 +126,19 @@ def report_data(self) -> CaptureListItem: return CaptureListItem( id=self.id, source_image=f"{constants.IMAGE_BASE_URL}vermont/snapshots/{self.path}", + path=self.path, timestamp=self.timestamp, last_read=self.last_read, last_processed=self.last_processed, in_queue=self.in_queue, num_detections=self.num_detected_objects, + event=self.monitoring_session.day, + deployment=self.monitoring_session.deployment, ) def report_detail(self) -> CaptureDetail: return CaptureDetail( **self.report_data().dict(), - event=self.monitoring_session.day, width=self.width, height=self.height, filesize=self.filesize, diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 2550ccfe..fc8bb072 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -14,10 +14,23 @@ from pydantic import BaseModel from trapdata import db +from trapdata.common.filemanagement import media_url +from trapdata.common.schemas import FilePath from trapdata.db import models -class Occurrence(BaseModel): +class OccurrenceNestedEvent(BaseModel): + id: int + day: datetime.date + url: Optional[str] = None + + +class OccurrenceNestedDetection(BaseModel): + id: int + cropped_image_path: str + + +class OccurrenceListItem(BaseModel): id: str label: str best_score: float @@ -25,14 +38,15 @@ class Occurrence(BaseModel): end_time: datetime.datetime duration: datetime.timedelta deployment: str - event: str + event: OccurrenceNestedEvent num_frames: int # cropped_image_path: pathlib.Path # source_image_id: int examples: list[dict] # detections: list[object] # deployment: object - # captures: list[object] + # captures: list[object] = + url: Optional[str] = None class SpeciesSummaryListItem(BaseModel): @@ -43,16 +57,19 @@ class SpeciesSummaryListItem(BaseModel): def list_occurrences( db_path: str, - monitoring_session: models.MonitoringSession, + image_base_path: FilePath, + monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, - num_examples: int = 3, + num_examples: int = 10, limit: Optional[int] = None, offset: int = 0, -) -> list[Occurrence]: + media_url_base: Optional[str] = None, +) -> list[OccurrenceListItem]: occurrences = [] for item in get_unique_species_by_track( db_path, - monitoring_session, + monitoring_session=monitoring_session, + image_base_path=image_base_path, classification_threshold=classification_threshold, num_examples=num_examples, limit=limit, @@ -60,14 +77,53 @@ def list_occurrences( ): prepped = {k.split("sequence_", 1)[-1]: v for k, v in item.items()} if prepped["id"]: - prepped["id"] = sequence_display_name(prepped["id"]) - prepped["event"] = monitoring_session.day.isoformat() - prepped["deployment"] = monitoring_session.deployment - occur = Occurrence(**prepped) + prepped["deployment"] = models.deployments.deployment_name( + item["monitoring_session_base_directory"] + ) + if media_url_base: + examples = [dict(example) for example in prepped["examples"]] + # @TODO use OccurrenceNestedDetection + for example in examples: + example["cropped_image_path"] = media_url( + example["cropped_image_path"], "crops", media_url_base + ) + example["source_image_path"] = media_url( + example["source_image_path"], "captures/", media_url_base + ) + prepped["examples"] = examples + + prepped["event"] = OccurrenceNestedEvent( + id=item["monitoring_session_id"], day=item["monitoring_session_day"] + ) + occur = OccurrenceListItem(**prepped) occurrences.append(occur) return occurrences +def get_valid_sequence_ids( + monitoring_session: Optional[models.MonitoringSession] = None, + confidence_threshold: float = 0, +) -> sa.ScalarSelect: + """ + Sequence IDs that have a detection with a score above the confidence threshold. + + Intended to be used as a subquery in a larger query. + """ + stmt = sa.select( + models.DetectedObject.sequence_id.distinct().label("id"), + ).where(models.DetectedObject.specific_label_score >= confidence_threshold) + if monitoring_session: + stmt = stmt.where( + models.DetectedObject.monitoring_session_id == monitoring_session.id + ) + stmt = ( + stmt.group_by(models.DetectedObject.sequence_id) + .order_by(models.DetectedObject.sequence_id) + .scalar_subquery() + ) + return stmt + + def list_species( db_path: str, image_base_path: pathlib.Path, @@ -107,7 +163,8 @@ def list_species( def get_unique_species_by_track( db_path: str, - monitoring_session=None, + image_base_path: FilePath, + monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, num_examples: int = 3, limit: Optional[int] = None, @@ -117,9 +174,14 @@ def get_unique_species_by_track( session = Session() # Select all sequences where at least one example is above the score threshold - sequences = session.execute( + stmt = ( sa.select( models.DetectedObject.sequence_id, + models.DetectedObject.monitoring_session_id, + models.MonitoringSession.day.label("monitoring_session_day"), + models.MonitoringSession.base_directory.label( + "monitoring_session_base_directory" + ), sa.func.count(models.DetectedObject.id).label( "sequence_frame_count" ), # frames in track @@ -129,8 +191,22 @@ def get_unique_species_by_track( sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), ) - .group_by("sequence_id") - .where((models.DetectedObject.monitoring_session_id == monitoring_session.id)) + .join( + models.MonitoringSession, + models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + ) + .where(models.MonitoringSession.base_directory == str(image_base_path)) + ) + if monitoring_session: + stmt = stmt.where(models.MonitoringSession.id == monitoring_session.id) + + stmt = ( + stmt.group_by( + "sequence_id", + "monitoring_session_id", + "monitoring_session_day", + "monitoring_session_base_directory", + ) .having( sa.func.max(models.DetectedObject.specific_label_score) >= classification_threshold, @@ -138,32 +214,40 @@ def get_unique_species_by_track( .order_by("sequence_id") .limit(limit) .offset(offset) - ).all() + ) + sequences = session.execute(stmt).all() rows = [] for sequence in sequences: - frames = session.execute( + stmt = ( sa.select( models.DetectedObject.id, models.DetectedObject.image_id.label("source_image_id"), models.TrapImage.path.label("source_image_path"), + models.TrapImage.width.label("source_image_width"), + models.TrapImage.height.label("source_image_height"), + models.TrapImage.filesize.label("source_image_filesize"), models.DetectedObject.specific_label.label("label"), models.DetectedObject.specific_label_score.label("score"), models.DetectedObject.path.label("cropped_image_path"), models.DetectedObject.sequence_id, models.DetectedObject.timestamp, - ) - .where( - (models.DetectedObject.monitoring_session_id == monitoring_session.id) - & (models.DetectedObject.sequence_id == sequence.sequence_id) + models.DetectedObject.bbox, ) .join( models.TrapImage, models.TrapImage.id == models.DetectedObject.image_id ) - # .order_by(sa.func.random()) - .order_by(sa.desc("score")) - .limit(num_examples) - ).all() + .where(models.DetectedObject.sequence_id == sequence.sequence_id) + ) + + if monitoring_session: + stmt = stmt.where( + models.DetectedObject.monitoring_session_id == monitoring_session.id + ) + stmt = stmt.order_by(sa.desc("score")).limit(num_examples) + + frames = session.execute(stmt).all() + row = dict(sequence._mapping) if frames: best_example = frames[0] diff --git a/trapdata/db/models/queue.py b/trapdata/db/models/queue.py index b80810fb..6b0300a4 100644 --- a/trapdata/db/models/queue.py +++ b/trapdata/db/models/queue.py @@ -1,10 +1,12 @@ +import pathlib from collections import OrderedDict from typing import Sequence, Union import sqlalchemy as sa +from pydantic import BaseModel from trapdata import constants, logger -from trapdata.common.schemas import FilePath +from trapdata.common.schemas import DatabaseURL, FilePath from trapdata.db import get_session from trapdata.db.models.detections import DetectedObject from trapdata.db.models.events import MonitoringSession @@ -655,6 +657,28 @@ def all_queues(db_path, base_directory) -> OrderedDict[str, QueueManager]: ) +class QueueListItem(BaseModel): + name: str + unprocessed_count: int + queue_count: int + done_count: int + + +def list_queues( + db_path: DatabaseURL, image_base_path: pathlib.Path +) -> Sequence[QueueListItem]: + queues = all_queues(db_path, image_base_path) + return [ + QueueListItem( + name=q.name, + unprocessed_count=q.unprocessed_count(), + queue_count=q.queue_count(), + done_count=q.done_count(), + ) + for q in queues.values() + ] + + def add_image_to_queue(db_path, image_id): with get_session(db_path) as sesh: logger.info(f"Adding image id {image_id} to queue") diff --git a/trapdata/settings.py b/trapdata/settings.py index 0785d501..a7fe9d31 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -13,10 +13,7 @@ from trapdata.common.schemas import FilePath -class Settings(BaseSettings): - # Can't use PyDantic DSN validator for database_url if sqlite filepath has spaces, see custom validator below - database_url: Union[str, sqlalchemy.engine.URL] = default_database_dsn() - user_data_path: pathlib.Path = get_app_dir() +class UserSettings(BaseSettings): image_base_path: Optional[pathlib.Path] localization_model: ml.models.ObjectDetectorChoice = Field( default=ml.models.DEFAULT_OBJECT_DETECTOR @@ -31,6 +28,15 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 + + class Config: + extra = "ignore" + + +class Settings(UserSettings): + # Can't use PyDantic DSN validator for database_url if sqlite filepath has spaces, see custom validator below + database_url: Union[str, sqlalchemy.engine.URL] = default_database_dsn() + user_data_path: pathlib.Path = get_app_dir() localization_batch_size: int = 2 classification_batch_size: int = 20 num_workers: int = 1 @@ -199,7 +205,7 @@ def kivy_settings_source(settings: BaseSettings) -> dict[str, str]: @lru_cache -def read_settings(*args, **kwargs): +def read_settings(*args, **kwargs) -> Settings: try: return Settings(*args, **kwargs) except ValidationError as e: diff --git a/trapdata/ui/summary.py b/trapdata/ui/summary.py index bfb7f6ec..1782f3ef 100644 --- a/trapdata/ui/summary.py +++ b/trapdata/ui/summary.py @@ -180,7 +180,8 @@ def load_species(self, ms): # ) classification_summary = get_unique_species_by_track( app.db_path, - ms, + image_base_path=ms.base_directory, + monitoring_session=ms, classification_threshold=classification_threshold, num_examples=NUM_EXAMPLES_PER_ROW, ) diff --git a/trapdata/webui/public/index.html b/trapdata/webui/public/index.html new file mode 100644 index 00000000..f944b384 --- /dev/null +++ b/trapdata/webui/public/index.html @@ -0,0 +1 @@ +:)