make enums their own class in python to better handle variants where name overlaps might happen

This commit is contained in:
OMGeeky
2024-05-31 15:09:06 +02:00
parent 0946a3b41d
commit d80498a8b6
4 changed files with 152 additions and 119 deletions

View File

@@ -2,7 +2,21 @@ from typing import Any
from .rust_type import RustType
from .types import Base
from .util import Context, UNUSED_TYPE_MARKER, schema_markers, canonical_type_name, remove_invalid_chars_in_ident, \
singular, activity_split, items
activity_split, items, Enum, EnumVariant
def get_enum_from_dict(name: RustType, s: dict) -> Enum:
description: str | None = get_from_enum(s, 'description')
(variants, is_any_variant_deprecated) = _get_enum_variants(s)
default: EnumVariant | str | None = get_enum_default(s)
default_variant: EnumVariant | None = None
if default is not None:
for variant in variants:
if variant.name == default:
default_variant = variant
break
return Enum(name, description, variants, default_variant, is_any_variant_deprecated)
def is_property_enum(s: dict) -> bool:
@@ -27,12 +41,33 @@ def get_enum_type(schema_name: str, property_name: str) -> RustType:
return Base(name)
def get_enum_variants(enum: dict) -> list[str]:
return get_from_enum(enum, 'enum')
def _get_enum_variants(enum: dict) -> tuple[list[EnumVariant], bool]:
variants = get_from_enum(enum, 'enum')
descriptions = get_from_enum(enum, 'enumDescriptions')
if not descriptions:
descriptions = variants
result = []
is_any_variant_deprecated = False
for variant in descriptions:
if is_enum_variant_deprecated(variant):
is_any_variant_deprecated = True
break
def get_enum_variants_descriptions(enum: dict) -> list[str]:
return get_from_enum(enum, 'enumDescriptions')
for i in range(len(variants)):
variant = variants[i]
description = None
if descriptions:
description = descriptions[i]
if variant is None:
continue
is_deprecated = is_enum_variant_deprecated(description)
name = to_enum_variant_name(variant, not is_any_variant_deprecated)
result.append(EnumVariant(name, variant, description, is_deprecated, description))
return result, is_any_variant_deprecated
def get_enum_default(enum: dict) -> str | None:
@@ -54,16 +89,14 @@ def get_from_enum(enum: dict, key: str) -> list[Any] | None:
if nested:
return nested
if key != 'default': # just a debugging help
print(f"could not find key '{key}' in enum:", enum)
return None
def get_enum_if_is_enum(k, property_name: str, property_value: dict) -> RustType | None:
def get_enum_if_is_enum(k, property_name: str, property_value: dict) -> Enum | None:
if property_value is None:
return None
if is_property_enum(property_value):
return get_enum_type(k, property_name)
return get_enum_from_dict(get_enum_type(k, property_name), property_value)
return get_enum_if_is_enum(k, property_name, _get_inner_enum(property_value))
@@ -79,8 +112,8 @@ def _get_inner_enum(pv: dict):
return None
def find_enums_in_context(c: Context) -> list[tuple[str, Any, RustType, Any]]:
enums: dict[RustType, tuple[str, Any, RustType, Any]] = {}
def find_enums_in_context(c: Context) -> list[Enum]:
enums: dict[RustType, tuple[str, Any, Enum]] = {}
for name, s in items(c.schemas):
if UNUSED_TYPE_MARKER in schema_markers(s, c, transitive=True):
continue
@@ -99,24 +132,42 @@ def find_enums_in_context(c: Context) -> list[tuple[str, Any, RustType, Any]]:
for pk, pv in items(parameters):
add_to_enums_if_enum(name, pk, pv, enums)
return list(enums.values())
result = []
for enum in enums.values():
result.append(enum[2])
return result
def add_to_enums_if_enum(schema_name, property_name, property_value,
enums: dict[RustType, tuple[str, Any, RustType, Any]]):
enum = get_enum_if_is_enum(schema_name, property_name, property_value)
enums: dict[RustType, tuple[str, Any, Enum]]):
enum: Enum | None = get_enum_if_is_enum(schema_name, property_name, property_value)
if enum:
existing_enum = enums.get(enum)
existing_enum = enums.get(enum.ty)
if existing_enum:
if existing_enum[2] != enum:
print('WARNING: duplicate enum entry. ', enum.name, schema_name, property_name, property_value)
print('existing enum: ', existing_enum[2].name, existing_enum[0], existing_enum[1], existing_enum[3])
if existing_enum[2].ty != enum.ty or existing_enum[2].variants != enum.variants:
print('WARNING: duplicate enum entry. ', enum.ty, schema_name, property_name, property_value)
print('existing enum: ', existing_enum[2].ty, existing_enum[0], existing_enum[1],
existing_enum[2])
return
enums[enum] = (schema_name, property_name, enum, property_value)
enums[enum.ty] = (schema_name, property_name, enum)
def to_enum_variant_name(name: str) -> str:
c_name = canonical_type_name(name)
c_name = remove_invalid_chars_in_ident(c_name)
return c_name
def to_enum_variant_name(name: str, make_camel_case: bool = True) -> str:
if make_camel_case:
name = canonical_type_name(name)
name = remove_invalid_chars_in_ident(name)
return name
def is_enum_variant_deprecated(description: str | None) -> bool:
if description is None:
return False
s = description.lower()
if 'deprecated' in s:
return True
return False

View File

@@ -24,15 +24,12 @@ re_relative_links = re.compile(r"\]\s*\([^h]")
HTTP_METHODS = set(("OPTIONS", "GET", "POST", "PUT", "DELETE", "HEAD", "TRACE", "CONNECT", "PATCH"))
RESERVED_WORDS = set(('abstract', 'alignof', 'as', 'become', 'box', 'break', 'const', 'continue', 'crate', 'do',
'else', 'enum', 'extern', 'false', 'final', 'fn', 'for', 'if', 'impl', 'in', 'let', 'loop',
'macro', 'match', 'mod', 'move', 'mut', 'offsetof', 'override', 'priv', 'pub', 'pure', 'ref',
'return', 'sizeof', 'static', 'self', 'struct', 'super', 'true', 'trait', 'type', 'typeof',
'unsafe', 'unsized', 'use', 'virtual', 'where', 'while', 'yield'))
TREF = '$ref'
IO_RESPONSE = 'response'
IO_REQUEST = 'request'
@@ -99,10 +96,8 @@ data_unit_multipliers = {
'%': 1,
}
inflection = inflect.engine()
HUB_TYPE_PARAMETERS = ('S',)
@@ -136,6 +131,7 @@ def rust_doc_comment(s):
def use_automatic_links_in_rust_doc_comment(s: str) -> str:
"""Surrounds all links in the text with <>."""
def replace_links(match):
link = match.group()
return f"<{link}>"
@@ -143,10 +139,12 @@ def use_automatic_links_in_rust_doc_comment(s: str) -> str:
# return re_non_hyper_links.sub(replace_links, s)
return s
# returns true if there is an indication for something that is interpreted as doc comment by rustdoc
def has_markdown_codeblock_with_indentation(s):
return re_spaces_after_newline.search(s) != None
def preprocess(base_url, s):
if base_url is None:
print(f"WARNING {s} has no base_url")
@@ -157,16 +155,18 @@ def preprocess(base_url, s):
stdout=subprocess.PIPE,
env={"URL_BASE": base_url or ""}
)
res = p.communicate(s.encode('utf-8'))
exitcode = p.wait(timeout=1)
if exitcode != 0:
raise ValueError(f"Child process exited with non-zero code {exitcode}")
return res[0].decode('utf-8')
def has_relative_links(s):
return re_relative_links.search(s) is not None
# runs the preprocessor in case there is evidence for code blocks using indentation
def rust_doc_sanitize(base_url):
def fixer(s):
@@ -174,6 +174,7 @@ def rust_doc_sanitize(base_url):
return preprocess(base_url, s)
else:
return s
return fixer
@@ -800,6 +801,26 @@ def build_all_params(c, m):
## -- End Activity Utilities -- @}
@dataclass
class EnumVariant:
name: str
""" the rust name of the enum variant """
value: str
""" the value of the enum variant (which is used for the api) """
description: str | None
deprecated: bool
deprecation_message: str | None
@dataclass
class Enum:
ty: RustType
description: str | None
variants: list[EnumVariant]
default: EnumVariant | None
has_deprecated_variants: bool
@dataclass
class Context:
sta_map: Dict[str, Any]
@@ -807,7 +828,7 @@ class Context:
rta_map: Dict[str, Any]
rtc_map: Dict[str, Any]
schemas: Dict[str, Any]
enums: List[Tuple[str, str, RustType, Dict[str, Any]]]
enums: List[Enum]
# return a newly build context from the given data
@@ -1221,18 +1242,17 @@ def rnd_arg_val_for_type(tn: str, c: Context = None) -> str:
return str(RUST_TYPE_RND_MAP[name]())
if c:
from .enum_utils import get_enum_variants, to_enum_variant_name
if tn.startswith("&"): # sometimes the types get passed as ref which doesn't make too much sense here
tn = tn[1:]
for (_, _, enum, values) in c.enums:
if tn == enum.name:
variants = get_enum_variants(values)
for enum in c.enums:
if tn == enum.ty:
variants = enum.variants
if len(variants) > 0:
variant = to_enum_variant_name(variants[0])
variant = variants[0].name
return f"&{tn}::{variant}"
print('Enum has no variants. This is probably not right...', enum, values)
print('Enum has no variants. This is probably not right...', enum)
return "&Default::default()"
@@ -1281,44 +1301,3 @@ def unique(
if __name__ == '__main__':
raise AssertionError('For import only')

View File

@@ -6,7 +6,6 @@
<%namespace name="schema" file="../lib/schema.mako"/>\
<%
from generator.lib.util import (new_context, hub_type, hub_type_params_s)
from generator.lib.enum_utils import (find_enums_in_context)
c = new_context(schemas, resources)
hub_type = hub_type(c.schemas, util.canonical_name())
@@ -18,7 +17,6 @@
use super::*;
% for schema_name,property_name,enum_type, e in c.enums:
${enum.new(enum_type, e, c)}
% for e in c.enums:
${enum.new(e, c)}
% endfor

View File

@@ -7,9 +7,7 @@
REQUEST_MARKER_TRAIT, RESPONSE_MARKER_TRAIT, supports_scopes, to_api_version,
to_fqan, METHODS_RESOURCE, ADD_PARAM_MEDIA_EXAMPLE, PROTOCOL_TYPE_INFO, enclose_in,
upload_action_fn, METHODS_BUILDER_MARKER_TRAIT, DELEGATE_TYPE,
to_extern_crate_name, rust_doc_sanitize)
from generator.lib.enum_utils import (to_enum_variant_name, get_enum_variants, get_enum_variants_descriptions, get_enum_default)
to_extern_crate_name, rust_doc_sanitize, escape_rust_string)
def pretty_name(name):
return ' '.join(split_camelcase_s(name).split('.'))
@@ -72,68 +70,75 @@ impl Default for Scope {
## Builds any generic enum for the API
###############################################################################################
###############################################################################################
<%def name="new(enum_type, e, c)">\
// region ${enum_type}
#[derive(Clone, Copy, Eq, Hash, Debug, PartialEq, Serialize, Deserialize)]
% if e.get('description'):
${rust_doc_comment(e.description)}
<%def name="new(enum, c)">\
// region ${enum.ty}
% if enum.has_deprecated_variants:
#[allow(non_camel_case_types, deprecated)]
% endif
pub enum ${enum_type} {
#[derive(Clone, Copy, Eq, Hash, Debug, PartialEq, Serialize, Deserialize)]
% if enum.description:
${rust_doc_comment(enum.description)}
% endif
pub enum ${enum.ty} {
<%
enum_variants = get_enum_variants(e)
enum_variants = enum.variants
if not enum_variants:
print('enum had no variants', e)
enum_variants = ['NO_VARIANTS_FOUND']
enum_descriptions = get_enum_variants_descriptions(e)
if not enum_descriptions:
enum_descriptions = ['no description found'] * len(enum_variants)
print('enum had no variants', enum)
enum_variants = []
%>\
% for (variant_name,description) in zip(enum_variants, enum_descriptions):
<% #print(variant_name, '=>', description)
%>
% if description:
${rust_doc_comment(description)}
% endif\
% for variant in enum_variants:
% if variant.description:
${rust_doc_comment(variant.description)}
% endif
/// value:
/// "${variant_name}"
#[serde(rename="${variant_name}")]
${to_enum_variant_name(variant_name)},
/// "${variant.value}"
#[serde(rename="${variant.value}")]
% if variant.deprecated:
#[deprecated(note="${escape_rust_string(variant.deprecation_message)}")]
% endif
${variant.name},
% endfor
}
impl AsRef<str> for ${enum_type} {
impl AsRef<str> for ${enum.ty} {
% if enum.has_deprecated_variants:
#[allow(deprecated)]
% endif
fn as_ref(&self) -> &str {
match *self {
% for variant in enum_variants:
${enum_type}::${to_enum_variant_name(variant)} => "${variant}",
${enum.ty}::${variant.name} => "${escape_rust_string(variant.value)}",
% endfor
}
}
}
impl ::std::convert::TryFrom< &str > for ${enum_type} {
impl ::std::convert::TryFrom< &str > for ${enum.ty} {
type Error = ();
fn try_from(value: &str) -> ::std::result::Result<Self, < ${enum_type} as ::std::convert::TryFrom < &str > >::Error> {
% if enum.has_deprecated_variants:
#[allow(deprecated)]
% endif
fn try_from(value: &str) -> ::std::result::Result<Self, < ${enum.ty} as ::std::convert::TryFrom < &str > >::Error> {
match value {
% for variant in enum_variants:
"${variant}" => ::std::result::Result::Ok(${enum_type}::${to_enum_variant_name(variant)}),
"${variant.value}" => ::std::result::Result::Ok(${enum.ty}::${variant.name}),
% endfor
_=> ::std::result::Result::Err(()),
_ => ::std::result::Result::Err(()),
}
}
}
impl<'a> Into<::std::borrow::Cow<'a, str>> for &'a ${enum_type} {
fn into(self) -> ::std::borrow::Cow<'a, str> {
self.as_ref().into()
impl<'a> From < &'a ${enum.ty} > for ::std::borrow::Cow< 'a, str > {
fn from(val: &'a ${enum.ty}) -> Self {
val.as_ref().into()
}
}
% if get_enum_default(e) is not None:
impl ::core::default::Default for ${enum_type} {
fn default() -> ${enum_type} {
${enum_type}::${to_enum_variant_name(e.default)}
% if enum.default is not None:
impl ::core::default::Default for ${enum.ty} {
fn default() -> ${enum.ty} {
${enum.ty}::${enum.default.name}
}
}
% endif