mirror of
https://github.com/OMGeeky/yup-oauth2.git
synced 2025-12-26 16:27:25 +01:00
Refactor JWT handling in ServiceAccountAccess.
Avoid reading and parsing the private key file on every invocation of token() in favor or reading it once when the ServiceAccountAccess is built. Also avoid unnecessary allocations when signing JWT tokens and renamed sub to subject to avoid any confusion with the std::ops::Sub trait.
This commit is contained in:
@@ -7,7 +7,9 @@ use yup_oauth2::GetToken;
|
||||
async fn main() {
|
||||
let creds =
|
||||
yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap();
|
||||
let sa = yup_oauth2::ServiceAccountAccess::new(creds).build();
|
||||
let sa = yup_oauth2::ServiceAccountAccess::new(creds)
|
||||
.build()
|
||||
.unwrap();
|
||||
let scopes = &["https://www.googleapis.com/auth/pubsub"];
|
||||
|
||||
let tok = sa.token(scopes).await.unwrap();
|
||||
|
||||
@@ -37,11 +37,11 @@ use hyper;
|
||||
use serde_json;
|
||||
|
||||
const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
|
||||
const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}";
|
||||
const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;
|
||||
|
||||
/// Encodes s as Base64
|
||||
fn encode_base64<T: AsRef<[u8]>>(s: T) -> String {
|
||||
base64::encode_config(s.as_ref(), base64::URL_SAFE)
|
||||
fn append_base64<T: AsRef<[u8]> + ?Sized>(s: &T, out: &mut String) {
|
||||
base64::encode_config_buf(s, base64::URL_SAFE, out)
|
||||
}
|
||||
|
||||
/// Decode a PKCS8 formatted RSA key.
|
||||
@@ -78,11 +78,11 @@ pub struct ServiceAccountKey {
|
||||
pub key_type: Option<String>,
|
||||
pub project_id: Option<String>,
|
||||
pub private_key_id: Option<String>,
|
||||
pub private_key: Option<String>,
|
||||
pub client_email: Option<String>,
|
||||
pub private_key: String,
|
||||
pub client_email: String,
|
||||
pub client_id: Option<String>,
|
||||
pub auth_uri: Option<String>,
|
||||
pub token_uri: Option<String>,
|
||||
pub token_uri: String,
|
||||
pub auth_provider_x509_cert_url: Option<String>,
|
||||
pub client_x509_cert_url: Option<String>,
|
||||
}
|
||||
@@ -90,52 +90,42 @@ pub struct ServiceAccountKey {
|
||||
/// Permissions requested for a JWT.
|
||||
/// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests.
|
||||
#[derive(Serialize, Debug)]
|
||||
struct Claims {
|
||||
iss: String,
|
||||
aud: String,
|
||||
struct Claims<'a> {
|
||||
iss: &'a str,
|
||||
aud: &'a str,
|
||||
exp: i64,
|
||||
iat: i64,
|
||||
sub: Option<String>,
|
||||
subject: Option<&'a str>,
|
||||
scope: String,
|
||||
}
|
||||
|
||||
/// A JSON Web Token ready for signing.
|
||||
struct JWT {
|
||||
/// The value of GOOGLE_RS256_HEAD.
|
||||
header: String,
|
||||
/// A Claims struct, expressing the set of desired permissions etc.
|
||||
claims: Claims,
|
||||
}
|
||||
impl<'a> Claims<'a> {
|
||||
fn new<T>(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let iat = chrono::Utc::now().timestamp();
|
||||
let expiry = iat + 3600 - 5; // Max validity is 1h.
|
||||
|
||||
impl JWT {
|
||||
/// Create a new JWT from claims.
|
||||
fn new(claims: Claims) -> JWT {
|
||||
JWT {
|
||||
header: GOOGLE_RS256_HEAD.to_string(),
|
||||
claims,
|
||||
let scope = crate::helper::join(scopes, " ");
|
||||
Claims {
|
||||
iss: &key.client_email,
|
||||
aud: &key.token_uri,
|
||||
exp: expiry,
|
||||
iat,
|
||||
subject,
|
||||
scope,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`.
|
||||
#[allow(dead_code)]
|
||||
pub fn set_header(&mut self, head: String) {
|
||||
self.header = head;
|
||||
}
|
||||
/// A JSON Web Token ready for signing.
|
||||
struct JWTSigner {
|
||||
signer: Box<dyn rustls::sign::Signer>,
|
||||
}
|
||||
|
||||
/// Encodes the first two parts (header and claims) to base64 and assembles them into a form
|
||||
/// ready to be signed.
|
||||
fn encode_claims(&self) -> String {
|
||||
let mut head = encode_base64(&self.header);
|
||||
let claims = encode_base64(serde_json::to_string(&self.claims).unwrap());
|
||||
|
||||
head.push_str(".");
|
||||
head.push_str(&claims);
|
||||
head
|
||||
}
|
||||
|
||||
/// Sign a JWT base string with `private_key`, which is a PKCS8 string.
|
||||
fn sign(&self, private_key: &str) -> Result<String, io::Error> {
|
||||
let mut jwt_head = self.encode_claims();
|
||||
impl JWTSigner {
|
||||
fn new(private_key: &str) -> Result<Self, io::Error> {
|
||||
let key = decode_rsa_key(private_key)?;
|
||||
let signing_key = sign::RSASigningKey::new(&key)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?;
|
||||
@@ -144,40 +134,25 @@ impl JWT {
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme")
|
||||
})?;
|
||||
let signature = signer
|
||||
.sign(jwt_head.as_bytes())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?;
|
||||
let signature_b64 = encode_base64(signature);
|
||||
Ok(JWTSigner { signer })
|
||||
}
|
||||
|
||||
fn sign_claims(&self, claims: &Claims) -> Result<String, rustls::TLSError> {
|
||||
let mut jwt_head = Self::encode_claims(claims);
|
||||
let signature = self.signer.sign(jwt_head.as_bytes())?;
|
||||
jwt_head.push_str(".");
|
||||
jwt_head.push_str(&signature_b64);
|
||||
|
||||
append_base64(&signature, &mut jwt_head);
|
||||
Ok(jwt_head)
|
||||
}
|
||||
}
|
||||
|
||||
/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`.
|
||||
fn init_claims_from_key<T>(key: &ServiceAccountKey, scopes: &[T]) -> Claims
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let iat = chrono::Utc::now().timestamp();
|
||||
let expiry = iat + 3600 - 5; // Max validity is 1h.
|
||||
|
||||
let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| {
|
||||
acc.push_str(sc.as_ref());
|
||||
acc.push_str(" ");
|
||||
acc
|
||||
});
|
||||
scopes_string.pop();
|
||||
|
||||
Claims {
|
||||
iss: key.client_email.clone().unwrap(),
|
||||
aud: key.token_uri.clone().unwrap(),
|
||||
exp: expiry,
|
||||
iat,
|
||||
sub: None,
|
||||
scope: scopes_string,
|
||||
/// Encodes the first two parts (header and claims) to base64 and assembles them into a form
|
||||
/// ready to be signed.
|
||||
fn encode_claims(claims: &Claims) -> String {
|
||||
let mut head = String::new();
|
||||
append_base64(GOOGLE_RS256_HEAD, &mut head);
|
||||
head.push_str(".");
|
||||
append_base64(&serde_json::to_string(&claims).unwrap(), &mut head);
|
||||
head
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +163,7 @@ where
|
||||
pub struct ServiceAccountAccess<C> {
|
||||
client: C,
|
||||
key: ServiceAccountKey,
|
||||
sub: Option<String>,
|
||||
subject: Option<String>,
|
||||
}
|
||||
|
||||
impl ServiceAccountAccess<DefaultHyperClient> {
|
||||
@@ -197,7 +172,7 @@ impl ServiceAccountAccess<DefaultHyperClient> {
|
||||
ServiceAccountAccess {
|
||||
client: DefaultHyperClient,
|
||||
key,
|
||||
sub: None,
|
||||
subject: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -214,21 +189,21 @@ where
|
||||
ServiceAccountAccess {
|
||||
client: hyper_client,
|
||||
key: self.key,
|
||||
sub: self.sub,
|
||||
subject: self.subject,
|
||||
}
|
||||
}
|
||||
|
||||
/// Use the provided sub.
|
||||
pub fn sub(self, sub: String) -> Self {
|
||||
/// Use the provided subject.
|
||||
pub fn subject(self, subject: String) -> Self {
|
||||
ServiceAccountAccess {
|
||||
sub: Some(sub),
|
||||
subject: Some(subject),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the configured ServiceAccountAccess.
|
||||
pub fn build(self) -> impl GetToken {
|
||||
ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub)
|
||||
pub fn build(self) -> Result<impl GetToken, io::Error> {
|
||||
ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,20 +211,27 @@ struct ServiceAccountAccessImpl<C> {
|
||||
client: hyper::Client<C, hyper::Body>,
|
||||
key: ServiceAccountKey,
|
||||
cache: Arc<Mutex<MemoryStorage>>,
|
||||
sub: Option<String>,
|
||||
subject: Option<String>,
|
||||
signer: JWTSigner,
|
||||
}
|
||||
|
||||
impl<C> ServiceAccountAccessImpl<C>
|
||||
where
|
||||
C: hyper::client::connect::Connect,
|
||||
{
|
||||
fn new(client: hyper::Client<C>, key: ServiceAccountKey, sub: Option<String>) -> Self {
|
||||
ServiceAccountAccessImpl {
|
||||
fn new(
|
||||
client: hyper::Client<C>,
|
||||
key: ServiceAccountKey,
|
||||
subject: Option<String>,
|
||||
) -> Result<Self, io::Error> {
|
||||
let signer = JWTSigner::new(&key.private_key)?;
|
||||
Ok(ServiceAccountAccessImpl {
|
||||
client,
|
||||
key,
|
||||
cache: Arc::new(Mutex::new(MemoryStorage::default())),
|
||||
sub,
|
||||
}
|
||||
subject,
|
||||
signer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,25 +250,25 @@ where
|
||||
/// Send a request for a new Bearer token to the OAuth provider.
|
||||
async fn request_token<T>(
|
||||
client: &hyper::client::Client<C>,
|
||||
sub: Option<&str>,
|
||||
signer: &JWTSigner,
|
||||
subject: Option<&str>,
|
||||
key: &ServiceAccountKey,
|
||||
scopes: &[T],
|
||||
) -> Result<Token, RequestError>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let mut claims = init_claims_from_key(&key, scopes);
|
||||
claims.sub = sub.map(|x| x.to_owned());
|
||||
let signed = JWT::new(claims)
|
||||
.sign(key.private_key.as_ref().unwrap())
|
||||
.map_err(RequestError::LowLevelError)?;
|
||||
let claims = Claims::new(key, scopes, subject);
|
||||
let signed = signer.sign_claims(&claims).map_err(|_| {
|
||||
RequestError::LowLevelError(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"unable to sign claims",
|
||||
))
|
||||
})?;
|
||||
let rqbody = form_urlencoded::Serializer::new(String::new())
|
||||
.extend_pairs(vec![
|
||||
("grant_type".to_string(), GRANT_TYPE.to_string()),
|
||||
("assertion".to_string(), signed),
|
||||
])
|
||||
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())])
|
||||
.finish();
|
||||
let request = hyper::Request::post(key.token_uri.as_ref().unwrap())
|
||||
let request = hyper::Request::post(&key.token_uri)
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(hyper::Body::from(rqbody))
|
||||
.unwrap();
|
||||
@@ -335,7 +317,8 @@ where
|
||||
}
|
||||
let token = Self::request_token(
|
||||
&self.client,
|
||||
self.sub.as_ref().map(|x| x.as_str()),
|
||||
&self.signer,
|
||||
self.subject.as_ref().map(|x| x.as_str()),
|
||||
&self.key,
|
||||
scopes,
|
||||
)
|
||||
@@ -399,7 +382,7 @@ mod tests {
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com"
|
||||
}"#;
|
||||
let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap();
|
||||
key.token_uri = Some(format!("{}/token", server_url));
|
||||
key.token_uri = format!("{}/token", server_url);
|
||||
|
||||
let json_response = r#"{
|
||||
"access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk",
|
||||
@@ -429,7 +412,7 @@ mod tests {
|
||||
.with_body(json_response)
|
||||
.expect(1)
|
||||
.create();
|
||||
let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None);
|
||||
let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap();
|
||||
let fut = async {
|
||||
let tok = acc
|
||||
.token(&["https://www.googleapis.com/auth/pubsub"])
|
||||
@@ -472,7 +455,8 @@ mod tests {
|
||||
.create();
|
||||
let acc = ServiceAccountAccess::new(key.clone())
|
||||
.hyper_client(client.clone())
|
||||
.build();
|
||||
.build()
|
||||
.unwrap();
|
||||
let fut = async {
|
||||
let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await;
|
||||
assert!(result.is_err());
|
||||
@@ -494,7 +478,10 @@ mod tests {
|
||||
let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap();
|
||||
let https = HttpsConnector::new();
|
||||
let client = hyper::Client::builder().build(https);
|
||||
let acc = ServiceAccountAccess::new(key).hyper_client(client).build();
|
||||
let acc = ServiceAccountAccess::new(key)
|
||||
.hyper_client(client)
|
||||
.build()
|
||||
.unwrap();
|
||||
let rt = tokio::runtime::Builder::new()
|
||||
.core_threads(1)
|
||||
.panic_handler(|e| std::panic::resume_unwind(e))
|
||||
@@ -512,7 +499,7 @@ mod tests {
|
||||
fn test_jwt_initialize_claims() {
|
||||
let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap();
|
||||
let scopes = vec!["scope1", "scope2", "scope3"];
|
||||
let claims = super::init_claims_from_key(&key, &scopes);
|
||||
let claims = Claims::new(&key, &scopes, None);
|
||||
|
||||
assert_eq!(
|
||||
claims.iss,
|
||||
@@ -532,9 +519,9 @@ mod tests {
|
||||
fn test_jwt_sign() {
|
||||
let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap();
|
||||
let scopes = vec!["scope1", "scope2", "scope3"];
|
||||
let claims = super::init_claims_from_key(&key, &scopes);
|
||||
let jwt = super::JWT::new(claims);
|
||||
let signature = jwt.sign(key.private_key.as_ref().unwrap());
|
||||
let signer = JWTSigner::new(&key.private_key).unwrap();
|
||||
let claims = Claims::new(&key, &scopes, None);
|
||||
let signature = signer.sign_claims(&claims);
|
||||
|
||||
assert!(signature.is_ok());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user