-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathgemini_router.py
More file actions
140 lines (126 loc) · 4.84 KB
/
gemini_router.py
File metadata and controls
140 lines (126 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Gemini API Router that routes to Gemini CLI when API key is all 9s
"""
import os
from typing import Any, Dict, List, Optional, Union
import google.generativeai as genai
from google.generativeai.types import GenerationConfig
from gemini_client import GeminiClient
# Global storage for the configured API key
_configured_api_key: Optional[str] = None
def _is_all_nines(api_key: Optional[str]) -> bool:
"""Check if the API key is all 9s."""
if not api_key:
return False
# Remove prefixes or query params if present (though typically just the key string)
return all(c == '9' for c in api_key)
def configure(api_key: Optional[str] = None, **kwargs):
"""
Configure the Gemini API.
"""
global _configured_api_key
_configured_api_key = api_key or os.environ.get("GOOGLE_API_KEY")
# Always configure the real library just in case, unless we want to prevent it
# completely for bad keys. But typically we just mirror.
# However, if it's all 9s, the real library might reject it if we call configure.
# So we only call real configure if it's NOT all 9s.
if not _is_all_nines(_configured_api_key):
genai.configure(api_key=_configured_api_key, **kwargs)
class GenerativeModel:
"""
A wrapper around google.generativeai.GenerativeModel that routes to Gemini CLI
when the configured API key is all 9s.
"""
def __init__(
self,
model_name: str,
generation_config: Optional[GenerationConfig] = None,
safety_settings: Optional[Any] = None,
tools: Optional[Any] = None,
tool_config: Optional[Any] = None,
system_instruction: Optional[Any] = None
):
self.model_name = model_name
self.generation_config = generation_config
self.safety_settings = safety_settings
self.tools = tools
self.tool_config = tool_config
self.system_instruction = system_instruction
# Determine mode based on globally configured key
self._is_local_mode = _is_all_nines(_configured_api_key)
if self._is_local_mode:
self.client = GeminiClient()
self._real_model = None
else:
self.client = None
self._real_model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
system_instruction=system_instruction
)
def generate_content(
self,
contents: Union[str, List[Dict[str, Any]]],
generation_config: Optional[GenerationConfig] = None,
safety_settings: Optional[Any] = None,
stream: bool = False,
**kwargs
) -> Any:
"""
Generate content, routing to local CLI if in local mode.
"""
if self._is_local_mode:
# Use local Gemini Client
# Merge config if provided
config = generation_config or self.generation_config
if stream:
raise NotImplementedError("Streaming not supported in local mode")
return self.client.generate_content(
model=self.model_name,
contents=contents,
generation_config=config,
stream=stream
)
else:
# Delegate to real library
return self._real_model.generate_content(
contents,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
**kwargs
)
async def generate_content_async(
self,
contents: Union[str, List[Dict[str, Any]]],
generation_config: Optional[GenerationConfig] = None,
safety_settings: Optional[Any] = None,
stream: bool = False,
**kwargs
) -> Any:
"""
Async version of generate_content.
"""
if self._is_local_mode:
config = generation_config or self.generation_config
if stream:
raise NotImplementedError("Streaming not supported in local mode")
return await self.client.generate_content_async(
model=self.model_name,
contents=contents,
generation_config=config,
stream=stream
)
else:
return await self._real_model.generate_content_async(
contents,
generation_config=generation_config,
safety_settings=safety_settings,
stream=stream,
**kwargs
)
# Expose other common functions/classes from genai if needed,
# but mostly GenerativeModel and configure are the entry points.