Fix using env with deserialize_with

Fixes #2

I simply forgot to use the `deserialize_with` attribute for the env
deserialization. The previous code was somewhat weirdly coded in that
we would always deserialize `Option<T>` (as it's a partial type) and
the "env variable not present" info would travel through the
deserializer to the `Option<T> as Deserialize` impl. Now it's more
straight forward.
This commit is contained in:
Lukas Kalbertodt
2022-10-14 16:21:34 +02:00
parent 04f7f6b3be
commit eb03488973
4 changed files with 85 additions and 75 deletions

View File

@@ -152,10 +152,15 @@ fn gen_partial_mod(input: &ir::Input) -> TokenStream {
let from_env_fields = input.fields.iter().map(|f| {
match &f.kind {
FieldKind::Leaf { env: Some(key), .. } => {
FieldKind::Leaf { env: Some(key), deserialize_with, .. } => {
let field = format!("{}::{}", input.name, f.name);
quote! {
confique::internal::from_env(#key, #field)?
match deserialize_with {
None => quote! {
confique::internal::from_env(#key, #field)?
},
Some(d) => quote! {
confique::internal::from_env_with(#key, #field, #d)?
},
}
}
FieldKind::Leaf { .. } => quote! { None },

View File

@@ -2,21 +2,16 @@
use std::fmt;
use serde::de::{Error as _, IntoDeserializer};
use serde::de::IntoDeserializer;
pub(crate) fn deserialize<'de, T: serde::Deserialize<'de>>(
value: Option<String>,
) -> Result<T, DeError> {
let mut deserializer = Deserializer { value };
T::deserialize(&mut deserializer)
}
/// Private error type only for deserialization. Gets converted into
/// `ErrorKind::EnvDeserialization` before reaching the public API.
/// Error type only for deserialization of env values.
///
/// Semantically private, only public as it's used in the API of the `internal`
/// module. Gets converted into `ErrorKind::EnvDeserialization` before reaching
/// the real public API.
#[derive(PartialEq)]
pub(crate) struct DeError(pub(crate) String);
pub struct DeError(pub(crate) String);
impl std::error::Error for DeError {}
@@ -42,14 +37,14 @@ impl serde::de::Error for DeError {
}
/// Deserializer type.
struct Deserializer {
value: Option<String>,
/// Deserializer type. Semantically private (see `DeError`).
pub struct Deserializer {
value: String,
}
impl Deserializer {
fn need_value(&mut self) -> Result<String, DeError> {
self.value.take().ok_or_else(|| DeError::custom("environment variable not set"))
pub(crate) fn new(value: String) -> Self {
Self { value }
}
}
@@ -59,8 +54,8 @@ macro_rules! deserialize_via_parse {
where
V: serde::de::Visitor<'de>
{
let s = self.need_value()?;
let v = s.trim().parse().map_err(|e| {
let s = self.value.trim();
let v = s.parse().map_err(|e| {
DeError(format!(
concat!("invalid value '{}' for type ", stringify!($int), ": {}"),
s,
@@ -72,24 +67,21 @@ macro_rules! deserialize_via_parse {
};
}
impl<'de> serde::Deserializer<'de> for &mut Deserializer {
impl<'de> serde::Deserializer<'de> for Deserializer {
type Error = DeError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>
{
match self.value.take() {
None => visitor.visit_none(),
Some(s) => s.into_deserializer().deserialize_any(visitor),
}
self.value.into_deserializer().deserialize_any(visitor)
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>
{
let v = match self.need_value()?.trim() {
let v = match self.value.trim() {
"1" | "true" | "TRUE" => true,
"0" | "false" | "FALSE" => false,
other => return Err(DeError(format!("invalid value for bool: '{}'", other))),
@@ -120,21 +112,12 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer {
visitor.visit_newtype_struct(self)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>
{
match self.value {
None => visitor.visit_none(),
Some(_) => visitor.visit_some(self),
}
}
serde::forward_to_deserialize_any! {
char str string
bytes byte_buf
unit unit_struct
map
option
struct
identifier
ignored_any
@@ -151,50 +134,36 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer {
mod tests {
use super::*;
fn de<'de, T: serde::Deserialize<'de>>(v: impl Into<Option<&'static str>>) -> Result<T, DeError> {
deserialize(v.into().map(|s| s.to_owned()))
fn de<'de, T: serde::Deserialize<'de>>(v: &'static str) -> Result<T, DeError> {
T::deserialize(Deserializer { value: v.into() })
}
#[test]
fn boolean() {
assert_eq!(de("1"), Ok(Some(true)));
assert_eq!(de("true "), Ok(Some(true)));
assert_eq!(de(" TRUE"), Ok(Some(true)));
assert_eq!(de("0 "), Ok(Some(false)));
assert_eq!(de(" false"), Ok(Some(false)));
assert_eq!(de("FALSE "), Ok(Some(false)));
assert_eq!(de(None), Ok(Option::<bool>::None));
assert_eq!(de("1"), Ok(true));
assert_eq!(de("true "), Ok(true));
assert_eq!(de(" TRUE"), Ok(true));
assert_eq!(de("0 "), Ok(false));
assert_eq!(de(" false"), Ok(false));
assert_eq!(de("FALSE "), Ok(false));
}
#[test]
fn ints() {
assert_eq!(de("0"), Ok(Some(0u8)));
assert_eq!(de("-1 "), Ok(Some(-1i8)));
assert_eq!(de(" 27"), Ok(Some(27u16)));
assert_eq!(de("-27"), Ok(Some(-27i16)));
assert_eq!(de(" 4301"), Ok(Some(4301u32)));
assert_eq!(de(" -123456"), Ok(Some(-123456i32)));
assert_eq!(de(" 986543210 "), Ok(Some(986543210u64)));
assert_eq!(de("-986543210"), Ok(Some(-986543210i64)));
assert_eq!(de(None), Ok(Option::<i8>::None));
assert_eq!(de(None), Ok(Option::<u8>::None));
assert_eq!(de(None), Ok(Option::<i16>::None));
assert_eq!(de(None), Ok(Option::<u16>::None));
assert_eq!(de(None), Ok(Option::<i32>::None));
assert_eq!(de(None), Ok(Option::<u32>::None));
assert_eq!(de(None), Ok(Option::<i64>::None));
assert_eq!(de(None), Ok(Option::<u64>::None));
assert_eq!(de("0"), Ok(0u8));
assert_eq!(de("-1 "), Ok(-1i8));
assert_eq!(de(" 27"), Ok(27u16));
assert_eq!(de("-27"), Ok(-27i16));
assert_eq!(de(" 4301"), Ok(4301u32));
assert_eq!(de(" -123456"), Ok(-123456i32));
assert_eq!(de(" 986543210 "), Ok(986543210u64));
assert_eq!(de("-986543210"), Ok(-986543210i64));
}
#[test]
fn floats() {
assert_eq!(de("3.1415"), Ok(Some(3.1415f32)));
assert_eq!(de("-123.456"), Ok(Some(-123.456f64)));
assert_eq!(de(None), Ok(Option::<f32>::None));
assert_eq!(de(None), Ok(Option::<f64>::None));
assert_eq!(de("3.1415"), Ok(3.1415f32));
assert_eq!(de("-123.456"), Ok(-123.456f64));
}
}

View File

@@ -34,6 +34,12 @@ pub(crate) enum ErrorInner {
err: Box<dyn std::error::Error + Send + Sync>,
},
/// When the env variable `key` is not Unicode.
EnvNotUnicode {
field: String,
key: String,
},
/// When deserialization via `env` fails. The string is what is passed to
/// `serde::de::Error::custom`.
EnvDeserialization {
@@ -72,6 +78,7 @@ impl std::error::Error for Error {
#[cfg(any(feature = "toml", feature = "yaml"))]
ErrorInner::Deserialization { err, .. } => Some(&**err),
ErrorInner::MissingValue(_) => None,
ErrorInner::EnvNotUnicode { .. } => None,
ErrorInner::EnvDeserialization { .. } => None,
#[cfg(any(feature = "toml", feature = "yaml"))]
ErrorInner::UnsupportedFileFormat { .. } => None,
@@ -108,6 +115,14 @@ impl fmt::Display for Error {
ErrorInner::Deserialization { source: None, .. } => {
std::write!(f, "failed to deserialize configuration")
}
ErrorInner::EnvNotUnicode { field, key } => {
std::write!(f,
"failed to load value `{}` from environment variable `{}`: \
value is not valid unicode",
field,
key,
)
}
ErrorInner::EnvDeserialization { field, key, msg } => {
std::write!(f,
"failed to deserialize value `{}` from environment variable `{}`: {}",

View File

@@ -32,12 +32,33 @@ pub fn prepend_missing_value_error(e: Error, prefix: &str) -> Error {
}
}
pub fn from_env<'de, T: serde::Deserialize<'de>>(key: &str, field: &str) -> Result<T, Error> {
crate::env::deserialize(std::env::var(key).ok()).map_err(|e| {
ErrorInner::EnvDeserialization {
pub fn from_env<'de, T: serde::Deserialize<'de>>(
key: &str,
field: &str,
) -> Result<Option<T>, Error> {
from_env_with(key, field, |de| T::deserialize(de))
}
pub fn from_env_with<T>(
key: &str,
field: &str,
deserialize: fn(crate::env::Deserializer) -> Result<T, crate::env::DeError>,
) -> Result<Option<T>, Error> {
let s = match std::env::var(key) {
Err(std::env::VarError::NotPresent) => return Ok(None),
Err(std::env::VarError::NotUnicode(_)) => return Err(ErrorInner::EnvNotUnicode {
key: key.into(),
field: field.into(),
}.into()),
Ok(s) => s,
};
match deserialize(crate::env::Deserializer::new(s)) {
Ok(v) => Ok(Some(v)),
Err(e) => Err(ErrorInner::EnvDeserialization {
key: key.into(),
field: field.into(),
msg: e.0,
}.into()
})
}.into()),
}
}