diff options
author | Christian Duerr <contact@christianduerr.com> | 2020-12-21 02:44:38 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-21 02:44:38 +0000 |
commit | 6e1b9d8b2502f5b47dc28eb5e0853e46ad8b4e84 (patch) | |
tree | 623a6cd8785529b28cc28af201c26b56fb47ac46 /alacritty_config_derive | |
parent | 37a3198d8882463c9873011c1d18c325ea46d7c8 (diff) | |
download | alacritty-6e1b9d8b2502f5b47dc28eb5e0853e46ad8b4e84.tar.gz alacritty-6e1b9d8b2502f5b47dc28eb5e0853e46ad8b4e84.zip |
Replace serde's derive with custom proc macro
This replaces the existing `Deserialize` derive from serde with a
`ConfigDeserialize` derive. The goal of this new proc macro is to allow
a more error-friendly deserialization for the Alacritty configuration
file without having to manage a lot of boilerplate code inside the
configuration modules.
The first part of the derive macro is for struct deserialization. This
takes structs which have `Default` implemented and will only replace
fields which can be successfully deserialized. Otherwise the `log` crate
is used for printing errors. Since this deserialization takes the
default value from the struct instead of the value, it removes the
necessity for creating new types just to implement `Default` on them for
deserialization.
Additionally, the struct deserialization also checks for `Option` values
and makes sure that explicitly specifying `none` as text literal is
allowed for all options.
The other part of the derive macro is responsible for deserializing
enums. While only enums with Unit variants are supported, it will
automatically implement a deserializer for these enums which accepts any
form of capitalization.
Since this custom derive prevents us from using serde's attributes on
fields, some of the attributes have been reimplemented for
`ConfigDeserialize`. These include `#[config(flatten)]`,
`#[config(skip)]` and `#[config(alias = "alias)]`. The flatten attribute
is currently limited to at most one per struct.
Additionally the `#[config(deprecated = "optional message")]` attribute
allows easily defining uniform deprecation messages for fields on
structs.
Diffstat (limited to 'alacritty_config_derive')
-rw-r--r-- | alacritty_config_derive/Cargo.toml | 21 | ||||
l--------- | alacritty_config_derive/LICENSE-APACHE | 1 | ||||
-rw-r--r-- | alacritty_config_derive/LICENSE-MIT | 23 | ||||
-rw-r--r-- | alacritty_config_derive/src/de_enum.rs | 66 | ||||
-rw-r--r-- | alacritty_config_derive/src/de_struct.rs | 226 | ||||
-rw-r--r-- | alacritty_config_derive/src/lib.rs | 27 | ||||
-rw-r--r-- | alacritty_config_derive/tests/config.rs | 155 |
7 files changed, 519 insertions, 0 deletions
diff --git a/alacritty_config_derive/Cargo.toml b/alacritty_config_derive/Cargo.toml new file mode 100644 index 00000000..8d6d9c4b --- /dev/null +++ b/alacritty_config_derive/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "alacritty_config_derive" +version = "0.1.0" +authors = ["Christian Duerr <contact@christianduerr.com>"] +license = "MIT/Apache-2.0" +description = "Failure resistant deserialization derive" +homepage = "https://github.com/alacritty/alacritty" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0.53", features = ["derive", "parsing", "proc-macro", "printing"], default-features = false } +proc-macro2 = "1.0.24" +quote = "1.0.7" + +[dev-dependencies] +serde_yaml = "0.8.14" +serde = "1.0.117" +log = "0.4.11" diff --git a/alacritty_config_derive/LICENSE-APACHE b/alacritty_config_derive/LICENSE-APACHE new file mode 120000 index 00000000..965b606f --- /dev/null +++ b/alacritty_config_derive/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE
\ No newline at end of file diff --git a/alacritty_config_derive/LICENSE-MIT b/alacritty_config_derive/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/alacritty_config_derive/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/alacritty_config_derive/src/de_enum.rs b/alacritty_config_derive/src/de_enum.rs new file mode 100644 index 00000000..98247c0c --- /dev/null +++ b/alacritty_config_derive/src/de_enum.rs @@ -0,0 +1,66 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{DataEnum, Ident}; + +pub fn derive_deserialize(ident: Ident, data_enum: DataEnum) -> TokenStream { + let visitor = format_ident!("{}Visitor", ident); + + // Create match arm streams and get a list with all available values. + let mut match_arms_stream = TokenStream2::new(); + let mut available_values = String::from("one of "); + for variant in data_enum.variants.iter().filter(|variant| { + // Skip deserialization for `#[config(skip)]` fields. + variant.attrs.iter().all(|attr| { + !crate::path_ends_with(&attr.path, "config") || attr.tokens.to_string() != "(skip)" + }) + }) { + let variant_ident = &variant.ident; + let variant_str = variant_ident.to_string(); + available_values = format!("{}`{}`, ", available_values, variant_str); + + let literal = variant_str.to_lowercase(); + + match_arms_stream.extend(quote! { + #literal => Ok(#ident :: #variant_ident), + }); + } + + // Remove trailing `, ` from the last enum variant. + available_values.truncate(available_values.len().saturating_sub(2)); + + // Generate deserialization impl. + let tokens = quote! { + struct #visitor; + impl<'de> serde::de::Visitor<'de> for #visitor { + type Value = #ident; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str(#available_values) + } + + fn visit_str<E>(self, s: &str) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + match s.to_lowercase().as_str() { + #match_arms_stream + _ => Err(E::custom( + &format!("unknown variant `{}`, expected {}", s, #available_values) + )), + } + } + } + + impl<'de> serde::Deserialize<'de> for #ident { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(#visitor) + } + } + }; + + tokens.into() +} diff --git a/alacritty_config_derive/src/de_struct.rs b/alacritty_config_derive/src/de_struct.rs new file mode 100644 index 00000000..1325cae3 --- /dev/null +++ b/alacritty_config_derive/src/de_struct.rs @@ -0,0 +1,226 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::parse::{self, Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{Error, Field, GenericParam, Generics, Ident, LitStr, Token, Type, TypeParam}; + +/// Error message when attempting to flatten multiple fields. +const MULTIPLE_FLATTEN_ERROR: &str = "At most one instance of #[config(flatten)] is supported"; +/// Use this crate's name as log target. +const LOG_TARGET: &str = env!("CARGO_PKG_NAME"); + +pub fn derive_deserialize<T>( + ident: Ident, + generics: Generics, + fields: Punctuated<Field, T>, +) -> TokenStream { + // Create all necessary tokens for the implementation. + let GenericsStreams { unconstrained, constrained, phantoms } = + generics_streams(generics.params); + let FieldStreams { flatten, match_assignments } = fields_deserializer(&fields); + let visitor = format_ident!("{}Visitor", ident); + + // Generate deserialization impl. + let tokens = quote! { + #[derive(Default)] + #[allow(non_snake_case)] + struct #visitor < #unconstrained > { + #phantoms + } + + impl<'de, #constrained> serde::de::Visitor<'de> for #visitor < #unconstrained > { + type Value = #ident < #unconstrained >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a mapping") + } + + fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error> + where + M: serde::de::MapAccess<'de>, + { + let mut config = Self::Value::default(); + + // NOTE: This could be used to print unused keys. + let mut unused = serde_yaml::Mapping::new(); + + while let Some((key, value)) = map.next_entry::<String, serde_yaml::Value>()? { + match key.as_str() { + #match_assignments + _ => { + unused.insert(serde_yaml::Value::String(key), value); + }, + } + } + + #flatten + + Ok(config) + } + } + + impl<'de, #constrained> serde::Deserialize<'de> for #ident < #unconstrained > { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(#visitor :: default()) + } + } + }; + + tokens.into() +} + +// Token streams created from the fields in the struct. +#[derive(Default)] +struct FieldStreams { + match_assignments: TokenStream2, + flatten: TokenStream2, +} + +/// Create the deserializers for match arms and flattened fields. +fn fields_deserializer<T>(fields: &Punctuated<Field, T>) -> FieldStreams { + let mut field_streams = FieldStreams::default(); + + // Create the deserialization stream for each field. + for field in fields.iter() { + if let Err(err) = field_deserializer(&mut field_streams, field) { + field_streams.flatten = err.to_compile_error(); + return field_streams; + } + } + + field_streams +} + +/// Append a single field deserializer to the stream. +fn field_deserializer(field_streams: &mut FieldStreams, field: &Field) -> Result<(), Error> { + let ident = field.ident.as_ref().expect("unreachable tuple struct"); + let literal = ident.to_string(); + let mut literals = vec![literal.clone()]; + + // Create default stream for deserializing fields. + let mut match_assignment_stream = quote! { + match serde::Deserialize::deserialize(value) { + Ok(value) => config.#ident = value, + Err(err) => { + log::error!(target: #LOG_TARGET, "Config error: {}: {}", #literal, err); + }, + } + }; + + // Iterate over all #[config(...)] attributes. + for attr in field.attrs.iter().filter(|attr| crate::path_ends_with(&attr.path, "config")) { + let parsed = match attr.parse_args::<Attr>() { + Ok(parsed) => parsed, + Err(_) => continue, + }; + + match parsed.ident.as_str() { + // Skip deserialization for `#[config(skip)]` fields. + "skip" => return Ok(()), + "flatten" => { + // NOTE: Currently only a single instance of flatten is supported per struct + // for complexity reasons. + if !field_streams.flatten.is_empty() { + return Err(Error::new(attr.span(), MULTIPLE_FLATTEN_ERROR)); + } + + // Create the tokens to deserialize the flattened struct from the unused fields. + field_streams.flatten.extend(quote! { + let unused = serde_yaml::Value::Mapping(unused); + config.#ident = serde::Deserialize::deserialize(unused).unwrap_or_default(); + }); + }, + "deprecated" => { + // Construct deprecation message and append optional attribute override. + let mut message = format!("Config warning: {} is deprecated", literal); + if let Some(warning) = parsed.param { + message = format!("{}; {}", message, warning.value()); + } + + // Append stream to log deprecation warning. + match_assignment_stream.extend(quote! { + log::warn!(target: #LOG_TARGET, #message); + }); + }, + // Add aliases to match pattern. + "alias" => { + if let Some(alias) = parsed.param { + literals.push(alias.value()); + } + }, + _ => (), + } + } + + // Create token stream for deserializing "none" string into `Option<T>`. + if let Type::Path(type_path) = &field.ty { + if crate::path_ends_with(&type_path.path, "Option") { + match_assignment_stream = quote! { + if value.as_str().map_or(false, |s| s.eq_ignore_ascii_case("none")) { + config.#ident = None; + continue; + } + #match_assignment_stream + }; + } + } + + // Create the token stream for deserialization and error handling. + field_streams.match_assignments.extend(quote! { + #(#literals)|* => { #match_assignment_stream }, + }); + + Ok(()) +} + +/// Field attribute. +struct Attr { + ident: String, + param: Option<LitStr>, +} + +impl Parse for Attr { + fn parse(input: ParseStream) -> parse::Result<Self> { + let ident = input.parse::<Ident>()?.to_string(); + let param = input.parse::<Token![=]>().and_then(|_| input.parse()).ok(); + Ok(Self { ident, param }) + } +} + +/// Storage for all necessary generics information. +#[derive(Default)] +struct GenericsStreams { + unconstrained: TokenStream2, + constrained: TokenStream2, + phantoms: TokenStream2, +} + +/// Create the necessary generics annotations. +/// +/// This will create three different token streams, which might look like this: +/// - unconstrained: `T` +/// - constrained: `T: Default + Deserialize<'de>` +/// - phantoms: `T: PhantomData<T>,` +fn generics_streams<T>(params: Punctuated<GenericParam, T>) -> GenericsStreams { + let mut generics = GenericsStreams::default(); + + for generic in params { + // NOTE: Lifetimes and const params are not supported. + if let GenericParam::Type(TypeParam { ident, .. }) = generic { + generics.unconstrained.extend(quote!( #ident , )); + generics.constrained.extend(quote! { + #ident : Default + serde::Deserialize<'de> , + }); + generics.phantoms.extend(quote! { + #ident : std::marker::PhantomData < #ident >, + }); + } + } + + generics +} diff --git a/alacritty_config_derive/src/lib.rs b/alacritty_config_derive/src/lib.rs new file mode 100644 index 00000000..8601d5cb --- /dev/null +++ b/alacritty_config_derive/src/lib.rs @@ -0,0 +1,27 @@ +use proc_macro::TokenStream; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Path}; + +mod de_enum; +mod de_struct; + +/// Error if the derive was used on an unsupported type. +const UNSUPPORTED_ERROR: &str = "ConfigDeserialize must be used on a struct with fields"; + +#[proc_macro_derive(ConfigDeserialize, attributes(config))] +pub fn derive_config_deserialize(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + match input.data { + Data::Struct(DataStruct { fields: Fields::Named(fields), .. }) => { + de_struct::derive_deserialize(input.ident, input.generics, fields.named) + }, + Data::Enum(data_enum) => de_enum::derive_deserialize(input.ident, data_enum), + _ => Error::new(input.ident.span(), UNSUPPORTED_ERROR).to_compile_error().into(), + } +} + +/// Verify that a token path ends with a specific segment. +pub(crate) fn path_ends_with(path: &Path, segment: &str) -> bool { + let segments = path.segments.iter(); + segments.last().map_or(false, |s| s.ident == segment) +} diff --git a/alacritty_config_derive/tests/config.rs b/alacritty_config_derive/tests/config.rs new file mode 100644 index 00000000..03abf893 --- /dev/null +++ b/alacritty_config_derive/tests/config.rs @@ -0,0 +1,155 @@ +use std::sync::{Arc, Mutex}; + +use log::{Level, Log, Metadata, Record}; + +use alacritty_config_derive::ConfigDeserialize; + +#[derive(ConfigDeserialize, Debug, PartialEq, Eq)] +enum TestEnum { + One, + Two, + Three, + #[config(skip)] + Nine(String), +} + +impl Default for TestEnum { + fn default() -> Self { + Self::Nine(String::from("nine")) + } +} + +#[derive(ConfigDeserialize)] +struct Test { + #[config(alias = "noalias")] + #[config(deprecated = "use field2 instead")] + field1: usize, + #[config(deprecated = "shouldn't be hit")] + field2: String, + field3: Option<u8>, + #[doc("aaa")] + nesting: Test2<usize>, + #[config(flatten)] + flatten: Test3, + enom_small: TestEnum, + enom_big: TestEnum, + #[config(deprecated)] + enom_error: TestEnum, +} + +impl Default for Test { + fn default() -> Self { + Self { + field1: 13, + field2: String::from("field2"), + field3: Some(23), + nesting: Test2::default(), + flatten: Test3::default(), + enom_small: TestEnum::default(), + enom_big: TestEnum::default(), + enom_error: TestEnum::default(), + } + } +} + +#[derive(ConfigDeserialize, Default)] +struct Test2<T: Default> { + field1: T, + field2: Option<usize>, + #[config(skip)] + field3: usize, + #[config(alias = "aliased")] + field4: u8, +} + +#[derive(ConfigDeserialize, Default)] +struct Test3 { + flatty: usize, +} + +#[test] +fn config_deserialize() { + let logger = unsafe { + LOGGER = Some(Logger::default()); + LOGGER.as_mut().unwrap() + }; + + log::set_logger(logger).unwrap(); + log::set_max_level(log::LevelFilter::Warn); + + let test: Test = serde_yaml::from_str( + r#" + field1: 3 + field3: 32 + nesting: + field1: "testing" + field2: None + field3: 99 + aliased: 8 + flatty: 123 + enom_small: "one" + enom_big: "THREE" + enom_error: "HugaBuga" + "#, + ) + .unwrap(); + + // Verify fields were deserialized correctly. + assert_eq!(test.field1, 3); + assert_eq!(test.field2, Test::default().field2); + assert_eq!(test.field3, Some(32)); + assert_eq!(test.enom_small, TestEnum::One); + assert_eq!(test.enom_big, TestEnum::Three); + assert_eq!(test.enom_error, Test::default().enom_error); + assert_eq!(test.nesting.field1, Test::default().nesting.field1); + assert_eq!(test.nesting.field2, None); + assert_eq!(test.nesting.field3, Test::default().nesting.field3); + assert_eq!(test.nesting.field4, 8); + assert_eq!(test.flatten.flatty, 123); + + // Verify all log messages are correct. + let error_logs = logger.error_logs.lock().unwrap(); + assert_eq!(error_logs.as_slice(), [ + "Config error: field1: invalid type: string \"testing\", expected usize", + "Config error: enom_error: unknown variant `HugaBuga`, expected one of `One`, `Two`, \ + `Three`", + ]); + let warn_logs = logger.warn_logs.lock().unwrap(); + assert_eq!(warn_logs.as_slice(), [ + "Config warning: field1 is deprecated; use field2 instead", + "Config warning: enom_error is deprecated", + ]); +} + +static mut LOGGER: Option<Logger> = None; + +/// Logger storing all messages for later validation. +#[derive(Default)] +struct Logger { + error_logs: Arc<Mutex<Vec<String>>>, + warn_logs: Arc<Mutex<Vec<String>>>, +} + +impl Log for Logger { + fn log(&self, record: &Record) { + assert_eq!(record.target(), env!("CARGO_PKG_NAME")); + + match record.level() { + Level::Error => { + let mut error_logs = self.error_logs.lock().unwrap(); + error_logs.push(record.args().to_string()); + }, + Level::Warn => { + let mut warn_logs = self.warn_logs.lock().unwrap(); + warn_logs.push(record.args().to_string()); + }, + _ => unreachable!(), + } + } + + fn enabled(&self, _metadata: &Metadata) -> bool { + true + } + + fn flush(&self) {} +} |