diff options
Diffstat (limited to 'alacritty_config_derive')
-rw-r--r-- | alacritty_config_derive/Cargo.toml | 21 | ||||
l--------- | alacritty_config_derive/LICENSE-APACHE | 1 | ||||
-rw-r--r-- | alacritty_config_derive/LICENSE-MIT | 23 | ||||
-rw-r--r-- | alacritty_config_derive/src/de_enum.rs | 66 | ||||
-rw-r--r-- | alacritty_config_derive/src/de_struct.rs | 226 | ||||
-rw-r--r-- | alacritty_config_derive/src/lib.rs | 27 | ||||
-rw-r--r-- | alacritty_config_derive/tests/config.rs | 155 |
7 files changed, 519 insertions, 0 deletions
diff --git a/alacritty_config_derive/Cargo.toml b/alacritty_config_derive/Cargo.toml new file mode 100644 index 00000000..8d6d9c4b --- /dev/null +++ b/alacritty_config_derive/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "alacritty_config_derive" +version = "0.1.0" +authors = ["Christian Duerr <contact@christianduerr.com>"] +license = "MIT/Apache-2.0" +description = "Failure resistant deserialization derive" +homepage = "https://github.com/alacritty/alacritty" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0.53", features = ["derive", "parsing", "proc-macro", "printing"], default-features = false } +proc-macro2 = "1.0.24" +quote = "1.0.7" + +[dev-dependencies] +serde_yaml = "0.8.14" +serde = "1.0.117" +log = "0.4.11" diff --git a/alacritty_config_derive/LICENSE-APACHE b/alacritty_config_derive/LICENSE-APACHE new file mode 120000 index 00000000..965b606f --- /dev/null +++ b/alacritty_config_derive/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE
\ No newline at end of file diff --git a/alacritty_config_derive/LICENSE-MIT b/alacritty_config_derive/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/alacritty_config_derive/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/alacritty_config_derive/src/de_enum.rs b/alacritty_config_derive/src/de_enum.rs new file mode 100644 index 00000000..98247c0c --- /dev/null +++ b/alacritty_config_derive/src/de_enum.rs @@ -0,0 +1,66 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{DataEnum, Ident}; + +pub fn derive_deserialize(ident: Ident, data_enum: DataEnum) -> TokenStream { + let visitor = format_ident!("{}Visitor", ident); + + // Create match arm streams and get a list with all available values. + let mut match_arms_stream = TokenStream2::new(); + let mut available_values = String::from("one of "); + for variant in data_enum.variants.iter().filter(|variant| { + // Skip deserialization for `#[config(skip)]` fields. + variant.attrs.iter().all(|attr| { + !crate::path_ends_with(&attr.path, "config") || attr.tokens.to_string() != "(skip)" + }) + }) { + let variant_ident = &variant.ident; + let variant_str = variant_ident.to_string(); + available_values = format!("{}`{}`, ", available_values, variant_str); + + let literal = variant_str.to_lowercase(); + + match_arms_stream.extend(quote! { + #literal => Ok(#ident :: #variant_ident), + }); + } + + // Remove trailing `, ` from the last enum variant. + available_values.truncate(available_values.len().saturating_sub(2)); + + // Generate deserialization impl. + let tokens = quote! { + struct #visitor; + impl<'de> serde::de::Visitor<'de> for #visitor { + type Value = #ident; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str(#available_values) + } + + fn visit_str<E>(self, s: &str) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + match s.to_lowercase().as_str() { + #match_arms_stream + _ => Err(E::custom( + &format!("unknown variant `{}`, expected {}", s, #available_values) + )), + } + } + } + + impl<'de> serde::Deserialize<'de> for #ident { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(#visitor) + } + } + }; + + tokens.into() +} diff --git a/alacritty_config_derive/src/de_struct.rs b/alacritty_config_derive/src/de_struct.rs new file mode 100644 index 00000000..1325cae3 --- /dev/null +++ b/alacritty_config_derive/src/de_struct.rs @@ -0,0 +1,226 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::parse::{self, Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{Error, Field, GenericParam, Generics, Ident, LitStr, Token, Type, TypeParam}; + +/// Error message when attempting to flatten multiple fields. +const MULTIPLE_FLATTEN_ERROR: &str = "At most one instance of #[config(flatten)] is supported"; +/// Use this crate's name as log target. +const LOG_TARGET: &str = env!("CARGO_PKG_NAME"); + +pub fn derive_deserialize<T>( + ident: Ident, + generics: Generics, + fields: Punctuated<Field, T>, +) -> TokenStream { + // Create all necessary tokens for the implementation. + let GenericsStreams { unconstrained, constrained, phantoms } = + generics_streams(generics.params); + let FieldStreams { flatten, match_assignments } = fields_deserializer(&fields); + let visitor = format_ident!("{}Visitor", ident); + + // Generate deserialization impl. + let tokens = quote! { + #[derive(Default)] + #[allow(non_snake_case)] + struct #visitor < #unconstrained > { + #phantoms + } + + impl<'de, #constrained> serde::de::Visitor<'de> for #visitor < #unconstrained > { + type Value = #ident < #unconstrained >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a mapping") + } + + fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error> + where + M: serde::de::MapAccess<'de>, + { + let mut config = Self::Value::default(); + + // NOTE: This could be used to print unused keys. + let mut unused = serde_yaml::Mapping::new(); + + while let Some((key, value)) = map.next_entry::<String, serde_yaml::Value>()? { + match key.as_str() { + #match_assignments + _ => { + unused.insert(serde_yaml::Value::String(key), value); + }, + } + } + + #flatten + + Ok(config) + } + } + + impl<'de, #constrained> serde::Deserialize<'de> for #ident < #unconstrained > { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(#visitor :: default()) + } + } + }; + + tokens.into() +} + +// Token streams created from the fields in the struct. +#[derive(Default)] +struct FieldStreams { + match_assignments: TokenStream2, + flatten: TokenStream2, +} + +/// Create the deserializers for match arms and flattened fields. +fn fields_deserializer<T>(fields: &Punctuated<Field, T>) -> FieldStreams { + let mut field_streams = FieldStreams::default(); + + // Create the deserialization stream for each field. + for field in fields.iter() { + if let Err(err) = field_deserializer(&mut field_streams, field) { + field_streams.flatten = err.to_compile_error(); + return field_streams; + } + } + + field_streams +} + +/// Append a single field deserializer to the stream. +fn field_deserializer(field_streams: &mut FieldStreams, field: &Field) -> Result<(), Error> { + let ident = field.ident.as_ref().expect("unreachable tuple struct"); + let literal = ident.to_string(); + let mut literals = vec![literal.clone()]; + + // Create default stream for deserializing fields. + let mut match_assignment_stream = quote! { + match serde::Deserialize::deserialize(value) { + Ok(value) => config.#ident = value, + Err(err) => { + log::error!(target: #LOG_TARGET, "Config error: {}: {}", #literal, err); + }, + } + }; + + // Iterate over all #[config(...)] attributes. + for attr in field.attrs.iter().filter(|attr| crate::path_ends_with(&attr.path, "config")) { + let parsed = match attr.parse_args::<Attr>() { + Ok(parsed) => parsed, + Err(_) => continue, + }; + + match parsed.ident.as_str() { + // Skip deserialization for `#[config(skip)]` fields. + "skip" => return Ok(()), + "flatten" => { + // NOTE: Currently only a single instance of flatten is supported per struct + // for complexity reasons. + if !field_streams.flatten.is_empty() { + return Err(Error::new(attr.span(), MULTIPLE_FLATTEN_ERROR)); + } + + // Create the tokens to deserialize the flattened struct from the unused fields. + field_streams.flatten.extend(quote! { + let unused = serde_yaml::Value::Mapping(unused); + config.#ident = serde::Deserialize::deserialize(unused).unwrap_or_default(); + }); + }, + "deprecated" => { + // Construct deprecation message and append optional attribute override. + let mut message = format!("Config warning: {} is deprecated", literal); + if let Some(warning) = parsed.param { + message = format!("{}; {}", message, warning.value()); + } + + // Append stream to log deprecation warning. + match_assignment_stream.extend(quote! { + log::warn!(target: #LOG_TARGET, #message); + }); + }, + // Add aliases to match pattern. + "alias" => { + if let Some(alias) = parsed.param { + literals.push(alias.value()); + } + }, + _ => (), + } + } + + // Create token stream for deserializing "none" string into `Option<T>`. + if let Type::Path(type_path) = &field.ty { + if crate::path_ends_with(&type_path.path, "Option") { + match_assignment_stream = quote! { + if value.as_str().map_or(false, |s| s.eq_ignore_ascii_case("none")) { + config.#ident = None; + continue; + } + #match_assignment_stream + }; + } + } + + // Create the token stream for deserialization and error handling. + field_streams.match_assignments.extend(quote! { + #(#literals)|* => { #match_assignment_stream }, + }); + + Ok(()) +} + +/// Field attribute. +struct Attr { + ident: String, + param: Option<LitStr>, +} + +impl Parse for Attr { + fn parse(input: ParseStream) -> parse::Result<Self> { + let ident = input.parse::<Ident>()?.to_string(); + let param = input.parse::<Token![=]>().and_then(|_| input.parse()).ok(); + Ok(Self { ident, param }) + } +} + +/// Storage for all necessary generics information. +#[derive(Default)] +struct GenericsStreams { + unconstrained: TokenStream2, + constrained: TokenStream2, + phantoms: TokenStream2, +} + +/// Create the necessary generics annotations. +/// +/// This will create three different token streams, which might look like this: +/// - unconstrained: `T` +/// - constrained: `T: Default + Deserialize<'de>` +/// - phantoms: `T: PhantomData<T>,` +fn generics_streams<T>(params: Punctuated<GenericParam, T>) -> GenericsStreams { + let mut generics = GenericsStreams::default(); + + for generic in params { + // NOTE: Lifetimes and const params are not supported. + if let GenericParam::Type(TypeParam { ident, .. }) = generic { + generics.unconstrained.extend(quote!( #ident , )); + generics.constrained.extend(quote! { + #ident : Default + serde::Deserialize<'de> , + }); + generics.phantoms.extend(quote! { + #ident : std::marker::PhantomData < #ident >, + }); + } + } + + generics +} diff --git a/alacritty_config_derive/src/lib.rs b/alacritty_config_derive/src/lib.rs new file mode 100644 index 00000000..8601d5cb --- /dev/null +++ b/alacritty_config_derive/src/lib.rs @@ -0,0 +1,27 @@ +use proc_macro::TokenStream; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Path}; + +mod de_enum; +mod de_struct; + +/// Error if the derive was used on an unsupported type. +const UNSUPPORTED_ERROR: &str = "ConfigDeserialize must be used on a struct with fields"; + +#[proc_macro_derive(ConfigDeserialize, attributes(config))] +pub fn derive_config_deserialize(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + match input.data { + Data::Struct(DataStruct { fields: Fields::Named(fields), .. }) => { + de_struct::derive_deserialize(input.ident, input.generics, fields.named) + }, + Data::Enum(data_enum) => de_enum::derive_deserialize(input.ident, data_enum), + _ => Error::new(input.ident.span(), UNSUPPORTED_ERROR).to_compile_error().into(), + } +} + +/// Verify that a token path ends with a specific segment. +pub(crate) fn path_ends_with(path: &Path, segment: &str) -> bool { + let segments = path.segments.iter(); + segments.last().map_or(false, |s| s.ident == segment) +} diff --git a/alacritty_config_derive/tests/config.rs b/alacritty_config_derive/tests/config.rs new file mode 100644 index 00000000..03abf893 --- /dev/null +++ b/alacritty_config_derive/tests/config.rs @@ -0,0 +1,155 @@ +use std::sync::{Arc, Mutex}; + +use log::{Level, Log, Metadata, Record}; + +use alacritty_config_derive::ConfigDeserialize; + +#[derive(ConfigDeserialize, Debug, PartialEq, Eq)] +enum TestEnum { + One, + Two, + Three, + #[config(skip)] + Nine(String), +} + +impl Default for TestEnum { + fn default() -> Self { + Self::Nine(String::from("nine")) + } +} + +#[derive(ConfigDeserialize)] +struct Test { + #[config(alias = "noalias")] + #[config(deprecated = "use field2 instead")] + field1: usize, + #[config(deprecated = "shouldn't be hit")] + field2: String, + field3: Option<u8>, + #[doc("aaa")] + nesting: Test2<usize>, + #[config(flatten)] + flatten: Test3, + enom_small: TestEnum, + enom_big: TestEnum, + #[config(deprecated)] + enom_error: TestEnum, +} + +impl Default for Test { + fn default() -> Self { + Self { + field1: 13, + field2: String::from("field2"), + field3: Some(23), + nesting: Test2::default(), + flatten: Test3::default(), + enom_small: TestEnum::default(), + enom_big: TestEnum::default(), + enom_error: TestEnum::default(), + } + } +} + +#[derive(ConfigDeserialize, Default)] +struct Test2<T: Default> { + field1: T, + field2: Option<usize>, + #[config(skip)] + field3: usize, + #[config(alias = "aliased")] + field4: u8, +} + +#[derive(ConfigDeserialize, Default)] +struct Test3 { + flatty: usize, +} + +#[test] +fn config_deserialize() { + let logger = unsafe { + LOGGER = Some(Logger::default()); + LOGGER.as_mut().unwrap() + }; + + log::set_logger(logger).unwrap(); + log::set_max_level(log::LevelFilter::Warn); + + let test: Test = serde_yaml::from_str( + r#" + field1: 3 + field3: 32 + nesting: + field1: "testing" + field2: None + field3: 99 + aliased: 8 + flatty: 123 + enom_small: "one" + enom_big: "THREE" + enom_error: "HugaBuga" + "#, + ) + .unwrap(); + + // Verify fields were deserialized correctly. + assert_eq!(test.field1, 3); + assert_eq!(test.field2, Test::default().field2); + assert_eq!(test.field3, Some(32)); + assert_eq!(test.enom_small, TestEnum::One); + assert_eq!(test.enom_big, TestEnum::Three); + assert_eq!(test.enom_error, Test::default().enom_error); + assert_eq!(test.nesting.field1, Test::default().nesting.field1); + assert_eq!(test.nesting.field2, None); + assert_eq!(test.nesting.field3, Test::default().nesting.field3); + assert_eq!(test.nesting.field4, 8); + assert_eq!(test.flatten.flatty, 123); + + // Verify all log messages are correct. + let error_logs = logger.error_logs.lock().unwrap(); + assert_eq!(error_logs.as_slice(), [ + "Config error: field1: invalid type: string \"testing\", expected usize", + "Config error: enom_error: unknown variant `HugaBuga`, expected one of `One`, `Two`, \ + `Three`", + ]); + let warn_logs = logger.warn_logs.lock().unwrap(); + assert_eq!(warn_logs.as_slice(), [ + "Config warning: field1 is deprecated; use field2 instead", + "Config warning: enom_error is deprecated", + ]); +} + +static mut LOGGER: Option<Logger> = None; + +/// Logger storing all messages for later validation. +#[derive(Default)] +struct Logger { + error_logs: Arc<Mutex<Vec<String>>>, + warn_logs: Arc<Mutex<Vec<String>>>, +} + +impl Log for Logger { + fn log(&self, record: &Record) { + assert_eq!(record.target(), env!("CARGO_PKG_NAME")); + + match record.level() { + Level::Error => { + let mut error_logs = self.error_logs.lock().unwrap(); + error_logs.push(record.args().to_string()); + }, + Level::Warn => { + let mut warn_logs = self.warn_logs.lock().unwrap(); + warn_logs.push(record.args().to_string()); + }, + _ => unreachable!(), + } + } + + fn enabled(&self, _metadata: &Metadata) -> bool { + true + } + + fn flush(&self) {} +} |