Skip to content

Commit 0061d13

Browse files
committed
small code reorganization
make functions available in a separate module
1 parent 62105c4 commit 0061d13

File tree

7 files changed

+155
-136
lines changed

7 files changed

+155
-136
lines changed

src/filesystem.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::webserver::database::ErrorWithStatus;
1+
use crate::webserver::ErrorWithStatus;
22
use crate::webserver::{make_placeholder, Database};
33
use crate::AppState;
44
use anyhow::Context;

src/webserver/database/mod.rs

Lines changed: 4 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
mod sql;
2+
mod sql_pseudofunctions;
23

3-
use actix_web::http::StatusCode;
4-
use actix_web_httpauth::headers::authorization::Basic;
54
use anyhow::{anyhow, Context};
65
use futures_util::stream::{self, BoxStream, Stream};
76
use futures_util::StreamExt;
@@ -15,7 +14,8 @@ use std::time::Duration;
1514
use crate::app_config::AppConfig;
1615
pub use crate::file_cache::FileCache;
1716
use crate::utils::add_value_to_map;
18-
use crate::webserver::http::{RequestInfo, SingleOrVec};
17+
use crate::webserver::database::sql_pseudofunctions::extract_req_param;
18+
use crate::webserver::http::RequestInfo;
1919
use crate::MIGRATIONS_DIR;
2020
pub use sql::make_placeholder;
2121
pub use sql::ParsedSqlFile;
@@ -162,86 +162,6 @@ fn bind_parameters<'a>(
162162
Ok(stmt.statement.query_with(arguments))
163163
}
164164

165-
fn extract_req_param<'a>(
166-
param: &StmtParam,
167-
request: &'a RequestInfo,
168-
) -> anyhow::Result<Option<Cow<'a, str>>> {
169-
Ok(match param {
170-
StmtParam::Get(x) => request.get_variables.get(x).map(SingleOrVec::as_json_str),
171-
StmtParam::Post(x) => request.post_variables.get(x).map(SingleOrVec::as_json_str),
172-
StmtParam::GetOrPost(x) => request
173-
.post_variables
174-
.get(x)
175-
.or_else(|| request.get_variables.get(x))
176-
.map(SingleOrVec::as_json_str),
177-
StmtParam::Cookie(x) => request.cookies.get(x).map(SingleOrVec::as_json_str),
178-
StmtParam::Header(x) => request.headers.get(x).map(SingleOrVec::as_json_str),
179-
StmtParam::Error(x) => anyhow::bail!("{}", x),
180-
StmtParam::BasicAuthPassword => extract_basic_auth_password(request)
181-
.map(Cow::Borrowed)
182-
.map(Some)?,
183-
StmtParam::BasicAuthUsername => extract_basic_auth_username(request)
184-
.map(Cow::Borrowed)
185-
.map(Some)?,
186-
StmtParam::HashPassword(inner) => extract_req_param(inner, request)?
187-
.map_or(Ok(None), |x| hash_password(&x).map(Cow::Owned).map(Some))?,
188-
StmtParam::RandomString(len) => Some(Cow::Owned(random_string(*len))),
189-
})
190-
}
191-
192-
fn random_string(len: usize) -> String {
193-
use rand::{distributions::Alphanumeric, Rng};
194-
password_hash::rand_core::OsRng
195-
.sample_iter(&Alphanumeric)
196-
.take(len)
197-
.map(char::from)
198-
.collect()
199-
}
200-
201-
fn hash_password(password: &str) -> anyhow::Result<String> {
202-
let phf = argon2::Argon2::default();
203-
let salt = password_hash::SaltString::generate(&mut password_hash::rand_core::OsRng);
204-
let password_hash = &password_hash::PasswordHash::generate(phf, password, &salt)
205-
.map_err(|e| anyhow!("Unable to hash password: {}", e))?;
206-
Ok(password_hash.to_string())
207-
}
208-
209-
#[derive(Debug)]
210-
pub struct ErrorWithStatus {
211-
pub status: StatusCode,
212-
}
213-
impl std::fmt::Display for ErrorWithStatus {
214-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215-
write!(f, "{}", self.status)
216-
}
217-
}
218-
impl std::error::Error for ErrorWithStatus {}
219-
220-
fn extract_basic_auth(request: &RequestInfo) -> anyhow::Result<&Basic> {
221-
request
222-
.basic_auth
223-
.as_ref()
224-
.ok_or_else(|| {
225-
anyhow::Error::new(ErrorWithStatus {
226-
status: StatusCode::UNAUTHORIZED,
227-
})
228-
})
229-
.with_context(|| "Expected the user to be authenticated with HTTP basic auth")
230-
}
231-
232-
fn extract_basic_auth_username(request: &RequestInfo) -> anyhow::Result<&str> {
233-
Ok(extract_basic_auth(request)?.user_id())
234-
}
235-
236-
fn extract_basic_auth_password(request: &RequestInfo) -> anyhow::Result<&str> {
237-
let password = extract_basic_auth(request)?.password().ok_or_else(|| {
238-
anyhow::Error::new(ErrorWithStatus {
239-
status: StatusCode::UNAUTHORIZED,
240-
})
241-
})?;
242-
Ok(password)
243-
}
244-
245165
#[derive(Debug)]
246166
pub enum DbItem {
247167
Row(Value),
@@ -371,7 +291,7 @@ fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfi
371291
}
372292
struct PreparedStatement {
373293
statement: AnyStatement<'static>,
374-
parameters: Vec<StmtParam>,
294+
parameters: Vec<sql_pseudofunctions::StmtParam>,
375295
}
376296

377297
impl Display for PreparedStatement {
@@ -380,20 +300,6 @@ impl Display for PreparedStatement {
380300
}
381301
}
382302

383-
#[derive(Debug, PartialEq, Eq)]
384-
enum StmtParam {
385-
Get(String),
386-
Post(String),
387-
GetOrPost(String),
388-
Cookie(String),
389-
Header(String),
390-
Error(String),
391-
BasicAuthPassword,
392-
BasicAuthUsername,
393-
HashPassword(Box<StmtParam>),
394-
RandomString(usize),
395-
}
396-
397303
#[actix_web::test]
398304
async fn test_row_to_json() -> anyhow::Result<()> {
399305
use sqlx::Connection;

src/webserver/database/sql.rs

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
use super::sql_pseudofunctions::{func_call_to_param, StmtParam};
12
use super::PreparedStatement;
23
use crate::file_cache::AsyncFromStrWithState;
3-
use crate::webserver::database::StmtParam;
44
use crate::{AppState, Database};
55
use async_trait::async_trait;
66
use sqlparser::ast::{
@@ -155,27 +155,22 @@ impl ParameterExtractor {
155155
}
156156
}
157157

158-
fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam {
159-
match func_name {
160-
"cookie" => extract_single_quoted_string("cookie", arguments)
161-
.map_or_else(StmtParam::Error, StmtParam::Cookie),
162-
"header" => extract_single_quoted_string("header", arguments)
163-
.map_or_else(StmtParam::Error, StmtParam::Header),
164-
"basic_auth_username" => StmtParam::BasicAuthUsername,
165-
"basic_auth_password" => StmtParam::BasicAuthPassword,
166-
"hash_password" => extract_variable_argument("hash_password", arguments)
167-
.map(Box::new)
168-
.map_or_else(StmtParam::Error, StmtParam::HashPassword),
169-
"random_string" => extract_integer("random_string", arguments)
170-
.map_or_else(StmtParam::Error, StmtParam::RandomString),
171-
unknown_name => StmtParam::Error(format!(
172-
"Unknown function {unknown_name}({})",
173-
FormatArguments(arguments)
174-
)),
158+
/** This is a helper struct to format a list of arguments for an error message. */
159+
pub(super) struct FormatArguments<'a>(pub &'a [FunctionArg]);
160+
impl std::fmt::Display for FormatArguments<'_> {
161+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162+
let mut args = self.0.iter();
163+
if let Some(arg) = args.next() {
164+
write!(f, "{arg}")?;
165+
}
166+
for arg in args {
167+
write!(f, ", {arg}")?;
168+
}
169+
Ok(())
175170
}
176171
}
177172

178-
fn extract_single_quoted_string(
173+
pub(super) fn extract_single_quoted_string(
179174
func_name: &'static str,
180175
arguments: &mut [FunctionArg],
181176
) -> Result<String, String> {
@@ -190,7 +185,7 @@ fn extract_single_quoted_string(
190185
}
191186
}
192187

193-
fn extract_integer(
188+
pub(super) fn extract_integer(
194189
func_name: &'static str,
195190
arguments: &mut [FunctionArg],
196191
) -> Result<usize, String> {
@@ -205,22 +200,7 @@ fn extract_integer(
205200
}
206201
}
207202

208-
/** This is a helper struct to format a list of arguments for an error message. */
209-
struct FormatArguments<'a>(&'a [FunctionArg]);
210-
impl std::fmt::Display for FormatArguments<'_> {
211-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212-
let mut args = self.0.iter();
213-
if let Some(arg) = args.next() {
214-
write!(f, "{arg}")?;
215-
}
216-
for arg in args {
217-
write!(f, ", {arg}")?;
218-
}
219-
Ok(())
220-
}
221-
}
222-
223-
fn extract_variable_argument(
203+
pub(super) fn extract_variable_argument(
224204
func_name: &'static str,
225205
arguments: &mut [FunctionArg],
226206
) -> Result<StmtParam, String> {
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
use std::borrow::Cow;
2+
3+
use actix_web::http::StatusCode;
4+
use actix_web_httpauth::headers::authorization::Basic;
5+
use sqlparser::ast::FunctionArg;
6+
7+
use crate::webserver::{
8+
http::{RequestInfo, SingleOrVec},
9+
ErrorWithStatus,
10+
};
11+
12+
use super::sql::{
13+
extract_integer, extract_single_quoted_string, extract_variable_argument, FormatArguments,
14+
};
15+
use anyhow::{anyhow, Context};
16+
17+
#[derive(Debug, PartialEq, Eq)]
18+
pub(super) enum StmtParam {
19+
Get(String),
20+
Post(String),
21+
GetOrPost(String),
22+
Cookie(String),
23+
Header(String),
24+
Error(String),
25+
BasicAuthPassword,
26+
BasicAuthUsername,
27+
HashPassword(Box<StmtParam>),
28+
RandomString(usize),
29+
}
30+
31+
pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam {
32+
match func_name {
33+
"cookie" => extract_single_quoted_string("cookie", arguments)
34+
.map_or_else(StmtParam::Error, StmtParam::Cookie),
35+
"header" => extract_single_quoted_string("header", arguments)
36+
.map_or_else(StmtParam::Error, StmtParam::Header),
37+
"basic_auth_username" => StmtParam::BasicAuthUsername,
38+
"basic_auth_password" => StmtParam::BasicAuthPassword,
39+
"hash_password" => extract_variable_argument("hash_password", arguments)
40+
.map(Box::new)
41+
.map_or_else(StmtParam::Error, StmtParam::HashPassword),
42+
"random_string" => extract_integer("random_string", arguments)
43+
.map_or_else(StmtParam::Error, StmtParam::RandomString),
44+
unknown_name => StmtParam::Error(format!(
45+
"Unknown function {unknown_name}({})",
46+
FormatArguments(arguments)
47+
)),
48+
}
49+
}
50+
51+
pub(super) fn extract_req_param<'a>(
52+
param: &StmtParam,
53+
request: &'a RequestInfo,
54+
) -> anyhow::Result<Option<Cow<'a, str>>> {
55+
Ok(match param {
56+
StmtParam::Get(x) => request.get_variables.get(x).map(SingleOrVec::as_json_str),
57+
StmtParam::Post(x) => request.post_variables.get(x).map(SingleOrVec::as_json_str),
58+
StmtParam::GetOrPost(x) => request
59+
.post_variables
60+
.get(x)
61+
.or_else(|| request.get_variables.get(x))
62+
.map(SingleOrVec::as_json_str),
63+
StmtParam::Cookie(x) => request.cookies.get(x).map(SingleOrVec::as_json_str),
64+
StmtParam::Header(x) => request.headers.get(x).map(SingleOrVec::as_json_str),
65+
StmtParam::Error(x) => anyhow::bail!("{}", x),
66+
StmtParam::BasicAuthPassword => extract_basic_auth_password(request)
67+
.map(Cow::Borrowed)
68+
.map(Some)?,
69+
StmtParam::BasicAuthUsername => extract_basic_auth_username(request)
70+
.map(Cow::Borrowed)
71+
.map(Some)?,
72+
StmtParam::HashPassword(inner) => extract_req_param(inner, request)?
73+
.map_or(Ok(None), |x| hash_password(&x).map(Cow::Owned).map(Some))?,
74+
StmtParam::RandomString(len) => Some(Cow::Owned(random_string(*len))),
75+
})
76+
}
77+
78+
fn random_string(len: usize) -> String {
79+
use rand::{distributions::Alphanumeric, Rng};
80+
password_hash::rand_core::OsRng
81+
.sample_iter(&Alphanumeric)
82+
.take(len)
83+
.map(char::from)
84+
.collect()
85+
}
86+
87+
fn hash_password(password: &str) -> anyhow::Result<String> {
88+
let phf = argon2::Argon2::default();
89+
let salt = password_hash::SaltString::generate(&mut password_hash::rand_core::OsRng);
90+
let password_hash = &password_hash::PasswordHash::generate(phf, password, &salt)
91+
.map_err(|e| anyhow!("Unable to hash password: {}", e))?;
92+
Ok(password_hash.to_string())
93+
}
94+
95+
fn extract_basic_auth_username(request: &RequestInfo) -> anyhow::Result<&str> {
96+
Ok(extract_basic_auth(request)?.user_id())
97+
}
98+
99+
fn extract_basic_auth_password(request: &RequestInfo) -> anyhow::Result<&str> {
100+
let password = extract_basic_auth(request)?.password().ok_or_else(|| {
101+
anyhow::Error::new(ErrorWithStatus {
102+
status: StatusCode::UNAUTHORIZED,
103+
})
104+
})?;
105+
Ok(password)
106+
}
107+
108+
fn extract_basic_auth(request: &RequestInfo) -> anyhow::Result<&Basic> {
109+
request
110+
.basic_auth
111+
.as_ref()
112+
.ok_or_else(|| {
113+
anyhow::Error::new(ErrorWithStatus {
114+
status: StatusCode::UNAUTHORIZED,
115+
})
116+
})
117+
.with_context(|| "Expected the user to be authenticated with HTTP basic auth")
118+
}

src/webserver/error_with_status.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use actix_web::http::StatusCode;
2+
3+
#[derive(Debug)]
4+
pub struct ErrorWithStatus {
5+
pub status: StatusCode,
6+
}
7+
impl std::fmt::Display for ErrorWithStatus {
8+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9+
write!(f, "{}", self.status)
10+
}
11+
}
12+
impl std::error::Error for ErrorWithStatus {}

src/webserver/http.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::render::{HeaderContext, PageContext, RenderContext};
2-
use crate::webserver::database::{stream_query_results, DbItem, ErrorWithStatus};
2+
use crate::webserver::database::{stream_query_results, DbItem};
3+
use crate::webserver::ErrorWithStatus;
34
use crate::{AppState, Config, ParsedSqlFile};
45
use actix_web::dev::{fn_service, ServiceFactory, ServiceRequest};
56
use actix_web::error::{ErrorInternalServerError, ErrorNotFound};

src/webserver/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
pub mod database;
2+
pub mod error_with_status;
23
pub mod http;
34

45
pub use database::Database;
6+
pub use error_with_status::ErrorWithStatus;
57

68
pub use database::apply_migrations;
79
pub use database::make_placeholder;

0 commit comments

Comments
 (0)