From 56aa9b20f5129afc4092a7ac5cfc9f873098a322 Mon Sep 17 00:00:00 2001 From: OMGeeky Date: Sun, 24 Sep 2023 18:54:41 +0200 Subject: [PATCH] initial commit --- .gitignore | 3 + Cargo.toml | 14 ++++ src/lib.rs | 208 +++++++++++++++++++++++++++++++++++++++++++++++++ src/prelude.rs | 6 ++ 4 files changed, 231 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/prelude.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ab5292 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +/Cargo.lock +.idea diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ce8de9b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "reqwest_backoff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +reqwest = "0.11" +thiserror = "1.0" +tokio = "1.32" +url = "2.4.1" +tracing = "0.1" +chrono = "0.4.31" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0bcb2c0 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,208 @@ +use std::ops::Deref; + +use chrono::{DateTime, NaiveDateTime, Utc}; +use reqwest::{Error, Request, Response}; +use url::Host; + +use prelude::*; + +const MAX_BACKOFF_ATTEMPTS: u32 = 50; +const MAX_BACKOFF_ATTEMPTS_GOOGLE: u32 = 50; +const MAX_BACKOFF_ATTEMPTS_TWITCH: u32 = 50; + +const GOOGLE_BASE_BACKOFF_TIME_S: u64 = 2; +const GOOGLE_MAX_BACKOFF_TIME_S: u64 = 3600; + +pub mod prelude; + +#[derive(Debug, thiserror::Error)] +pub enum ReqwestBackoffError { + #[error("Reqwest error")] + Reqwest(#[from] Error), + #[error("Other error")] + Other(#[from] Box), + #[error("Backoff error after {backoff_attempts} attempts")] + BackoffExceeded { backoff_attempts: u32 }, +} + +#[derive(Debug, Clone)] +pub struct ReqwestClient { + client: reqwest::Client, +} + +impl Deref for ReqwestClient { + type Target = reqwest::Client; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +impl From for ReqwestClient { + fn from(client: reqwest::Client) -> Self { + Self { client } + } +} + +impl Default for ReqwestClient { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HostType { + Twitch, + Google, + Youtube, + Other, +} + +impl ReqwestClient { + pub fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } + #[tracing::instrument] + pub async fn execute_with_backoff(&self, request: Request) -> Result { + let host: HostType = get_host_from_request(&request); + + let request_clone = request.try_clone(); + if let Some(request_clone) = request_clone { + self.execute_with_backoff_inner(request_clone, host).await + } else { + warn!("Failed to clone request. No backoff possible."); + Ok(self + .client + .execute(request) + .await + .map_err(ReqwestBackoffError::Reqwest)?) + } + } + + /// Execute a request with backoff if the response indicates that it should. + /// + /// # Arguments + /// + /// * `self` - The client to use for the request. + /// * `request` - The request to execute. This needs to be cloneable otherwise the function will panic. (not cloneable requests can't be retried) + /// * `host` - The host of the request. This is used to determine the backoff time. + async fn execute_with_backoff_inner( + &self, + request: Request, + host: HostType, + ) -> Result { + let mut attempt: u32 = 1; + let mut response = self + .execute(request.try_clone().unwrap()) + .await + .map_err(ReqwestBackoffError::Reqwest)?; + while check_response_is_backoff(&response, host) { + if is_backoff_limit_reached(attempt, host) { + return Err(ReqwestBackoffError::BackoffExceeded { + backoff_attempts: attempt, + }); + } + let sleep_duration = get_backoff_time(&response, host, attempt)?; + info!("Sleeping for {} seconds", sleep_duration); + tokio::time::sleep(std::time::Duration::from_secs(sleep_duration)).await; + attempt += 1; + info!("Backoff attempt #{}", attempt); + response = self + .client + .execute(request.try_clone().unwrap()) + .await + .map_err(ReqwestBackoffError::Reqwest)?; + } + Ok(response) + } +} + +#[tracing::instrument] +fn get_host_from_request(request: &Request) -> HostType { + if let Some(Host::Domain(domain)) = request.url().host() { + match domain { + "twitch.tv" => HostType::Twitch, + "google.com" => HostType::Google, + "youtube.com" => HostType::Youtube, + _ => HostType::Other, + } + } else { + HostType::Other + } +} + +#[tracing::instrument] +fn is_backoff_limit_reached(attempt: u32, host: HostType) -> bool { + match host { + HostType::Twitch => attempt > MAX_BACKOFF_ATTEMPTS_TWITCH, + HostType::Google | HostType::Youtube => attempt > MAX_BACKOFF_ATTEMPTS_GOOGLE, + HostType::Other => attempt > MAX_BACKOFF_ATTEMPTS, + } +} + +#[tracing::instrument] +fn check_response_is_backoff(response: &Response, host: HostType) -> bool { + dbg!(response, host); + let code = response.status(); + if code.is_success() { + return false; + } + let code = code.as_u16(); + match host { + HostType::Twitch => code == 429, + HostType::Google | HostType::Youtube => { + if !(code == 403 || code == 400) { + return false; + } + warn!("check_response_is_backoff->code: {}", code); + warn!("check_response_is_backoff->response: {:?}", response); + true + } + HostType::Other => false, + } +} + +#[tracing::instrument] +fn get_backoff_time(response: &Response, host: HostType, attempt: u32) -> Result { + dbg!(response, host); + Ok(match host { + HostType::Twitch => { + let timestamp = get_twitch_rate_limit_value(response)?; + let duration = chrono::Local::now().naive_utc().and_utc() - timestamp; + let duration = duration.num_seconds() as u64; + if duration > 0 { + duration + } else { + 1 + } + } + HostType::Google | HostType::Youtube => { + let backoff_time = GOOGLE_BASE_BACKOFF_TIME_S.pow(attempt); + if backoff_time > GOOGLE_MAX_BACKOFF_TIME_S { + GOOGLE_MAX_BACKOFF_TIME_S + } else { + backoff_time + } + } + HostType::Other => 5, + }) +} + +#[tracing::instrument] +fn get_twitch_rate_limit_value(response: &Response) -> Result> { + let timestamp = response + .headers() + .get("Ratelimit-Reset") + .unwrap() + .to_str() + .map_err(|e| ReqwestBackoffError::Other(e.into()))? + .to_string() + .parse::() + .map_err(|e| ReqwestBackoffError::Other(e.into()))?; + let timestamp = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or( + ReqwestBackoffError::Other("Could not convert the provided timestamp".into()), + )?; + Ok(timestamp.and_utc()) +} diff --git a/src/prelude.rs b/src/prelude.rs new file mode 100644 index 0000000..48a3605 --- /dev/null +++ b/src/prelude.rs @@ -0,0 +1,6 @@ +pub(crate) use std::error::Error as StdError; + +pub use crate::ReqwestBackoffError; + +pub(crate) use tracing::{info, warn}; +pub(crate) type Result = std::result::Result;