Skip to content

Commit 59fd500

Browse files
authored
Merge pull request #14 from RISE-UNIBAS/mistral-ai
mistral-ai
2 parents 40307f6 + b54a74e commit 59fd500

File tree

4 files changed

+73
-30
lines changed

4 files changed

+73
-30
lines changed

benchmarks/benchmarks_tests.csv

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ T18,metadata_extraction,anthropic,claude-3-5-sonnet-20241022,Document,0.0,You ar
2020
T19,metadata_extraction,genai,gemini-2.5-pro-exp-03-25,Document,0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history,prompt.txt,false
2121
T20,metadata_extraction,genai,gemini-2.0-flash-lite,Document,0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history,prompt.txt,false
2222
T21,metadata_extraction,genai,gemini-2.0-pro-exp-02-05,Document,0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history,prompt.txt,false
23-
T22,fraktur,genai,gemini-2.5-pro-exp-03-25,"",0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history,prompt.txt,false
23+
T22,fraktur,genai,gemini-2.5-pro-exp-03-25,"",0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history,prompt.txt,false
24+
T23,metadata_extraction,mistral,pixtral-large-latest,Document,0.0,You are a historian with keyword knowledge and an expert in the field of 20th century Swiss history. You only return valid JSON an no other text.,prompt.txt,false

scripts/__init__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1-
import logging
21
import os
2+
import time
3+
import logging
4+
5+
# Ensure the logs directory exists
6+
log_dir = "logs"
7+
if not os.path.exists(log_dir):
8+
os.makedirs(log_dir)
39

4-
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
10+
# Configure logging
511
logging.basicConfig(
6-
level=log_level,
7-
format='%(asctime)s [%(levelname)s] %(message)s',
8-
datefmt='%Y-%m-%d %H:%M:%S'
12+
level=logging.INFO,
13+
format="%(asctime)s %(levelname)s:%(name)s:%(message)s",
14+
handlers=[
15+
logging.FileHandler(f"{log_dir}/{time.strftime('%Y%m%d-%H%M%S')}.log"),
16+
logging.StreamHandler(),
17+
]
918
)
1019

1120
logger = logging.getLogger(__name__)

scripts/benchmark_base.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def is_runnable(self) -> bool:
5959
if not os.path.exists(os.path.join(self.benchmark_dir, "ground_truths")):
6060
logging.error(f"Ground truths directory not found: {self.benchmark_dir}")
6161
return False
62-
if not self.provider in ["openai", "genai", "anthropic"]:
62+
if not self.provider in ["openai", "genai", "anthropic", "mistral"]:
6363
logging.error(f"Invalid provider: {self.provider}")
6464
return False
6565
if not self.model:
@@ -74,7 +74,7 @@ def load_prompt(self) -> str:
7474
logging.debug(f"Loaded prompt from {prompt_path}")
7575
if self.has_file_information:
7676
try:
77-
kwargs = {} # Add file information here
77+
kwargs = {} # Add file information here
7878
return prompt.format(**kwargs)
7979
except KeyError as e:
8080
return prompt
@@ -106,7 +106,6 @@ def load_ground_truth(self,
106106
return {"error": "Invalid JSON format."}
107107
return {"response_text": ground_truth_text}
108108

109-
110109
def ask_llm(self,
111110
image_paths: list[str]) -> dict:
112111
""" Ask the language model a question. """
@@ -135,7 +134,7 @@ def get_request_answer_path(self):
135134

136135
def get_request_answer_file_name(self, image_name):
137136
""" Get the path to the answer file. """
138-
return os.path.join(self.get_request_answer_path(), self.get_request_name(image_name)+".json")
137+
return os.path.join(self.get_request_answer_path(), self.get_request_name(image_name) + ".json")
139138

140139
def get_request_render_path(self):
141140
date_str = datetime.now().strftime('%Y-%m-%d')
@@ -159,7 +158,7 @@ def save_request_answer(self,
159158
logging.info(f"Saved answer to {file_name}")
160159

161160
def save_benchmark_score(self,
162-
score: dict) -> None:
161+
score: dict) -> None:
163162
""" Save the benchmark score to a file. """
164163
date_str = datetime.now().strftime('%Y-%m-%d')
165164
save_path = os.path.join('..', "results", date_str, self.id, "scoring.json")
@@ -237,7 +236,7 @@ def run(self, regenerate_existing_results=True):
237236
image_paths = [os.path.join(images_dir, img) for img in img_files]
238237

239238
if (regenerate_existing_results and os.path.exists(self.get_request_answer_file_name(image_name))) or \
240-
(not os.path.exists(self.get_request_answer_file_name(image_name))):
239+
(not os.path.exists(self.get_request_answer_file_name(image_name))):
241240
logging.info(f"Processing {self.id}, {image_name}...")
242241
answer = self.ask_llm(image_paths)
243242
self.save_request_answer(image_name, answer)
@@ -255,7 +254,6 @@ def run(self, regenerate_existing_results=True):
255254
benchmark_score = self.score_benchmark(benchmark_scores)
256255
self.save_benchmark_score(benchmark_score)
257256

258-
259257
def get_request_name(self, image_name: str) -> str:
260258
""" Get the name of the request. """
261259
return f"request_{self.id}_{os.path.splitext(image_name)[0]}"
@@ -271,9 +269,9 @@ def create_request_render(self,
271269

272270
@abstractmethod
273271
def score_request_answer(self,
274-
image_name: str,
275-
response: dict,
276-
ground_truth: dict) -> dict:
272+
image_name: str,
273+
response: dict,
274+
ground_truth: dict) -> dict:
277275
""" Score the response. """
278276
pass
279277

@@ -327,18 +325,18 @@ def score_benchmark(self, all_scores):
327325
return {"score": "niy"}
328326

329327
def score_request_answer(self,
330-
image_name: str,
331-
response: dict,
332-
ground_truth: dict) -> dict:
328+
image_name: str,
329+
response: dict,
330+
ground_truth: dict) -> dict:
333331
""" Score the response. """
334332
return {}
335333

336334
def create_request_render(self,
337-
image_name: str,
338-
result: dict,
339-
score: dict,
340-
truth) -> str:
341-
""" Create a markdown render of the request. """
342-
return ("### Result for image: {image_name}"
343-
"\n\n"
344-
"no details available")
335+
image_name: str,
336+
result: dict,
337+
score: dict,
338+
truth) -> str:
339+
""" Create a markdown render of the request. """
340+
return ("### Result for image: {image_name}"
341+
"\n\n"
342+
"no details available")

scripts/simple_ai_clients.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Simple AI API client for OpenAI, GenAI, and Anthropic."""
1+
"""Simple AI API client for OpenAI, GenAI, Anthropic, and Mistral AI."""
22
import base64
33
from dataclasses import asdict
44
from datetime import datetime
@@ -7,13 +7,16 @@
77
import google.generativeai as genai
88
from openai import OpenAI
99
from anthropic import Anthropic
10+
from mistralai import Mistral
11+
1012

1113
class AiApiClient:
1214
"""Simple AI API client for OpenAI, GenAI, and Anthropic."""
1315

1416
SUPPORTED_APIS = ['openai',
1517
'genai',
16-
'anthropic']
18+
'anthropic',
19+
'mistral']
1720

1821
api_client = None
1922
image_resources = []
@@ -52,6 +55,11 @@ def init_client(self):
5255
api_key=self.api_key,
5356
)
5457

58+
if self.api == 'mistral':
59+
self.api_client = Mistral(
60+
api_key=self.api_key
61+
)
62+
5563
@property
5664
def elapsed_time(self):
5765
"""Return the elapsed time since the client was initialized."""
@@ -144,6 +152,28 @@ def prompt(self, model, prompt):
144152
)
145153
answer = message
146154

155+
if self.api == 'mistral':
156+
content = [{"type": "text", "text": prompt}]
157+
for img_path in self.image_resources:
158+
with open(img_path, "rb") as image_file:
159+
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
160+
data_uri = f"data:image/jpeg;base64,{base64_image}"
161+
content.append({
162+
"type": "image_url",
163+
"image_url": {
164+
"url": data_uri
165+
}
166+
})
167+
168+
message = self.api_client.chat.complete(
169+
messages=[{
170+
"role": "user",
171+
"content": content,
172+
}],
173+
model=model,
174+
)
175+
answer = message
176+
147177
end_time = time.time()
148178
elapsed_time = end_time - prompt_start
149179
return self.create_answer(answer, elapsed_time, model)
@@ -169,6 +199,8 @@ def create_answer(self, response, elapsed_time, model):
169199
answer['response_text'] = response.text
170200
elif self.api == 'anthropic':
171201
answer['response_text'] = response.content[0].text
202+
elif self.api == 'mistral':
203+
answer['response_text'] = response.choices[0].message.content
172204

173205
return answer
174206

@@ -184,4 +216,7 @@ def get_model_list(self):
184216
return genai.list_models()
185217

186218
if self.api == 'anthropic':
187-
return self.api_client.models.list()
219+
return self.api_client.models.list()
220+
221+
if self.api == 'mistral':
222+
return self.api_client.models.list()

0 commit comments

Comments
 (0)