mirror of
https://github.com/OMGeeky/google-apis-rs.git
synced 2025-12-26 17:02:24 +01:00
Support serde for arbitrary field types
This introduces the `serde_with` dependency and `rust_type.py`, to allow supporting arbitrary types for serialization. Since fields may have arbitrary types (eg. `HashMap<_, chrono::Duration>`) which need deserialization, it is necessary to use type-based serialization to avoid implementing (de)serialization for every permutation of types that require special serialization. However, `serde` does not let you (de)serialize one type as another (eg. `chrono::Duration` as `Wrapper`) - thus necessitating `serde_with`, which does. `rust_type.py` introduces the `RustType` class, which makes it easy to describe the (de)serialization type used by `serde_with`
This commit is contained in:
@@ -18,9 +18,12 @@ doctest = false
|
||||
[dependencies]
|
||||
mime = "^ 0.2.0"
|
||||
serde = { version = "^ 1.0", features = ["derive"] }
|
||||
base64 = "0.13.0"
|
||||
serde_with = "2.0.1"
|
||||
serde_json = "^ 1.0"
|
||||
|
||||
base64 = "0.13.0"
|
||||
chrono = { version = "0.4.22", features = ["serde"] }
|
||||
|
||||
## TODO: Make yup-oauth2 optional
|
||||
## yup-oauth2 = { version = "^ 7.0", optional = true }
|
||||
yup-oauth2 = "^ 7.0"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod serde;
|
||||
pub mod field_mask;
|
||||
pub mod serde;
|
||||
|
||||
use std::error;
|
||||
use std::error::Error as StdError;
|
||||
@@ -28,6 +28,7 @@ use tower_service;
|
||||
|
||||
pub use chrono;
|
||||
pub use field_mask::FieldMask;
|
||||
pub use serde_with;
|
||||
pub use yup_oauth2 as oauth2;
|
||||
|
||||
const LINE_ENDING: &str = "\r\n";
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
pub mod duration {
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_with::{DeserializeAs, SerializeAs};
|
||||
use std::fmt::Formatter;
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
|
||||
use chrono::Duration;
|
||||
|
||||
const MAX_SECONDS: i64 = 315576000000i64;
|
||||
@@ -53,7 +53,7 @@ pub mod duration {
|
||||
|
||||
impl std::error::Error for ParseDurationError {}
|
||||
|
||||
fn parse_duration(s: &str) -> Result<Duration, ParseDurationError> {
|
||||
fn duration_from_str(s: &str) -> Result<Duration, ParseDurationError> {
|
||||
// TODO: Test strings like -.s, -0.0s
|
||||
let value = match s.strip_suffix('s') {
|
||||
None => return Err(ParseDurationError::MissingSecondSuffix),
|
||||
@@ -97,115 +97,95 @@ pub mod duration {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize<S>(x: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match x {
|
||||
None => s.serialize_none(),
|
||||
Some(x) => {
|
||||
let seconds = x.num_seconds();
|
||||
let nanoseconds = (*x - Duration::seconds(seconds))
|
||||
.num_nanoseconds()
|
||||
.expect("absolute number of nanoseconds is less than 1 billion")
|
||||
as i32;
|
||||
if nanoseconds != 0 {
|
||||
if seconds == 0 && nanoseconds.is_negative() {
|
||||
s.serialize_str(&format!("-0.{:0>9}s", nanoseconds.abs()))
|
||||
} else {
|
||||
s.serialize_str(&format!("{}.{:0>9}s", seconds, nanoseconds.abs()))
|
||||
}
|
||||
} else {
|
||||
s.serialize_str(&format!("{}s", seconds))
|
||||
}
|
||||
fn duration_to_string(duration: &Duration) -> String {
|
||||
let seconds = duration.num_seconds();
|
||||
let nanoseconds = (*duration - Duration::seconds(seconds))
|
||||
.num_nanoseconds()
|
||||
.expect("absolute number of nanoseconds is less than 1 billion")
|
||||
as i32;
|
||||
if nanoseconds != 0 {
|
||||
if seconds == 0 && nanoseconds.is_negative() {
|
||||
format!("-0.{:0>9}s", nanoseconds.abs())
|
||||
} else {
|
||||
format!("{}.{:0>9}s", seconds, nanoseconds.abs())
|
||||
}
|
||||
} else {
|
||||
format!("{}s", seconds)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
|
||||
s.map(parse_duration)
|
||||
.transpose()
|
||||
.map_err(serde::de::Error::custom)
|
||||
pub struct Wrapper;
|
||||
|
||||
impl SerializeAs<Duration> for Wrapper {
|
||||
fn serialize_as<S>(value: &Duration, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
s.serialize_str(&duration_to_string(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> DeserializeAs<'de, Duration> for Wrapper {
|
||||
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = Deserialize::deserialize(deserializer)?;
|
||||
duration_from_str(s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod urlsafe_base64 {
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
use serde_with::{DeserializeAs, SerializeAs};
|
||||
|
||||
pub fn serialize<S>(x: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match x {
|
||||
None => s.serialize_none(),
|
||||
Some(x) => s.serialize_some(&base64::encode_config(x, base64::URL_SAFE)),
|
||||
pub struct Wrapper;
|
||||
|
||||
impl SerializeAs<Vec<u8>> for Wrapper {
|
||||
fn serialize_as<S>(value: &Vec<u8>, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
s.serialize_str(&base64::encode_config(value, base64::URL_SAFE))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
|
||||
s.map(|s| base64::decode_config(s, base64::URL_SAFE))
|
||||
.transpose()
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
pub mod str_like {
|
||||
/// Implementation based on `https://chromium.googlesource.com/infra/luci/luci-go/+/23ea7a05c6a5/common/proto/fieldmasks.go#184`
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
use std::str::FromStr;
|
||||
|
||||
pub fn serialize<S, T>(x: &Option<T>, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
match x {
|
||||
None => s.serialize_none(),
|
||||
Some(num) => s.serialize_some(num.to_string().as_str()),
|
||||
impl<'de> DeserializeAs<'de, Vec<u8>> for Wrapper {
|
||||
fn deserialize_as<D>(deserializer: D) -> Result<Vec<u8>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s: &str = Deserialize::deserialize(deserializer)?;
|
||||
base64::decode_config(s, base64::URL_SAFE).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
T: FromStr,
|
||||
<T as FromStr>::Err: std::fmt::Display,
|
||||
{
|
||||
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
|
||||
s.map(T::from_str)
|
||||
.transpose()
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{duration, str_like, urlsafe_base64};
|
||||
use super::{duration, urlsafe_base64};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
struct DurationWrapper {
|
||||
#[serde(default, with = "duration")]
|
||||
#[serde_as(as = "Option<duration::Wrapper>")]
|
||||
duration: Option<chrono::Duration>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
struct Base64Wrapper {
|
||||
#[serde(default, with = "urlsafe_base64")]
|
||||
#[serde_as(as = "Option<urlsafe_base64::Wrapper>")]
|
||||
bytes: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
struct I64Wrapper {
|
||||
#[serde(default, with = "str_like")]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
num: Option<i64>,
|
||||
}
|
||||
|
||||
|
||||
76
src/generator/lib/rust_type.py
Normal file
76
src/generator/lib/rust_type.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import Optional, List
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class RustType:
|
||||
def __init__(self, name: str, members: Optional[List["RustType"]] = None):
|
||||
self.name = name
|
||||
self.members = members
|
||||
|
||||
def serde_replace_inner_ty(self, from_to):
|
||||
if self.members is None:
|
||||
return False
|
||||
|
||||
changed = False
|
||||
for i, member in enumerate(self.members):
|
||||
if member in from_to:
|
||||
self.members[i] = from_to[member]
|
||||
changed = True
|
||||
else:
|
||||
# serde_as fails to compile if type definition includes
|
||||
# types without custom serialization
|
||||
if not member.serde_replace_inner_ty(from_to):
|
||||
self.members[i] = Base("_")
|
||||
return changed
|
||||
|
||||
def serde_as(self) -> "RustType":
|
||||
copied = deepcopy(self)
|
||||
from_to = {
|
||||
Vec(Base("u8")): Base("::client::serde::urlsafe_base64::Wrapper"),
|
||||
Base("client::chrono::Duration"): Base("::client::serde::duration::Wrapper"),
|
||||
Base("i64"): Base("::client::serde_with::DisplayFromStr"),
|
||||
Base("u64"): Base("::client::serde_with::DisplayFromStr"),
|
||||
}
|
||||
|
||||
changed = copied.serde_replace_inner_ty(from_to)
|
||||
|
||||
return copied, changed
|
||||
|
||||
def __str__(self):
|
||||
if self.members:
|
||||
return f"{self.name}<{', '.join(str(m) for m in self.members)}>"
|
||||
return self.name
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, RustType):
|
||||
return False
|
||||
return self.name == other.name and self.members == other.members
|
||||
|
||||
def __hash__(self):
|
||||
if self.members:
|
||||
return hash((self.name, *[(i, v) for i, v in enumerate(self.members)]))
|
||||
return hash((self.name, None))
|
||||
|
||||
class Option(RustType):
|
||||
def __init__(self, member):
|
||||
super().__init__("Option", [member])
|
||||
|
||||
|
||||
class Box(RustType):
|
||||
def __init__(self, member):
|
||||
super().__init__("Box", [member])
|
||||
|
||||
|
||||
class Vec(RustType):
|
||||
def __init__(self, member):
|
||||
super().__init__("Vec", [member])
|
||||
|
||||
|
||||
class HashMap(RustType):
|
||||
def __init__(self, key, value):
|
||||
super().__init__("HashMap", [key, value])
|
||||
|
||||
|
||||
class Base(RustType):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from random import (randint, random, choice, seed)
|
||||
from typing import Any, Dict, List, Mapping, Tuple
|
||||
from copy import deepcopy
|
||||
from .rust_type import Base, Box, HashMap, Vec, Option, RustType
|
||||
|
||||
seed(1337)
|
||||
|
||||
@@ -53,6 +54,36 @@ TYPE_MAP = {
|
||||
"google-fieldmask": "client::FieldMask"
|
||||
}
|
||||
|
||||
RUST_TYPE_MAP = {
|
||||
'boolean': Base("bool"),
|
||||
'integer': USE_FORMAT,
|
||||
'number': USE_FORMAT,
|
||||
'uint32': Base("u32"),
|
||||
'double': Base("f64"),
|
||||
'float': Base("f32"),
|
||||
'int32': Base("i32"),
|
||||
'any': Base("String"), # TODO: Figure out how to handle it. It's 'interface' in Go ...
|
||||
'int64': Base("i64"),
|
||||
'uint64': Base("u64"),
|
||||
'array': Vec(None),
|
||||
'string': Base("String"),
|
||||
'object': HashMap(None, None),
|
||||
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/timestamp.proto
|
||||
# In JSON format, the Timestamp type is encoded as a string in the [RFC 3339] format
|
||||
'google-datetime': Base(CHRONO_DATETIME),
|
||||
# Per .json files: RFC 3339 timestamp
|
||||
'date-time': Base(CHRONO_DATETIME),
|
||||
# Per .json files: A date in RFC 3339 format with only the date part
|
||||
# e.g. "2013-01-15"
|
||||
'date': Base(CHRONO_DATE),
|
||||
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/duration.proto
|
||||
'google-duration': Base(f"{CHRONO_PATH}::Duration"),
|
||||
# guessing bytes is universally url-safe b64
|
||||
"byte": Vec(Base("u8")),
|
||||
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto
|
||||
"google-fieldmask": Base("client::FieldMask")
|
||||
}
|
||||
|
||||
RESERVED_WORDS = set(('abstract', 'alignof', 'as', 'become', 'box', 'break', 'const', 'continue', 'crate', 'do',
|
||||
'else', 'enum', 'extern', 'false', 'final', 'fn', 'for', 'if', 'impl', 'in', 'let', 'loop',
|
||||
'macro', 'match', 'mod', 'move', 'mut', 'offsetof', 'override', 'priv', 'pub', 'pure', 'ref',
|
||||
@@ -430,21 +461,43 @@ def to_rust_type(
|
||||
allow_optionals=True,
|
||||
_is_recursive=False
|
||||
) -> str:
|
||||
def nested_type(nt):
|
||||
return str(to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive))
|
||||
|
||||
|
||||
def to_serde_type(
|
||||
schemas,
|
||||
schema_name,
|
||||
property_name,
|
||||
t,
|
||||
allow_optionals=True,
|
||||
_is_recursive=False
|
||||
) -> Tuple[RustType, bool]:
|
||||
return to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive).serde_as()
|
||||
|
||||
|
||||
def to_rust_type_inner(
|
||||
schemas,
|
||||
schema_name,
|
||||
property_name,
|
||||
t,
|
||||
allow_optionals=True,
|
||||
_is_recursive=False
|
||||
) -> RustType:
|
||||
def nested_type(nt) -> RustType:
|
||||
if 'items' in nt:
|
||||
nt = nt['items']
|
||||
elif 'additionalProperties' in nt:
|
||||
nt = nt['additionalProperties']
|
||||
else:
|
||||
assert (is_nested_type_property(nt))
|
||||
assert is_nested_type_property(nt)
|
||||
# It's a nested type - we take it literally like $ref, but generate a name for the type ourselves
|
||||
return _assure_unique_type_name(schemas, nested_type_name(schema_name, property_name))
|
||||
return to_rust_type(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
|
||||
return Base(_assure_unique_type_name(schemas, nested_type_name(schema_name, property_name)))
|
||||
return to_rust_type_inner(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
|
||||
|
||||
def wrap_type(tn):
|
||||
def wrap_type(rt) -> RustType:
|
||||
if allow_optionals:
|
||||
tn = "Option<%s>" % tn
|
||||
return tn
|
||||
return Option(rt)
|
||||
return rt
|
||||
|
||||
# unconditionally handle $ref types, which should point to another schema.
|
||||
if TREF in t:
|
||||
@@ -452,22 +505,21 @@ def to_rust_type(
|
||||
# which is fine for now. 'allow_optionals' implicitly restricts type boxing for simple types - it
|
||||
# is usually on the first call, and off when recursion is involved.
|
||||
tn = t[TREF]
|
||||
rt = Base(tn)
|
||||
if not _is_recursive and tn == schema_name:
|
||||
tn = 'Option<Box<%s>>' % tn
|
||||
return wrap_type(tn)
|
||||
rt = Option(Box(rt))
|
||||
return wrap_type(rt)
|
||||
try:
|
||||
# prefer format if present
|
||||
rust_type = TYPE_MAP[t.get("format", t["type"])]
|
||||
|
||||
if t['type'] == 'array':
|
||||
return wrap_type("%s<%s>" % (rust_type, nested_type(t)))
|
||||
elif t['type'] == 'object':
|
||||
rust_type = RUST_TYPE_MAP[t.get("format", t["type"])]
|
||||
if rust_type == Vec(None):
|
||||
return wrap_type(Vec(nested_type(t)))
|
||||
if rust_type == HashMap(None, None):
|
||||
if is_map_prop(t):
|
||||
return wrap_type("%s<String, %s>" % (rust_type, nested_type(t)))
|
||||
return wrap_type(HashMap(Base("String"), nested_type(t)))
|
||||
return wrap_type(nested_type(t))
|
||||
|
||||
if t.get('repeated', False):
|
||||
return 'Vec<%s>' % rust_type
|
||||
return Vec(rust_type)
|
||||
return wrap_type(rust_type)
|
||||
except KeyError as err:
|
||||
raise AssertionError(
|
||||
|
||||
@@ -31,7 +31,7 @@ use tokio::time::sleep;
|
||||
use tower_service;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::{client, client::GetToken, client::oauth2};
|
||||
use crate::{client, client::GetToken, client::oauth2, client::serde_with};
|
||||
|
||||
// ##############
|
||||
// UTILITIES ###
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<%!
|
||||
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_rust_type, put_and,
|
||||
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_serde_type, to_rust_type, put_and,
|
||||
IO_TYPES, activity_split, enclose_in, REQUEST_MARKER_TRAIT, mb_type, indent_all_but_first_by,
|
||||
NESTED_TYPE_SUFFIX, RESPONSE_MARKER_TRAIT, split_camelcase_s, METHODS_RESOURCE,
|
||||
PART_MARKER_TRAIT, canonical_type_name, TO_PARTS_MARKER, UNUSED_TYPE_MARKER, is_schema_with_optionals,
|
||||
@@ -17,14 +17,14 @@ ${struct} {
|
||||
% if pn != mangle_ident(pn):
|
||||
#[serde(rename="${pn}")]
|
||||
% endif
|
||||
% if p.get("format") == "byte":
|
||||
#[serde(default, with = "client::serde::urlsafe_base64")]
|
||||
% elif p.get("format") == "google-duration":
|
||||
#[serde(default, with = "client::serde::duration")]
|
||||
% elif p.get("format") in {"uint64", "int64"}:
|
||||
#[serde(default, with = "client::serde::str_like")]
|
||||
<%
|
||||
rust_ty = to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
|
||||
serde_ty, use_custom_serde = to_serde_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
|
||||
%>
|
||||
% if use_custom_serde:
|
||||
#[serde_as(as = "${serde_ty}")]
|
||||
% endif
|
||||
pub ${mangle_ident(pn)}: ${to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)},
|
||||
pub ${mangle_ident(pn)}: ${rust_ty},
|
||||
% endfor
|
||||
}
|
||||
% elif 'additionalProperties' in s:
|
||||
@@ -83,7 +83,8 @@ ${struct} { _never_set: Option<bool> }
|
||||
<%block filter="rust_doc_sanitize, rust_doc_comment">\
|
||||
${doc(s, c)}\
|
||||
</%block>
|
||||
#[derive(${', '.join(traits)})]
|
||||
#[serde_with::serde_as(crate = "::client::serde_with")]
|
||||
#[derive(${', '.join(traits)})]
|
||||
% if s.type == 'object':
|
||||
${_new_object(s, s.get('properties'), c, allow_optionals)}\
|
||||
% elif s.type == 'array':
|
||||
|
||||
Reference in New Issue
Block a user