summaryrefslogtreecommitdiff
path: root/alacritty_config_derive
diff options
context:
space:
mode:
Diffstat (limited to 'alacritty_config_derive')
-rw-r--r--alacritty_config_derive/Cargo.toml21
l---------alacritty_config_derive/LICENSE-APACHE1
-rw-r--r--alacritty_config_derive/LICENSE-MIT23
-rw-r--r--alacritty_config_derive/src/de_enum.rs66
-rw-r--r--alacritty_config_derive/src/de_struct.rs226
-rw-r--r--alacritty_config_derive/src/lib.rs27
-rw-r--r--alacritty_config_derive/tests/config.rs155
7 files changed, 519 insertions, 0 deletions
diff --git a/alacritty_config_derive/Cargo.toml b/alacritty_config_derive/Cargo.toml
new file mode 100644
index 00000000..8d6d9c4b
--- /dev/null
+++ b/alacritty_config_derive/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "alacritty_config_derive"
+version = "0.1.0"
+authors = ["Christian Duerr <contact@christianduerr.com>"]
+license = "MIT/Apache-2.0"
+description = "Failure resistant deserialization derive"
+homepage = "https://github.com/alacritty/alacritty"
+edition = "2018"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+syn = { version = "1.0.53", features = ["derive", "parsing", "proc-macro", "printing"], default-features = false }
+proc-macro2 = "1.0.24"
+quote = "1.0.7"
+
+[dev-dependencies]
+serde_yaml = "0.8.14"
+serde = "1.0.117"
+log = "0.4.11"
diff --git a/alacritty_config_derive/LICENSE-APACHE b/alacritty_config_derive/LICENSE-APACHE
new file mode 120000
index 00000000..965b606f
--- /dev/null
+++ b/alacritty_config_derive/LICENSE-APACHE
@@ -0,0 +1 @@
+../LICENSE-APACHE \ No newline at end of file
diff --git a/alacritty_config_derive/LICENSE-MIT b/alacritty_config_derive/LICENSE-MIT
new file mode 100644
index 00000000..31aa7938
--- /dev/null
+++ b/alacritty_config_derive/LICENSE-MIT
@@ -0,0 +1,23 @@
+Permission is hereby granted, free of charge, to any
+person obtaining a copy of this software and associated
+documentation files (the "Software"), to deal in the
+Software without restriction, including without
+limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software
+is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice
+shall be included in all copies or substantial portions
+of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
diff --git a/alacritty_config_derive/src/de_enum.rs b/alacritty_config_derive/src/de_enum.rs
new file mode 100644
index 00000000..98247c0c
--- /dev/null
+++ b/alacritty_config_derive/src/de_enum.rs
@@ -0,0 +1,66 @@
+use proc_macro::TokenStream;
+use proc_macro2::TokenStream as TokenStream2;
+use quote::{format_ident, quote};
+use syn::{DataEnum, Ident};
+
+pub fn derive_deserialize(ident: Ident, data_enum: DataEnum) -> TokenStream {
+ let visitor = format_ident!("{}Visitor", ident);
+
+ // Create match arm streams and get a list with all available values.
+ let mut match_arms_stream = TokenStream2::new();
+ let mut available_values = String::from("one of ");
+ for variant in data_enum.variants.iter().filter(|variant| {
+ // Skip deserialization for `#[config(skip)]` fields.
+ variant.attrs.iter().all(|attr| {
+ !crate::path_ends_with(&attr.path, "config") || attr.tokens.to_string() != "(skip)"
+ })
+ }) {
+ let variant_ident = &variant.ident;
+ let variant_str = variant_ident.to_string();
+ available_values = format!("{}`{}`, ", available_values, variant_str);
+
+ let literal = variant_str.to_lowercase();
+
+ match_arms_stream.extend(quote! {
+ #literal => Ok(#ident :: #variant_ident),
+ });
+ }
+
+ // Remove trailing `, ` from the last enum variant.
+ available_values.truncate(available_values.len().saturating_sub(2));
+
+ // Generate deserialization impl.
+ let tokens = quote! {
+ struct #visitor;
+ impl<'de> serde::de::Visitor<'de> for #visitor {
+ type Value = #ident;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+ formatter.write_str(#available_values)
+ }
+
+ fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ match s.to_lowercase().as_str() {
+ #match_arms_stream
+ _ => Err(E::custom(
+ &format!("unknown variant `{}`, expected {}", s, #available_values)
+ )),
+ }
+ }
+ }
+
+ impl<'de> serde::Deserialize<'de> for #ident {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ deserializer.deserialize_str(#visitor)
+ }
+ }
+ };
+
+ tokens.into()
+}
diff --git a/alacritty_config_derive/src/de_struct.rs b/alacritty_config_derive/src/de_struct.rs
new file mode 100644
index 00000000..1325cae3
--- /dev/null
+++ b/alacritty_config_derive/src/de_struct.rs
@@ -0,0 +1,226 @@
+use proc_macro::TokenStream;
+use proc_macro2::TokenStream as TokenStream2;
+use quote::{format_ident, quote};
+use syn::parse::{self, Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::spanned::Spanned;
+use syn::{Error, Field, GenericParam, Generics, Ident, LitStr, Token, Type, TypeParam};
+
+/// Error message when attempting to flatten multiple fields.
+const MULTIPLE_FLATTEN_ERROR: &str = "At most one instance of #[config(flatten)] is supported";
+/// Use this crate's name as log target.
+const LOG_TARGET: &str = env!("CARGO_PKG_NAME");
+
+pub fn derive_deserialize<T>(
+ ident: Ident,
+ generics: Generics,
+ fields: Punctuated<Field, T>,
+) -> TokenStream {
+ // Create all necessary tokens for the implementation.
+ let GenericsStreams { unconstrained, constrained, phantoms } =
+ generics_streams(generics.params);
+ let FieldStreams { flatten, match_assignments } = fields_deserializer(&fields);
+ let visitor = format_ident!("{}Visitor", ident);
+
+ // Generate deserialization impl.
+ let tokens = quote! {
+ #[derive(Default)]
+ #[allow(non_snake_case)]
+ struct #visitor < #unconstrained > {
+ #phantoms
+ }
+
+ impl<'de, #constrained> serde::de::Visitor<'de> for #visitor < #unconstrained > {
+ type Value = #ident < #unconstrained >;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+ formatter.write_str("a mapping")
+ }
+
+ fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
+ where
+ M: serde::de::MapAccess<'de>,
+ {
+ let mut config = Self::Value::default();
+
+ // NOTE: This could be used to print unused keys.
+ let mut unused = serde_yaml::Mapping::new();
+
+ while let Some((key, value)) = map.next_entry::<String, serde_yaml::Value>()? {
+ match key.as_str() {
+ #match_assignments
+ _ => {
+ unused.insert(serde_yaml::Value::String(key), value);
+ },
+ }
+ }
+
+ #flatten
+
+ Ok(config)
+ }
+ }
+
+ impl<'de, #constrained> serde::Deserialize<'de> for #ident < #unconstrained > {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ deserializer.deserialize_map(#visitor :: default())
+ }
+ }
+ };
+
+ tokens.into()
+}
+
+// Token streams created from the fields in the struct.
+#[derive(Default)]
+struct FieldStreams {
+ match_assignments: TokenStream2,
+ flatten: TokenStream2,
+}
+
+/// Create the deserializers for match arms and flattened fields.
+fn fields_deserializer<T>(fields: &Punctuated<Field, T>) -> FieldStreams {
+ let mut field_streams = FieldStreams::default();
+
+ // Create the deserialization stream for each field.
+ for field in fields.iter() {
+ if let Err(err) = field_deserializer(&mut field_streams, field) {
+ field_streams.flatten = err.to_compile_error();
+ return field_streams;
+ }
+ }
+
+ field_streams
+}
+
+/// Append a single field deserializer to the stream.
+fn field_deserializer(field_streams: &mut FieldStreams, field: &Field) -> Result<(), Error> {
+ let ident = field.ident.as_ref().expect("unreachable tuple struct");
+ let literal = ident.to_string();
+ let mut literals = vec![literal.clone()];
+
+ // Create default stream for deserializing fields.
+ let mut match_assignment_stream = quote! {
+ match serde::Deserialize::deserialize(value) {
+ Ok(value) => config.#ident = value,
+ Err(err) => {
+ log::error!(target: #LOG_TARGET, "Config error: {}: {}", #literal, err);
+ },
+ }
+ };
+
+ // Iterate over all #[config(...)] attributes.
+ for attr in field.attrs.iter().filter(|attr| crate::path_ends_with(&attr.path, "config")) {
+ let parsed = match attr.parse_args::<Attr>() {
+ Ok(parsed) => parsed,
+ Err(_) => continue,
+ };
+
+ match parsed.ident.as_str() {
+ // Skip deserialization for `#[config(skip)]` fields.
+ "skip" => return Ok(()),
+ "flatten" => {
+ // NOTE: Currently only a single instance of flatten is supported per struct
+ // for complexity reasons.
+ if !field_streams.flatten.is_empty() {
+ return Err(Error::new(attr.span(), MULTIPLE_FLATTEN_ERROR));
+ }
+
+ // Create the tokens to deserialize the flattened struct from the unused fields.
+ field_streams.flatten.extend(quote! {
+ let unused = serde_yaml::Value::Mapping(unused);
+ config.#ident = serde::Deserialize::deserialize(unused).unwrap_or_default();
+ });
+ },
+ "deprecated" => {
+ // Construct deprecation message and append optional attribute override.
+ let mut message = format!("Config warning: {} is deprecated", literal);
+ if let Some(warning) = parsed.param {
+ message = format!("{}; {}", message, warning.value());
+ }
+
+ // Append stream to log deprecation warning.
+ match_assignment_stream.extend(quote! {
+ log::warn!(target: #LOG_TARGET, #message);
+ });
+ },
+ // Add aliases to match pattern.
+ "alias" => {
+ if let Some(alias) = parsed.param {
+ literals.push(alias.value());
+ }
+ },
+ _ => (),
+ }
+ }
+
+ // Create token stream for deserializing "none" string into `Option<T>`.
+ if let Type::Path(type_path) = &field.ty {
+ if crate::path_ends_with(&type_path.path, "Option") {
+ match_assignment_stream = quote! {
+ if value.as_str().map_or(false, |s| s.eq_ignore_ascii_case("none")) {
+ config.#ident = None;
+ continue;
+ }
+ #match_assignment_stream
+ };
+ }
+ }
+
+ // Create the token stream for deserialization and error handling.
+ field_streams.match_assignments.extend(quote! {
+ #(#literals)|* => { #match_assignment_stream },
+ });
+
+ Ok(())
+}
+
+/// Field attribute.
+struct Attr {
+ ident: String,
+ param: Option<LitStr>,
+}
+
+impl Parse for Attr {
+ fn parse(input: ParseStream) -> parse::Result<Self> {
+ let ident = input.parse::<Ident>()?.to_string();
+ let param = input.parse::<Token![=]>().and_then(|_| input.parse()).ok();
+ Ok(Self { ident, param })
+ }
+}
+
+/// Storage for all necessary generics information.
+#[derive(Default)]
+struct GenericsStreams {
+ unconstrained: TokenStream2,
+ constrained: TokenStream2,
+ phantoms: TokenStream2,
+}
+
+/// Create the necessary generics annotations.
+///
+/// This will create three different token streams, which might look like this:
+/// - unconstrained: `T`
+/// - constrained: `T: Default + Deserialize<'de>`
+/// - phantoms: `T: PhantomData<T>,`
+fn generics_streams<T>(params: Punctuated<GenericParam, T>) -> GenericsStreams {
+ let mut generics = GenericsStreams::default();
+
+ for generic in params {
+ // NOTE: Lifetimes and const params are not supported.
+ if let GenericParam::Type(TypeParam { ident, .. }) = generic {
+ generics.unconstrained.extend(quote!( #ident , ));
+ generics.constrained.extend(quote! {
+ #ident : Default + serde::Deserialize<'de> ,
+ });
+ generics.phantoms.extend(quote! {
+ #ident : std::marker::PhantomData < #ident >,
+ });
+ }
+ }
+
+ generics
+}
diff --git a/alacritty_config_derive/src/lib.rs b/alacritty_config_derive/src/lib.rs
new file mode 100644
index 00000000..8601d5cb
--- /dev/null
+++ b/alacritty_config_derive/src/lib.rs
@@ -0,0 +1,27 @@
+use proc_macro::TokenStream;
+use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Path};
+
+mod de_enum;
+mod de_struct;
+
+/// Error if the derive was used on an unsupported type.
+const UNSUPPORTED_ERROR: &str = "ConfigDeserialize must be used on a struct with fields";
+
+#[proc_macro_derive(ConfigDeserialize, attributes(config))]
+pub fn derive_config_deserialize(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input as DeriveInput);
+
+ match input.data {
+ Data::Struct(DataStruct { fields: Fields::Named(fields), .. }) => {
+ de_struct::derive_deserialize(input.ident, input.generics, fields.named)
+ },
+ Data::Enum(data_enum) => de_enum::derive_deserialize(input.ident, data_enum),
+ _ => Error::new(input.ident.span(), UNSUPPORTED_ERROR).to_compile_error().into(),
+ }
+}
+
+/// Verify that a token path ends with a specific segment.
+pub(crate) fn path_ends_with(path: &Path, segment: &str) -> bool {
+ let segments = path.segments.iter();
+ segments.last().map_or(false, |s| s.ident == segment)
+}
diff --git a/alacritty_config_derive/tests/config.rs b/alacritty_config_derive/tests/config.rs
new file mode 100644
index 00000000..03abf893
--- /dev/null
+++ b/alacritty_config_derive/tests/config.rs
@@ -0,0 +1,155 @@
+use std::sync::{Arc, Mutex};
+
+use log::{Level, Log, Metadata, Record};
+
+use alacritty_config_derive::ConfigDeserialize;
+
+#[derive(ConfigDeserialize, Debug, PartialEq, Eq)]
+enum TestEnum {
+ One,
+ Two,
+ Three,
+ #[config(skip)]
+ Nine(String),
+}
+
+impl Default for TestEnum {
+ fn default() -> Self {
+ Self::Nine(String::from("nine"))
+ }
+}
+
+#[derive(ConfigDeserialize)]
+struct Test {
+ #[config(alias = "noalias")]
+ #[config(deprecated = "use field2 instead")]
+ field1: usize,
+ #[config(deprecated = "shouldn't be hit")]
+ field2: String,
+ field3: Option<u8>,
+ #[doc("aaa")]
+ nesting: Test2<usize>,
+ #[config(flatten)]
+ flatten: Test3,
+ enom_small: TestEnum,
+ enom_big: TestEnum,
+ #[config(deprecated)]
+ enom_error: TestEnum,
+}
+
+impl Default for Test {
+ fn default() -> Self {
+ Self {
+ field1: 13,
+ field2: String::from("field2"),
+ field3: Some(23),
+ nesting: Test2::default(),
+ flatten: Test3::default(),
+ enom_small: TestEnum::default(),
+ enom_big: TestEnum::default(),
+ enom_error: TestEnum::default(),
+ }
+ }
+}
+
+#[derive(ConfigDeserialize, Default)]
+struct Test2<T: Default> {
+ field1: T,
+ field2: Option<usize>,
+ #[config(skip)]
+ field3: usize,
+ #[config(alias = "aliased")]
+ field4: u8,
+}
+
+#[derive(ConfigDeserialize, Default)]
+struct Test3 {
+ flatty: usize,
+}
+
+#[test]
+fn config_deserialize() {
+ let logger = unsafe {
+ LOGGER = Some(Logger::default());
+ LOGGER.as_mut().unwrap()
+ };
+
+ log::set_logger(logger).unwrap();
+ log::set_max_level(log::LevelFilter::Warn);
+
+ let test: Test = serde_yaml::from_str(
+ r#"
+ field1: 3
+ field3: 32
+ nesting:
+ field1: "testing"
+ field2: None
+ field3: 99
+ aliased: 8
+ flatty: 123
+ enom_small: "one"
+ enom_big: "THREE"
+ enom_error: "HugaBuga"
+ "#,
+ )
+ .unwrap();
+
+ // Verify fields were deserialized correctly.
+ assert_eq!(test.field1, 3);
+ assert_eq!(test.field2, Test::default().field2);
+ assert_eq!(test.field3, Some(32));
+ assert_eq!(test.enom_small, TestEnum::One);
+ assert_eq!(test.enom_big, TestEnum::Three);
+ assert_eq!(test.enom_error, Test::default().enom_error);
+ assert_eq!(test.nesting.field1, Test::default().nesting.field1);
+ assert_eq!(test.nesting.field2, None);
+ assert_eq!(test.nesting.field3, Test::default().nesting.field3);
+ assert_eq!(test.nesting.field4, 8);
+ assert_eq!(test.flatten.flatty, 123);
+
+ // Verify all log messages are correct.
+ let error_logs = logger.error_logs.lock().unwrap();
+ assert_eq!(error_logs.as_slice(), [
+ "Config error: field1: invalid type: string \"testing\", expected usize",
+ "Config error: enom_error: unknown variant `HugaBuga`, expected one of `One`, `Two`, \
+ `Three`",
+ ]);
+ let warn_logs = logger.warn_logs.lock().unwrap();
+ assert_eq!(warn_logs.as_slice(), [
+ "Config warning: field1 is deprecated; use field2 instead",
+ "Config warning: enom_error is deprecated",
+ ]);
+}
+
+static mut LOGGER: Option<Logger> = None;
+
+/// Logger storing all messages for later validation.
+#[derive(Default)]
+struct Logger {
+ error_logs: Arc<Mutex<Vec<String>>>,
+ warn_logs: Arc<Mutex<Vec<String>>>,
+}
+
+impl Log for Logger {
+ fn log(&self, record: &Record) {
+ assert_eq!(record.target(), env!("CARGO_PKG_NAME"));
+
+ match record.level() {
+ Level::Error => {
+ let mut error_logs = self.error_logs.lock().unwrap();
+ error_logs.push(record.args().to_string());
+ },
+ Level::Warn => {
+ let mut warn_logs = self.warn_logs.lock().unwrap();
+ warn_logs.push(record.args().to_string());
+ },
+ _ => unreachable!(),
+ }
+ }
+
+ fn enabled(&self, _metadata: &Metadata) -> bool {
+ true
+ }
+
+ fn flush(&self) {}
+}