Skip to content

Commit 809727b

Browse files
committed
feat: move to inference endpoints OpenAI api
* fix: avoid parsing ratelimit response * feat: add `/index-issue` to index specific issue nb
1 parent e6d5f7b commit 809727b

File tree

6 files changed

+233
-46
lines changed

6 files changed

+233
-46
lines changed

TODO.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
- [x] fix: delete associated comments, reviews & review comments
3030

31+
- [ ] check token length of prompt because of vllm switch
32+
3133
## Infra resilience tasks
3234

3335
- [x] helm chart for the issue bot

issue-bot/src/embeddings/inference_endpoints.rs

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,28 @@ use std::time::Duration;
22

33
use reqwest::{
44
header::{HeaderMap, HeaderValue, AUTHORIZATION},
5-
Client,
5+
Client, StatusCode,
66
};
7-
use serde::Serialize;
7+
use serde::{Deserialize, Serialize};
88
use tracing::warn;
99

1010
use crate::{config::EmbeddingApiConfig, APP_USER_AGENT};
1111

1212
use super::EmbeddingError;
1313

1414
#[derive(Serialize)]
15-
enum TruncateDirection {
16-
#[allow(unused)]
17-
Left,
18-
Right,
15+
struct OAIEmbedRequest {
16+
input: String,
1917
}
2018

21-
#[derive(Serialize)]
22-
struct EmbedRequest {
23-
inputs: String,
24-
truncate: bool,
25-
truncate_direction: TruncateDirection,
19+
#[derive(Deserialize)]
20+
struct OAIEmbedResponse {
21+
data: Vec<OAIEmbedData>,
22+
}
23+
24+
#[derive(Deserialize)]
25+
struct OAIEmbedData {
26+
embedding: Vec<f32>,
2627
}
2728

2829
#[derive(Clone)]
@@ -52,11 +53,9 @@ impl EmbeddingApi {
5253
loop {
5354
let res = self
5455
.client
55-
.post(&self.cfg.url)
56-
.json(&EmbedRequest {
57-
inputs: text.clone(),
58-
truncate: true,
59-
truncate_direction: TruncateDirection::Right,
56+
.post(format!("{}/v1/embeddings", self.cfg.url))
57+
.json(&OAIEmbedRequest {
58+
input: text.clone(),
6059
})
6160
.send()
6261
.await;
@@ -75,7 +74,17 @@ impl EmbeddingApi {
7574
}
7675
Ok(res) => res,
7776
};
78-
if res.status() != 200 {
77+
let status = res.status();
78+
// Shortcircuit on client errors (4xx)
79+
if status.is_client_error() {
80+
let response_content = res.text().await?;
81+
warn!(
82+
"[status: {}] Embedding API returned: '{}'",
83+
status, response_content
84+
);
85+
return Err(EmbeddingError::HttpClientError(status));
86+
}
87+
if res.status() != StatusCode::OK {
7988
let status = res.status();
8089
let response_content = res.text().await?;
8190
warn!(
@@ -90,9 +99,11 @@ impl EmbeddingApi {
9099
continue;
91100
}
92101
return res
93-
.json::<Vec<Vec<f32>>>()
102+
.json::<OAIEmbedResponse>()
94103
.await?
104+
.data
95105
.pop()
106+
.map(|d| d.embedding)
96107
.ok_or(EmbeddingError::MissingEmbedding);
97108
}
98109
}

issue-bot/src/embeddings/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use reqwest::StatusCode;
12
use thiserror::Error;
23

34
pub mod inference_endpoints;
@@ -9,6 +10,8 @@ pub enum EmbeddingError {
910
// Candle(#[from] candle::Error),
1011
// #[error("hf hub error: {0}")]
1112
// HfHub(#[from] hf_hub::api::tokio::ApiError),
13+
#[error("http client error: {0}")]
14+
HttpClientError(StatusCode),
1215
#[error("invalid header value: {0}")]
1316
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
1417
#[error("io error: {0}")]
@@ -17,7 +20,7 @@ pub enum EmbeddingError {
1720
Join(#[from] tokio::task::JoinError),
1821
#[error("maximum retries ({0}) exceeded")]
1922
MaxRetriesExceeded(u32),
20-
#[error("embedding is missing from API response")]
23+
#[error("no embedding was returned from the API")]
2124
MissingEmbedding,
2225
#[error("reqwest error: {0}")]
2326
Reqwest(#[from] reqwest::Error),

issue-bot/src/github.rs

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ use chrono::Utc;
55
use futures::Stream;
66
use reqwest::{
77
header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, LINK},
8-
Client,
8+
Client, StatusCode,
99
};
1010
use serde::{Deserialize, Serialize};
1111
use thiserror::Error;
1212
use tokio::time::sleep;
13-
use tracing::info;
13+
use tracing::{error, info};
1414

1515
use crate::{
1616
config::{GithubApiConfig, MessageConfig},
@@ -38,6 +38,8 @@ pub enum GithubApiError {
3838
TaskJoin(#[from] tokio::task::JoinError),
3939
#[error("to str error: {0}")]
4040
ToStr(#[from] axum::http::header::ToStrError),
41+
#[error("unsuccesful response: {0}")]
42+
UnsuccesfulResponse(StatusCode),
4143
}
4244

4345
#[derive(Debug, Deserialize)]
@@ -178,6 +180,28 @@ impl GithubApi {
178180
Ok(())
179181
}
180182

183+
pub(crate) async fn get_issue(
184+
&self,
185+
number: i32,
186+
repository_full_name: &str,
187+
) -> Result<IssueWithComments, GithubApiError> {
188+
let url = format!(
189+
"https://api.github.com/repos/{}/issues/{}",
190+
repository_full_name, number
191+
);
192+
let issue = self.client.get(&url).send().await?.json::<Issue>().await?;
193+
let comments = self
194+
.client
195+
.get(&issue.comments_url)
196+
.query(&[("direction", "asc")])
197+
.send()
198+
.await?
199+
.json::<Vec<Comment>>()
200+
.await?;
201+
202+
Ok(IssueWithComments::new(issue, comments))
203+
}
204+
181205
pub(crate) fn get_issues(
182206
&self,
183207
from_page: i32,
@@ -202,23 +226,30 @@ impl GithubApi {
202226
let link_header = res.headers().get(LINK).cloned();
203227
let ratelimit_remaining = res.headers().get(X_RATELIMIT_REMAINING).cloned();
204228
let ratelimit_reset = res.headers().get(X_RATELIMIT_RESET).cloned();
229+
if handle_ratelimit(ratelimit_remaining, ratelimit_reset).await? {
230+
continue;
231+
}
205232
let issues = res.json::<Vec<Issue>>().await?;
206233
info!("fetched {} issues from page {}, getting comments for each issue next", issues.len(), page);
207-
handle_ratelimit(ratelimit_remaining, ratelimit_reset).await?;
208234
let page_issue_count = issues.len();
209235
for (i, issue) in issues.into_iter().enumerate() {
210-
let res = client
211-
.get(&issue.comments_url)
212-
.query(&[("direction", "asc")])
213-
.send()
214-
.await?;
215-
let ratelimit_remaining = res.headers().get(X_RATELIMIT_REMAINING).cloned();
216-
let ratelimit_reset = res.headers().get(X_RATELIMIT_RESET).cloned();
217-
handle_ratelimit(ratelimit_remaining, ratelimit_reset).await?;
218-
let comments = res
219-
.json::<Vec<Comment>>()
220-
.await?;
221-
yield (IssueWithComments::new(issue, comments), (i + 1 == page_issue_count).then_some(page));
236+
loop {
237+
let res = client
238+
.get(&issue.comments_url)
239+
.query(&[("direction", "asc")])
240+
.send()
241+
.await?;
242+
let ratelimit_remaining = res.headers().get(X_RATELIMIT_REMAINING).cloned();
243+
let ratelimit_reset = res.headers().get(X_RATELIMIT_RESET).cloned();
244+
if handle_ratelimit(ratelimit_remaining, ratelimit_reset).await? {
245+
continue;
246+
}
247+
let comments = res
248+
.json::<Vec<Comment>>()
249+
.await?;
250+
yield (IssueWithComments::new(issue, comments), (i + 1 == page_issue_count).then_some(page));
251+
break;
252+
}
222253
}
223254
if get_next_page(link_header)?.is_none() {
224255
break;
@@ -229,23 +260,23 @@ impl GithubApi {
229260
}
230261
}
231262

263+
/// returns true if rate limited and sleeps until reset
232264
async fn handle_ratelimit(
233265
remaining: Option<HeaderValue>,
234266
reset: Option<HeaderValue>,
235-
) -> Result<(), GithubApiError> {
267+
) -> Result<bool, GithubApiError> {
236268
match (remaining, reset) {
237269
(Some(remaining), Some(reset)) => {
238270
let remaining: i32 = remaining.to_str()?.parse()?;
239271
let reset: i64 = reset.to_str()?.parse()?;
240-
if remaining == 0 {
272+
let rate_limited = remaining == 0;
273+
if rate_limited {
241274
let duration = Duration::from_secs((reset - Utc::now().timestamp() + 2) as u64);
242275
info!("rate limit reached, sleeping for {}s", duration.as_secs());
243276
sleep(duration).await;
244277
}
278+
Ok(rate_limited)
245279
}
246-
(remaining, reset) => {
247-
return Err(GithubApiError::MissingRateLimitHeaders(remaining, reset))
248-
}
280+
(remaining, reset) => Err(GithubApiError::MissingRateLimitHeaders(remaining, reset)),
249281
}
250-
Ok(())
251282
}

0 commit comments

Comments
 (0)