diff --git a/builder/Cargo.toml b/builder/Cargo.toml index 9e84555cb..4bf350f36 100644 --- a/builder/Cargo.toml +++ b/builder/Cargo.toml @@ -16,4 +16,5 @@ path = "tests/progress.rs" trybuild = { version = "1.0.49", features = ["diff"] } [dependencies] -# TODO +quote = "1.0.36" +syn = { version = "2.0.72", features = ["extra-traits", "printing"] } diff --git a/builder/src/lib.rs b/builder/src/lib.rs index f49435b6a..1dff3c69f 100644 --- a/builder/src/lib.rs +++ b/builder/src/lib.rs @@ -1,8 +1,97 @@ +#![allow(unused)] + use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn; #[proc_macro_derive(Builder)] -pub fn derive(input: TokenStream) -> TokenStream { - let _ = input; +pub fn derive(item: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(item as syn::DeriveInput); + let syn::DeriveInput { + ident, + data: syn::Data::Struct(syn::DataStruct { fields, .. }), + .. + } = input + else { + todo!() + }; + let builder_ident = format_ident!("{}Builder", ident); + + let syn::Fields::Named(syn::FieldsNamed { named, .. }) = fields else { + todo!() + }; + let fields = named.iter(); + + let struct_impl = quote! { + impl #ident { + pub fn builder() -> #builder_ident { + #builder_ident { + ..Default::default() + } + } + } + }; + + let option_fields = fields.clone().map(|f| { + let name = &f.ident; + let ty = &f.ty; + quote! { #name: std::option::Option<#ty> } + }); + let builder_def = quote! { + #[derive(Default)] + pub struct #builder_ident { + #( #option_fields ),* + } + }; + + let field_methods = fields.clone().map(|f| { + let name = &f.ident; + let ty = &f.ty; + let optional_fields; + if let syn::Type::Path(syn::TypePath { + path: syn::Path { segments, .. }, + .. + }) = ty + { + optional_fields = segments.iter().filter_map(|seg| { + let syn::PathSegment { ident, arguments } = seg; + if ident == &format_ident!("Option") { + return None; + } + Some(seg) + }) + } + quote! { + pub fn #name(&mut self, #name: #ty) -> &mut Self { + self.#name = Some(#name); + self + } + } + }); + let built_fields = fields.clone().map(|f| { + let name = &f.ident; + let ty = &f.ty; + let err_msg = format!("{} missing", name.as_ref().unwrap()); + quote! { + #name: self.#name.clone().ok_or(#err_msg)? + } + }); + + let builder_impl = quote! { + impl #builder_ident { + #( #field_methods )* + fn build(&self) -> Result<#ident, Box> { + Ok(#ident { + #( #built_fields ),* + }) + } + } + }; - unimplemented!() + quote! { + #struct_impl + #builder_def + #builder_impl + } + .into() } diff --git a/builder/tests/progress.rs b/builder/tests/progress.rs index cd33fe070..03a85876a 100644 --- a/builder/tests/progress.rs +++ b/builder/tests/progress.rs @@ -1,11 +1,11 @@ #[test] fn tests() { let t = trybuild::TestCases::new(); - //t.pass("tests/01-parse.rs"); - //t.pass("tests/02-create-builder.rs"); - //t.pass("tests/03-call-setters.rs"); - //t.pass("tests/04-call-build.rs"); - //t.pass("tests/05-method-chaining.rs"); + t.pass("tests/01-parse.rs"); + t.pass("tests/02-create-builder.rs"); + t.pass("tests/03-call-setters.rs"); + t.pass("tests/04-call-build.rs"); + t.pass("tests/05-method-chaining.rs"); //t.pass("tests/06-optional-field.rs"); //t.pass("tests/07-repeated-field.rs"); //t.compile_fail("tests/08-unrecognized-attribute.rs");