Support serde for arbitrary field types

This introduces the `serde_with` dependency and `rust_type.py`, to allow supporting arbitrary types for serialization.
Since fields may have arbitrary types (eg. `HashMap<_, chrono::Duration>`) which need deserialization, it is necessary to
use type-based serialization to avoid implementing (de)serialization for every permutation of types that require special serialization.
However, `serde` does not let you (de)serialize one type as another (eg. `chrono::Duration` as `Wrapper`) - thus necessitating `serde_with`, which does. `rust_type.py` introduces the `RustType` class, which makes it easy to describe the (de)serialization type used by `serde_with`
This commit is contained in:
philippeitis
2022-10-08 23:01:30 -07:00
parent 8cc2707563
commit f6cced9605
7 changed files with 221 additions and 108 deletions

View File

@@ -0,0 +1,76 @@
from typing import Optional, List
from copy import deepcopy
class RustType:
def __init__(self, name: str, members: Optional[List["RustType"]] = None):
self.name = name
self.members = members
def serde_replace_inner_ty(self, from_to):
if self.members is None:
return False
changed = False
for i, member in enumerate(self.members):
if member in from_to:
self.members[i] = from_to[member]
changed = True
else:
# serde_as fails to compile if type definition includes
# types without custom serialization
if not member.serde_replace_inner_ty(from_to):
self.members[i] = Base("_")
return changed
def serde_as(self) -> "RustType":
copied = deepcopy(self)
from_to = {
Vec(Base("u8")): Base("::client::serde::urlsafe_base64::Wrapper"),
Base("client::chrono::Duration"): Base("::client::serde::duration::Wrapper"),
Base("i64"): Base("::client::serde_with::DisplayFromStr"),
Base("u64"): Base("::client::serde_with::DisplayFromStr"),
}
changed = copied.serde_replace_inner_ty(from_to)
return copied, changed
def __str__(self):
if self.members:
return f"{self.name}<{', '.join(str(m) for m in self.members)}>"
return self.name
def __eq__(self, other):
if not isinstance(other, RustType):
return False
return self.name == other.name and self.members == other.members
def __hash__(self):
if self.members:
return hash((self.name, *[(i, v) for i, v in enumerate(self.members)]))
return hash((self.name, None))
class Option(RustType):
def __init__(self, member):
super().__init__("Option", [member])
class Box(RustType):
def __init__(self, member):
super().__init__("Box", [member])
class Vec(RustType):
def __init__(self, member):
super().__init__("Vec", [member])
class HashMap(RustType):
def __init__(self, key, value):
super().__init__("HashMap", [key, value])
class Base(RustType):
def __init__(self, name):
super().__init__(name)

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from random import (randint, random, choice, seed)
from typing import Any, Dict, List, Mapping, Tuple
from copy import deepcopy
from .rust_type import Base, Box, HashMap, Vec, Option, RustType
seed(1337)
@@ -53,6 +54,36 @@ TYPE_MAP = {
"google-fieldmask": "client::FieldMask"
}
RUST_TYPE_MAP = {
'boolean': Base("bool"),
'integer': USE_FORMAT,
'number': USE_FORMAT,
'uint32': Base("u32"),
'double': Base("f64"),
'float': Base("f32"),
'int32': Base("i32"),
'any': Base("String"), # TODO: Figure out how to handle it. It's 'interface' in Go ...
'int64': Base("i64"),
'uint64': Base("u64"),
'array': Vec(None),
'string': Base("String"),
'object': HashMap(None, None),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/timestamp.proto
# In JSON format, the Timestamp type is encoded as a string in the [RFC 3339] format
'google-datetime': Base(CHRONO_DATETIME),
# Per .json files: RFC 3339 timestamp
'date-time': Base(CHRONO_DATETIME),
# Per .json files: A date in RFC 3339 format with only the date part
# e.g. "2013-01-15"
'date': Base(CHRONO_DATE),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/duration.proto
'google-duration': Base(f"{CHRONO_PATH}::Duration"),
# guessing bytes is universally url-safe b64
"byte": Vec(Base("u8")),
# https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto
"google-fieldmask": Base("client::FieldMask")
}
RESERVED_WORDS = set(('abstract', 'alignof', 'as', 'become', 'box', 'break', 'const', 'continue', 'crate', 'do',
'else', 'enum', 'extern', 'false', 'final', 'fn', 'for', 'if', 'impl', 'in', 'let', 'loop',
'macro', 'match', 'mod', 'move', 'mut', 'offsetof', 'override', 'priv', 'pub', 'pure', 'ref',
@@ -430,21 +461,43 @@ def to_rust_type(
allow_optionals=True,
_is_recursive=False
) -> str:
def nested_type(nt):
return str(to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive))
def to_serde_type(
schemas,
schema_name,
property_name,
t,
allow_optionals=True,
_is_recursive=False
) -> Tuple[RustType, bool]:
return to_rust_type_inner(schemas, schema_name, property_name, t, allow_optionals, _is_recursive).serde_as()
def to_rust_type_inner(
schemas,
schema_name,
property_name,
t,
allow_optionals=True,
_is_recursive=False
) -> RustType:
def nested_type(nt) -> RustType:
if 'items' in nt:
nt = nt['items']
elif 'additionalProperties' in nt:
nt = nt['additionalProperties']
else:
assert (is_nested_type_property(nt))
assert is_nested_type_property(nt)
# It's a nested type - we take it literally like $ref, but generate a name for the type ourselves
return _assure_unique_type_name(schemas, nested_type_name(schema_name, property_name))
return to_rust_type(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
return Base(_assure_unique_type_name(schemas, nested_type_name(schema_name, property_name)))
return to_rust_type_inner(schemas, schema_name, property_name, nt, allow_optionals=False, _is_recursive=True)
def wrap_type(tn):
def wrap_type(rt) -> RustType:
if allow_optionals:
tn = "Option<%s>" % tn
return tn
return Option(rt)
return rt
# unconditionally handle $ref types, which should point to another schema.
if TREF in t:
@@ -452,22 +505,21 @@ def to_rust_type(
# which is fine for now. 'allow_optionals' implicitly restricts type boxing for simple types - it
# is usually on the first call, and off when recursion is involved.
tn = t[TREF]
rt = Base(tn)
if not _is_recursive and tn == schema_name:
tn = 'Option<Box<%s>>' % tn
return wrap_type(tn)
rt = Option(Box(rt))
return wrap_type(rt)
try:
# prefer format if present
rust_type = TYPE_MAP[t.get("format", t["type"])]
if t['type'] == 'array':
return wrap_type("%s<%s>" % (rust_type, nested_type(t)))
elif t['type'] == 'object':
rust_type = RUST_TYPE_MAP[t.get("format", t["type"])]
if rust_type == Vec(None):
return wrap_type(Vec(nested_type(t)))
if rust_type == HashMap(None, None):
if is_map_prop(t):
return wrap_type("%s<String, %s>" % (rust_type, nested_type(t)))
return wrap_type(HashMap(Base("String"), nested_type(t)))
return wrap_type(nested_type(t))
if t.get('repeated', False):
return 'Vec<%s>' % rust_type
return Vec(rust_type)
return wrap_type(rust_type)
except KeyError as err:
raise AssertionError(

View File

@@ -31,7 +31,7 @@ use tokio::time::sleep;
use tower_service;
use serde::{Serialize, Deserialize};
use crate::{client, client::GetToken, client::oauth2};
use crate::{client, client::GetToken, client::oauth2, client::serde_with};
// ##############
// UTILITIES ###

View File

@@ -1,5 +1,5 @@
<%!
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_rust_type, put_and,
from generator.lib.util import (schema_markers, rust_doc_comment, mangle_ident, to_serde_type, to_rust_type, put_and,
IO_TYPES, activity_split, enclose_in, REQUEST_MARKER_TRAIT, mb_type, indent_all_but_first_by,
NESTED_TYPE_SUFFIX, RESPONSE_MARKER_TRAIT, split_camelcase_s, METHODS_RESOURCE,
PART_MARKER_TRAIT, canonical_type_name, TO_PARTS_MARKER, UNUSED_TYPE_MARKER, is_schema_with_optionals,
@@ -17,14 +17,14 @@ ${struct} {
% if pn != mangle_ident(pn):
#[serde(rename="${pn}")]
% endif
% if p.get("format") == "byte":
#[serde(default, with = "client::serde::urlsafe_base64")]
% elif p.get("format") == "google-duration":
#[serde(default, with = "client::serde::duration")]
% elif p.get("format") in {"uint64", "int64"}:
#[serde(default, with = "client::serde::str_like")]
<%
rust_ty = to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
serde_ty, use_custom_serde = to_serde_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)
%>
% if use_custom_serde:
#[serde_as(as = "${serde_ty}")]
% endif
pub ${mangle_ident(pn)}: ${to_rust_type(schemas, s.id, pn, p, allow_optionals=allow_optionals)},
pub ${mangle_ident(pn)}: ${rust_ty},
% endfor
}
% elif 'additionalProperties' in s:
@@ -83,7 +83,8 @@ ${struct} { _never_set: Option<bool> }
<%block filter="rust_doc_sanitize, rust_doc_comment">\
${doc(s, c)}\
</%block>
#[derive(${', '.join(traits)})]
#[serde_with::serde_as(crate = "::client::serde_with")]
#[derive(${', '.join(traits)})]
% if s.type == 'object':
${_new_object(s, s.get('properties'), c, allow_optionals)}\
% elif s.type == 'array':