diff --git a/Cargo.lock b/Cargo.lock index d6e0fffa..ee2dcb75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -326,6 +326,12 @@ version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayvec" version = "0.7.6" @@ -449,7 +455,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tower", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -475,6 +481,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-server" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56bac90848f6a9393ac03c63c640925c4b7c8ca21654de40d53f55964667c7d8" +dependencies = [ + "arc-swap", + "bytes", + "futures-util", + "http 1.2.0", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower 0.4.13", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -2366,6 +2396,7 @@ dependencies = [ "average", "aws-sign-v4", "axum", + "axum-server", "base64", "byte-unit", "bytes", @@ -2518,6 +2549,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2754,6 +2805,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" dependencies = [ + "aws-lc-rs", "pem", "ring", "rustls-pki-types", @@ -2890,7 +2942,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tower", + "tower 0.5.2", "tower-service", "url", "wasm-bindgen", @@ -3053,6 +3105,15 @@ dependencies = [ "security-framework 3.2.0", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.11.0" @@ -3664,6 +3725,21 @@ dependencies = [ "winnow", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index ef2ef85d..d20fc237 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,17 +83,21 @@ rlimit = "0.10.1" [dev-dependencies] assert_cmd = "2.0.14" axum = { version = "0.8.1", features = ["http2"] } +axum-server = { version = "0.7.1", features = ["tls-rustls"] } bytes = "1.6" float-cmp = "0.10.0" http-mitm-proxy = "0.12.0" jsonschema = "0.28.1" lazy_static = "1.5.0" predicates = "3.1.0" -rcgen = "0.13.1" +# features = ["aws_lc_rs"] is a workaround for mac & native-tls +# https://github.com/sfackler/rust-native-tls/issues/225 +rcgen = { version = "0.13.1", features = ["aws_lc_rs"] } regex = "1.10.5" +tempfile = "3.10.1" +rustls = "0.23.18" [target.'cfg(unix)'.dev-dependencies] -tempfile = "3.10.1" actix-web = "4" [profile.pgo] diff --git a/README.md b/README.md index b6dfaf89..2d3ee85d 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,12 @@ Options: Lookup only ipv6. --ipv4 Lookup only ipv4. + --cacert + (TLS) Use the specified certificate file to verify the peer. Native certificate store is used even if this argument is specified. + --cert + (TLS) Use the specified client certificate file. --key must be also specified + --key + (TLS) Use the specified client key file. --cert must be also specified --insecure Accept invalid certs. --connect-to diff --git a/src/db.rs b/src/db.rs index f5132c8c..26131774 100644 --- a/src/db.rs +++ b/src/db.rs @@ -93,9 +93,9 @@ mod test_db { #[cfg(feature = "vsock")] vsock_addr: None, #[cfg(feature = "rustls")] - rustls_configs: crate::tls_config::RuslsConfigs::new(false), + rustls_configs: crate::tls_config::RuslsConfigs::new(false, None, None), #[cfg(all(feature = "native-tls", not(feature = "rustls")))] - native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false), + native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false, None, None), }; let result = store(&client, ":memory:", start, &test_vec); assert_eq!(result.unwrap(), 2); diff --git a/src/main.rs b/src/main.rs index c8aead42..c4eb6d47 100644 --- a/src/main.rs +++ b/src/main.rs @@ -210,6 +210,21 @@ Note: If qps is specified, burst will be ignored", ipv6: bool, #[arg(help = "Lookup only ipv4.", long = "ipv4")] ipv4: bool, + #[arg( + help = "(TLS) Use the specified certificate file to verify the peer. Native certificate store is used even if this argument is specified.", + long + )] + cacert: Option, + #[arg( + help = "(TLS) Use the specified client certificate file. --key must be also specified", + long + )] + cert: Option, + #[arg( + help = "(TLS) Use the specified client key file. --cert must be also specified", + long + )] + key: Option, #[arg(help = "Accept invalid certs.", long = "insecure")] insecure: bool, #[arg( @@ -520,6 +535,13 @@ async fn run() -> anyhow::Result<()> { let (config, mut resolver_opts) = system_resolv_conf()?; resolver_opts.ip_strategy = ip_strategy; let resolver = hickory_resolver::AsyncResolver::tokio(config, resolver_opts); + let cacert = opts.cacert.as_deref().map(std::fs::read).transpose()?; + let client_auth = match (opts.cert, opts.key) { + (Some(cert), Some(key)) => Some((std::fs::read(cert)?, std::fs::read(key)?)), + (None, None) => None, + // TODO: Ensure it on clap + _ => anyhow::bail!("Both --cert and --key must be specified"), + }; let client = Arc::new(client::Client { aws_config, @@ -542,9 +564,21 @@ async fn run() -> anyhow::Result<()> { #[cfg(feature = "vsock")] vsock_addr: opts.vsock_addr.map(|v| v.0), #[cfg(feature = "rustls")] - rustls_configs: tls_config::RuslsConfigs::new(opts.insecure), + rustls_configs: tls_config::RuslsConfigs::new( + opts.insecure, + cacert.as_deref(), + client_auth + .as_ref() + .map(|(cert, key)| (cert.as_slice(), key.as_slice())), + ), #[cfg(all(feature = "native-tls", not(feature = "rustls")))] - native_tls_connectors: tls_config::NativeTlsConnectors::new(opts.insecure), + native_tls_connectors: tls_config::NativeTlsConnectors::new( + opts.insecure, + cacert.as_deref(), + client_auth + .as_ref() + .map(|(cert, key)| (cert.as_slice(), key.as_slice())), + ), }); if !opts.no_pre_lookup { @@ -595,10 +629,8 @@ async fn run() -> anyhow::Result<()> { match work_mode { WorkMode::Debug => { let mut print_config = print_config; - if let Err(e) = client::work_debug(&mut print_config.output, client).await { - eprintln!("{e}"); - } - std::process::exit(libc::EXIT_SUCCESS) + client::work_debug(&mut print_config.output, client).await?; + return Ok(()); } WorkMode::FixedNumber { n_requests, diff --git a/src/tls_config.rs b/src/tls_config.rs index 3dab1e5b..5eb4fc75 100644 --- a/src/tls_config.rs +++ b/src/tls_config.rs @@ -6,7 +6,12 @@ pub struct RuslsConfigs { #[cfg(feature = "rustls")] impl RuslsConfigs { - pub fn new(insecure: bool) -> Self { + pub fn new( + insecure: bool, + cacert_pem: Option<&[u8]>, + client_auth: Option<(&[u8], &[u8])>, + ) -> Self { + use rustls_pki_types::pem::PemObject; use std::sync::Arc; let mut root_cert_store = rustls::RootCertStore::empty(); @@ -14,9 +19,25 @@ impl RuslsConfigs { { root_cert_store.add(cert).unwrap(); } - let mut config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store.clone()) - .with_no_client_auth(); + + if let Some(cacert_pem) = cacert_pem { + for der in rustls_pki_types::CertificateDer::pem_slice_iter(cacert_pem) { + root_cert_store.add(der.unwrap()).unwrap(); + } + } + + let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store); + + let mut config = if let Some((cert, key)) = client_auth { + let certs = rustls_pki_types::CertificateDer::pem_slice_iter(cert) + .collect::, _>>() + .unwrap(); + let key = rustls_pki_types::PrivateKeyDer::from_pem_slice(key).unwrap(); + + builder.with_client_auth_cert(certs, key).unwrap() + } else { + builder.with_no_client_auth() + }; if insecure { config .dangerous() @@ -50,15 +71,32 @@ pub struct NativeTlsConnectors { #[cfg(all(feature = "native-tls", not(feature = "rustls")))] impl NativeTlsConnectors { - pub fn new(insecure: bool) -> Self { + pub fn new( + insecure: bool, + cacert_pem: Option<&[u8]>, + client_auth: Option<(&[u8], &[u8])>, + ) -> Self { let new = |is_http2: bool| { let mut connector_builder = native_tls::TlsConnector::builder(); + + if let Some(cacert_pem) = cacert_pem { + let cert = native_tls::Certificate::from_pem(cacert_pem) + .expect("Failed to parse cacert_pem"); + connector_builder.add_root_certificate(cert); + } + if insecure { connector_builder .danger_accept_invalid_certs(true) .danger_accept_invalid_hostnames(true); } + if let Some((cert, key)) = client_auth { + let cert = native_tls::Identity::from_pkcs8(cert, key) + .expect("Failed to parse client_auth cert/key"); + connector_builder.identity(cert); + } + if is_http2 { connector_builder.request_alpns(&["h2"]); } diff --git a/tests/tests.rs b/tests/tests.rs index d6b7ad08..15a05264 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,7 +1,9 @@ use std::{ convert::Infallible, error::Error as StdError, + fs::File, future::Future, + io::Write, net::{Ipv6Addr, SocketAddr}, sync::{atomic::AtomicU16, Arc}, }; @@ -903,3 +905,93 @@ async fn test_json_schema() { panic!("JSON schema validation failed\n{output_json_stats_success_breakdown}"); } } + +fn setup_mtls_server( + dir: &std::path::Path, +) -> (u16, impl Future>) { + let port = PORT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + + // build our application with a route + let app = Router::new() + // `GET /` goes to `root` + .route("/", get(|| async { "Hello, World" })); + + let make_cert = || { + // Workaround for mac & native-tls + // https://github.com/sfackler/rust-native-tls/issues/225 + let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap(); + let cert = rcgen::CertificateParams::new(vec!["localhost".to_string()]) + .unwrap() + .self_signed(&key_pair) + .unwrap(); + rcgen::CertifiedKey { cert, key_pair } + }; + + let server_cert = make_cert(); + let client_cert = make_cert(); + + let mut roots = rustls::RootCertStore::empty(); + roots.add(client_cert.cert.der().clone()).unwrap(); + let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(roots)) + .build() + .unwrap(); + + let config = rustls::ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert( + vec![server_cert.cert.der().clone()], + rustls::pki_types::PrivateKeyDer::Pkcs8(rustls::pki_types::PrivatePkcs8KeyDer::from( + server_cert.key_pair.serialize_der(), + )), + ) + .unwrap(); + + let config = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(config)); + + File::create(dir.join("server.crt")) + .unwrap() + .write_all(server_cert.cert.pem().as_bytes()) + .unwrap(); + + File::create(dir.join("client.crt")) + .unwrap() + .write_all(client_cert.cert.pem().as_bytes()) + .unwrap(); + + File::create(dir.join("client.key")) + .unwrap() + .write_all(client_cert.key_pair.serialize_pem().as_bytes()) + .unwrap(); + + ( + port, + axum_server::bind_rustls(addr, config).serve(app.into_make_service()), + ) +} + +#[tokio::test] +async fn test_mtls() { + let dir = tempfile::tempdir().unwrap(); + let (port, server) = setup_mtls_server(dir.path()); + + tokio::spawn(server); + + let mut command = Command::cargo_bin("oha").unwrap(); + command + .args([ + "--debug", + "--cacert", + dir.path().join("server.crt").to_str().unwrap(), + "--cert", + dir.path().join("client.crt").to_str().unwrap(), + "--key", + dir.path().join("client.key").to_str().unwrap(), + ]) + .arg(format!("https://localhost:{port}/")); + tokio::task::spawn_blocking(move || { + command.assert().success(); + }) + .await + .unwrap(); +}