From 3ac922ca2f62ad8423d51473f48bc4b6fab42650 Mon Sep 17 00:00:00 2001 From: Lukas Kalbertodt Date: Tue, 27 Jul 2021 01:10:12 +0200 Subject: [PATCH] Add deserialization logic for reading values from environment variables --- src/env.rs | 200 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 201 insertions(+) create mode 100644 src/env.rs diff --git a/src/env.rs b/src/env.rs new file mode 100644 index 0000000..8bf5fa6 --- /dev/null +++ b/src/env.rs @@ -0,0 +1,200 @@ +//! Deserialize values from environment variables. + +use std::fmt; + +use serde::de::{Error as _, IntoDeserializer}; + + +pub(crate) fn deserialize<'de, T: serde::Deserialize<'de>>( + value: Option, +) -> Result { + let mut deserializer = Deserializer { value }; + T::deserialize(&mut deserializer) +} + + +/// Private error type only for deserialization. Gets converted into +/// `ErrorKind::EnvDeserialization` before reaching the public API. +#[derive(PartialEq)] +pub(crate) struct DeError(pub(crate) String); + +impl std::error::Error for DeError {} + +impl fmt::Debug for DeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for DeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl serde::de::Error for DeError { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + Self(msg.to_string()) + } +} + + +/// Deserializer type. +struct Deserializer { + value: Option, +} + +impl Deserializer { + fn need_value(&mut self) -> Result { + self.value.take().ok_or_else(|| DeError::custom("environment variable not set")) + } +} + +macro_rules! deserialize_via_parse { + ($method:ident, $visit_method:ident, $int:ident) => { + fn $method(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de> + { + let s = self.need_value()?; + let v = s.trim().parse().map_err(|e| { + DeError(format!( + concat!("invalid value '{}' for type ", stringify!($int), ": {}"), + s, + e, + )) + })?; + visitor.$visit_method(v) + } + }; +} + +impl<'de> serde::Deserializer<'de> for &mut Deserializer { + type Error = DeError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de> + { + match self.value.take() { + None => visitor.visit_none(), + Some(s) => s.into_deserializer().deserialize_any(visitor), + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de> + { + let v = match self.need_value()?.trim() { + "1" | "true" | "TRUE" => true, + "0" | "false" | "FALSE" => false, + other => return Err(DeError(format!("invalid value for bool: '{}'", other))), + }; + + visitor.visit_bool(v) + } + + deserialize_via_parse!(deserialize_i8, visit_i8, i8); + deserialize_via_parse!(deserialize_i16, visit_i16, i16); + deserialize_via_parse!(deserialize_i32, visit_i32, i32); + deserialize_via_parse!(deserialize_i64, visit_i64, i64); + deserialize_via_parse!(deserialize_u8, visit_u8, u8); + deserialize_via_parse!(deserialize_u16, visit_u16, u16); + deserialize_via_parse!(deserialize_u32, visit_u32, u32); + deserialize_via_parse!(deserialize_u64, visit_u64, u64); + deserialize_via_parse!(deserialize_f32, visit_f32, f32); + deserialize_via_parse!(deserialize_f64, visit_f64, f64); + + fn deserialize_newtype_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de> + { + match self.value { + None => visitor.visit_none(), + Some(_) => visitor.visit_some(self), + } + } + + serde::forward_to_deserialize_any! { + char str string + bytes byte_buf + unit unit_struct + map + struct + identifier + ignored_any + + // TODO: think about manually implementing these + enum + seq + tuple tuple_struct + } +} + + +#[cfg(test)] +mod tests { + use super::*; + + fn de<'de, T: serde::Deserialize<'de>>(v: impl Into>) -> Result { + deserialize(v.into().map(|s| s.to_owned())) + } + + + #[test] + fn boolean() { + assert_eq!(de("1"), Ok(Some(true))); + assert_eq!(de("true "), Ok(Some(true))); + assert_eq!(de(" TRUE"), Ok(Some(true))); + assert_eq!(de("0 "), Ok(Some(false))); + assert_eq!(de(" false"), Ok(Some(false))); + assert_eq!(de("FALSE "), Ok(Some(false))); + + assert_eq!(de(None), Ok(Option::::None)); + } + + #[test] + fn ints() { + assert_eq!(de("0"), Ok(Some(0u8))); + assert_eq!(de("-1 "), Ok(Some(-1i8))); + assert_eq!(de(" 27"), Ok(Some(27u16))); + assert_eq!(de("-27"), Ok(Some(-27i16))); + assert_eq!(de(" 4301"), Ok(Some(4301u32))); + assert_eq!(de(" -123456"), Ok(Some(-123456i32))); + assert_eq!(de(" 986543210 "), Ok(Some(986543210u64))); + assert_eq!(de("-986543210"), Ok(Some(-986543210i64))); + + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + } + + #[test] + fn floats() { + assert_eq!(de("3.1415"), Ok(Some(3.1415f32))); + assert_eq!(de("-123.456"), Ok(Some(-123.456f64))); + + assert_eq!(de(None), Ok(Option::::None)); + assert_eq!(de(None), Ok(Option::::None)); + } +} diff --git a/src/lib.rs b/src/lib.rs index f808cae..865f30c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ use serde::Deserialize; pub mod internal; mod builder; +mod env; mod error; mod file; pub mod meta;