@@ -3,6 +3,11 @@ use crate::Error;
3
3
use std:: convert:: TryFrom ;
4
4
use std:: path:: PathBuf ;
5
5
use std:: str:: FromStr ;
6
+ #[ cfg( any(
7
+ feature = "postgres" ,
8
+ feature = "tokio-postgres" ,
9
+ feature = "tiberius-config"
10
+ ) ) ]
6
11
use std:: { borrow:: Cow , collections:: HashMap } ;
7
12
use url:: Url ;
8
13
@@ -35,7 +40,8 @@ impl Config {
35
40
db_user : None ,
36
41
db_pass : None ,
37
42
db_name : None ,
38
- use_tls : None ,
43
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
44
+ use_tls : false ,
39
45
#[ cfg( feature = "tiberius-config" ) ]
40
46
trust_cert : false ,
41
47
} ,
@@ -141,7 +147,8 @@ impl Config {
141
147
self . main . db_port . as_deref ( )
142
148
}
143
149
144
- pub fn use_tls ( & self ) -> Option < bool > {
150
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
151
+ pub fn use_tls ( & self ) -> bool {
145
152
self . main . use_tls
146
153
}
147
154
@@ -189,6 +196,16 @@ impl Config {
189
196
} ,
190
197
}
191
198
}
199
+
200
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
201
+ pub fn set_use_tls ( self , use_tls : bool ) -> Config {
202
+ Config {
203
+ main : Main {
204
+ use_tls,
205
+ ..self . main
206
+ } ,
207
+ }
208
+ }
192
209
}
193
210
194
211
impl TryFrom < Url > for Config {
@@ -209,6 +226,11 @@ impl TryFrom<Url> for Config {
209
226
}
210
227
} ;
211
228
229
+ #[ cfg( any(
230
+ feature = "postgres" ,
231
+ feature = "tokio-postgres" ,
232
+ feature = "tiberius-config"
233
+ ) ) ]
212
234
let query_params = url
213
235
. query_pairs ( )
214
236
. collect :: < HashMap < Cow < ' _ , str > , Cow < ' _ , str > > > ( ) ;
@@ -228,12 +250,10 @@ impl TryFrom<Url> for Config {
228
250
}
229
251
}
230
252
231
- let use_tls = match query_params
232
- . get ( "sslmode" )
233
- . unwrap_or ( & Cow :: Borrowed ( "disable" ) )
234
- {
235
- & Cow :: Borrowed ( "disable" ) => false ,
236
- & Cow :: Borrowed ( "require" ) => true ,
253
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
254
+ let use_tls = match query_params. get ( "sslmode" ) {
255
+ Some ( & Cow :: Borrowed ( "require" ) ) => true ,
256
+ Some ( & Cow :: Borrowed ( "disable" ) ) | None => false ,
237
257
_ => {
238
258
return Err ( Error :: new (
239
259
Kind :: ConfigError ( "Invalid sslmode value, please use disable/require" . into ( ) ) ,
@@ -257,7 +277,8 @@ impl TryFrom<Url> for Config {
257
277
db_user : Some ( url. username ( ) . to_string ( ) ) ,
258
278
db_pass : url. password ( ) . map ( |r| r. to_string ( ) ) ,
259
279
db_name : Some ( url. path ( ) . trim_start_matches ( '/' ) . to_string ( ) ) ,
260
- use_tls : Some ( use_tls) ,
280
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
281
+ use_tls,
261
282
#[ cfg( feature = "tiberius-config" ) ]
262
283
trust_cert,
263
284
} ,
@@ -290,7 +311,9 @@ struct Main {
290
311
db_user : Option < String > ,
291
312
db_pass : Option < String > ,
292
313
db_name : Option < String > ,
293
- use_tls : Option < bool > ,
314
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
315
+ #[ serde( default ) ]
316
+ use_tls : bool ,
294
317
#[ cfg( feature = "tiberius-config" ) ]
295
318
#[ serde( default ) ]
296
319
trust_cert : bool ,
@@ -474,18 +497,40 @@ mod tests {
474
497
) ;
475
498
}
476
499
500
+ #[ cfg( any( feature = "postgres" , feature = "tokio-postgres" ) ) ]
477
501
#[ test]
478
502
fn builds_from_sslmode_str ( ) {
479
- let config =
503
+ use crate :: config:: ConfigDbType ;
504
+
505
+ let config_disable =
480
506
Config :: from_str ( "postgres://root:1234@localhost:5432/refinery?sslmode=disable" )
481
507
. unwrap ( ) ;
482
- assert ! ( config . use_tls( ) . is_some ( ) ) ;
483
- assert ! ( !config . use_tls ( ) . unwrap ( ) ) ;
484
- let config =
508
+ assert ! ( !config_disable . use_tls( ) ) ;
509
+
510
+ let config_require =
485
511
Config :: from_str ( "postgres://root:1234@localhost:5432/refinery?sslmode=require" )
486
512
. unwrap ( ) ;
487
- assert ! ( config. use_tls( ) . is_some( ) ) ;
488
- assert ! ( config. use_tls( ) . unwrap( ) ) ;
513
+ assert ! ( config_require. use_tls( ) ) ;
514
+
515
+ // Verify that manually created config matches parsed URL config
516
+ let manual_config_disable = Config :: new ( ConfigDbType :: Postgres )
517
+ . set_db_user ( "root" )
518
+ . set_db_pass ( "1234" )
519
+ . set_db_host ( "localhost" )
520
+ . set_db_port ( "5432" )
521
+ . set_db_name ( "refinery" )
522
+ . set_use_tls ( false ) ;
523
+ assert_eq ! ( config_disable. use_tls( ) , manual_config_disable. use_tls( ) ) ;
524
+
525
+ let manual_config_require = Config :: new ( ConfigDbType :: Postgres )
526
+ . set_db_user ( "root" )
527
+ . set_db_pass ( "1234" )
528
+ . set_db_host ( "localhost" )
529
+ . set_db_port ( "5432" )
530
+ . set_db_name ( "refinery" )
531
+ . set_use_tls ( true ) ;
532
+ assert_eq ! ( config_require. use_tls( ) , manual_config_require. use_tls( ) ) ;
533
+
489
534
let config =
490
535
Config :: from_str ( "postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue" ) ;
491
536
assert ! ( config. is_err( ) ) ;
0 commit comments