From 060eb92bf7d2f7f90ae57db136910d0448dfc57e Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 14:04:27 -0800 Subject: [PATCH] Refactor JWT handling in ServiceAccountAccess. Avoid reading and parsing the private key file on every invocation of token() in favor or reading it once when the ServiceAccountAccess is built. Also avoid unnecessary allocations when signing JWT tokens and renamed sub to subject to avoid any confusion with the std::ops::Sub trait. --- examples/test-svc-acct/src/main.rs | 4 +- src/service_account.rs | 195 ++++++++++++++--------------- 2 files changed, 94 insertions(+), 105 deletions(-) diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 4945adc..e5ef33c 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -7,7 +7,9 @@ use yup_oauth2::GetToken; async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let sa = yup_oauth2::ServiceAccountAccess::new(creds) + .build() + .unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/service_account.rs b/src/service_account.rs index 2c4cb3f..03525a5 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -37,11 +37,11 @@ use hyper; use serde_json; const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; +const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#; /// Encodes s as Base64 -fn encode_base64>(s: T) -> String { - base64::encode_config(s.as_ref(), base64::URL_SAFE) +fn append_base64 + ?Sized>(s: &T, out: &mut String) { + base64::encode_config_buf(s, base64::URL_SAFE, out) } /// Decode a PKCS8 formatted RSA key. @@ -78,11 +78,11 @@ pub struct ServiceAccountKey { pub key_type: Option, pub project_id: Option, pub private_key_id: Option, - pub private_key: Option, - pub client_email: Option, + pub private_key: String, + pub client_email: String, pub client_id: Option, pub auth_uri: Option, - pub token_uri: Option, + pub token_uri: String, pub auth_provider_x509_cert_url: Option, pub client_x509_cert_url: Option, } @@ -90,52 +90,42 @@ pub struct ServiceAccountKey { /// Permissions requested for a JWT. /// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests. #[derive(Serialize, Debug)] -struct Claims { - iss: String, - aud: String, +struct Claims<'a> { + iss: &'a str, + aud: &'a str, exp: i64, iat: i64, - sub: Option, + subject: Option<&'a str>, scope: String, } -/// A JSON Web Token ready for signing. -struct JWT { - /// The value of GOOGLE_RS256_HEAD. - header: String, - /// A Claims struct, expressing the set of desired permissions etc. - claims: Claims, -} +impl<'a> Claims<'a> { + fn new(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self + where + T: AsRef, + { + let iat = chrono::Utc::now().timestamp(); + let expiry = iat + 3600 - 5; // Max validity is 1h. -impl JWT { - /// Create a new JWT from claims. - fn new(claims: Claims) -> JWT { - JWT { - header: GOOGLE_RS256_HEAD.to_string(), - claims, + let scope = crate::helper::join(scopes, " "); + Claims { + iss: &key.client_email, + aud: &key.token_uri, + exp: expiry, + iat, + subject, + scope, } } +} - /// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`. - #[allow(dead_code)] - pub fn set_header(&mut self, head: String) { - self.header = head; - } +/// A JSON Web Token ready for signing. +struct JWTSigner { + signer: Box, +} - /// Encodes the first two parts (header and claims) to base64 and assembles them into a form - /// ready to be signed. - fn encode_claims(&self) -> String { - let mut head = encode_base64(&self.header); - let claims = encode_base64(serde_json::to_string(&self.claims).unwrap()); - - head.push_str("."); - head.push_str(&claims); - head - } - - /// Sign a JWT base string with `private_key`, which is a PKCS8 string. - fn sign(&self, private_key: &str) -> Result { - let mut jwt_head = self.encode_claims(); +impl JWTSigner { + fn new(private_key: &str) -> Result { let key = decode_rsa_key(private_key)?; let signing_key = sign::RSASigningKey::new(&key) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?; @@ -144,40 +134,25 @@ impl JWT { .ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme") })?; - let signature = signer - .sign(jwt_head.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?; - let signature_b64 = encode_base64(signature); + Ok(JWTSigner { signer }) + } + fn sign_claims(&self, claims: &Claims) -> Result { + let mut jwt_head = Self::encode_claims(claims); + let signature = self.signer.sign(jwt_head.as_bytes())?; jwt_head.push_str("."); - jwt_head.push_str(&signature_b64); - + append_base64(&signature, &mut jwt_head); Ok(jwt_head) } -} -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. -fn init_claims_from_key(key: &ServiceAccountKey, scopes: &[T]) -> Claims -where - T: AsRef, -{ - let iat = chrono::Utc::now().timestamp(); - let expiry = iat + 3600 - 5; // Max validity is 1h. - - let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - scopes_string.pop(); - - Claims { - iss: key.client_email.clone().unwrap(), - aud: key.token_uri.clone().unwrap(), - exp: expiry, - iat, - sub: None, - scope: scopes_string, + /// Encodes the first two parts (header and claims) to base64 and assembles them into a form + /// ready to be signed. + fn encode_claims(claims: &Claims) -> String { + let mut head = String::new(); + append_base64(GOOGLE_RS256_HEAD, &mut head); + head.push_str("."); + append_base64(&serde_json::to_string(&claims).unwrap(), &mut head); + head } } @@ -188,7 +163,7 @@ where pub struct ServiceAccountAccess { client: C, key: ServiceAccountKey, - sub: Option, + subject: Option, } impl ServiceAccountAccess { @@ -197,7 +172,7 @@ impl ServiceAccountAccess { ServiceAccountAccess { client: DefaultHyperClient, key, - sub: None, + subject: None, } } } @@ -214,21 +189,21 @@ where ServiceAccountAccess { client: hyper_client, key: self.key, - sub: self.sub, + subject: self.subject, } } - /// Use the provided sub. - pub fn sub(self, sub: String) -> Self { + /// Use the provided subject. + pub fn subject(self, subject: String) -> Self { ServiceAccountAccess { - sub: Some(sub), + subject: Some(subject), ..self } } /// Build the configured ServiceAccountAccess. - pub fn build(self) -> impl GetToken { - ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub) + pub fn build(self) -> Result { + ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject) } } @@ -236,20 +211,27 @@ struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, cache: Arc>, - sub: Option, + subject: Option, + signer: JWTSigner, } impl ServiceAccountAccessImpl where C: hyper::client::connect::Connect, { - fn new(client: hyper::Client, key: ServiceAccountKey, sub: Option) -> Self { - ServiceAccountAccessImpl { + fn new( + client: hyper::Client, + key: ServiceAccountKey, + subject: Option, + ) -> Result { + let signer = JWTSigner::new(&key.private_key)?; + Ok(ServiceAccountAccessImpl { client, key, cache: Arc::new(Mutex::new(MemoryStorage::default())), - sub, - } + subject, + signer, + }) } } @@ -268,25 +250,25 @@ where /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( client: &hyper::client::Client, - sub: Option<&str>, + signer: &JWTSigner, + subject: Option<&str>, key: &ServiceAccountKey, scopes: &[T], ) -> Result where T: AsRef, { - let mut claims = init_claims_from_key(&key, scopes); - claims.sub = sub.map(|x| x.to_owned()); - let signed = JWT::new(claims) - .sign(key.private_key.as_ref().unwrap()) - .map_err(RequestError::LowLevelError)?; + let claims = Claims::new(key, scopes, subject); + let signed = signer.sign_claims(&claims).map_err(|_| { + RequestError::LowLevelError(io::Error::new( + io::ErrorKind::Other, + "unable to sign claims", + )) + })?; let rqbody = form_urlencoded::Serializer::new(String::new()) - .extend_pairs(vec![ - ("grant_type".to_string(), GRANT_TYPE.to_string()), - ("assertion".to_string(), signed), - ]) + .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())]) .finish(); - let request = hyper::Request::post(key.token_uri.as_ref().unwrap()) + let request = hyper::Request::post(&key.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); @@ -335,7 +317,8 @@ where } let token = Self::request_token( &self.client, - self.sub.as_ref().map(|x| x.as_str()), + &self.signer, + self.subject.as_ref().map(|x| x.as_str()), &self.key, scopes, ) @@ -399,7 +382,7 @@ mod tests { "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" }"#; let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap(); - key.token_uri = Some(format!("{}/token", server_url)); + key.token_uri = format!("{}/token", server_url); let json_response = r#"{ "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", @@ -429,7 +412,7 @@ mod tests { .with_body(json_response) .expect(1) .create(); - let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); + let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap(); let fut = async { let tok = acc .token(&["https://www.googleapis.com/auth/pubsub"]) @@ -472,7 +455,8 @@ mod tests { .create(); let acc = ServiceAccountAccess::new(key.clone()) .hyper_client(client.clone()) - .build(); + .build() + .unwrap(); let fut = async { let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; assert!(result.is_err()); @@ -494,7 +478,10 @@ mod tests { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - let acc = ServiceAccountAccess::new(key).hyper_client(client).build(); + let acc = ServiceAccountAccess::new(key) + .hyper_client(client) + .build() + .unwrap(); let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) @@ -512,7 +499,7 @@ mod tests { fn test_jwt_initialize_claims() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); + let claims = Claims::new(&key, &scopes, None); assert_eq!( claims.iss, @@ -532,9 +519,9 @@ mod tests { fn test_jwt_sign() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); - let jwt = super::JWT::new(claims); - let signature = jwt.sign(key.private_key.as_ref().unwrap()); + let signer = JWTSigner::new(&key.private_key).unwrap(); + let claims = Claims::new(&key, &scopes, None); + let signature = signer.sign_claims(&claims); assert!(signature.is_ok());