Skip to content

Commit dac7482

Browse files
committed
Start work on functions "sqlpage.basic_auth_user" and "sqlpage.basic_auth_password"
1 parent c1edb9a commit dac7482

File tree

5 files changed

+185
-47
lines changed

5 files changed

+185
-47
lines changed

Cargo.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ config = { version = "0.13.3", features = ["json"] }
4242
markdown = { version = "1.0.0-alpha.9", features = ["log"] }
4343
password-hash = "0.5.0"
4444
argon2 = "0.5.0"
45+
actix-web-httpauth = "0.8.0"
4546

4647
[build-dependencies]
4748
ureq = "2.6.2"

src/webserver/database/mod.rs

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
mod sql;
22

3+
use actix_web::http::StatusCode;
4+
use actix_web_httpauth::headers::authorization::Basic;
35
use anyhow::{anyhow, Context};
46
use futures_util::stream::{self, BoxStream, Stream};
57
use futures_util::StreamExt;
68
use serde_json::{Map, Value};
9+
use std::borrow::Cow;
710
use std::fmt::{Display, Formatter};
811
use std::future::ready;
912
use std::path::Path;
@@ -144,29 +147,74 @@ fn bind_parameters<'a>(
144147
) -> anyhow::Result<Query<'a, sqlx::Any, AnyArguments<'a>>> {
145148
let mut arguments = AnyArguments::default();
146149
for param in &stmt.parameters {
147-
let argument = match param {
148-
StmtParam::Get(x) => request.get_variables.get(x),
149-
StmtParam::Post(x) => request.post_variables.get(x),
150-
StmtParam::GetOrPost(x) => request
151-
.post_variables
152-
.get(x)
153-
.or_else(|| request.get_variables.get(x)),
154-
StmtParam::Cookie(x) => request.cookies.get(x),
155-
StmtParam::Header(x) => request.headers.get(x),
156-
StmtParam::Error(x) => anyhow::bail!("{}", x),
157-
};
150+
let argument = extract_req_param(param, request)?;
158151
log::debug!("Binding value {:?} in statement {}", &argument, stmt);
159152
match argument {
160153
None => arguments.add(None::<String>),
161-
Some(SingleOrVec::Single(s)) => arguments.add(s),
162-
Some(SingleOrVec::Vec(v)) => {
163-
arguments.add(serde_json::to_string(v).unwrap_or_default());
164-
}
154+
Some(Cow::Owned(s)) => arguments.add(s),
155+
Some(Cow::Borrowed(v)) => arguments.add(v),
165156
}
166157
}
167158
Ok(stmt.statement.query_with(arguments))
168159
}
169160

161+
fn extract_req_param<'a>(
162+
param: &StmtParam,
163+
request: &'a RequestInfo,
164+
) -> anyhow::Result<Option<Cow<'a, str>>> {
165+
Ok(match param {
166+
StmtParam::Get(x) => request.get_variables.get(x).map(SingleOrVec::as_json_str),
167+
StmtParam::Post(x) => request.post_variables.get(x).map(SingleOrVec::as_json_str),
168+
StmtParam::GetOrPost(x) => request
169+
.post_variables
170+
.get(x)
171+
.or_else(|| request.get_variables.get(x))
172+
.map(SingleOrVec::as_json_str),
173+
StmtParam::Cookie(x) => request.cookies.get(x).map(SingleOrVec::as_json_str),
174+
StmtParam::Header(x) => request.headers.get(x).map(SingleOrVec::as_json_str),
175+
StmtParam::Error(x) => anyhow::bail!("{}", x),
176+
StmtParam::BasicAuthPassword => extract_basic_auth_password(request)
177+
.map(Cow::Borrowed)
178+
.map(Some)?,
179+
StmtParam::BasicAuthUsername => extract_basic_auth_username(request)
180+
.map(Cow::Borrowed)
181+
.map(Some)?,
182+
StmtParam::HashPassword(_) => todo!(),
183+
})
184+
}
185+
186+
#[derive(Debug)]
187+
pub struct ErrorWithStatus {
188+
pub status: StatusCode,
189+
}
190+
impl std::fmt::Display for ErrorWithStatus {
191+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192+
write!(f, "HTTP error with status {}", self.status)
193+
}
194+
}
195+
impl std::error::Error for ErrorWithStatus {}
196+
197+
fn extract_basic_auth(request: &RequestInfo) -> anyhow::Result<&Basic> {
198+
request.basic_auth.as_ref().ok_or_else(|| {
199+
anyhow::Error::new(ErrorWithStatus {
200+
status: StatusCode::UNAUTHORIZED,
201+
})
202+
})
203+
}
204+
205+
fn extract_basic_auth_username(request: &RequestInfo) -> anyhow::Result<&str> {
206+
Ok(extract_basic_auth(request)?.user_id())
207+
}
208+
209+
fn extract_basic_auth_password(request: &RequestInfo) -> anyhow::Result<&str> {
210+
let password = extract_basic_auth(request)?.password().ok_or_else(|| {
211+
anyhow::Error::new(ErrorWithStatus {
212+
status: StatusCode::UNAUTHORIZED,
213+
})
214+
})?;
215+
Ok(password)
216+
}
217+
170218
#[derive(Debug)]
171219
pub enum DbItem {
172220
Row(Value),
@@ -304,6 +352,9 @@ enum StmtParam {
304352
Cookie(String),
305353
Header(String),
306354
Error(String),
355+
BasicAuthPassword,
356+
BasicAuthUsername,
357+
HashPassword(Box<StmtParam>),
307358
}
308359

309360
#[actix_web::test]

src/webserver/database/sql.rs

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -149,34 +149,79 @@ impl ParameterExtractor {
149149
) -> Expr {
150150
#[allow(clippy::single_match_else)]
151151
let placeholder = self.make_placeholder();
152-
let param = match func_name {
153-
"cookie" => extract_single_quoted_string("cookie", &mut arguments)
154-
.map_or_else(StmtParam::Error, StmtParam::Cookie),
155-
"header" => extract_single_quoted_string("header", &mut arguments)
156-
.map_or_else(StmtParam::Error, StmtParam::Header),
157-
unknown_name => {
158-
StmtParam::Error(format!("Unknown function {unknown_name}({arguments:#?})"))
159-
}
160-
};
152+
let param = func_call_to_param(func_name, &mut arguments);
161153
self.parameters.push(param);
162154
placeholder
163155
}
164156
}
165157

158+
fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam {
159+
match func_name {
160+
"cookie" => extract_single_quoted_string("cookie", arguments)
161+
.map_or_else(StmtParam::Error, StmtParam::Cookie),
162+
"header" => extract_single_quoted_string("header", arguments)
163+
.map_or_else(StmtParam::Error, StmtParam::Header),
164+
"basic_auth_user" => StmtParam::BasicAuthUsername,
165+
"basic_auth_password" => StmtParam::BasicAuthPassword,
166+
"hash_password" => extract_variable_argument("hash_password", arguments)
167+
.map(Box::new)
168+
.map_or_else(StmtParam::Error, StmtParam::HashPassword),
169+
unknown_name => {
170+
StmtParam::Error(format!("Unknown function {unknown_name}({arguments:#?})"))
171+
}
172+
}
173+
}
174+
166175
fn extract_single_quoted_string(
167176
func_name: &'static str,
168177
arguments: &mut [FunctionArg],
169178
) -> Result<String, String> {
170-
if let [FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(
171-
param_value,
172-
))))] = arguments
173-
{
174-
Ok(std::mem::take(param_value))
175-
} else {
176-
Err(format!(
179+
match arguments.first_mut().and_then(function_arg_expr) {
180+
Some(Expr::Value(Value::SingleQuotedString(param_value))) => {
181+
Ok(std::mem::take(param_value))
182+
}
183+
_ => Err(format!(
177184
"{func_name}({args}) is not a valid call. Expected a literal single quoted string.",
178-
args = arguments.iter().map(ToString::to_string).collect::<Vec<_>>().join(", ")
179-
))
185+
args = arguments
186+
.iter()
187+
.map(ToString::to_string)
188+
.collect::<Vec<_>>()
189+
.join(", ")
190+
)),
191+
}
192+
}
193+
194+
fn extract_variable_argument(
195+
func_name: &'static str,
196+
arguments: &mut [FunctionArg],
197+
) -> Result<StmtParam, String> {
198+
match arguments.first_mut().and_then(function_arg_expr) {
199+
Some(Expr::Value(Value::Placeholder(placeholder))) => {
200+
Ok(map_param(std::mem::take(placeholder)))
201+
}
202+
Some(Expr::Function(Function {
203+
name: ObjectName(func_name_parts),
204+
args,
205+
..
206+
})) if is_sqlpage_func(func_name_parts) => Ok(func_call_to_param(
207+
sqlpage_func_name(func_name_parts),
208+
args.as_mut_slice(),
209+
)),
210+
_ => Err(format!(
211+
"{func_name}({args}) is not a valid call. Expected a literal single quoted string.",
212+
args = arguments
213+
.iter()
214+
.map(ToString::to_string)
215+
.collect::<Vec<_>>()
216+
.join(", ")
217+
)),
218+
}
219+
}
220+
221+
fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> {
222+
match arg {
223+
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr),
224+
_ => None,
180225
}
181226
}
182227

@@ -206,27 +251,38 @@ impl VisitorMut for ParameterExtractor {
206251
distinct: false,
207252
over: None,
208253
..
209-
}) => {
210-
if let [Ident {
211-
value: func_name_part_0,
212-
..
213-
}, Ident {
214-
value: func_name, ..
215-
}] = &func_name_parts[..]
216-
{
217-
if func_name_part_0 == "sqlpage" {
218-
log::debug!("Handling builtin function: {func_name}");
219-
let arguments = std::mem::take(args);
220-
*value = self.handle_builtin_function(func_name, arguments);
221-
}
222-
}
254+
}) if is_sqlpage_func(func_name_parts) => {
255+
let func_name = sqlpage_func_name(func_name_parts);
256+
log::debug!("Handling builtin function: {func_name}");
257+
let arguments = std::mem::take(args);
258+
*value = self.handle_builtin_function(func_name, arguments);
223259
}
224260
_ => (),
225261
}
226262
ControlFlow::<()>::Continue(())
227263
}
228264
}
229265

266+
fn is_sqlpage_func(func_name_parts: &[Ident]) -> bool {
267+
if let [Ident { value, .. }, Ident { .. }] = func_name_parts {
268+
value == "sqlpage"
269+
} else {
270+
false
271+
}
272+
}
273+
274+
fn sqlpage_func_name(func_name_parts: &[Ident]) -> &str {
275+
if let [Ident { .. }, Ident { value, .. }] = func_name_parts {
276+
value
277+
} else {
278+
debug_assert!(
279+
false,
280+
"sqlpage function name should have been checked by is_sqlpage_func"
281+
);
282+
""
283+
}
284+
}
285+
230286
#[test]
231287
fn test_statement_rewrite() {
232288
let sql = "select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')";

src/webserver/http.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use actix_web::{
1111
};
1212

1313
use actix_web::body::MessageBody;
14+
use actix_web_httpauth::headers::authorization::{Authorization, Basic};
1415
use anyhow::Context;
1516
use chrono::{DateTime, Utc};
1617
use futures_util::stream::Stream;
@@ -273,6 +274,13 @@ impl SingleOrVec {
273274
SingleOrVec::Vec(v) => mem::take(v),
274275
}
275276
}
277+
278+
pub fn as_json_str(&self) -> Cow<'_, str> {
279+
match self {
280+
SingleOrVec::Single(x) => Cow::Borrowed(x),
281+
SingleOrVec::Vec(v) => Cow::Owned(serde_json::to_string(v).unwrap()),
282+
}
283+
}
276284
}
277285

278286
#[derive(Debug)]
@@ -282,6 +290,7 @@ pub struct RequestInfo {
282290
pub headers: ParamMap,
283291
pub client_ip: Option<IpAddr>,
284292
pub cookies: ParamMap,
293+
pub basic_auth: Option<Basic>,
285294
}
286295

287296
fn param_map<PAIRS: IntoIterator<Item = (String, String)>>(values: PAIRS) -> ParamMap {
@@ -330,12 +339,17 @@ async fn extract_request_info(req: &mut ServiceRequest) -> RequestInfo {
330339
.flat_map(|c| c.iter())
331340
.map(|cookie| (cookie.name().to_string(), cookie.value().to_string()));
332341

342+
let basic_auth = Authorization::<Basic>::parse(req)
343+
.ok()
344+
.map(Authorization::into_scheme);
345+
333346
RequestInfo {
334347
headers: param_map(headers),
335348
get_variables: param_map(get_variables),
336349
post_variables: param_map(post_variables),
337350
client_ip,
338351
cookies: param_map(cookies),
352+
basic_auth,
339353
}
340354
}
341355

0 commit comments

Comments
 (0)