diff --git a/google-apis-common/Cargo.toml b/google-apis-common/Cargo.toml index ac27adbe23..ca375dbd0d 100644 --- a/google-apis-common/Cargo.toml +++ b/google-apis-common/Cargo.toml @@ -18,9 +18,12 @@ doctest = false [dependencies] mime = "^ 0.2.0" serde = { version = "^ 1.0", features = ["derive"] } -base64 = "0.13.0" +serde_with = "2.0.1" serde_json = "^ 1.0" + +base64 = "0.13.0" chrono = { version = "0.4.22", features = ["serde"] } + ## TODO: Make yup-oauth2 optional ## yup-oauth2 = { version = "^ 7.0", optional = true } yup-oauth2 = "^ 7.0" diff --git a/google-apis-common/src/lib.rs b/google-apis-common/src/lib.rs index 80f4ed565f..022b9a22c0 100644 --- a/google-apis-common/src/lib.rs +++ b/google-apis-common/src/lib.rs @@ -1,5 +1,5 @@ -pub mod serde; pub mod field_mask; +pub mod serde; use std::error; use std::error::Error as StdError; @@ -28,6 +28,7 @@ use tower_service; pub use chrono; pub use field_mask::FieldMask; +pub use serde_with; pub use yup_oauth2 as oauth2; const LINE_ENDING: &str = "\r\n"; diff --git a/google-apis-common/src/serde.rs b/google-apis-common/src/serde.rs index d5a0942aa6..a34d4318ad 100644 --- a/google-apis-common/src/serde.rs +++ b/google-apis-common/src/serde.rs @@ -1,9 +1,9 @@ pub mod duration { + use serde::{Deserialize, Deserializer}; + use serde_with::{DeserializeAs, SerializeAs}; use std::fmt::Formatter; use std::str::FromStr; - use serde::{Deserialize, Deserializer, Serializer}; - use chrono::Duration; const MAX_SECONDS: i64 = 315576000000i64; @@ -53,7 +53,7 @@ pub mod duration { impl std::error::Error for ParseDurationError {} - fn parse_duration(s: &str) -> Result { + fn duration_from_str(s: &str) -> Result { // TODO: Test strings like -.s, -0.0s let value = match s.strip_suffix('s') { None => return Err(ParseDurationError::MissingSecondSuffix), @@ -97,115 +97,95 @@ pub mod duration { } } - pub fn serialize(x: &Option, s: S) -> Result - where - S: Serializer, - { - match x { - None => s.serialize_none(), - Some(x) => { - let seconds = x.num_seconds(); - let nanoseconds = (*x - Duration::seconds(seconds)) - .num_nanoseconds() - .expect("absolute number of nanoseconds is less than 1 billion") - as i32; - if nanoseconds != 0 { - if seconds == 0 && nanoseconds.is_negative() { - s.serialize_str(&format!("-0.{:0>9}s", nanoseconds.abs())) - } else { - s.serialize_str(&format!("{}.{:0>9}s", seconds, nanoseconds.abs())) - } - } else { - s.serialize_str(&format!("{}s", seconds)) - } + fn duration_to_string(duration: &Duration) -> String { + let seconds = duration.num_seconds(); + let nanoseconds = (*duration - Duration::seconds(seconds)) + .num_nanoseconds() + .expect("absolute number of nanoseconds is less than 1 billion") + as i32; + if nanoseconds != 0 { + if seconds == 0 && nanoseconds.is_negative() { + format!("-0.{:0>9}s", nanoseconds.abs()) + } else { + format!("{}.{:0>9}s", seconds, nanoseconds.abs()) } + } else { + format!("{}s", seconds) } } - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - { - let s: Option<&str> = Deserialize::deserialize(deserializer)?; - s.map(parse_duration) - .transpose() - .map_err(serde::de::Error::custom) + pub struct Wrapper; + + impl SerializeAs for Wrapper { + fn serialize_as(value: &Duration, s: S) -> Result + where + S: serde::Serializer, + { + s.serialize_str(&duration_to_string(value)) + } + } + + impl<'de> DeserializeAs<'de, Duration> for Wrapper { + fn deserialize_as(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = Deserialize::deserialize(deserializer)?; + duration_from_str(s).map_err(serde::de::Error::custom) + } } } pub mod urlsafe_base64 { use serde::{Deserialize, Deserializer, Serializer}; + use serde_with::{DeserializeAs, SerializeAs}; - pub fn serialize(x: &Option>, s: S) -> Result - where - S: Serializer, - { - match x { - None => s.serialize_none(), - Some(x) => s.serialize_some(&base64::encode_config(x, base64::URL_SAFE)), + pub struct Wrapper; + + impl SerializeAs> for Wrapper { + fn serialize_as(value: &Vec, s: S) -> Result + where + S: Serializer, + { + s.serialize_str(&base64::encode_config(value, base64::URL_SAFE)) } } - pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> - where - D: Deserializer<'de>, - { - let s: Option<&str> = Deserialize::deserialize(deserializer)?; - s.map(|s| base64::decode_config(s, base64::URL_SAFE)) - .transpose() - .map_err(serde::de::Error::custom) - } -} - -pub mod str_like { - /// Implementation based on `https://chromium.googlesource.com/infra/luci/luci-go/+/23ea7a05c6a5/common/proto/fieldmasks.go#184` - use serde::{Deserialize, Deserializer, Serializer}; - use std::str::FromStr; - - pub fn serialize(x: &Option, s: S) -> Result - where - S: Serializer, - T: std::fmt::Display, - { - match x { - None => s.serialize_none(), - Some(num) => s.serialize_some(num.to_string().as_str()), + impl<'de> DeserializeAs<'de, Vec> for Wrapper { + fn deserialize_as(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let s: &str = Deserialize::deserialize(deserializer)?; + base64::decode_config(s, base64::URL_SAFE).map_err(serde::de::Error::custom) } } - - pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, - T: FromStr, - ::Err: std::fmt::Display, - { - let s: Option<&str> = Deserialize::deserialize(deserializer)?; - s.map(T::from_str) - .transpose() - .map_err(serde::de::Error::custom) - } } #[cfg(test)] mod test { - use super::{duration, str_like, urlsafe_base64}; + use super::{duration, urlsafe_base64}; use serde::{Deserialize, Serialize}; + use serde_with::{serde_as, DisplayFromStr}; + #[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] struct DurationWrapper { - #[serde(default, with = "duration")] + #[serde_as(as = "Option")] duration: Option, } + #[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] struct Base64Wrapper { - #[serde(default, with = "urlsafe_base64")] + #[serde_as(as = "Option")] bytes: Option>, } + #[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] struct I64Wrapper { - #[serde(default, with = "str_like")] + #[serde_as(as = "Option")] num: Option, } diff --git a/src/generator/lib/rust_type.py b/src/generator/lib/rust_type.py new file mode 100644 index 0000000000..5a07765672 --- /dev/null +++ b/src/generator/lib/rust_type.py @@ -0,0 +1,76 @@ +from typing import Optional, List +from copy import deepcopy + + +class RustType: + def __init__(self, name: str, members: Optional[List["RustType"]] = None): + self.name = name + self.members = members + + def serde_replace_inner_ty(self, from_to): + if self.members is None: + return False + + changed = False + for i, member in enumerate(self.members): + if member in from_to: + self.members[i] = from_to[member] + changed = True + else: + # serde_as fails to compile if type definition includes + # types without custom serialization + if not member.serde_replace_inner_ty(from_to): + self.members[i] = Base("_") + return changed + + def serde_as(self) -> "RustType": + copied = deepcopy(self) + from_to = { + Vec(Base("u8")): Base("::client::serde::urlsafe_base64::Wrapper"), + Base("client::chrono::Duration"): Base("::client::serde::duration::Wrapper"), + Base("i64"): Base("::client::serde_with::DisplayFromStr"), + Base("u64"): Base("::client::serde_with::DisplayFromStr"), + } + + changed = copied.serde_replace_inner_ty(from_to) + + return copied, changed + + def __str__(self): + if self.members: + return f"{self.name}<{', '.join(str(m) for m in self.members)}>" + return self.name + + def __eq__(self, other): + if not isinstance(other, RustType): + return False + return self.name == other.name and self.members == other.members + + def __hash__(self): + if self.members: + return hash((self.name, *[(i, v) for i, v in enumerate(self.members)])) + return hash((self.name, None)) + +class Option(RustType): + def __init__(self, member): + super().__init__("Option", [member]) + + +class Box(RustType): + def __init__(self, member): + super().__init__("Box", [member]) + + +class Vec(RustType): + def __init__(self, member): + super().__init__("Vec", [member]) + + +class HashMap(RustType): + def __init__(self, key, value): + super().__init__("HashMap", [key, value]) + + +class Base(RustType): + def __init__(self, name): + super().__init__(name) \ No newline at end of file diff --git a/src/generator/lib/util.py b/src/generator/lib/util.py index b17d508bc1..5379c04a55 100644 --- a/src/generator/lib/util.py +++ b/src/generator/lib/util.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from random import (randint, random, choice, seed) from typing import Any, Dict, List, Mapping, Tuple from copy import deepcopy +from .rust_type import Base, Box, HashMap, Vec, Option, RustType seed(1337) @@ -53,6 +54,36 @@ TYPE_MAP = { "google-fieldmask": "client::FieldMask" } +RUST_TYPE_MAP = { + 'boolean': Base("bool"), + 'integer': USE_FORMAT, + 'number': USE_FORMAT, + 'uint32': Base("u32"), + 'double': Base("f64"), + 'float': Base("f32"), + 'int32': Base("i32"), + 'any': Base("String"), # TODO: Figure out how to handle it. It's 'interface' in Go ... + 'int64': Base("i64"), + 'uint64': Base("u64"), + 'array': Vec(None), + 'string': Base("String"), + 'object': HashMap(None, None), + # https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/timestamp.proto + # In JSON format, the Timestamp type is encoded as a string in the [RFC 3339] format + 'google-datetime': Base(CHRONO_DATETIME), + # Per .json files: RFC 3339 timestamp + 'date-time': Base(CHRONO_DATETIME), + # Per .json files: A date in RFC 3339 format with only the date part + # e.g. "2013-01-15" + 'date': Base(CHRONO_DATE), + # https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/duration.proto + 'google-duration': Base(f"{CHRONO_PATH}::Duration"), + # guessing bytes is universally url-safe b64 + "byte": Vec(Base("u8")), + # https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto + "google-fieldmask": Base("client::FieldMask") +} + RESERVED_WORDS = set(('abstract', 'alignof', 'as', 'become', 'box', 'break', 'const', 'continue', 'crate', 'do', 'else', 'enum', 'extern', 'false', 'final', 'fn', 'for', 'if', 'impl', 'in', 'let', 'loop', 'macro', 'match', 'mod', 'move', 'mut', 'offsetof', 'override', 'priv', 'pub', 'pure', 'ref', @@ -430,21 +461,43 @@ def to_rust_type( allow_optionals=True, _is_recursive=False ) -> str: - def nested_type(nt): + return str(to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive)) + + +def to_serde_type( + schemas, + schema_name, + property_name, + t, + allow_optionals=True, + _is_recursive=False +) -> Tuple[RustType, bool]: + return to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive).serde_as() + + +def to_rust_type_inner( + schemas, + schema_name, + property_name, + t, + allow_optionals=True, + _is_recursive=False +) -> RustType: + def nested_type(nt) -> RustType: if 'items' in nt: nt = nt['items'] elif 'additionalProperties' in nt: nt = nt['additionalProperties'] else: - assert (is_nested_type_property(nt)) + assert is_nested_type_property(nt) # It's a nested type - we take it literally like $ref, but generate a name for the type ourselves - return _assure_unique_type_name(schemas, nested_type_name(schema_name, property_name)) - return to_rust_type(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True) + return Base(_assure_unique_type_name(schemas, nested_type_name(schema_name, property_name))) + return to_rust_type_inner(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True) - def wrap_type(tn): + def wrap_type(rt) -> RustType: if allow_optionals: - tn = "Option<%s>" % tn - return tn + return Option(rt) + return rt # unconditionally handle $ref types, which should point to another schema. if TREF in t: @@ -452,22 +505,21 @@ def to_rust_type( # which is fine for now. 'allow_optionals' implicitly restricts type boxing for simple types - it # is usually on the first call, and off when recursion is involved. tn = t[TREF] + rt = Base(tn) if not _is_recursive and tn == schema_name: - tn = 'Option>' % tn - return wrap_type(tn) + rt = Option(Box(rt)) + return wrap_type(rt) try: # prefer format if present - rust_type = TYPE_MAP[t.get("format", t["type"])] - - if t['type'] == 'array': - return wrap_type("%s<%s>" % (rust_type, nested_type(t))) - elif t['type'] == 'object': + rust_type = RUST_TYPE_MAP[t.get("format", t["type"])] + if rust_type == Vec(None): + return wrap_type(Vec(nested_type(t))) + if rust_type == HashMap(None, None): if is_map_prop(t): - return wrap_type("%s" % (rust_type, nested_type(t))) + return wrap_type(HashMap(Base("String"), nested_type(t))) return wrap_type(nested_type(t)) - if t.get('repeated', False): - return 'Vec<%s>' % rust_type + return Vec(rust_type) return wrap_type(rust_type) except KeyError as err: raise AssertionError( diff --git a/src/generator/templates/api/api.rs.mako b/src/generator/templates/api/api.rs.mako index 7bc840be33..cf15ee2aff 100644 --- a/src/generator/templates/api/api.rs.mako +++ b/src/generator/templates/api/api.rs.mako @@ -31,7 +31,7 @@ use tokio::time::sleep; use tower_service; use serde::{Serialize, Deserialize}; -use crate::{client, client::GetToken, client::oauth2}; +use crate::{client, client::GetToken, client::oauth2, client::serde_with}; // ############## // UTILITIES ### diff --git a/src/generator/templates/api/lib/schema.mako b/src/generator/templates/api/lib/schema.mako index e10d88fbce..0a5b87ca05 100644 --- a/src/generator/templates/api/lib/schema.mako +++ b/src/generator/templates/api/lib/schema.mako @@ -1,5 +1,5 @@ <%! - from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_rust_type, put_and, + from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_serde_type, to_rust_type, put_and, IO_TYPES, activity_split, enclose_in, REQUEST_MARKER_TRAIT, mb_type, indent_all_but_first_by, NESTED_TYPE_SUFFIX, RESPONSE_MARKER_TRAIT, split_camelcase_s, METHODS_RESOURCE, PART_MARKER_TRAIT, canonical_type_name, TO_PARTS_MARKER, UNUSED_TYPE_MARKER, is_schema_with_optionals, @@ -17,14 +17,14 @@ ${struct} { % if pn != mangle_ident(pn): #[serde(rename="${pn}")] % endif - % if p.get("format") == "byte": - #[serde(default, with = "client::serde::urlsafe_base64")] - % elif p.get("format") == "google-duration": - #[serde(default, with = "client::serde::duration")] - % elif p.get("format") in {"uint64", "int64"}: - #[serde(default, with = "client::serde::str_like")] + <% + rust_ty = to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals) + serde_ty, use_custom_serde = to_serde_type(schemas, s.id, pn, p, allow_optionals=allow_optionals) + %> + % if use_custom_serde: + #[serde_as(as = "${serde_ty}")] % endif - pub ${mangle_ident(pn)}: ${to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)}, + pub ${mangle_ident(pn)}: ${rust_ty}, % endfor } % elif 'additionalProperties' in s: @@ -83,7 +83,8 @@ ${struct} { _never_set: Option } <%block filter="rust_doc_sanitize, rust_doc_comment">\ ${doc(s, c)}\ -#[derive(${', '.join(traits)})] + #[serde_with::serde_as(crate = "::client::serde_with")] + #[derive(${', '.join(traits)})] % if s.type == 'object': ${_new_object(s, s.get('properties'), c, allow_optionals)}\ % elif s.type == 'array':