From eb034889730744b469c8d8f46db650e1684811e0 Mon Sep 17 00:00:00 2001 From: Lukas Kalbertodt Date: Fri, 14 Oct 2022 16:21:34 +0200 Subject: [PATCH] Fix using `env` with `deserialize_with` Fixes #2 I simply forgot to use the `deserialize_with` attribute for the env deserialization. The previous code was somewhat weirdly coded in that we would always deserialize `Option` (as it's a partial type) and the "env variable not present" info would travel through the deserializer to the `Option as Deserialize` impl. Now it's more straight forward. --- macro/src/gen.rs | 11 +++-- src/env.rs | 103 +++++++++++++++++------------------------------ src/error.rs | 15 +++++++ src/internal.rs | 31 +++++++++++--- 4 files changed, 85 insertions(+), 75 deletions(-) diff --git a/macro/src/gen.rs b/macro/src/gen.rs index e7f6b40..0248cc0 100644 --- a/macro/src/gen.rs +++ b/macro/src/gen.rs @@ -152,10 +152,15 @@ fn gen_partial_mod(input: &ir::Input) -> TokenStream { let from_env_fields = input.fields.iter().map(|f| { match &f.kind { - FieldKind::Leaf { env: Some(key), .. } => { + FieldKind::Leaf { env: Some(key), deserialize_with, .. } => { let field = format!("{}::{}", input.name, f.name); - quote! { - confique::internal::from_env(#key, #field)? + match deserialize_with { + None => quote! { + confique::internal::from_env(#key, #field)? + }, + Some(d) => quote! { + confique::internal::from_env_with(#key, #field, #d)? + }, } } FieldKind::Leaf { .. } => quote! { None }, diff --git a/src/env.rs b/src/env.rs index 8bf5fa6..fb0d7e0 100644 --- a/src/env.rs +++ b/src/env.rs @@ -2,21 +2,16 @@ use std::fmt; -use serde::de::{Error as _, IntoDeserializer}; +use serde::de::IntoDeserializer; -pub(crate) fn deserialize<'de, T: serde::Deserialize<'de>>( - value: Option, -) -> Result { - let mut deserializer = Deserializer { value }; - T::deserialize(&mut deserializer) -} - - -/// Private error type only for deserialization. Gets converted into -/// `ErrorKind::EnvDeserialization` before reaching the public API. +/// Error type only for deserialization of env values. +/// +/// Semantically private, only public as it's used in the API of the `internal` +/// module. Gets converted into `ErrorKind::EnvDeserialization` before reaching +/// the real public API. #[derive(PartialEq)] -pub(crate) struct DeError(pub(crate) String); +pub struct DeError(pub(crate) String); impl std::error::Error for DeError {} @@ -42,14 +37,14 @@ impl serde::de::Error for DeError { } -/// Deserializer type. -struct Deserializer { - value: Option, +/// Deserializer type. Semantically private (see `DeError`). +pub struct Deserializer { + value: String, } impl Deserializer { - fn need_value(&mut self) -> Result { - self.value.take().ok_or_else(|| DeError::custom("environment variable not set")) + pub(crate) fn new(value: String) -> Self { + Self { value } } } @@ -59,8 +54,8 @@ macro_rules! deserialize_via_parse { where V: serde::de::Visitor<'de> { - let s = self.need_value()?; - let v = s.trim().parse().map_err(|e| { + let s = self.value.trim(); + let v = s.parse().map_err(|e| { DeError(format!( concat!("invalid value '{}' for type ", stringify!($int), ": {}"), s, @@ -72,24 +67,21 @@ macro_rules! deserialize_via_parse { }; } -impl<'de> serde::Deserializer<'de> for &mut Deserializer { +impl<'de> serde::Deserializer<'de> for Deserializer { type Error = DeError; fn deserialize_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de> { - match self.value.take() { - None => visitor.visit_none(), - Some(s) => s.into_deserializer().deserialize_any(visitor), - } + self.value.into_deserializer().deserialize_any(visitor) } fn deserialize_bool(self, visitor: V) -> Result where V: serde::de::Visitor<'de> { - let v = match self.need_value()?.trim() { + let v = match self.value.trim() { "1" | "true" | "TRUE" => true, "0" | "false" | "FALSE" => false, other => return Err(DeError(format!("invalid value for bool: '{}'", other))), @@ -120,21 +112,12 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer { visitor.visit_newtype_struct(self) } - fn deserialize_option(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de> - { - match self.value { - None => visitor.visit_none(), - Some(_) => visitor.visit_some(self), - } - } - serde::forward_to_deserialize_any! { char str string bytes byte_buf unit unit_struct map + option struct identifier ignored_any @@ -151,50 +134,36 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer { mod tests { use super::*; - fn de<'de, T: serde::Deserialize<'de>>(v: impl Into>) -> Result { - deserialize(v.into().map(|s| s.to_owned())) + fn de<'de, T: serde::Deserialize<'de>>(v: &'static str) -> Result { + T::deserialize(Deserializer { value: v.into() }) } #[test] fn boolean() { - assert_eq!(de("1"), Ok(Some(true))); - assert_eq!(de("true "), Ok(Some(true))); - assert_eq!(de(" TRUE"), Ok(Some(true))); - assert_eq!(de("0 "), Ok(Some(false))); - assert_eq!(de(" false"), Ok(Some(false))); - assert_eq!(de("FALSE "), Ok(Some(false))); - - assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de("1"), Ok(true)); + assert_eq!(de("true "), Ok(true)); + assert_eq!(de(" TRUE"), Ok(true)); + assert_eq!(de("0 "), Ok(false)); + assert_eq!(de(" false"), Ok(false)); + assert_eq!(de("FALSE "), Ok(false)); } #[test] fn ints() { - assert_eq!(de("0"), Ok(Some(0u8))); - assert_eq!(de("-1 "), Ok(Some(-1i8))); - assert_eq!(de(" 27"), Ok(Some(27u16))); - assert_eq!(de("-27"), Ok(Some(-27i16))); - assert_eq!(de(" 4301"), Ok(Some(4301u32))); - assert_eq!(de(" -123456"), Ok(Some(-123456i32))); - assert_eq!(de(" 986543210 "), Ok(Some(986543210u64))); - assert_eq!(de("-986543210"), Ok(Some(-986543210i64))); - - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de("0"), Ok(0u8)); + assert_eq!(de("-1 "), Ok(-1i8)); + assert_eq!(de(" 27"), Ok(27u16)); + assert_eq!(de("-27"), Ok(-27i16)); + assert_eq!(de(" 4301"), Ok(4301u32)); + assert_eq!(de(" -123456"), Ok(-123456i32)); + assert_eq!(de(" 986543210 "), Ok(986543210u64)); + assert_eq!(de("-986543210"), Ok(-986543210i64)); } #[test] fn floats() { - assert_eq!(de("3.1415"), Ok(Some(3.1415f32))); - assert_eq!(de("-123.456"), Ok(Some(-123.456f64))); - - assert_eq!(de(None), Ok(Option::::None)); - assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de("3.1415"), Ok(3.1415f32)); + assert_eq!(de("-123.456"), Ok(-123.456f64)); } } diff --git a/src/error.rs b/src/error.rs index 136bcb5..9367d4e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,12 @@ pub(crate) enum ErrorInner { err: Box, }, + /// When the env variable `key` is not Unicode. + EnvNotUnicode { + field: String, + key: String, + }, + /// When deserialization via `env` fails. The string is what is passed to /// `serde::de::Error::custom`. EnvDeserialization { @@ -72,6 +78,7 @@ impl std::error::Error for Error { #[cfg(any(feature = "toml", feature = "yaml"))] ErrorInner::Deserialization { err, .. } => Some(&**err), ErrorInner::MissingValue(_) => None, + ErrorInner::EnvNotUnicode { .. } => None, ErrorInner::EnvDeserialization { .. } => None, #[cfg(any(feature = "toml", feature = "yaml"))] ErrorInner::UnsupportedFileFormat { .. } => None, @@ -108,6 +115,14 @@ impl fmt::Display for Error { ErrorInner::Deserialization { source: None, .. } => { std::write!(f, "failed to deserialize configuration") } + ErrorInner::EnvNotUnicode { field, key } => { + std::write!(f, + "failed to load value `{}` from environment variable `{}`: \ + value is not valid unicode", + field, + key, + ) + } ErrorInner::EnvDeserialization { field, key, msg } => { std::write!(f, "failed to deserialize value `{}` from environment variable `{}`: {}", diff --git a/src/internal.rs b/src/internal.rs index 4db4b1b..75ca6e1 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -32,12 +32,33 @@ pub fn prepend_missing_value_error(e: Error, prefix: &str) -> Error { } } -pub fn from_env<'de, T: serde::Deserialize<'de>>(key: &str, field: &str) -> Result { - crate::env::deserialize(std::env::var(key).ok()).map_err(|e| { - ErrorInner::EnvDeserialization { +pub fn from_env<'de, T: serde::Deserialize<'de>>( + key: &str, + field: &str, +) -> Result, Error> { + from_env_with(key, field, |de| T::deserialize(de)) +} + +pub fn from_env_with( + key: &str, + field: &str, + deserialize: fn(crate::env::Deserializer) -> Result, +) -> Result, Error> { + let s = match std::env::var(key) { + Err(std::env::VarError::NotPresent) => return Ok(None), + Err(std::env::VarError::NotUnicode(_)) => return Err(ErrorInner::EnvNotUnicode { + key: key.into(), + field: field.into(), + }.into()), + Ok(s) => s, + }; + + match deserialize(crate::env::Deserializer::new(s)) { + Ok(v) => Ok(Some(v)), + Err(e) => Err(ErrorInner::EnvDeserialization { key: key.into(), field: field.into(), msg: e.0, - }.into() - }) + }.into()), + } }