From e2dded17fabe3eb6ab8d6652a089aa1cafb5f763 Mon Sep 17 00:00:00 2001 From: Lukas Kalbertodt Date: Sun, 6 Nov 2022 11:23:21 +0100 Subject: [PATCH] Change `parse_env` error type from `Debug` to `impl std::error::Error` This is the more appropriate trait I think and should work well for most real world use cases. --- examples/parse_env.rs | 7 ++---- macro/src/gen/mod.rs | 7 +++--- src/error.rs | 12 ++++++++++ src/internal.rs | 51 +++++++++++++++++++++++-------------------- tests/general.rs | 10 ++++----- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/examples/parse_env.rs b/examples/parse_env.rs index e9d56ff..266a493 100644 --- a/examples/parse_env.rs +++ b/examples/parse_env.rs @@ -6,7 +6,7 @@ use confique::{ }, Config, }; -use std::{collections::HashSet, num::NonZeroU64, path::PathBuf, str::FromStr}; +use std::{collections::HashSet, num::NonZeroU64, path::PathBuf, str::FromStr, convert::Infallible}; #[derive(Debug, Config)] struct Conf { @@ -45,10 +45,7 @@ enum Format { Yaml, } -#[derive(Debug)] -enum Error {} - -fn parse_formats(input: &str) -> Result, Error> { +fn parse_formats(input: &str) -> Result, Infallible> { let mut result = Vec::new(); if input.contains("toml") { diff --git a/macro/src/gen/mod.rs b/macro/src/gen/mod.rs index d51c698..93f5b33 100644 --- a/macro/src/gen/mod.rs +++ b/macro/src/gen/mod.rs @@ -156,10 +156,11 @@ fn gen_partial_mod(input: &ir::Input) -> TokenStream { confique::internal::from_env(#key, #field)? }, (None, Some(deserialize_with)) => quote! { - confique::internal::deserialize_from_env_with(#key, #field, #deserialize_with)? + confique::internal::from_env_with_deserializer( + #key, #field, #deserialize_with)? }, - (Some(parse_env), None) | (Some(parse_env), Some(_)) => quote! { - confique::internal::parse_from_env_with(#key, #field, #parse_env)? + (Some(parse_env), _) => quote! { + confique::internal::from_env_with_parser(#key, #field, #parse_env)? }, } } diff --git a/src/error.rs b/src/error.rs index f3e7e2e..f46fa02 100644 --- a/src/error.rs +++ b/src/error.rs @@ -50,6 +50,13 @@ pub(crate) enum ErrorInner { msg: String, }, + /// When a custom `parse_env` function fails. + EnvParseError { + field: String, + key: String, + err: Box, + }, + /// Returned by the [`Source`] impls for `Path` and `PathBuf` if the file /// extension is not supported by confique or if the corresponding Cargo /// feature of confique was not enabled. @@ -71,6 +78,7 @@ impl std::error::Error for Error { ErrorInner::MissingValue(_) => None, ErrorInner::EnvNotUnicode { .. } => None, ErrorInner::EnvDeserialization { .. } => None, + ErrorInner::EnvParseError { err, .. } => Some(&**err), ErrorInner::UnsupportedFileFormat { .. } => None, ErrorInner::MissingFileExtension { .. } => None, ErrorInner::MissingRequiredFile { .. } => None, @@ -107,6 +115,10 @@ impl fmt::Display for Error { std::write!(f, "failed to deserialize value `{field}` from \ environment variable `{key}`: {msg}") } + ErrorInner::EnvParseError { field, key, err } => { + std::write!(f, "failed to parse environment variable `{key}` into \ + field `{field}`: {err}") + } ErrorInner::UnsupportedFileFormat { path } => { std::write!(f, "unknown configuration file format/extension: '{}'", diff --git a/src/internal.rs b/src/internal.rs index c0b4856..3ad175f 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -2,8 +2,6 @@ //! intended to be used directly. None of this is covered by semver! Do not use //! any of this directly. -use std::fmt::Debug; - use crate::{error::ErrorInner, Error}; pub fn deserialize_default(src: I) -> Result @@ -38,48 +36,53 @@ pub fn map_err_prefix_path(res: Result, prefix: &str) -> Result { + match std::env::var($key) { + Err(std::env::VarError::NotPresent) => return Ok(None), + Err(std::env::VarError::NotUnicode(_)) => { + let err = ErrorInner::EnvNotUnicode { + key: $key.into(), + field: $field.into(), + }; + return Err(err.into()); + } + Ok(s) => s, + } + }; +} + pub fn from_env<'de, T: serde::Deserialize<'de>>( key: &str, field: &str, ) -> Result, Error> { - deserialize_from_env_with(key, field, |de| T::deserialize(de)) + from_env_with_deserializer(key, field, |de| T::deserialize(de)) } -pub fn parse_from_env_with( +pub fn from_env_with_parser( key: &str, field: &str, parse: fn(&str) -> Result, ) -> Result, Error> { - from_env::(key, field)? - .as_deref() - .map(parse) - .transpose() + let v = get_env_var!(key, field); + parse(&v) + .map(Some) .map_err(|err| { - ErrorInner::EnvDeserialization { + ErrorInner::EnvParseError { field: field.to_owned(), key: key.to_owned(), - msg: format!("Error while parse: {:?}", err), - } - .into() + err: Box::new(err), + }.into() }) } -pub fn deserialize_from_env_with( +pub fn from_env_with_deserializer( key: &str, field: &str, deserialize: fn(crate::env::Deserializer) -> Result, ) -> Result, Error> { - let s = match std::env::var(key) { - Err(std::env::VarError::NotPresent) => return Ok(None), - Err(std::env::VarError::NotUnicode(_)) => { - let err = ErrorInner::EnvNotUnicode { - key: key.into(), - field: field.into(), - }; - return Err(err.into()); - } - Ok(s) => s, - }; + let s = get_env_var!(key, field); match deserialize(crate::env::Deserializer::new(s)) { Ok(v) => Ok(Some(v)), diff --git a/tests/general.rs b/tests/general.rs index d33b36b..493f689 100644 --- a/tests/general.rs +++ b/tests/general.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, net::IpAddr, path::PathBuf}; +use std::{collections::HashMap, net::IpAddr, path::PathBuf, convert::Infallible}; use pretty_assertions::assert_eq; use serde::Deserialize; @@ -114,7 +114,7 @@ mod full { optional: Option, #[config(env = "ENV_TEST_FULL_4", parse_env = parse_dummy_collection)] - env_collection: DummyCollection<','>, + env_collection: DummyCollection, } } @@ -294,11 +294,11 @@ where } #[derive(Debug, PartialEq, Deserialize)] -struct DummyCollection(Vec); +struct DummyCollection(Vec); -pub(crate) fn parse_dummy_collection(input: &str) -> Result, String> { +pub(crate) fn parse_dummy_collection(input: &str) -> Result { Ok(DummyCollection( - input.split(SEPARATOR).map(ToString::to_string).collect(), + input.split(',').map(ToString::to_string).collect(), )) }