mirror of
https://github.com/OMGeeky/confique.git
synced 2025-12-27 06:29:27 +01:00
Fix using env with deserialize_with
Fixes #2 I simply forgot to use the `deserialize_with` attribute for the env deserialization. The previous code was somewhat weirdly coded in that we would always deserialize `Option<T>` (as it's a partial type) and the "env variable not present" info would travel through the deserializer to the `Option<T> as Deserialize` impl. Now it's more straight forward.
This commit is contained in:
@@ -152,10 +152,15 @@ fn gen_partial_mod(input: &ir::Input) -> TokenStream {
|
||||
|
||||
let from_env_fields = input.fields.iter().map(|f| {
|
||||
match &f.kind {
|
||||
FieldKind::Leaf { env: Some(key), .. } => {
|
||||
FieldKind::Leaf { env: Some(key), deserialize_with, .. } => {
|
||||
let field = format!("{}::{}", input.name, f.name);
|
||||
quote! {
|
||||
confique::internal::from_env(#key, #field)?
|
||||
match deserialize_with {
|
||||
None => quote! {
|
||||
confique::internal::from_env(#key, #field)?
|
||||
},
|
||||
Some(d) => quote! {
|
||||
confique::internal::from_env_with(#key, #field, #d)?
|
||||
},
|
||||
}
|
||||
}
|
||||
FieldKind::Leaf { .. } => quote! { None },
|
||||
|
||||
103
src/env.rs
103
src/env.rs
@@ -2,21 +2,16 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use serde::de::{Error as _, IntoDeserializer};
|
||||
use serde::de::IntoDeserializer;
|
||||
|
||||
|
||||
pub(crate) fn deserialize<'de, T: serde::Deserialize<'de>>(
|
||||
value: Option<String>,
|
||||
) -> Result<T, DeError> {
|
||||
let mut deserializer = Deserializer { value };
|
||||
T::deserialize(&mut deserializer)
|
||||
}
|
||||
|
||||
|
||||
/// Private error type only for deserialization. Gets converted into
|
||||
/// `ErrorKind::EnvDeserialization` before reaching the public API.
|
||||
/// Error type only for deserialization of env values.
|
||||
///
|
||||
/// Semantically private, only public as it's used in the API of the `internal`
|
||||
/// module. Gets converted into `ErrorKind::EnvDeserialization` before reaching
|
||||
/// the real public API.
|
||||
#[derive(PartialEq)]
|
||||
pub(crate) struct DeError(pub(crate) String);
|
||||
pub struct DeError(pub(crate) String);
|
||||
|
||||
impl std::error::Error for DeError {}
|
||||
|
||||
@@ -42,14 +37,14 @@ impl serde::de::Error for DeError {
|
||||
}
|
||||
|
||||
|
||||
/// Deserializer type.
|
||||
struct Deserializer {
|
||||
value: Option<String>,
|
||||
/// Deserializer type. Semantically private (see `DeError`).
|
||||
pub struct Deserializer {
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl Deserializer {
|
||||
fn need_value(&mut self) -> Result<String, DeError> {
|
||||
self.value.take().ok_or_else(|| DeError::custom("environment variable not set"))
|
||||
pub(crate) fn new(value: String) -> Self {
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,8 +54,8 @@ macro_rules! deserialize_via_parse {
|
||||
where
|
||||
V: serde::de::Visitor<'de>
|
||||
{
|
||||
let s = self.need_value()?;
|
||||
let v = s.trim().parse().map_err(|e| {
|
||||
let s = self.value.trim();
|
||||
let v = s.parse().map_err(|e| {
|
||||
DeError(format!(
|
||||
concat!("invalid value '{}' for type ", stringify!($int), ": {}"),
|
||||
s,
|
||||
@@ -72,24 +67,21 @@ macro_rules! deserialize_via_parse {
|
||||
};
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserializer<'de> for &mut Deserializer {
|
||||
impl<'de> serde::Deserializer<'de> for Deserializer {
|
||||
type Error = DeError;
|
||||
|
||||
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: serde::de::Visitor<'de>
|
||||
{
|
||||
match self.value.take() {
|
||||
None => visitor.visit_none(),
|
||||
Some(s) => s.into_deserializer().deserialize_any(visitor),
|
||||
}
|
||||
self.value.into_deserializer().deserialize_any(visitor)
|
||||
}
|
||||
|
||||
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: serde::de::Visitor<'de>
|
||||
{
|
||||
let v = match self.need_value()?.trim() {
|
||||
let v = match self.value.trim() {
|
||||
"1" | "true" | "TRUE" => true,
|
||||
"0" | "false" | "FALSE" => false,
|
||||
other => return Err(DeError(format!("invalid value for bool: '{}'", other))),
|
||||
@@ -120,21 +112,12 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer {
|
||||
visitor.visit_newtype_struct(self)
|
||||
}
|
||||
|
||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: serde::de::Visitor<'de>
|
||||
{
|
||||
match self.value {
|
||||
None => visitor.visit_none(),
|
||||
Some(_) => visitor.visit_some(self),
|
||||
}
|
||||
}
|
||||
|
||||
serde::forward_to_deserialize_any! {
|
||||
char str string
|
||||
bytes byte_buf
|
||||
unit unit_struct
|
||||
map
|
||||
option
|
||||
struct
|
||||
identifier
|
||||
ignored_any
|
||||
@@ -151,50 +134,36 @@ impl<'de> serde::Deserializer<'de> for &mut Deserializer {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn de<'de, T: serde::Deserialize<'de>>(v: impl Into<Option<&'static str>>) -> Result<T, DeError> {
|
||||
deserialize(v.into().map(|s| s.to_owned()))
|
||||
fn de<'de, T: serde::Deserialize<'de>>(v: &'static str) -> Result<T, DeError> {
|
||||
T::deserialize(Deserializer { value: v.into() })
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn boolean() {
|
||||
assert_eq!(de("1"), Ok(Some(true)));
|
||||
assert_eq!(de("true "), Ok(Some(true)));
|
||||
assert_eq!(de(" TRUE"), Ok(Some(true)));
|
||||
assert_eq!(de("0 "), Ok(Some(false)));
|
||||
assert_eq!(de(" false"), Ok(Some(false)));
|
||||
assert_eq!(de("FALSE "), Ok(Some(false)));
|
||||
|
||||
assert_eq!(de(None), Ok(Option::<bool>::None));
|
||||
assert_eq!(de("1"), Ok(true));
|
||||
assert_eq!(de("true "), Ok(true));
|
||||
assert_eq!(de(" TRUE"), Ok(true));
|
||||
assert_eq!(de("0 "), Ok(false));
|
||||
assert_eq!(de(" false"), Ok(false));
|
||||
assert_eq!(de("FALSE "), Ok(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ints() {
|
||||
assert_eq!(de("0"), Ok(Some(0u8)));
|
||||
assert_eq!(de("-1 "), Ok(Some(-1i8)));
|
||||
assert_eq!(de(" 27"), Ok(Some(27u16)));
|
||||
assert_eq!(de("-27"), Ok(Some(-27i16)));
|
||||
assert_eq!(de(" 4301"), Ok(Some(4301u32)));
|
||||
assert_eq!(de(" -123456"), Ok(Some(-123456i32)));
|
||||
assert_eq!(de(" 986543210 "), Ok(Some(986543210u64)));
|
||||
assert_eq!(de("-986543210"), Ok(Some(-986543210i64)));
|
||||
|
||||
assert_eq!(de(None), Ok(Option::<i8>::None));
|
||||
assert_eq!(de(None), Ok(Option::<u8>::None));
|
||||
assert_eq!(de(None), Ok(Option::<i16>::None));
|
||||
assert_eq!(de(None), Ok(Option::<u16>::None));
|
||||
assert_eq!(de(None), Ok(Option::<i32>::None));
|
||||
assert_eq!(de(None), Ok(Option::<u32>::None));
|
||||
assert_eq!(de(None), Ok(Option::<i64>::None));
|
||||
assert_eq!(de(None), Ok(Option::<u64>::None));
|
||||
assert_eq!(de("0"), Ok(0u8));
|
||||
assert_eq!(de("-1 "), Ok(-1i8));
|
||||
assert_eq!(de(" 27"), Ok(27u16));
|
||||
assert_eq!(de("-27"), Ok(-27i16));
|
||||
assert_eq!(de(" 4301"), Ok(4301u32));
|
||||
assert_eq!(de(" -123456"), Ok(-123456i32));
|
||||
assert_eq!(de(" 986543210 "), Ok(986543210u64));
|
||||
assert_eq!(de("-986543210"), Ok(-986543210i64));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floats() {
|
||||
assert_eq!(de("3.1415"), Ok(Some(3.1415f32)));
|
||||
assert_eq!(de("-123.456"), Ok(Some(-123.456f64)));
|
||||
|
||||
assert_eq!(de(None), Ok(Option::<f32>::None));
|
||||
assert_eq!(de(None), Ok(Option::<f64>::None));
|
||||
assert_eq!(de("3.1415"), Ok(3.1415f32));
|
||||
assert_eq!(de("-123.456"), Ok(-123.456f64));
|
||||
}
|
||||
}
|
||||
|
||||
15
src/error.rs
15
src/error.rs
@@ -34,6 +34,12 @@ pub(crate) enum ErrorInner {
|
||||
err: Box<dyn std::error::Error + Send + Sync>,
|
||||
},
|
||||
|
||||
/// When the env variable `key` is not Unicode.
|
||||
EnvNotUnicode {
|
||||
field: String,
|
||||
key: String,
|
||||
},
|
||||
|
||||
/// When deserialization via `env` fails. The string is what is passed to
|
||||
/// `serde::de::Error::custom`.
|
||||
EnvDeserialization {
|
||||
@@ -72,6 +78,7 @@ impl std::error::Error for Error {
|
||||
#[cfg(any(feature = "toml", feature = "yaml"))]
|
||||
ErrorInner::Deserialization { err, .. } => Some(&**err),
|
||||
ErrorInner::MissingValue(_) => None,
|
||||
ErrorInner::EnvNotUnicode { .. } => None,
|
||||
ErrorInner::EnvDeserialization { .. } => None,
|
||||
#[cfg(any(feature = "toml", feature = "yaml"))]
|
||||
ErrorInner::UnsupportedFileFormat { .. } => None,
|
||||
@@ -108,6 +115,14 @@ impl fmt::Display for Error {
|
||||
ErrorInner::Deserialization { source: None, .. } => {
|
||||
std::write!(f, "failed to deserialize configuration")
|
||||
}
|
||||
ErrorInner::EnvNotUnicode { field, key } => {
|
||||
std::write!(f,
|
||||
"failed to load value `{}` from environment variable `{}`: \
|
||||
value is not valid unicode",
|
||||
field,
|
||||
key,
|
||||
)
|
||||
}
|
||||
ErrorInner::EnvDeserialization { field, key, msg } => {
|
||||
std::write!(f,
|
||||
"failed to deserialize value `{}` from environment variable `{}`: {}",
|
||||
|
||||
@@ -32,12 +32,33 @@ pub fn prepend_missing_value_error(e: Error, prefix: &str) -> Error {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env<'de, T: serde::Deserialize<'de>>(key: &str, field: &str) -> Result<T, Error> {
|
||||
crate::env::deserialize(std::env::var(key).ok()).map_err(|e| {
|
||||
ErrorInner::EnvDeserialization {
|
||||
pub fn from_env<'de, T: serde::Deserialize<'de>>(
|
||||
key: &str,
|
||||
field: &str,
|
||||
) -> Result<Option<T>, Error> {
|
||||
from_env_with(key, field, |de| T::deserialize(de))
|
||||
}
|
||||
|
||||
pub fn from_env_with<T>(
|
||||
key: &str,
|
||||
field: &str,
|
||||
deserialize: fn(crate::env::Deserializer) -> Result<T, crate::env::DeError>,
|
||||
) -> Result<Option<T>, Error> {
|
||||
let s = match std::env::var(key) {
|
||||
Err(std::env::VarError::NotPresent) => return Ok(None),
|
||||
Err(std::env::VarError::NotUnicode(_)) => return Err(ErrorInner::EnvNotUnicode {
|
||||
key: key.into(),
|
||||
field: field.into(),
|
||||
}.into()),
|
||||
Ok(s) => s,
|
||||
};
|
||||
|
||||
match deserialize(crate::env::Deserializer::new(s)) {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(e) => Err(ErrorInner::EnvDeserialization {
|
||||
key: key.into(),
|
||||
field: field.into(),
|
||||
msg: e.0,
|
||||
}.into()
|
||||
})
|
||||
}.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user