Refactor JWT handling in ServiceAccountAccess.

Avoid reading and parsing the private key file on every invocation of
token() in favor or reading it once when the ServiceAccountAccess is
built. Also avoid unnecessary allocations when signing JWT tokens and
renamed sub to subject to avoid any confusion with the std::ops::Sub
trait.
This commit is contained in:
Glenn Griffin
2019-11-11 14:04:27 -08:00
parent 05f7c10533
commit 060eb92bf7
2 changed files with 94 additions and 105 deletions

View File

@@ -7,7 +7,9 @@ use yup_oauth2::GetToken;
async fn main() {
let creds =
yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap();
let sa = yup_oauth2::ServiceAccountAccess::new(creds).build();
let sa = yup_oauth2::ServiceAccountAccess::new(creds)
.build()
.unwrap();
let scopes = &["https://www.googleapis.com/auth/pubsub"];
let tok = sa.token(scopes).await.unwrap();

View File

@@ -37,11 +37,11 @@ use hyper;
use serde_json;
const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}";
const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;
/// Encodes s as Base64
fn encode_base64<T: AsRef<[u8]>>(s: T) -> String {
base64::encode_config(s.as_ref(), base64::URL_SAFE)
fn append_base64<T: AsRef<[u8]> + ?Sized>(s: &T, out: &mut String) {
base64::encode_config_buf(s, base64::URL_SAFE, out)
}
/// Decode a PKCS8 formatted RSA key.
@@ -78,11 +78,11 @@ pub struct ServiceAccountKey {
pub key_type: Option<String>,
pub project_id: Option<String>,
pub private_key_id: Option<String>,
pub private_key: Option<String>,
pub client_email: Option<String>,
pub private_key: String,
pub client_email: String,
pub client_id: Option<String>,
pub auth_uri: Option<String>,
pub token_uri: Option<String>,
pub token_uri: String,
pub auth_provider_x509_cert_url: Option<String>,
pub client_x509_cert_url: Option<String>,
}
@@ -90,52 +90,42 @@ pub struct ServiceAccountKey {
/// Permissions requested for a JWT.
/// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests.
#[derive(Serialize, Debug)]
struct Claims {
iss: String,
aud: String,
struct Claims<'a> {
iss: &'a str,
aud: &'a str,
exp: i64,
iat: i64,
sub: Option<String>,
subject: Option<&'a str>,
scope: String,
}
/// A JSON Web Token ready for signing.
struct JWT {
/// The value of GOOGLE_RS256_HEAD.
header: String,
/// A Claims struct, expressing the set of desired permissions etc.
claims: Claims,
}
impl<'a> Claims<'a> {
fn new<T>(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self
where
T: AsRef<str>,
{
let iat = chrono::Utc::now().timestamp();
let expiry = iat + 3600 - 5; // Max validity is 1h.
impl JWT {
/// Create a new JWT from claims.
fn new(claims: Claims) -> JWT {
JWT {
header: GOOGLE_RS256_HEAD.to_string(),
claims,
let scope = crate::helper::join(scopes, " ");
Claims {
iss: &key.client_email,
aud: &key.token_uri,
exp: expiry,
iat,
subject,
scope,
}
}
}
/// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`.
#[allow(dead_code)]
pub fn set_header(&mut self, head: String) {
self.header = head;
}
/// A JSON Web Token ready for signing.
struct JWTSigner {
signer: Box<dyn rustls::sign::Signer>,
}
/// Encodes the first two parts (header and claims) to base64 and assembles them into a form
/// ready to be signed.
fn encode_claims(&self) -> String {
let mut head = encode_base64(&self.header);
let claims = encode_base64(serde_json::to_string(&self.claims).unwrap());
head.push_str(".");
head.push_str(&claims);
head
}
/// Sign a JWT base string with `private_key`, which is a PKCS8 string.
fn sign(&self, private_key: &str) -> Result<String, io::Error> {
let mut jwt_head = self.encode_claims();
impl JWTSigner {
fn new(private_key: &str) -> Result<Self, io::Error> {
let key = decode_rsa_key(private_key)?;
let signing_key = sign::RSASigningKey::new(&key)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?;
@@ -144,40 +134,25 @@ impl JWT {
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme")
})?;
let signature = signer
.sign(jwt_head.as_bytes())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?;
let signature_b64 = encode_base64(signature);
Ok(JWTSigner { signer })
}
fn sign_claims(&self, claims: &Claims) -> Result<String, rustls::TLSError> {
let mut jwt_head = Self::encode_claims(claims);
let signature = self.signer.sign(jwt_head.as_bytes())?;
jwt_head.push_str(".");
jwt_head.push_str(&signature_b64);
append_base64(&signature, &mut jwt_head);
Ok(jwt_head)
}
}
/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`.
fn init_claims_from_key<T>(key: &ServiceAccountKey, scopes: &[T]) -> Claims
where
T: AsRef<str>,
{
let iat = chrono::Utc::now().timestamp();
let expiry = iat + 3600 - 5; // Max validity is 1h.
let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| {
acc.push_str(sc.as_ref());
acc.push_str(" ");
acc
});
scopes_string.pop();
Claims {
iss: key.client_email.clone().unwrap(),
aud: key.token_uri.clone().unwrap(),
exp: expiry,
iat,
sub: None,
scope: scopes_string,
/// Encodes the first two parts (header and claims) to base64 and assembles them into a form
/// ready to be signed.
fn encode_claims(claims: &Claims) -> String {
let mut head = String::new();
append_base64(GOOGLE_RS256_HEAD, &mut head);
head.push_str(".");
append_base64(&serde_json::to_string(&claims).unwrap(), &mut head);
head
}
}
@@ -188,7 +163,7 @@ where
pub struct ServiceAccountAccess<C> {
client: C,
key: ServiceAccountKey,
sub: Option<String>,
subject: Option<String>,
}
impl ServiceAccountAccess<DefaultHyperClient> {
@@ -197,7 +172,7 @@ impl ServiceAccountAccess<DefaultHyperClient> {
ServiceAccountAccess {
client: DefaultHyperClient,
key,
sub: None,
subject: None,
}
}
}
@@ -214,21 +189,21 @@ where
ServiceAccountAccess {
client: hyper_client,
key: self.key,
sub: self.sub,
subject: self.subject,
}
}
/// Use the provided sub.
pub fn sub(self, sub: String) -> Self {
/// Use the provided subject.
pub fn subject(self, subject: String) -> Self {
ServiceAccountAccess {
sub: Some(sub),
subject: Some(subject),
..self
}
}
/// Build the configured ServiceAccountAccess.
pub fn build(self) -> impl GetToken {
ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub)
pub fn build(self) -> Result<impl GetToken, io::Error> {
ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject)
}
}
@@ -236,20 +211,27 @@ struct ServiceAccountAccessImpl<C> {
client: hyper::Client<C, hyper::Body>,
key: ServiceAccountKey,
cache: Arc<Mutex<MemoryStorage>>,
sub: Option<String>,
subject: Option<String>,
signer: JWTSigner,
}
impl<C> ServiceAccountAccessImpl<C>
where
C: hyper::client::connect::Connect,
{
fn new(client: hyper::Client<C>, key: ServiceAccountKey, sub: Option<String>) -> Self {
ServiceAccountAccessImpl {
fn new(
client: hyper::Client<C>,
key: ServiceAccountKey,
subject: Option<String>,
) -> Result<Self, io::Error> {
let signer = JWTSigner::new(&key.private_key)?;
Ok(ServiceAccountAccessImpl {
client,
key,
cache: Arc::new(Mutex::new(MemoryStorage::default())),
sub,
}
subject,
signer,
})
}
}
@@ -268,25 +250,25 @@ where
/// Send a request for a new Bearer token to the OAuth provider.
async fn request_token<T>(
client: &hyper::client::Client<C>,
sub: Option<&str>,
signer: &JWTSigner,
subject: Option<&str>,
key: &ServiceAccountKey,
scopes: &[T],
) -> Result<Token, RequestError>
where
T: AsRef<str>,
{
let mut claims = init_claims_from_key(&key, scopes);
claims.sub = sub.map(|x| x.to_owned());
let signed = JWT::new(claims)
.sign(key.private_key.as_ref().unwrap())
.map_err(RequestError::LowLevelError)?;
let claims = Claims::new(key, scopes, subject);
let signed = signer.sign_claims(&claims).map_err(|_| {
RequestError::LowLevelError(io::Error::new(
io::ErrorKind::Other,
"unable to sign claims",
))
})?;
let rqbody = form_urlencoded::Serializer::new(String::new())
.extend_pairs(vec![
("grant_type".to_string(), GRANT_TYPE.to_string()),
("assertion".to_string(), signed),
])
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())])
.finish();
let request = hyper::Request::post(key.token_uri.as_ref().unwrap())
let request = hyper::Request::post(&key.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(rqbody))
.unwrap();
@@ -335,7 +317,8 @@ where
}
let token = Self::request_token(
&self.client,
self.sub.as_ref().map(|x| x.as_str()),
&self.signer,
self.subject.as_ref().map(|x| x.as_str()),
&self.key,
scopes,
)
@@ -399,7 +382,7 @@ mod tests {
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com"
}"#;
let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap();
key.token_uri = Some(format!("{}/token", server_url));
key.token_uri = format!("{}/token", server_url);
let json_response = r#"{
"access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk",
@@ -429,7 +412,7 @@ mod tests {
.with_body(json_response)
.expect(1)
.create();
let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None);
let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap();
let fut = async {
let tok = acc
.token(&["https://www.googleapis.com/auth/pubsub"])
@@ -472,7 +455,8 @@ mod tests {
.create();
let acc = ServiceAccountAccess::new(key.clone())
.hyper_client(client.clone())
.build();
.build()
.unwrap();
let fut = async {
let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await;
assert!(result.is_err());
@@ -494,7 +478,10 @@ mod tests {
let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap();
let https = HttpsConnector::new();
let client = hyper::Client::builder().build(https);
let acc = ServiceAccountAccess::new(key).hyper_client(client).build();
let acc = ServiceAccountAccess::new(key)
.hyper_client(client)
.build()
.unwrap();
let rt = tokio::runtime::Builder::new()
.core_threads(1)
.panic_handler(|e| std::panic::resume_unwind(e))
@@ -512,7 +499,7 @@ mod tests {
fn test_jwt_initialize_claims() {
let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let claims = super::init_claims_from_key(&key, &scopes);
let claims = Claims::new(&key, &scopes, None);
assert_eq!(
claims.iss,
@@ -532,9 +519,9 @@ mod tests {
fn test_jwt_sign() {
let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let claims = super::init_claims_from_key(&key, &scopes);
let jwt = super::JWT::new(claims);
let signature = jwt.sign(key.private_key.as_ref().unwrap());
let signer = JWTSigner::new(&key.private_key).unwrap();
let claims = Claims::new(&key, &scopes, None);
let signature = signer.sign_claims(&claims);
assert!(signature.is_ok());