diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 92d76c9..d7b0c02 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -6,8 +6,8 @@ use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; use std::sync::Arc; use tokio::net::TcpListener; use tokio::net::TcpStream; -use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore}; -use tokio_rustls::{webpki, TlsAcceptor, TlsConnector}; +use tokio_rustls::rustls::{self, RootCertStore}; +use tokio_rustls::{TlsAcceptor, TlsConnector}; use tarpc::context::Context; use tarpc::serde_transport as transport; @@ -32,7 +32,7 @@ impl PingService for Service { // certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca // used on client-side for server tls -const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain"); +const END_CHAIN: &str = include_str!("certs/eddsa/end.chain"); // used on client-side for client-auth const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key"); const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert"); @@ -43,6 +43,14 @@ const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key"); // used on server-side for client-auth const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain"); +pub fn load_certs(data: &str) -> Vec { + certs(&mut BufReader::new(Cursor::new(data))) + .unwrap() + .into_iter() + .map(rustls::Certificate) + .collect() +} + pub fn load_private_key(key: &str) -> rustls::PrivateKey { let mut reader = BufReader::new(Cursor::new(key)); loop { @@ -62,22 +70,13 @@ async fn main() -> anyhow::Result<()> { // -------------------- start here to setup tls tcp tokio stream -------------------------- // ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs // ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs - let cert = certs(&mut BufReader::new(Cursor::new(END_CERT))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); + let cert = load_certs(END_CERT); let key = load_private_key(END_PRIVATEKEY); let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); // ------------- server side client_auth cert loading start - let roots: Vec = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); let mut client_auth_roots = RootCertStore::empty(); - for root in roots { + for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) { client_auth_roots.add(&root).unwrap(); } let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots); @@ -96,7 +95,6 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { loop { let (stream, _peer_addr) = listener.accept().await.unwrap(); - let acceptor = acceptor.clone(); let tls_stream = acceptor.accept(stream).await.unwrap(); let framed = codec_builder.new_framed(tls_stream); @@ -108,26 +106,14 @@ async fn main() -> anyhow::Result<()> { }); // ---------------------- client connection --------------------- - // cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113 // tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap(); let mut root_store = rustls::RootCertStore::empty(); - root_store.add_server_trust_anchors(chain.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + for root in load_certs(END_CHAIN) { + root_store.add(&root).unwrap(); + } let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH); - let client_auth_certs: Vec = - certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH))) - .unwrap() - .into_iter() - .map(rustls::Certificate) - .collect(); + let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH); let config = rustls::ClientConfig::builder() .with_safe_defaults()