diff --git a/google-apis-common/src/serde.rs b/google-apis-common/src/serde.rs index e933776c7a..f0fb840e5b 100644 --- a/google-apis-common/src/serde.rs +++ b/google-apis-common/src/serde.rs @@ -127,8 +127,9 @@ pub mod duration { D: Deserializer<'de>, { let s: Option<&str> = Deserialize::deserialize(deserializer)?; - s.map(|s| parse_duration(s).map_err(serde::de::Error::custom)) + s.map(parse_duration) .transpose() + .map_err(serde::de::Error::custom) } } @@ -150,8 +151,9 @@ pub mod urlsafe_base64 { D: Deserializer<'de>, { let s: Option<&str> = Deserialize::deserialize(deserializer)?; - s.map(|s| base64::decode_config(s, base64::URL_SAFE).map_err(serde::de::Error::custom)) + s.map(|s| base64::decode_config(s, base64::URL_SAFE)) .transpose() + .map_err(serde::de::Error::custom) } } @@ -210,30 +212,65 @@ pub mod field_mask { } } +pub mod str_like { + /// Implementation based on `https://chromium.googlesource.com/infra/luci/luci-go/+/23ea7a05c6a5/common/proto/fieldmasks.go#184` + use serde::{Deserialize, Deserializer, Serializer}; + use std::str::FromStr; + + pub fn serialize(x: &Option, s: S) -> Result + where + S: Serializer, + T: std::fmt::Display, + { + match x { + None => s.serialize_none(), + Some(num) => s.serialize_some(num.to_string().as_str()), + } + } + + pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + T: FromStr, + ::Err: std::fmt::Display, + { + let s: Option<&str> = Deserialize::deserialize(deserializer)?; + s.map(T::from_str) + .transpose() + .map_err(serde::de::Error::custom) + } +} + #[cfg(test)] mod test { - use super::{duration, field_mask, urlsafe_base64}; + use super::{duration, field_mask, str_like, urlsafe_base64}; use crate::FieldMask; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, PartialEq)] struct DurationWrapper { - #[serde(with = "duration")] + #[serde(default, with = "duration")] duration: Option, } #[derive(Serialize, Deserialize, Debug, PartialEq)] struct Base64Wrapper { - #[serde(with = "urlsafe_base64")] + #[serde(default, with = "urlsafe_base64")] bytes: Option>, } #[derive(Serialize, Deserialize, Debug, PartialEq)] struct FieldMaskWrapper { - #[serde(with = "field_mask")] + #[serde(default, with = "field_mask")] fields: Option, } + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct I64Wrapper { + #[serde(default, with = "str_like")] + num: Option, + } + #[test] fn test_duration_de_success_cases() { let durations = [ @@ -338,4 +375,31 @@ mod test { "round trip should succeed" ); } + + #[test] + fn num_roundtrip() { + let wrapper = I64Wrapper { + num: Some(i64::MAX), + }; + + let json_repr = &serde_json::to_string(&wrapper); + assert!(json_repr.is_ok(), "serialization should succeed"); + assert_eq!( + wrapper, + serde_json::from_str(&format!("{{\"num\": \"{}\"}}", i64::MAX)).unwrap() + ); + assert_eq!( + wrapper, + serde_json::from_str(json_repr.as_ref().unwrap()).unwrap(), + "round trip should succeed" + ); + } + + #[test] + fn test_empty_wrapper() { + assert_eq!(DurationWrapper { duration: None }, serde_json::from_str("{}").unwrap()); + assert_eq!(Base64Wrapper { bytes: None }, serde_json::from_str("{}").unwrap()); + assert_eq!(FieldMaskWrapper { fields: None }, serde_json::from_str("{}").unwrap()); + assert_eq!(I64Wrapper { num: None }, serde_json::from_str("{}").unwrap()); + } } diff --git a/src/generator/templates/api/lib/schema.mako b/src/generator/templates/api/lib/schema.mako index 34f6cb6769..d4b22f941b 100644 --- a/src/generator/templates/api/lib/schema.mako +++ b/src/generator/templates/api/lib/schema.mako @@ -18,11 +18,13 @@ ${struct} { #[serde(rename="${pn}")] % endif % if p.get("format") == "byte": - #[serde(with = "client::serde::urlsafe_base64")] + #[serde(default, with = "client::serde::urlsafe_base64")] % elif p.get("format") == "google-duration": - #[serde(with = "client::serde::duration")] + #[serde(default, with = "client::serde::duration")] % elif p.get("format") == "google-fieldmask": - #[serde(with = "client::serde::field_mask")] + #[serde(default, with = "client::serde::field_mask")] + % elif p.get("format") in {"uint64", "int64"}: + #[serde(default, with = "client::serde::str_like")] % endif pub ${mangle_ident(pn)}: ${to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)}, % endfor