diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 4945adc..e5ef33c 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -7,7 +7,9 @@ use yup_oauth2::GetToken; async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let sa = yup_oauth2::ServiceAccountAccess::new(creds) + .build() + .unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/service_account.rs b/src/service_account.rs index 2c4cb3f..03525a5 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -37,11 +37,11 @@ use hyper; use serde_json; const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; +const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#; /// Encodes s as Base64 -fn encode_base64>(s: T) -> String { - base64::encode_config(s.as_ref(), base64::URL_SAFE) +fn append_base64 + ?Sized>(s: &T, out: &mut String) { + base64::encode_config_buf(s, base64::URL_SAFE, out) } /// Decode a PKCS8 formatted RSA key. @@ -78,11 +78,11 @@ pub struct ServiceAccountKey { pub key_type: Option, pub project_id: Option, pub private_key_id: Option, - pub private_key: Option, - pub client_email: Option, + pub private_key: String, + pub client_email: String, pub client_id: Option, pub auth_uri: Option, - pub token_uri: Option, + pub token_uri: String, pub auth_provider_x509_cert_url: Option, pub client_x509_cert_url: Option, } @@ -90,52 +90,42 @@ pub struct ServiceAccountKey { /// Permissions requested for a JWT. /// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests. #[derive(Serialize, Debug)] -struct Claims { - iss: String, - aud: String, +struct Claims<'a> { + iss: &'a str, + aud: &'a str, exp: i64, iat: i64, - sub: Option, + subject: Option<&'a str>, scope: String, } -/// A JSON Web Token ready for signing. -struct JWT { - /// The value of GOOGLE_RS256_HEAD. - header: String, - /// A Claims struct, expressing the set of desired permissions etc. - claims: Claims, -} +impl<'a> Claims<'a> { + fn new(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self + where + T: AsRef, + { + let iat = chrono::Utc::now().timestamp(); + let expiry = iat + 3600 - 5; // Max validity is 1h. -impl JWT { - /// Create a new JWT from claims. - fn new(claims: Claims) -> JWT { - JWT { - header: GOOGLE_RS256_HEAD.to_string(), - claims, + let scope = crate::helper::join(scopes, " "); + Claims { + iss: &key.client_email, + aud: &key.token_uri, + exp: expiry, + iat, + subject, + scope, } } +} - /// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`. - #[allow(dead_code)] - pub fn set_header(&mut self, head: String) { - self.header = head; - } +/// A JSON Web Token ready for signing. +struct JWTSigner { + signer: Box, +} - /// Encodes the first two parts (header and claims) to base64 and assembles them into a form - /// ready to be signed. - fn encode_claims(&self) -> String { - let mut head = encode_base64(&self.header); - let claims = encode_base64(serde_json::to_string(&self.claims).unwrap()); - - head.push_str("."); - head.push_str(&claims); - head - } - - /// Sign a JWT base string with `private_key`, which is a PKCS8 string. - fn sign(&self, private_key: &str) -> Result { - let mut jwt_head = self.encode_claims(); +impl JWTSigner { + fn new(private_key: &str) -> Result { let key = decode_rsa_key(private_key)?; let signing_key = sign::RSASigningKey::new(&key) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?; @@ -144,40 +134,25 @@ impl JWT { .ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme") })?; - let signature = signer - .sign(jwt_head.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?; - let signature_b64 = encode_base64(signature); + Ok(JWTSigner { signer }) + } + fn sign_claims(&self, claims: &Claims) -> Result { + let mut jwt_head = Self::encode_claims(claims); + let signature = self.signer.sign(jwt_head.as_bytes())?; jwt_head.push_str("."); - jwt_head.push_str(&signature_b64); - + append_base64(&signature, &mut jwt_head); Ok(jwt_head) } -} -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. -fn init_claims_from_key(key: &ServiceAccountKey, scopes: &[T]) -> Claims -where - T: AsRef, -{ - let iat = chrono::Utc::now().timestamp(); - let expiry = iat + 3600 - 5; // Max validity is 1h. - - let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - scopes_string.pop(); - - Claims { - iss: key.client_email.clone().unwrap(), - aud: key.token_uri.clone().unwrap(), - exp: expiry, - iat, - sub: None, - scope: scopes_string, + /// Encodes the first two parts (header and claims) to base64 and assembles them into a form + /// ready to be signed. + fn encode_claims(claims: &Claims) -> String { + let mut head = String::new(); + append_base64(GOOGLE_RS256_HEAD, &mut head); + head.push_str("."); + append_base64(&serde_json::to_string(&claims).unwrap(), &mut head); + head } } @@ -188,7 +163,7 @@ where pub struct ServiceAccountAccess { client: C, key: ServiceAccountKey, - sub: Option, + subject: Option, } impl ServiceAccountAccess { @@ -197,7 +172,7 @@ impl ServiceAccountAccess { ServiceAccountAccess { client: DefaultHyperClient, key, - sub: None, + subject: None, } } } @@ -214,21 +189,21 @@ where ServiceAccountAccess { client: hyper_client, key: self.key, - sub: self.sub, + subject: self.subject, } } - /// Use the provided sub. - pub fn sub(self, sub: String) -> Self { + /// Use the provided subject. + pub fn subject(self, subject: String) -> Self { ServiceAccountAccess { - sub: Some(sub), + subject: Some(subject), ..self } } /// Build the configured ServiceAccountAccess. - pub fn build(self) -> impl GetToken { - ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub) + pub fn build(self) -> Result { + ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject) } } @@ -236,20 +211,27 @@ struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, cache: Arc>, - sub: Option, + subject: Option, + signer: JWTSigner, } impl ServiceAccountAccessImpl where C: hyper::client::connect::Connect, { - fn new(client: hyper::Client, key: ServiceAccountKey, sub: Option) -> Self { - ServiceAccountAccessImpl { + fn new( + client: hyper::Client, + key: ServiceAccountKey, + subject: Option, + ) -> Result { + let signer = JWTSigner::new(&key.private_key)?; + Ok(ServiceAccountAccessImpl { client, key, cache: Arc::new(Mutex::new(MemoryStorage::default())), - sub, - } + subject, + signer, + }) } } @@ -268,25 +250,25 @@ where /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( client: &hyper::client::Client, - sub: Option<&str>, + signer: &JWTSigner, + subject: Option<&str>, key: &ServiceAccountKey, scopes: &[T], ) -> Result where T: AsRef, { - let mut claims = init_claims_from_key(&key, scopes); - claims.sub = sub.map(|x| x.to_owned()); - let signed = JWT::new(claims) - .sign(key.private_key.as_ref().unwrap()) - .map_err(RequestError::LowLevelError)?; + let claims = Claims::new(key, scopes, subject); + let signed = signer.sign_claims(&claims).map_err(|_| { + RequestError::LowLevelError(io::Error::new( + io::ErrorKind::Other, + "unable to sign claims", + )) + })?; let rqbody = form_urlencoded::Serializer::new(String::new()) - .extend_pairs(vec![ - ("grant_type".to_string(), GRANT_TYPE.to_string()), - ("assertion".to_string(), signed), - ]) + .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())]) .finish(); - let request = hyper::Request::post(key.token_uri.as_ref().unwrap()) + let request = hyper::Request::post(&key.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); @@ -335,7 +317,8 @@ where } let token = Self::request_token( &self.client, - self.sub.as_ref().map(|x| x.as_str()), + &self.signer, + self.subject.as_ref().map(|x| x.as_str()), &self.key, scopes, ) @@ -399,7 +382,7 @@ mod tests { "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" }"#; let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap(); - key.token_uri = Some(format!("{}/token", server_url)); + key.token_uri = format!("{}/token", server_url); let json_response = r#"{ "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", @@ -429,7 +412,7 @@ mod tests { .with_body(json_response) .expect(1) .create(); - let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); + let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap(); let fut = async { let tok = acc .token(&["https://www.googleapis.com/auth/pubsub"]) @@ -472,7 +455,8 @@ mod tests { .create(); let acc = ServiceAccountAccess::new(key.clone()) .hyper_client(client.clone()) - .build(); + .build() + .unwrap(); let fut = async { let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; assert!(result.is_err()); @@ -494,7 +478,10 @@ mod tests { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - let acc = ServiceAccountAccess::new(key).hyper_client(client).build(); + let acc = ServiceAccountAccess::new(key) + .hyper_client(client) + .build() + .unwrap(); let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) @@ -512,7 +499,7 @@ mod tests { fn test_jwt_initialize_claims() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); + let claims = Claims::new(&key, &scopes, None); assert_eq!( claims.iss, @@ -532,9 +519,9 @@ mod tests { fn test_jwt_sign() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); - let jwt = super::JWT::new(claims); - let signature = jwt.sign(key.private_key.as_ref().unwrap()); + let signer = JWTSigner::new(&key.private_key).unwrap(); + let claims = Claims::new(&key, &scopes, None); + let signature = signer.sign_claims(&claims); assert!(signature.is_ok());