Use chrono::Duration directly with serde attributes

This commit is contained in:
philippeitis
2022-10-07 20:34:40 -07:00
parent 444b610ddc
commit 05df68de32
2 changed files with 73 additions and 75 deletions

View File

@@ -849,39 +849,6 @@ pub mod types {
use std::str::FromStr;
use serde::{Deserialize, Deserializer, Serializer};
// https://github.com/protocolbuffers/protobuf-go/blob/6875c3d7242d1a3db910ce8a504f124cb840c23a/types/known/durationpb/duration.pb.go#L148
#[derive(Deserialize)]
#[serde(try_from = "IntermediateDuration")]
pub struct Duration {
pub seconds: i64,
pub nanoseconds: i32,
}
impl From<Duration> for chrono::Duration {
fn from(duration: Duration) -> chrono::Duration {
chrono::Duration::seconds(duration.seconds) + chrono::Duration::nanoseconds(duration.nanoseconds as i64)
}
}
#[derive(Deserialize)]
struct IntermediateDuration<'a>(&'a str);
impl serde::Serialize for Duration {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if self.nanoseconds != 0 {
if self.seconds == 0 && self.nanoseconds.is_negative() {
serializer.serialize_str(&format!("-0.{}s", self.nanoseconds.abs()))
} else {
serializer.serialize_str(&format!("{}.{}s", self.seconds, self.nanoseconds.abs()))
}
} else {
serializer.serialize_str(&format!("{}s", self.seconds))
}
}
}
#[derive(Debug)]
enum ParseDurationError {
@@ -892,6 +859,45 @@ pub mod types {
SecondUnderflow { seconds: i64, min_seconds: i64 }
}
fn parse_duration_from_str(s: &str) -> Result<chrono::Duration, ParseDurationError> {
let abs_duration = 315576000000i64;
// TODO: Test strings like -.s, -0.0s
let value = match s.strip_suffix('s') {
None => return Err(ParseDurationError::MissingSecondSuffix),
Some(v) => v
};
let (seconds, nanoseconds) = if let Some((seconds, nanos)) = value.split_once('.') {
let is_neg = seconds.starts_with("-");
let seconds = i64::from_str(seconds)?;
let nano_magnitude = nanos.chars().filter(|c| c.is_digit(10)).count() as u32;
if nano_magnitude > 9 {
// not enough precision to model the remaining digits
return Err(ParseDurationError::NanosTooSmall);
}
// u32::from_str prevents negative nanos (eg '0.-12s) -> lossless conversion to i32
// 10_u32.pow(...) scales number to appropriate # of nanoseconds
let nanos = u32::from_str(nanos)? as i32;
let mut nanos = nanos * 10_i32.pow(9 - nano_magnitude);
if is_neg {
nanos = -nanos;
}
(seconds, nanos)
} else {
(i64::from_str(value)?, 0)
};
if seconds >= abs_duration {
Err(ParseDurationError::SecondOverflow { seconds, max_seconds: abs_duration })
} else if seconds <= -abs_duration {
Err(ParseDurationError::SecondUnderflow { seconds, min_seconds: -abs_duration })
} else {
Ok(chrono::Duration::seconds(seconds) + chrono::Duration::nanoseconds(nanoseconds.into()))
}
}
impl From<std::num::ParseIntError> for ParseDurationError {
fn from(pie: std::num::ParseIntError) -> Self {
ParseDurationError::ParseIntError(pie)
@@ -903,7 +909,7 @@ pub mod types {
match self {
ParseDurationError::MissingSecondSuffix => write!(f, "'s' suffix was not present"),
ParseDurationError::NanosTooSmall => write!(f, "more than 9 digits of second precision required"),
ParseDurationError::ParseIntError(pie) => write!(f, "{}", pie),
ParseDurationError::ParseIntError(pie) => write!(f, "{:?}", pie),
ParseDurationError::SecondOverflow { seconds, max_seconds } => write!(f, "seconds overflow (got {}, maximum seconds possible {})", seconds, max_seconds),
ParseDurationError::SecondUnderflow { seconds, min_seconds } => write!(f, "seconds underflow (got {}, minimum seconds possible {})", seconds, min_seconds)
}
@@ -912,52 +918,41 @@ pub mod types {
impl std::error::Error for ParseDurationError {}
impl <'a> TryFrom<IntermediateDuration<'a>> for Duration {
type Error = ParseDurationError;
fn try_from(value: IntermediateDuration<'a>) -> Result<Self, Self::Error> {
let abs_duration = 315576000000i64;
// TODO: Test strings like -.s, -0.0s
let value = match value.0.strip_suffix('s') {
None => return Err(ParseDurationError::MissingSecondSuffix),
Some(v) => v
};
let (seconds, nanoseconds) = if let Some((seconds, nanos)) = value.split_once('.') {
let is_neg = seconds.starts_with("-");
let seconds = i64::from_str(seconds)?;
let nano_magnitude = nanos.chars().filter(|c| c.is_digit(10)).count() as u32;
if nano_magnitude > 9 {
// not enough precision to model the remaining digits
return Err(ParseDurationError::NanosTooSmall);
pub fn to_duration_str<S>(x: Option<&chrono::Duration>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match x {
None => s.serialize_none(),
Some(x) => {
let seconds = x.num_seconds();
let nanoseconds = (*x - chrono::Duration::seconds(seconds))
.num_nanoseconds()
.expect("number of nanoseconds is less than or equal to 1 billion") as i32;
// might be left with -1 + non-zero nanos
if nanoseconds != 0 {
if seconds == 0 && nanoseconds.is_negative() {
s.serialize_str(&format!("-0.{}s", nanoseconds.abs()))
} else {
s.serialize_str(&format!("{}.{}s", seconds, nanoseconds.abs()))
}
} else {
s.serialize_str(&format!("{}s", seconds))
}
// u32::from_str prevents negative nanos (eg '0.-12s) -> lossless conversion to i32
// 10_u32.pow(...) scales number to appropriate # of nanoseconds
let nanos = u32::from_str(nanos)? as i32;
let mut nanos = nanos * 10_i32.pow(9 - nano_magnitude);
if is_neg {
nanos = -nanos;
}
(seconds, nanos)
} else {
(i64::from_str(value)?, 0)
};
if seconds >= abs_duration {
Err(ParseDurationError::SecondOverflow { seconds, max_seconds: abs_duration })
} else if seconds <= -abs_duration {
Err(ParseDurationError::SecondUnderflow { seconds, min_seconds: -abs_duration })
} else {
Ok(Duration { seconds, nanoseconds})
}
}
}
// #[serde(deserialize_with = "path")]
pub fn from_duration_str<'de, D>(deserializer: D) -> Result<Option<chrono::Duration>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
// TODO: Map error
Ok(s.map(|s| parse_duration_from_str(s).unwrap()))
}
// #[serde(serialize_with = "path")]
pub fn to_urlsafe_base64<S>(x: Option<&str>, s: S) -> Result<S::Ok, S::Error>
pub fn to_urlsafe_base64<S>(x: Option<&[u8]>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
@@ -966,6 +961,7 @@ pub mod types {
Some(x) => s.serialize_some(&base64::encode_config(x, base64::URL_SAFE))
}
}
// #[serde(deserialize_with = "path")]
pub fn from_urlsafe_base64<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
where