Support serde for arbitrary field types

This introduces the `serde_with` dependency and `rust_type.py`, to allow supporting arbitrary types for serialization.
Since fields may have arbitrary types (eg. `HashMap<_, chrono::Duration>`) which need deserialization, it is necessary to
use type-based serialization to avoid implementing (de)serialization for every permutation of types that require special serialization.
However, `serde` does not let you (de)serialize one type as another (eg. `chrono::Duration` as `Wrapper`) - thus necessitating `serde_with`, which does. `rust_type.py` introduces the `RustType` class, which makes it easy to describe the (de)serialization type used by `serde_with`
This commit is contained in:
philippeitis
2022-10-08 23:01:30 -07:00
parent 8cc2707563
commit f6cced9605
7 changed files with 221 additions and 108 deletions

View File

@@ -18,9 +18,12 @@ doctest = false
[dependencies]
mime = "^ 0.2.0"
serde = { version = "^ 1.0", features = ["derive"] }
base64 = "0.13.0"
serde_with = "2.0.1"
serde_json = "^ 1.0"
base64 = "0.13.0"
chrono = { version = "0.4.22", features = ["serde"] }
## TODO: Make yup-oauth2 optional
## yup-oauth2 = { version = "^ 7.0", optional = true }
yup-oauth2 = "^ 7.0"

View File

@@ -1,5 +1,5 @@
pub mod serde;
pub mod field_mask;
pub mod serde;
use std::error;
use std::error::Error as StdError;
@@ -28,6 +28,7 @@ use tower_service;
pub use chrono;
pub use field_mask::FieldMask;
pub use serde_with;
pub use yup_oauth2 as oauth2;
const LINE_ENDING: &str = "\r\n";

View File

@@ -1,9 +1,9 @@
pub mod duration {
use serde::{Deserialize, Deserializer};
use serde_with::{DeserializeAs, SerializeAs};
use std::fmt::Formatter;
use std::str::FromStr;
use serde::{Deserialize, Deserializer, Serializer};
use chrono::Duration;
const MAX_SECONDS: i64 = 315576000000i64;
@@ -53,7 +53,7 @@ pub mod duration {
impl std::error::Error for ParseDurationError {}
fn parse_duration(s: &str) -> Result<Duration, ParseDurationError> {
fn duration_from_str(s: &str) -> Result<Duration, ParseDurationError> {
// TODO: Test strings like -.s, -0.0s
let value = match s.strip_suffix('s') {
None => return Err(ParseDurationError::MissingSecondSuffix),
@@ -97,115 +97,95 @@ pub mod duration {
}
}
pub fn serialize<S>(x: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match x {
None => s.serialize_none(),
Some(x) => {
let seconds = x.num_seconds();
let nanoseconds = (*x - Duration::seconds(seconds))
.num_nanoseconds()
.expect("absolute number of nanoseconds is less than 1 billion")
as i32;
if nanoseconds != 0 {
if seconds == 0 && nanoseconds.is_negative() {
s.serialize_str(&format!("-0.{:0>9}s", nanoseconds.abs()))
} else {
s.serialize_str(&format!("{}.{:0>9}s", seconds, nanoseconds.abs()))
}
} else {
s.serialize_str(&format!("{}s", seconds))
}
fn duration_to_string(duration: &Duration) -> String {
let seconds = duration.num_seconds();
let nanoseconds = (*duration - Duration::seconds(seconds))
.num_nanoseconds()
.expect("absolute number of nanoseconds is less than 1 billion")
as i32;
if nanoseconds != 0 {
if seconds == 0 && nanoseconds.is_negative() {
format!("-0.{:0>9}s", nanoseconds.abs())
} else {
format!("{}.{:0>9}s", seconds, nanoseconds.abs())
}
} else {
format!("{}s", seconds)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
s.map(parse_duration)
.transpose()
.map_err(serde::de::Error::custom)
pub struct Wrapper;
impl SerializeAs<Duration> for Wrapper {
fn serialize_as<S>(value: &Duration, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
s.serialize_str(&duration_to_string(value))
}
}
impl<'de> DeserializeAs<'de, Duration> for Wrapper {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let s = Deserialize::deserialize(deserializer)?;
duration_from_str(s).map_err(serde::de::Error::custom)
}
}
}
pub mod urlsafe_base64 {
use serde::{Deserialize, Deserializer, Serializer};
use serde_with::{DeserializeAs, SerializeAs};
pub fn serialize<S>(x: &Option<Vec<u8>>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match x {
None => s.serialize_none(),
Some(x) => s.serialize_some(&base64::encode_config(x, base64::URL_SAFE)),
pub struct Wrapper;
impl SerializeAs<Vec<u8>> for Wrapper {
fn serialize_as<S>(value: &Vec<u8>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_str(&base64::encode_config(value, base64::URL_SAFE))
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
s.map(|s| base64::decode_config(s, base64::URL_SAFE))
.transpose()
.map_err(serde::de::Error::custom)
}
}
pub mod str_like {
/// Implementation based on `https://chromium.googlesource.com/infra/luci/luci-go/+/23ea7a05c6a5/common/proto/fieldmasks.go#184`
use serde::{Deserialize, Deserializer, Serializer};
use std::str::FromStr;
pub fn serialize<S, T>(x: &Option<T>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: std::fmt::Display,
{
match x {
None => s.serialize_none(),
Some(num) => s.serialize_some(num.to_string().as_str()),
impl<'de> DeserializeAs<'de, Vec<u8>> for Wrapper {
fn deserialize_as<D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
base64::decode_config(s, base64::URL_SAFE).map_err(serde::de::Error::custom)
}
}
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
where
D: Deserializer<'de>,
T: FromStr,
<T as FromStr>::Err: std::fmt::Display,
{
let s: Option<&str> = Deserialize::deserialize(deserializer)?;
s.map(T::from_str)
.transpose()
.map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod test {
use super::{duration, str_like, urlsafe_base64};
use super::{duration, urlsafe_base64};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct DurationWrapper {
#[serde(default, with = "duration")]
#[serde_as(as = "Option<duration::Wrapper>")]
duration: Option<chrono::Duration>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct Base64Wrapper {
#[serde(default, with = "urlsafe_base64")]
#[serde_as(as = "Option<urlsafe_base64::Wrapper>")]
bytes: Option<Vec<u8>>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct I64Wrapper {
#[serde(default, with = "str_like")]
#[serde_as(as = "Option<DisplayFromStr>")]
num: Option<i64>,
}

View File

@@ -0,0 +1,76 @@
from typing import Optional, List
from copy import deepcopy
class RustType:
def __init__(self, name: str, members: Optional[List["RustType"]] = None):
self.name = name
self.members = members
def serde_replace_inner_ty(self, from_to):
if self.members is None:
return False
changed = False
for i, member in enumerate(self.members):
if member in from_to:
self.members[i] = from_to[member]
changed = True
else:
# serde_as fails to compile if type definition includes
# types without custom serialization
if not member.serde_replace_inner_ty(from_to):
self.members[i] = Base("_")
return changed
def serde_as(self) -> "RustType":
copied = deepcopy(self)
from_to = {
Vec(Base("u8")): Base("::client::serde::urlsafe_base64::Wrapper"),
Base("client::chrono::Duration"): Base("::client::serde::duration::Wrapper"),
Base("i64"): Base("::client::serde_with::DisplayFromStr"),
Base("u64"): Base("::client::serde_with::DisplayFromStr"),
}
changed = copied.serde_replace_inner_ty(from_to)
return copied, changed
def __str__(self):
if self.members:
return f"{self.name}<{', '.join(str(m) for m in self.members)}>"
return self.name
def __eq__(self, other):
if not isinstance(other, RustType):
return False
return self.name == other.name and self.members == other.members
def __hash__(self):
if self.members:
return hash((self.name, *[(i, v) for i, v in enumerate(self.members)]))
return hash((self.name, None))
class Option(RustType):
def __init__(self, member):
super().__init__("Option", [member])
class Box(RustType):
def __init__(self, member):
super().__init__("Box", [member])
class Vec(RustType):
def __init__(self, member):
super().__init__("Vec", [member])
class HashMap(RustType):
def __init__(self, key, value):
super().__init__("HashMap", [key, value])
class Base(RustType):
def __init__(self, name):
super().__init__(name)

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from random import (randint, random, choice, seed)
from typing import Any, Dict, List, Mapping, Tuple
from copy import deepcopy
from .rust_type import Base, Box, HashMap, Vec, Option, RustType
seed(1337)
@@ -53,6 +54,36 @@ TYPE_MAP = {
"google-fieldmask": "client::FieldMask"
}
RUST_TYPE_MAP = {
'boolean': Base("bool"),
'integer': USE_FORMAT,
'number': USE_FORMAT,
'uint32': Base("u32"),
'double': Base("f64"),
'float': Base("f32"),
'int32': Base("i32"),
'any': Base("String"), # TODO: Figure out how to handle it. It's 'interface' in Go ...
'int64': Base("i64"),
'uint64': Base("u64"),
'array': Vec(None),
'string': Base("String"),
'object': HashMap(None, None),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/timestamp.proto
# In JSON format, the Timestamp type is encoded as a string in the [RFC 3339] format
'google-datetime': Base(CHRONO_DATETIME),
# Per .json files: RFC 3339 timestamp
'date-time': Base(CHRONO_DATETIME),
# Per .json files: A date in RFC 3339 format with only the date part
# e.g. "2013-01-15"
'date': Base(CHRONO_DATE),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/duration.proto
'google-duration': Base(f"{CHRONO_PATH}::Duration"),
# guessing bytes is universally url-safe b64
"byte": Vec(Base("u8")),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto
"google-fieldmask": Base("client::FieldMask")
}
RESERVED_WORDS = set(('abstract', 'alignof', 'as', 'become', 'box', 'break', 'const', 'continue', 'crate', 'do',
'else', 'enum', 'extern', 'false', 'final', 'fn', 'for', 'if', 'impl', 'in', 'let', 'loop',
'macro', 'match', 'mod', 'move', 'mut', 'offsetof', 'override', 'priv', 'pub', 'pure', 'ref',
@@ -430,21 +461,43 @@ def to_rust_type(
allow_optionals=True,
_is_recursive=False
) -> str:
def nested_type(nt):
return str(to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive))
def to_serde_type(
schemas,
schema_name,
property_name,
t,
allow_optionals=True,
_is_recursive=False
) -> Tuple[RustType, bool]:
return to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive).serde_as()
def to_rust_type_inner(
schemas,
schema_name,
property_name,
t,
allow_optionals=True,
_is_recursive=False
) -> RustType:
def nested_type(nt) -> RustType:
if 'items' in nt:
nt = nt['items']
elif 'additionalProperties' in nt:
nt = nt['additionalProperties']
else:
assert (is_nested_type_property(nt))
assert is_nested_type_property(nt)
# It's a nested type - we take it literally like $ref, but generate a name for the type ourselves
return _assure_unique_type_name(schemas, nested_type_name(schema_name, property_name))
return to_rust_type(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
return Base(_assure_unique_type_name(schemas, nested_type_name(schema_name, property_name)))
return to_rust_type_inner(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
def wrap_type(tn):
def wrap_type(rt) -> RustType:
if allow_optionals:
tn = "Option<%s>" % tn
return tn
return Option(rt)
return rt
# unconditionally handle $ref types, which should point to another schema.
if TREF in t:
@@ -452,22 +505,21 @@ def to_rust_type(
# which is fine for now. 'allow_optionals' implicitly restricts type boxing for simple types - it
# is usually on the first call, and off when recursion is involved.
tn = t[TREF]
rt = Base(tn)
if not _is_recursive and tn == schema_name:
tn = 'Option<Box<%s>>' % tn
return wrap_type(tn)
rt = Option(Box(rt))
return wrap_type(rt)
try:
# prefer format if present
rust_type = TYPE_MAP[t.get("format", t["type"])]
if t['type'] == 'array':
return wrap_type("%s<%s>" % (rust_type, nested_type(t)))
elif t['type'] == 'object':
rust_type = RUST_TYPE_MAP[t.get("format", t["type"])]
if rust_type == Vec(None):
return wrap_type(Vec(nested_type(t)))
if rust_type == HashMap(None, None):
if is_map_prop(t):
return wrap_type("%s<String, %s>" % (rust_type, nested_type(t)))
return wrap_type(HashMap(Base("String"), nested_type(t)))
return wrap_type(nested_type(t))
if t.get('repeated', False):
return 'Vec<%s>' % rust_type
return Vec(rust_type)
return wrap_type(rust_type)
except KeyError as err:
raise AssertionError(

View File

@@ -31,7 +31,7 @@ use tokio::time::sleep;
use tower_service;
use serde::{Serialize, Deserialize};
use crate::{client, client::GetToken, client::oauth2};
use crate::{client, client::GetToken, client::oauth2, client::serde_with};
// ##############
// UTILITIES ###

View File

@@ -1,5 +1,5 @@
<%!
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_rust_type, put_and,
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_serde_type, to_rust_type, put_and,
IO_TYPES, activity_split, enclose_in, REQUEST_MARKER_TRAIT, mb_type, indent_all_but_first_by,
NESTED_TYPE_SUFFIX, RESPONSE_MARKER_TRAIT, split_camelcase_s, METHODS_RESOURCE,
PART_MARKER_TRAIT, canonical_type_name, TO_PARTS_MARKER, UNUSED_TYPE_MARKER, is_schema_with_optionals,
@@ -17,14 +17,14 @@ ${struct} {
% if pn != mangle_ident(pn):
#[serde(rename="${pn}")]
% endif
% if p.get("format") == "byte":
#[serde(default, with = "client::serde::urlsafe_base64")]
% elif p.get("format") == "google-duration":
#[serde(default, with = "client::serde::duration")]
% elif p.get("format") in {"uint64", "int64"}:
#[serde(default, with = "client::serde::str_like")]
<%
rust_ty = to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
serde_ty, use_custom_serde = to_serde_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
%>
% if use_custom_serde:
#[serde_as(as = "${serde_ty}")]
% endif
pub ${mangle_ident(pn)}: ${to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)},
pub ${mangle_ident(pn)}: ${rust_ty},
% endfor
}
% elif 'additionalProperties' in s:
@@ -83,7 +83,8 @@ ${struct} { _never_set: Option<bool> }
<%block filter="rust_doc_sanitize, rust_doc_comment">\
${doc(s, c)}\
</%block>
#[derive(${', '.join(traits)})]
#[serde_with::serde_as(crate = "::client::serde_with")]
#[derive(${', '.join(traits)})]
% if s.type == 'object':
${_new_object(s, s.get('properties'), c, allow_optionals)}\
% elif s.type == 'array':