forked from ClinicianFOCUS/local-llm-container
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathapi_wrapper.py
More file actions
142 lines (119 loc) · 4.49 KB
/
api_wrapper.py
File metadata and controls
142 lines (119 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
FastAPI proxy server for Ollama API with rate limiting and authentication.
This module implements a proxy server that forwards requests to an Ollama API instance,
adding authentication and rate limiting capabilities.
"""
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse, PlainTextResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import requests
from APIKeyManager import APIKeyManager
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi import Limiter
# Initialize the FastAPI app
app = FastAPI()
# Initialize API key management
API_KEY_MANAGER = APIKeyManager()
# Print API key information and security warnings
print("\n" + "="*50)
print(" " * 5 + "⚠️ IMPORTANT: API Key Information ⚠️" + " " * 5)
print("="*50)
print("\n" + " " * 3 + f" Session API Key: {API_KEY_MANAGER.api_key} " + "\n")
print("="*50)
print("\nNOTE:")
print("- Do not share your API key publicly.")
print("- Avoid committing API keys in code repositories.")
print("- If exposed, reset and replace it immediately.\n")
print("="*50 + "\n")
#: str: The URL of the internal Ollama API service
OLLAMA_URL = "http://ollama:11434"
#: Limiter: Rate limiter instance configured with default limits
limiter = Limiter(
key_func=get_remote_address,
default_limits=["1/second"]
)
@app.middleware("http")
async def rate_limit_middleware(request, call_next):
"""
Middleware for handling rate limiting of requests.
This middleware intercepts all HTTP requests and enforces rate limiting rules.
If a request exceeds the rate limit, it returns a 429 status code.
Args:
request (Request): The incoming FastAPI request object
call_next (Callable): Function to call the next middleware or route handler
Returns:
Response: Either the normal response or a rate limit exceeded response
Raises:
RateLimitExceeded: When the request rate exceeds the defined limit
"""
try:
response = await call_next(request)
return response
except RateLimitExceeded:
return PlainTextResponse(
"Rate limit exceeded. Try again later.",
status_code=429
)
@app.get("/health")
@limiter.limit("1/second")
async def health_check(request: Request):
"""
Health check endpoint to verify the service is running.
Returns:
JSONResponse: A simple JSON response indicating the service is running
"""
return JSONResponse(
content={"status": "ok"}
)
@app.api_route("/{path:path}", methods=["GET", "POST"], dependencies=[Depends(API_KEY_MANAGER.verify_api_key)])
@limiter.limit("10/second")
async def proxy_request(path: str, request: Request):
"""
Proxy endpoint that forwards requests to the Ollama API.
This endpoint handles both GET and POST requests, forwarding them to the
corresponding Ollama API endpoint while maintaining headers and request body.
Args:
path (str): The path component of the URL to forward to Ollama
request (Request): The incoming FastAPI request object
Returns:
JSONResponse: The response from the Ollama API, wrapped in a JSONResponse
Raises:
HTTPException: When there's an error processing the request
JSONDecodeError: When the response from Ollama is not valid JSON
Exception: For any other unexpected errors
Examples:
>>> # GET request
>>> response = client.get("/api/v1/models")
>>> # POST request
>>> response = client.post("/api/v1/generate", json={"prompt": "Hello"})
"""
try:
headers = request.headers
if request.method == "GET":
response = requests.get(
f"{OLLAMA_URL}/{path}",
headers=headers,
params=request.query_params
)
elif request.method == "POST":
body = await request.json()
response = requests.post(
f"{OLLAMA_URL}/{path}",
headers=headers,
json=body
)
return JSONResponse(
content=response.json(),
status_code=response.status_code
)
except requests.exceptions.JSONDecodeError:
return JSONResponse(
content=response.text,
status_code=response.status_code
)
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)