feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Derive crate for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = []
|
||||
license.workspace = true
|
||||
name = "burn-derive"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-derive"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = { workspace = true }
|
||||
quote = { workspace = true }
|
||||
syn = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-derive/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-derive/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,6 @@
|
||||
# Burn Derive
|
||||
|
||||
This crate should only be used with [burn](https://github.com/tracel-ai/burn).
|
||||
|
||||
[](https://crates.io/crates/burn-derive)
|
||||
[](https://github.com/tracel-ai/burn-derive/blob/master/README.md)
|
||||
@@ -0,0 +1,87 @@
|
||||
use super::ConfigEnumAnalyzer;
|
||||
use crate::config::ConfigStructAnalyzer;
|
||||
use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{Field, Ident};
|
||||
|
||||
pub struct ConfigAnalyzerFactory {}
|
||||
|
||||
pub trait ConfigAnalyzer {
|
||||
fn gen_new_fn(&self) -> TokenStream {
|
||||
quote! {}
|
||||
}
|
||||
fn gen_builder_fns(&self) -> TokenStream {
|
||||
quote! {}
|
||||
}
|
||||
fn gen_serde_impl(&self) -> TokenStream;
|
||||
fn gen_clone_impl(&self) -> TokenStream;
|
||||
fn gen_display_impl(&self) -> TokenStream;
|
||||
fn gen_config_impl(&self) -> TokenStream;
|
||||
}
|
||||
|
||||
impl ConfigAnalyzerFactory {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box<dyn ConfigAnalyzer> {
|
||||
let name = item.ident.clone();
|
||||
let config_type = parse_asm(item);
|
||||
|
||||
match config_type {
|
||||
ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)),
|
||||
ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_struct_analyzer(&self, name: Ident, fields: Vec<Field>) -> ConfigStructAnalyzer {
|
||||
let fields = fields.into_iter().map(FieldTypeAnalyzer::new);
|
||||
|
||||
let mut fields_required = Vec::new();
|
||||
let mut fields_option = Vec::new();
|
||||
let mut fields_default = Vec::new();
|
||||
|
||||
for field in fields {
|
||||
let attributes: Vec<AttributeItem> = field
|
||||
.attributes()
|
||||
.filter(|attr| attr.has_name("config"))
|
||||
.map(|attr| attr.item())
|
||||
.collect();
|
||||
|
||||
if !attributes.is_empty() {
|
||||
let item = attributes.first().unwrap().clone();
|
||||
fields_default.push((field.clone(), item));
|
||||
continue;
|
||||
}
|
||||
|
||||
if field.is_of_type(&["Option"]) {
|
||||
fields_option.push(field.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
fields_required.push(field.clone());
|
||||
}
|
||||
|
||||
ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default)
|
||||
}
|
||||
|
||||
fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer {
|
||||
ConfigEnumAnalyzer::new(name, data)
|
||||
}
|
||||
}
|
||||
|
||||
enum ConfigType {
|
||||
Struct(Vec<Field>),
|
||||
Enum(syn::DataEnum),
|
||||
}
|
||||
|
||||
fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
|
||||
match &ast.data {
|
||||
syn::Data::Struct(struct_data) => {
|
||||
ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
|
||||
}
|
||||
syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
|
||||
syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
use crate::shared::enum_variant::map_enum_variant;
|
||||
|
||||
use super::ConfigAnalyzer;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub struct ConfigEnumAnalyzer {
|
||||
name: Ident,
|
||||
data: syn::DataEnum,
|
||||
}
|
||||
|
||||
impl ConfigEnumAnalyzer {
|
||||
pub fn new(name: Ident, data: syn::DataEnum) -> Self {
|
||||
Self { name, data }
|
||||
}
|
||||
|
||||
fn serde_enum_ident(&self) -> Ident {
|
||||
Ident::new(&format!("{}Serde", self.name), self.name.span())
|
||||
}
|
||||
|
||||
fn gen_serde_enum(&self) -> TokenStream {
|
||||
let enum_name = self.serde_enum_ident();
|
||||
let data = &self.data.variants;
|
||||
|
||||
quote! {
|
||||
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
enum #enum_name {
|
||||
#data
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_serialize_fn(&self) -> TokenStream {
|
||||
let enum_name = self.serde_enum_ident();
|
||||
let variants = self.data.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
|
||||
|
||||
quote! { Self::#variant_name #inputs => #enum_name::#variant_name #outputs }
|
||||
});
|
||||
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl burn::serde::Serialize for #name {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: burn::serde::Serializer {
|
||||
let serde_state = match self {
|
||||
#(#variants),*
|
||||
};
|
||||
serde_state.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_deserialize_fn(&self) -> TokenStream {
|
||||
let enum_name = self.serde_enum_ident();
|
||||
let variants = self.data.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
|
||||
|
||||
quote! { #enum_name::#variant_name #inputs => Self::#variant_name #outputs }
|
||||
});
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl<'de> burn::serde::Deserialize<'de> for #name {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: burn::serde::Deserializer<'de> {
|
||||
let serde_state = #enum_name::deserialize(deserializer)?;
|
||||
Ok(match serde_state {
|
||||
#(#variants),*
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigAnalyzer for ConfigEnumAnalyzer {
|
||||
fn gen_serde_impl(&self) -> TokenStream {
|
||||
let struct_gen = self.gen_serde_enum();
|
||||
let serialize_gen = self.gen_serialize_fn();
|
||||
let deserialize_gen = self.gen_deserialize_fn();
|
||||
|
||||
quote! {
|
||||
#struct_gen
|
||||
#serialize_gen
|
||||
#deserialize_gen
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_clone_impl(&self) -> TokenStream {
|
||||
let variants = self.data.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
|
||||
|
||||
quote! { Self::#variant_name #inputs => Self::#variant_name #outputs }
|
||||
});
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl Clone for #name {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
#(#variants),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_display_impl(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl core::fmt::Display for #name {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(&burn::config::config_to_json(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_config_impl(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl burn::config::Config for #name {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
use super::ConfigAnalyzer;
|
||||
use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub struct ConfigStructAnalyzer {
|
||||
name: Ident,
|
||||
fields_required: Vec<FieldTypeAnalyzer>,
|
||||
fields_option: Vec<FieldTypeAnalyzer>,
|
||||
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
|
||||
}
|
||||
|
||||
impl ConfigStructAnalyzer {
|
||||
pub fn new(
|
||||
name: Ident,
|
||||
fields_required: Vec<FieldTypeAnalyzer>,
|
||||
fields_option: Vec<FieldTypeAnalyzer>,
|
||||
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
fields_required,
|
||||
fields_option,
|
||||
fields_default,
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl #name {
|
||||
#tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn names(&self) -> Vec<FieldTypeAnalyzer> {
|
||||
let mut names = Vec::new();
|
||||
|
||||
for field in self.fields_required.iter() {
|
||||
names.push(field.clone());
|
||||
}
|
||||
|
||||
for field in self.fields_option.iter() {
|
||||
names.push(field.clone());
|
||||
}
|
||||
|
||||
for (field, _) in self.fields_default.iter() {
|
||||
names.push(field.clone());
|
||||
}
|
||||
|
||||
names
|
||||
}
|
||||
|
||||
fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec<TokenStream> {
|
||||
let mut name_types = Vec::new();
|
||||
|
||||
for field in names.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
|
||||
name_types.push(quote! {
|
||||
#name: #ty
|
||||
});
|
||||
}
|
||||
|
||||
name_types
|
||||
}
|
||||
|
||||
fn serde_struct_ident(&self) -> Ident {
|
||||
Ident::new(&format!("{}Serde", self.name), self.name.span())
|
||||
}
|
||||
|
||||
fn gen_serialize_fn(
|
||||
&self,
|
||||
struct_name: &Ident,
|
||||
struct_gen: &TokenStream,
|
||||
names: &[FieldTypeAnalyzer],
|
||||
) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let names = names.iter().map(|name| {
|
||||
let name = name.ident();
|
||||
quote! { #name: self.#name.clone() }
|
||||
});
|
||||
|
||||
quote! {
|
||||
impl burn::serde::Serialize for #name {
|
||||
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: burn::serde::Serializer {
|
||||
#[derive(burn::serde::Serialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#struct_gen
|
||||
|
||||
let serde_state = #struct_name {
|
||||
#(#names),*
|
||||
};
|
||||
serde_state.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_deserialize_fn(
|
||||
&self,
|
||||
struct_name: &Ident,
|
||||
struct_gen: &TokenStream,
|
||||
names: &[FieldTypeAnalyzer],
|
||||
) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let names = names.iter().map(|name| {
|
||||
let name = name.ident();
|
||||
quote! { #name: serde_state.#name }
|
||||
});
|
||||
|
||||
quote! {
|
||||
impl<'de> burn::serde::Deserialize<'de> for #name {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: burn::serde::Deserializer<'de> {
|
||||
#[derive(burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#struct_gen
|
||||
|
||||
let serde_state = #struct_name::deserialize(deserializer)?;
|
||||
Ok(#name {
|
||||
#(#names),*
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream {
|
||||
let struct_name = self.serde_struct_ident();
|
||||
|
||||
quote! {
|
||||
struct #struct_name {
|
||||
#(#names),*
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigAnalyzer for ConfigStructAnalyzer {
|
||||
fn gen_new_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
let mut args = Vec::new();
|
||||
|
||||
let mut fn_docs = quote! {};
|
||||
let mut has_field_docs = false;
|
||||
let mut has_required_docs = false;
|
||||
let mut has_option_docs = false;
|
||||
let mut has_default_docs = false;
|
||||
let mut docs_header = |fn_docs: &mut TokenStream,
|
||||
required_docs: bool,
|
||||
option_docs: bool,
|
||||
default_docs: bool| {
|
||||
if !has_field_docs {
|
||||
has_field_docs = true;
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = "# Arguments"]
|
||||
});
|
||||
}
|
||||
if !has_required_docs && required_docs {
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = "###### Required Arguments"]
|
||||
});
|
||||
has_required_docs = true;
|
||||
}
|
||||
if !has_option_docs && option_docs {
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = "###### Optional Arguments"]
|
||||
});
|
||||
has_option_docs = true;
|
||||
}
|
||||
if !has_default_docs && default_docs {
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = "###### Default Arguments"]
|
||||
});
|
||||
has_default_docs = true;
|
||||
}
|
||||
};
|
||||
|
||||
for field in self.fields_required.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let docs = field.docs();
|
||||
|
||||
body.extend(quote! {
|
||||
#name: #name,
|
||||
});
|
||||
args.push(quote! {
|
||||
#name: #ty
|
||||
});
|
||||
docs_header(&mut fn_docs, true, false, false);
|
||||
let doc_str = format!("###### `{}`\n\n", quote!(#name));
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = #doc_str]
|
||||
#(#docs)*
|
||||
});
|
||||
}
|
||||
|
||||
for field in self.fields_option.iter() {
|
||||
let name = field.ident();
|
||||
let docs = field.docs();
|
||||
|
||||
body.extend(quote! {
|
||||
#name: None,
|
||||
});
|
||||
docs_header(&mut fn_docs, false, true, false);
|
||||
let default_doc = "- Defaults to `None`";
|
||||
let doc_str = format!("###### `{}`\n", quote!(#name));
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = #doc_str]
|
||||
#(#docs)*
|
||||
#[doc = #default_doc]
|
||||
});
|
||||
}
|
||||
|
||||
for (field, attribute) in self.fields_default.iter() {
|
||||
let name = field.ident();
|
||||
let value = &attribute.value;
|
||||
let docs = field.docs();
|
||||
|
||||
match value {
|
||||
syn::Lit::Str(value) => {
|
||||
let stream: proc_macro2::TokenStream = value.value().parse().unwrap();
|
||||
|
||||
body.extend(quote! {
|
||||
#name: #stream,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
body.extend(quote! {
|
||||
#name: #value,
|
||||
});
|
||||
}
|
||||
};
|
||||
docs_header(&mut fn_docs, false, false, true);
|
||||
let default_doc = format!("- Defaults to `{}`", quote!(#value));
|
||||
let doc_str = format!("###### `{}`\n", quote!(#name));
|
||||
fn_docs.extend(quote! {
|
||||
#[doc = #doc_str]
|
||||
#(#docs)*
|
||||
#[doc = #default_doc]
|
||||
});
|
||||
}
|
||||
|
||||
let body = quote! {
|
||||
#[doc = "Create a new instance of the config."]
|
||||
#fn_docs
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
#(#args),*
|
||||
) -> Self {
|
||||
Self { #body }
|
||||
}
|
||||
};
|
||||
self.wrap_impl_block(body)
|
||||
}
|
||||
|
||||
fn gen_builder_fns(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
|
||||
for (field, attribute) in self.fields_default.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let value = &attribute.value;
|
||||
let docs = field.docs();
|
||||
let default_doc = format!("- Defaults to `{}`", quote!(#value));
|
||||
let doc_str = format!(
|
||||
"Sets the value for the field [`{}`](Self::{0}).\n\n",
|
||||
quote!(#name)
|
||||
);
|
||||
let fn_docs = quote! {
|
||||
#[doc = #doc_str]
|
||||
#(#docs)*
|
||||
#[doc = #default_doc]
|
||||
};
|
||||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
#fn_docs
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for field in self.fields_option.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let docs = field.docs();
|
||||
let default_doc = "- Defaults to `None`";
|
||||
let doc_str = format!(
|
||||
"Sets the value for the field [`{}`](Self::{0}).\n\n",
|
||||
quote!(#name)
|
||||
);
|
||||
let fn_docs = quote! {
|
||||
#[doc = #doc_str]
|
||||
#(#docs)*
|
||||
#[doc = #default_doc]
|
||||
};
|
||||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
#fn_docs
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
self.wrap_impl_block(body)
|
||||
}
|
||||
|
||||
fn gen_serde_impl(&self) -> TokenStream {
|
||||
let names = self.names();
|
||||
|
||||
let struct_name = self.serde_struct_ident();
|
||||
let name_types = self.name_types(&names);
|
||||
let struct_gen = self.gen_serde_struct(&name_types);
|
||||
|
||||
let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names);
|
||||
let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names);
|
||||
|
||||
quote! {
|
||||
#serialize_gen
|
||||
#deserialize_gen
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_clone_impl(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let names = self.names().into_iter().map(|name| {
|
||||
let name = name.ident();
|
||||
quote! { #name: self.#name.clone() }
|
||||
});
|
||||
|
||||
quote! {
|
||||
impl Clone for #name {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_display_impl(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl core::fmt::Display for #name {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(&burn::config::config_to_json(self))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_config_impl(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
impl burn::config::Config for #name {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use super::ConfigAnalyzerFactory;
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream {
|
||||
let factory = ConfigAnalyzerFactory::new();
|
||||
let analyzer = factory.create_analyzer(item);
|
||||
|
||||
let constructor = analyzer.gen_new_fn();
|
||||
let builders = analyzer.gen_builder_fns();
|
||||
let serde = analyzer.gen_serde_impl();
|
||||
let clone = analyzer.gen_clone_impl();
|
||||
let display = analyzer.gen_display_impl();
|
||||
let config_impl = analyzer.gen_config_impl();
|
||||
|
||||
quote! {
|
||||
#config_impl
|
||||
#constructor
|
||||
#builders
|
||||
#serde
|
||||
#clone
|
||||
#display
|
||||
}
|
||||
.into()
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
mod analyzer;
|
||||
mod analyzer_enum;
|
||||
mod analyzer_struct;
|
||||
mod base;
|
||||
|
||||
pub(crate) use analyzer::*;
|
||||
pub(crate) use analyzer_enum::*;
|
||||
pub(crate) use analyzer_struct::*;
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,34 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! The derive crate of Burn.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
pub(crate) mod config;
|
||||
pub(crate) mod module;
|
||||
pub(crate) mod record;
|
||||
pub(crate) mod shared;
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_derive(Module, attributes(module))]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
module::derive_impl(&input)
|
||||
}
|
||||
|
||||
/// Derive macro for the record.
|
||||
#[proc_macro_derive(Record)]
|
||||
pub fn record_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
record::derive_impl(&input)
|
||||
}
|
||||
|
||||
/// Derive macro for the config.
|
||||
#[proc_macro_derive(Config, attributes(config))]
|
||||
pub fn config_derive(input: TokenStream) -> TokenStream {
|
||||
let item = syn::parse(input).unwrap();
|
||||
config::derive_impl(&item)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use super::{
|
||||
codegen::{generate_module_const, generate_module_standard},
|
||||
codegen_enum::EnumModuleCodegen,
|
||||
codegen_struct::StructModuleCodegen,
|
||||
};
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let has_backend = ast
|
||||
.generics
|
||||
.type_params()
|
||||
.map(|param| param.ident == "B")
|
||||
.reduce(|accum, is_backend| is_backend || accum)
|
||||
.unwrap_or(false);
|
||||
|
||||
match &ast.data {
|
||||
syn::Data::Struct(_) => {
|
||||
if has_backend {
|
||||
generate_module_standard(ast, StructModuleCodegen::from_ast(ast))
|
||||
} else {
|
||||
generate_module_const(ast)
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(_data) => match EnumModuleCodegen::from_ast(ast) {
|
||||
Ok(enum_codegen) => {
|
||||
if has_backend {
|
||||
generate_module_standard(ast, enum_codegen)
|
||||
} else {
|
||||
generate_module_const(ast)
|
||||
}
|
||||
}
|
||||
Err(err) => err.to_compile_error(),
|
||||
},
|
||||
syn::Data::Union(_) => {
|
||||
syn::Error::new_spanned(ast, "Union modules aren't supported").to_compile_error()
|
||||
}
|
||||
}
|
||||
.into()
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
use super::{display, record::ModuleRecordCodegen};
|
||||
use crate::shared::generics::GenericsHelper;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Attribute, Generics, parse_quote};
|
||||
|
||||
/// Basic trait to be implemented for Module generation.
|
||||
pub(crate) trait ModuleCodegen {
|
||||
type RecordCodegen: ModuleRecordCodegen;
|
||||
|
||||
fn gen_num_params(&self) -> TokenStream;
|
||||
fn gen_visit(&self) -> TokenStream;
|
||||
fn gen_collect_devices(&self) -> TokenStream;
|
||||
fn gen_to_device(&self) -> TokenStream;
|
||||
fn gen_fork(&self) -> TokenStream;
|
||||
fn gen_map(&self) -> TokenStream;
|
||||
fn gen_valid(&self) -> TokenStream;
|
||||
fn gen_from_inner(&self) -> TokenStream;
|
||||
fn gen_into_record(&self) -> TokenStream;
|
||||
fn gen_load_record(&self) -> TokenStream;
|
||||
fn gen_clone(&self) -> TokenStream;
|
||||
|
||||
fn record_codegen(self) -> Self::RecordCodegen;
|
||||
}
|
||||
|
||||
pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
|
||||
ast: &syn::DeriveInput,
|
||||
codegen: Codegen,
|
||||
) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
|
||||
let generics = GenericsParser::from_ast(&ast.generics);
|
||||
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
let num_params_fn = codegen.gen_num_params();
|
||||
let visit = codegen.gen_visit();
|
||||
let map_mut = codegen.gen_map();
|
||||
let collect_devices = codegen.gen_collect_devices();
|
||||
let to_device = codegen.gen_to_device();
|
||||
let fork = codegen.gen_fork();
|
||||
let valid_fn = codegen.gen_valid();
|
||||
let from_inner_fn = codegen.gen_from_inner();
|
||||
let into_record_fn = codegen.gen_into_record();
|
||||
let load_record_fn = codegen.gen_load_record();
|
||||
let clone_fn = codegen.gen_clone();
|
||||
|
||||
let record = codegen.record_codegen();
|
||||
let record_name = Ident::new(format!("{name}Record").as_str(), name.span());
|
||||
let record_type = record.gen_record_type(&record_name, &generics.module);
|
||||
|
||||
let (generics_module, generics_ty_module, generics_where_module) =
|
||||
generics.module.split_for_impl();
|
||||
let (generics_module_autodiff, generics_ty_module_autodiff, generics_where_module_autodiff) =
|
||||
generics.module_autodiff.split_for_impl();
|
||||
let (generics_module_has_autodiff, _generics_ty, generics_where_module_has_autodiff) =
|
||||
generics.module_has_autodiff.split_for_impl();
|
||||
|
||||
let generics_ty_inner_module = generics.inner_module_ty;
|
||||
let generics_ty_train_module = generics.train_module_ty;
|
||||
let generics_ty_train_inner_module = generics.train_inner_ty;
|
||||
|
||||
let mut codegen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
|
||||
type Record = #record_name #generics_ty_module;
|
||||
|
||||
#load_record_fn
|
||||
#into_record_fn
|
||||
|
||||
#num_params_fn
|
||||
|
||||
#visit
|
||||
#map_mut
|
||||
|
||||
#collect_devices
|
||||
#to_device
|
||||
#fork
|
||||
|
||||
}
|
||||
|
||||
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
|
||||
{
|
||||
type InnerModule=#name<B::InnerBackend, #generics_ty_inner_module>;
|
||||
|
||||
#valid_fn
|
||||
|
||||
#from_inner_fn
|
||||
}
|
||||
|
||||
impl #generics_module_has_autodiff burn::module::HasAutodiffModule<B> for #name<B::InnerBackend, #generics_ty_train_module> #generics_where_module_has_autodiff
|
||||
{
|
||||
type TrainModule=#name<B, #generics_ty_train_inner_module>;
|
||||
}
|
||||
|
||||
impl #generics_module core::fmt::Display for #name #generics_ty_module #generics_where_module {
|
||||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
|
||||
#attributes_fn
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
burn::module::Module::num_params(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
|
||||
#clone_fn
|
||||
}
|
||||
|
||||
#record_type
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
codegen.extend(quote! {
|
||||
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
codegen
|
||||
}
|
||||
|
||||
// When there is no backend in the generic parameter, the type is considered as a constant.
|
||||
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||
|
||||
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
|
||||
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
|
||||
|
||||
let mut generics_module = ast.generics.clone();
|
||||
let mut generics_module_autodiff = ast.generics.clone();
|
||||
|
||||
for param in backend.params.into_iter() {
|
||||
generics_module.params.push(param);
|
||||
}
|
||||
for param in backend_ad.params.into_iter() {
|
||||
generics_module_autodiff.params.push(param);
|
||||
}
|
||||
let (generics_module, _, _) = generics_module.split_for_impl();
|
||||
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
|
||||
|
||||
let display_fn = display::display_fn(ast);
|
||||
let attributes_fn = display::attributes_fn(ast);
|
||||
|
||||
let mut codegen = quote! {
|
||||
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||
burn::constant!(module);
|
||||
}
|
||||
|
||||
impl #generics_module_ad burn::module::AutodiffModule<B>
|
||||
for #name #generics_ty #generics_where {
|
||||
burn::constant!(ad_module, #name #generics_ty);
|
||||
}
|
||||
|
||||
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
|
||||
|
||||
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
|
||||
#attributes_fn
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
if !has_custom_display(&ast.attrs) {
|
||||
codegen.extend(quote! {
|
||||
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
codegen
|
||||
}
|
||||
|
||||
struct GenericsParser {
|
||||
module: Generics,
|
||||
module_autodiff: Generics,
|
||||
module_has_autodiff: Generics,
|
||||
inner_module_ty: TokenStream,
|
||||
train_module_ty: TokenStream,
|
||||
train_inner_ty: TokenStream,
|
||||
}
|
||||
|
||||
impl GenericsParser {
|
||||
fn from_ast(generics: &Generics) -> Self {
|
||||
let mut module = GenericsHelper::new(generics.clone());
|
||||
let mut module_autodiff = GenericsHelper::new(generics.clone());
|
||||
let mut module_has_autodiff = GenericsHelper::new(generics.clone());
|
||||
|
||||
let backend_trait = module.fetch_backend_trait();
|
||||
|
||||
module_autodiff.add_predicate(parse_quote! {
|
||||
B: burn::tensor::backend::AutodiffBackend
|
||||
});
|
||||
|
||||
module_autodiff.add_predicate(parse_quote! {
|
||||
<B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait
|
||||
});
|
||||
|
||||
module_has_autodiff.add_predicate(parse_quote! {
|
||||
B: burn::tensor::backend::AutodiffBackend
|
||||
});
|
||||
|
||||
module_has_autodiff.add_predicate(parse_quote! {
|
||||
<B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait
|
||||
});
|
||||
|
||||
let mut generics_names_except_backend = quote! {};
|
||||
let mut train_generics_names_except_backend = quote! {};
|
||||
let mut train_inner_generics_names_except_backend = quote! {};
|
||||
|
||||
module
|
||||
.types()
|
||||
.into_iter()
|
||||
.filter(|ident| ident != "B")
|
||||
.for_each(|ident| {
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::Module<B>
|
||||
}
|
||||
);
|
||||
|
||||
module.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::AutodiffModule<B>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
|
||||
|
||||
module_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
module_has_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::Module<B::InnerBackend>
|
||||
}
|
||||
);
|
||||
|
||||
module_has_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
|
||||
module_has_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident: burn::module::HasAutodiffModule<B>
|
||||
}
|
||||
);
|
||||
|
||||
module_has_autodiff.add_predicate(
|
||||
parse_quote! {
|
||||
#ident::TrainModule: burn::module::ModuleDisplay
|
||||
}
|
||||
);
|
||||
train_generics_names_except_backend.extend(quote! { #ident, });
|
||||
train_inner_generics_names_except_backend.extend(quote! { #ident::TrainModule, });
|
||||
|
||||
});
|
||||
|
||||
module.consts().into_iter().for_each(|ident| {
|
||||
generics_names_except_backend.extend(quote! { #ident, });
|
||||
train_generics_names_except_backend.extend(quote! { #ident, });
|
||||
train_inner_generics_names_except_backend.extend(quote! { #ident, });
|
||||
});
|
||||
|
||||
Self {
|
||||
module: module.generics,
|
||||
module_autodiff: module_autodiff.generics,
|
||||
module_has_autodiff: module_has_autodiff.generics,
|
||||
inner_module_ty: generics_names_except_backend,
|
||||
train_module_ty: train_generics_names_except_backend,
|
||||
train_inner_ty: train_inner_generics_names_except_backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_custom_display(attrs: &[Attribute]) -> bool {
|
||||
attrs.iter().any(|attr| {
|
||||
attr.path().is_ident("module")
|
||||
&& attr
|
||||
.parse_nested_meta(|meta| {
|
||||
if meta.path.is_ident("custom_display") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(meta.error("unsupported attribute"))
|
||||
}
|
||||
})
|
||||
.is_ok()
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
|
||||
use crate::shared::enum_variant::{EnumVariant, parse_variants};
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::Visibility;
|
||||
|
||||
pub(crate) struct EnumModuleCodegen {
|
||||
pub name: Ident,
|
||||
pub variants: Vec<EnumVariant>,
|
||||
pub vis: Visibility,
|
||||
}
|
||||
|
||||
impl ModuleCodegen for EnumModuleCodegen {
|
||||
type RecordCodegen = EnumModuleRecordCodegen;
|
||||
|
||||
fn gen_num_params(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|_| {
|
||||
quote! {
|
||||
burn::module::Module::<B>::num_params(module)
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn num_params(&self) -> usize {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_visit(&self) -> TokenStream {
|
||||
let enum_name = self.name.to_string();
|
||||
let container_type = format!("Enum:{}", enum_name);
|
||||
let match_body = self.gen_variants_match_fn(|variant_name| {
|
||||
let variant_str = variant_name.to_string();
|
||||
quote! {
|
||||
{
|
||||
visitor.enter_module(#variant_str, #container_type);
|
||||
burn::module::Module::visit(module, visitor);
|
||||
visitor.exit_module(#variant_str, #container_type);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_collect_devices(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|_| {
|
||||
quote! {
|
||||
burn::module::Module::<B>::collect_devices(module, devices)
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
devices: burn::module::Devices<B>
|
||||
) -> burn::module::Devices<B> {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_to_device(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::Module::<B>::to_device(module, device))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_fork(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::Module::<B>::fork(module, device))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_map(&self) -> TokenStream {
|
||||
let enum_name = self.name.to_string();
|
||||
let container_type = format!("Enum:{}", enum_name);
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
let variant_str = variant.to_string();
|
||||
quote! {
|
||||
{
|
||||
mapper.enter_module(#variant_str, #container_type);
|
||||
let result = burn::module::Module::<B>::map(module, mapper);
|
||||
mapper.exit_module(#variant_str, #container_type);
|
||||
Self::#variant(result)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_valid(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_from_inner(&self) -> TokenStream {
|
||||
let match_body =
|
||||
self.gen_variants_match_fn_param("module", "Self::InnerModule::", |variant| {
|
||||
quote! {
|
||||
Self::#variant(burn::module::AutodiffModule::<B>::from_inner(module))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_record(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::Record::#variant(burn::module::Module::<B>::into_record(module))
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn into_record(self) -> Self::Record {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_load_record(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
{
|
||||
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");};
|
||||
Self::#variant(burn::module::Module::<B>::load_record(module, r))
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_clone(&self) -> TokenStream {
|
||||
let match_body = self.gen_variants_match_fn(|variant| {
|
||||
quote! {
|
||||
Self::#variant(module.clone())
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn clone(&self) -> Self {
|
||||
#match_body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_codegen(self) -> Self::RecordCodegen {
|
||||
EnumModuleRecordCodegen::new(self.variants, self.vis)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnumModuleCodegen {
|
||||
pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
name: ast.ident.clone(),
|
||||
variants: parse_variants(ast)?,
|
||||
vis: ast.vis.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate the enum variants' match arms with the provided function
|
||||
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
self.gen_variants_match_fn_param("self", "Self::", func)
|
||||
}
|
||||
|
||||
/// Generate a match expression over the given argument (e.g., `self`)
|
||||
/// and using the provided prefix for variants (e.g., `Self::` or `Self::InnerModule::`)
|
||||
fn gen_variants_match_fn_param<F>(&self, arg: &str, prefix: &str, func: F) -> TokenStream
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let match_arms = self.variants.iter().map(|variant| {
|
||||
let name = &variant.ident;
|
||||
let full_variant = syn::parse_str::<syn::Path>(&format!("{prefix}{name}")).unwrap();
|
||||
let arm_pattern = quote! { #full_variant(module) };
|
||||
let arm_code = func(name.clone());
|
||||
quote! { #arm_pattern => #arm_code, }
|
||||
});
|
||||
|
||||
let arg = Ident::new(arg, Span::call_site());
|
||||
|
||||
quote! {
|
||||
match #arg {
|
||||
#(#match_arms)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
use super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen};
|
||||
use crate::shared::field::{FieldTypeAnalyzer, parse_fields};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::Visibility;
|
||||
|
||||
pub(crate) struct StructModuleCodegen {
|
||||
pub name: Ident,
|
||||
pub fields: Vec<FieldTypeAnalyzer>,
|
||||
pub vis: Visibility,
|
||||
}
|
||||
|
||||
impl ModuleCodegen for StructModuleCodegen {
|
||||
type RecordCodegen = StructModuleRecordCodegen;
|
||||
|
||||
fn gen_num_params(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
num_params += burn::module::Module::<B>::num_params(&self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
#body
|
||||
num_params
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_visit(&self) -> TokenStream {
|
||||
let struct_name = self.name.to_string();
|
||||
let container_type = format!("Struct:{}", struct_name);
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
let name_str = name.to_string();
|
||||
quote! {
|
||||
visitor.enter_module(#name_str, #container_type);
|
||||
burn::module::Module::visit(&self.#name, visitor);
|
||||
visitor.exit_module(#name_str, #container_type);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_collect_devices(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn collect_devices(
|
||||
&self,
|
||||
devices: burn::module::Devices<B>
|
||||
) -> burn::module::Devices<B> {
|
||||
#body
|
||||
|
||||
devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_to_device(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::to_device(self.#name, device);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_fork(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::fork(self.#name, device);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn fork(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_map(&self) -> TokenStream {
|
||||
let struct_name = self.name.to_string();
|
||||
let container_type = format!("Struct:{}", struct_name);
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
let name_str = name.to_string();
|
||||
quote! {
|
||||
mapper.enter_module(#name_str, #container_type);
|
||||
let #name = burn::module::Module::<B>::map(self.#name, mapper);
|
||||
mapper.exit_module(#name_str, #container_type);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_valid(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::AutodiffModule::<B>::valid(&self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
#body
|
||||
|
||||
Self::InnerModule {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_from_inner(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::AutodiffModule::<B>::from_inner(#name);
|
||||
}
|
||||
});
|
||||
|
||||
// Destructure inner module to move all fields
|
||||
let destructure = quote! {
|
||||
let Self::InnerModule { #(#names),* } = module;
|
||||
};
|
||||
|
||||
quote! {
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
#destructure
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_record(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
#name: burn::module::Module::<B>::into_record(self.#name),
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn into_record(self) -> Self::Record {
|
||||
Self::Record {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_load_record(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
#name: burn::module::Module::<B>::load_record(self.#name, record.#name),
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn load_record(self, record: Self::Record) -> Self {
|
||||
Self {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_clone(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn clone(&self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_codegen(self) -> Self::RecordCodegen {
|
||||
StructModuleRecordCodegen::new(self.fields, self.vis)
|
||||
}
|
||||
}
|
||||
|
||||
impl StructModuleCodegen {
|
||||
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
name: ast.ident.clone(),
|
||||
fields: parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.collect(),
|
||||
vis: ast.vis.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_fields_fn_names<F>(&self, func: F) -> (Vec<Ident>, TokenStream)
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let name = field.ident();
|
||||
|
||||
names.push(name.clone());
|
||||
body.extend(func(field.ident()));
|
||||
}
|
||||
|
||||
(names, body)
|
||||
}
|
||||
|
||||
fn gen_fields_fn<F>(&self, func: F) -> TokenStream
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut body = quote! {};
|
||||
|
||||
for field in self.fields.iter() {
|
||||
body.extend(func(field.ident()));
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
use quote::quote;
|
||||
|
||||
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
match &ast.data {
|
||||
syn::Data::Struct(data_struct) => {
|
||||
let fields = match &data_struct.fields {
|
||||
syn::Fields::Named(named_fields) => named_fields.named.iter().collect::<Vec<_>>(),
|
||||
syn::Fields::Unit => Vec::new(),
|
||||
_ => panic!("attributes_fn only supports structs with named or unit fields"),
|
||||
};
|
||||
let field_prints = fields.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
let struct_name = &ast.ident;
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
content
|
||||
.set_top_level_type(&stringify!(#struct_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(data_enum) => {
|
||||
let variant_prints = data_enum.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
match &variant.fields {
|
||||
syn::Fields::Unit => {
|
||||
quote! {
|
||||
Self::#variant_name => {
|
||||
content.add_formatted(&stringify!(#variant_name).to_string())
|
||||
.optional()
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Named(named_fields) => {
|
||||
let field_prints = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { .add(stringify!(#field_name), &self.#field_name) }
|
||||
});
|
||||
|
||||
let field_names = named_fields.named.iter().map(|field| {
|
||||
let field_name = &field.ident;
|
||||
quote! { #field_name }
|
||||
});
|
||||
|
||||
quote! {
|
||||
Self::#variant_name { #(#field_names),* } => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Fields::Unnamed(unnamed_fields) => {
|
||||
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
|
||||
syn::Ident::new(&format!("_{i}"), proc_macro2::Span::call_site())
|
||||
});
|
||||
|
||||
let field_prints = field_names.clone().map(|field_name| {
|
||||
quote! { .add(stringify!(#field_name), #field_name) }
|
||||
});
|
||||
quote! {
|
||||
Self::#variant_name(#(#field_names),*) => {
|
||||
content.set_top_level_type(&stringify!(#variant_name))
|
||||
#(#field_prints)*
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
quote! {
|
||||
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
|
||||
match self {
|
||||
#(#variant_prints)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => panic!("attributes_fn only supports structs and enums"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
|
||||
write!(f, "{}", formatted)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod codegen_enum;
|
||||
pub(crate) mod codegen_struct;
|
||||
pub(crate) mod display;
|
||||
pub(crate) mod record;
|
||||
pub(crate) mod record_enum;
|
||||
pub(crate) mod record_struct;
|
||||
|
||||
mod base;
|
||||
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,8 @@
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use syn::Generics;
|
||||
|
||||
/// Basic trait to generate a record type based on the Module struct.
|
||||
pub(crate) trait ModuleRecordCodegen {
|
||||
/// Generate the record type (i.e a struct)
|
||||
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream;
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
use crate::shared::enum_variant::EnumVariant;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Generics, Visibility};
|
||||
|
||||
use super::record::ModuleRecordCodegen;
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct EnumModuleRecordCodegen {
|
||||
variants: Vec<EnumVariant>,
|
||||
vis: Visibility,
|
||||
}
|
||||
|
||||
impl ModuleRecordCodegen for EnumModuleRecordCodegen {
|
||||
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
|
||||
let mut variants = quote! {};
|
||||
let vis = &self.vis;
|
||||
|
||||
// Capture the Record enum variant types
|
||||
for variant in self.variants.iter() {
|
||||
let ty = &variant.ty;
|
||||
let name = &variant.ident;
|
||||
|
||||
variants.extend(quote! {
|
||||
/// The module record associative type.
|
||||
#name(<#ty as burn::module::Module<B>>::Record),
|
||||
});
|
||||
}
|
||||
|
||||
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record type for the module.
|
||||
#[derive(burn::record::Record)]
|
||||
#vis enum #record_name #generics #generics_where {
|
||||
#variants
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
use crate::shared::field::FieldTypeAnalyzer;
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Generics, Visibility};
|
||||
|
||||
use super::record::ModuleRecordCodegen;
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct StructModuleRecordCodegen {
|
||||
fields: Vec<FieldTypeAnalyzer>,
|
||||
vis: Visibility,
|
||||
}
|
||||
|
||||
impl ModuleRecordCodegen for StructModuleRecordCodegen {
|
||||
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
|
||||
let mut fields = quote! {};
|
||||
let vis = &self.vis;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ty = &field.field.ty;
|
||||
let name = &field.field.ident;
|
||||
|
||||
fields.extend(quote! {
|
||||
/// The module record associative type.
|
||||
#vis #name: <#ty as burn::module::Module<B>>::Record,
|
||||
});
|
||||
}
|
||||
|
||||
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record type for the module.
|
||||
#[derive(burn::record::Record)]
|
||||
#vis struct #record_name #generics #generics_where {
|
||||
#fields
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
use super::{
|
||||
codegen::generate_record,
|
||||
item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen},
|
||||
};
|
||||
|
||||
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream {
|
||||
match &ast.data {
|
||||
syn::Data::Struct(_) => generate_record::<StructRecordItemCodegen>(ast),
|
||||
syn::Data::Enum(_) => generate_record::<EnumRecordItemCodegen>(ast),
|
||||
syn::Data::Union(_) => panic!("Union modules aren't supported yet."),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Generics, parse_quote};
|
||||
|
||||
use crate::record::item::codegen::RecordItemCodegen;
|
||||
|
||||
pub(crate) fn generate_record<G: RecordItemCodegen>(ast: &syn::DeriveInput) -> TokenStream {
|
||||
let record_gen: syn::Result<RecordCodegen<G>> = RecordCodegen::from_ast(ast);
|
||||
match record_gen {
|
||||
Ok(record_gen) => {
|
||||
let item_type = record_gen.gen_record_type();
|
||||
let record_impl = record_gen.gen_impl_record();
|
||||
|
||||
quote! {
|
||||
#item_type
|
||||
#record_impl
|
||||
}
|
||||
}
|
||||
Err(err) => err.to_compile_error(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RecordCodegen<G: RecordItemCodegen> {
|
||||
/// Record type info.
|
||||
ty: RecordType,
|
||||
/// Record item code gen.
|
||||
codegen: G,
|
||||
}
|
||||
|
||||
impl<G: RecordItemCodegen> RecordCodegen<G> {
|
||||
/// Generate the record type with the correct generics.
|
||||
pub(crate) fn gen_record_type(&self) -> TokenStream {
|
||||
// Add precision settings type bound
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.ty.generics.clone();
|
||||
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
// Generate the record item definition
|
||||
self.codegen
|
||||
.gen_item_type(&self.ty.item, &generics, self.ty.has_backend)
|
||||
}
|
||||
|
||||
/// Generate the implementation for the Record trait.
|
||||
pub(crate) fn gen_impl_record(&self) -> TokenStream {
|
||||
// Capture the record type's generics and bounds in where clauses
|
||||
let item_generics = self.record_item_generics();
|
||||
let (_, ty_generics_item, _) = item_generics.split_for_impl();
|
||||
let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl();
|
||||
|
||||
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
|
||||
impl_generic
|
||||
} else {
|
||||
quote! { #impl_generics }
|
||||
};
|
||||
|
||||
let name_item = &self.ty.item;
|
||||
let into_item_fn = self.codegen.gen_into_item(name_item);
|
||||
let from_item_fn = self.codegen.gen_from_item();
|
||||
|
||||
// Return the generated stream of token trees (i.e., code to be generated)
|
||||
let name = &self.ty.name;
|
||||
quote! {
|
||||
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
|
||||
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
|
||||
|
||||
#into_item_fn
|
||||
#from_item_fn
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add backend generic type to the implementation block.
|
||||
fn impl_generics(&self) -> Option<TokenStream> {
|
||||
if self.ty.has_backend {
|
||||
return None;
|
||||
}
|
||||
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
let mut generics = self.ty.generics.clone();
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
|
||||
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
|
||||
|
||||
Some(quote! {#impl_generics})
|
||||
}
|
||||
|
||||
/// Get the generics attached to the record item type.
|
||||
fn record_item_generics(&self) -> Generics {
|
||||
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
|
||||
let mut generics = self.ty.generics.clone();
|
||||
for param in param.params.into_iter() {
|
||||
generics.params.push(param);
|
||||
}
|
||||
|
||||
if !self.ty.has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
|
||||
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
ty: RecordType::from_ast(ast),
|
||||
codegen: G::from_ast(ast)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a record type.
|
||||
struct RecordType {
|
||||
/// Record type name.
|
||||
name: Ident,
|
||||
/// Record item type name.
|
||||
item: Ident,
|
||||
/// Lifetimes and type parameters attached to the record type declaration.
|
||||
generics: Generics,
|
||||
/// Whether or not the record type should specify a backend generic.
|
||||
has_backend: bool,
|
||||
}
|
||||
|
||||
impl RecordType {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
let name = ast.ident.clone();
|
||||
let item = Ident::new(format!("{name}Item").as_str(), name.span());
|
||||
let has_backend = ast
|
||||
.generics
|
||||
.type_params()
|
||||
.map(|param| param.ident == "B")
|
||||
.reduce(|accum, is_backend| is_backend || accum)
|
||||
.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
name,
|
||||
item,
|
||||
generics: ast.generics.clone(),
|
||||
has_backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use syn::Generics;
|
||||
|
||||
/// Basic trait to be implemented for record generation.
|
||||
pub(crate) trait RecordItemCodegen {
|
||||
/// Initialize the record item.
|
||||
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
/// Generate the record item type.
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream;
|
||||
/// Generate the into_item function.
|
||||
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
|
||||
/// Generate the from item function.
|
||||
fn gen_from_item(&self) -> TokenStream;
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
use crate::shared::enum_variant::{EnumVariant, parse_variants};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Generics, Visibility, parse_quote};
|
||||
|
||||
use super::codegen::RecordItemCodegen;
|
||||
|
||||
pub(crate) struct EnumRecordItemCodegen {
|
||||
/// Enum variants.
|
||||
variants: Vec<EnumVariant>,
|
||||
vis: Visibility,
|
||||
}
|
||||
|
||||
impl RecordItemCodegen for EnumRecordItemCodegen {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
variants: parse_variants(ast)?,
|
||||
vis: ast.vis.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream {
|
||||
let mut variants = quote! {};
|
||||
let mut serde_bounds = quote! {};
|
||||
let mut clone_bounds = vec![];
|
||||
let mut clone_match_arms = quote! {};
|
||||
let vis = &self.vis;
|
||||
|
||||
// Capture the Record enum variant types and names to transpose them in RecordItem
|
||||
for variant in self.variants.iter() {
|
||||
let ty = &variant.ty;
|
||||
let name = &variant.ident;
|
||||
|
||||
variants.extend(quote! {
|
||||
/// Variant to be serialized.
|
||||
#name(<#ty as burn::record::Record<B>>::Item<S>),
|
||||
});
|
||||
|
||||
// Item types must implement serialization/deserialization
|
||||
serde_bounds.extend(quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
clone_bounds.push(parse_quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: Clone
|
||||
});
|
||||
|
||||
clone_match_arms.extend(quote! {
|
||||
Self::#name(inner) => Self::#name(inner.clone()),
|
||||
});
|
||||
}
|
||||
let serde_bound = serde_bounds.to_string();
|
||||
|
||||
// Capture the type's generics and bounds in where clauses
|
||||
let mut generics = generics.clone();
|
||||
if !has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
let (generics, type_generics, generics_where) = generics.split_for_impl();
|
||||
|
||||
let clone_bounds = generics_where.cloned().map(|mut where_clause| {
|
||||
for predicate in clone_bounds {
|
||||
where_clause.predicates.push(predicate);
|
||||
}
|
||||
where_clause
|
||||
});
|
||||
|
||||
let clone_impl = quote! {
|
||||
impl #generics Clone for #item_name #type_generics #clone_bounds {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
#clone_match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Return the generated stream of token trees (i.e., code to be generated)
|
||||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#[serde(bound = #serde_bound)]
|
||||
#vis enum #item_name #generics #generics_where {
|
||||
#variants
|
||||
}
|
||||
|
||||
#clone_impl
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_item(&self, _item_name: &Ident) -> TokenStream {
|
||||
let mut into_item_match_arms = quote! {};
|
||||
|
||||
for variant in self.variants.iter() {
|
||||
let name = &variant.ident;
|
||||
|
||||
into_item_match_arms.extend(quote! {
|
||||
Self::#name(record) => Self::Item::#name(burn::record::Record::<B>::into_item::<S>(record)),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
|
||||
match self {
|
||||
#into_item_match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_from_item(&self) -> TokenStream {
|
||||
let mut from_item_match_arms = quote! {};
|
||||
|
||||
for variant in self.variants.iter() {
|
||||
let name = &variant.ident;
|
||||
|
||||
from_item_match_arms.extend(quote! {
|
||||
Self::Item::#name(item) => Self::#name(burn::record::Record::<B>::from_item::<S>(item, device)),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
#from_item_match_arms
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use crate::shared::field::{FieldTypeAnalyzer, parse_fields};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{Generics, Visibility, parse_quote};
|
||||
|
||||
use super::codegen::RecordItemCodegen;
|
||||
|
||||
pub(crate) struct StructRecordItemCodegen {
|
||||
fields: Vec<FieldTypeAnalyzer>,
|
||||
vis: Visibility,
|
||||
}
|
||||
|
||||
impl RecordItemCodegen for StructRecordItemCodegen {
|
||||
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
|
||||
Ok(Self {
|
||||
fields: parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.collect(),
|
||||
vis: ast.vis.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn gen_item_type(
|
||||
&self,
|
||||
item_name: &Ident,
|
||||
generics: &Generics,
|
||||
has_backend: bool,
|
||||
) -> TokenStream {
|
||||
let mut fields = quote! {};
|
||||
let mut serde_bounds = quote! {};
|
||||
let mut clone_bounds = vec![];
|
||||
let mut clone_delegate = quote! {};
|
||||
let vis = &self.vis;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ty = &field.field.ty;
|
||||
let name = &field.field.ident;
|
||||
|
||||
fields.extend(quote! {
|
||||
/// Field to be serialized.
|
||||
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
|
||||
});
|
||||
|
||||
serde_bounds.extend(quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
|
||||
});
|
||||
|
||||
clone_bounds.push(parse_quote! {
|
||||
<#ty as burn::record::Record<B>>::Item<S>: Clone
|
||||
});
|
||||
|
||||
clone_delegate.extend(quote! {
|
||||
#name: self.#name.clone(),
|
||||
});
|
||||
}
|
||||
let serde_bound = serde_bounds.to_string();
|
||||
|
||||
let mut generics = generics.clone();
|
||||
if !has_backend {
|
||||
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
|
||||
generics.params.push(syn::GenericParam::Type(param));
|
||||
}
|
||||
let (generics, type_generics, generics_where) = generics.split_for_impl();
|
||||
|
||||
let clone_bounds = generics_where.cloned().map(|mut where_clause| {
|
||||
for predicate in clone_bounds {
|
||||
where_clause.predicates.push(predicate);
|
||||
}
|
||||
where_clause
|
||||
});
|
||||
|
||||
let clone_impl = quote! {
|
||||
impl #generics Clone for #item_name #type_generics #clone_bounds {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
#clone_delegate
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
|
||||
#[serde(crate = "burn::serde")]
|
||||
#[serde(bound = #serde_bound)]
|
||||
#vis struct #item_name #generics #generics_where {
|
||||
#fields
|
||||
}
|
||||
|
||||
#clone_impl
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_into_item(&self, item_name: &Ident) -> TokenStream {
|
||||
let mut body_into_item = quote! {};
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let name = &field.field.ident;
|
||||
|
||||
body_into_item.extend(quote! {
|
||||
#name: burn::record::Record::<B>::into_item::<S>(self.#name),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
|
||||
#item_name {
|
||||
#body_into_item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_from_item(&self) -> TokenStream {
|
||||
let mut body_from_item = quote! {};
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let name = &field.field.ident;
|
||||
|
||||
body_from_item.extend(quote! {
|
||||
#name: burn::record::Record::<B>::from_item::<S>(item.#name, device),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
Self {
|
||||
#body_from_item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod codegen_enum;
|
||||
pub(crate) mod codegen_struct;
|
||||
@@ -0,0 +1,5 @@
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod item;
|
||||
|
||||
mod base;
|
||||
pub(crate) use base::*;
|
||||
@@ -0,0 +1,49 @@
|
||||
use syn::{Attribute, Meta};
|
||||
|
||||
pub struct AttributeAnalyzer {
|
||||
attr: Attribute,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AttributeItem {
|
||||
pub value: syn::Lit,
|
||||
}
|
||||
|
||||
impl AttributeAnalyzer {
|
||||
pub fn new(attr: Attribute) -> Self {
|
||||
Self { attr }
|
||||
}
|
||||
|
||||
pub fn item(&self) -> AttributeItem {
|
||||
let value = match &self.attr.meta {
|
||||
Meta::List(val) => val.parse_args::<syn::MetaNameValue>().unwrap(),
|
||||
Meta::NameValue(meta) => meta.clone(),
|
||||
Meta::Path(_) => panic!("Path meta unsupported"),
|
||||
};
|
||||
|
||||
let lit = match value.value {
|
||||
syn::Expr::Lit(lit) => lit.lit,
|
||||
_ => panic!("Only literal is supported"),
|
||||
};
|
||||
|
||||
AttributeItem { value: lit }
|
||||
}
|
||||
|
||||
pub fn has_name(&self, name: &str) -> bool {
|
||||
Self::path_syn_name(self.attr.path()) == name
|
||||
}
|
||||
|
||||
fn path_syn_name(path: &syn::Path) -> String {
|
||||
let length = path.segments.len();
|
||||
let mut name = String::new();
|
||||
for (i, segment) in path.segments.iter().enumerate() {
|
||||
if i == length - 1 {
|
||||
name += segment.ident.to_string().as_str();
|
||||
} else {
|
||||
let tmp = segment.ident.to_string() + "::";
|
||||
name += tmp.as_str();
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{FieldsNamed, Variant};
|
||||
|
||||
/// Process a variant of an enum where the output is the result of the given mapper.
|
||||
pub(crate) fn map_enum_variant<Mapper>(
|
||||
variant: &Variant,
|
||||
mapper: Mapper,
|
||||
) -> (TokenStream, TokenStream)
|
||||
where
|
||||
Mapper: Fn(&Ident) -> TokenStream,
|
||||
{
|
||||
let gen_fields_unnamed = |num: usize| {
|
||||
let mut inputs = Vec::new();
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
for i in 0..num {
|
||||
let arg_name = Ident::new(&format!("arg_{i}"), Span::call_site());
|
||||
let input = quote! { #arg_name };
|
||||
let output = mapper(&arg_name);
|
||||
|
||||
inputs.push(input);
|
||||
outputs.push(output);
|
||||
}
|
||||
|
||||
(quote! (( #(#inputs),* )), quote! (( #(#outputs),* )))
|
||||
};
|
||||
let gen_fields_named = |fields: &FieldsNamed| {
|
||||
let mut inputs = Vec::new();
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
fields.named.iter().for_each(|field| {
|
||||
let ident = field.ident.as_ref().expect("Named field to have a name.");
|
||||
let input = quote! { #ident };
|
||||
let output = mapper(ident);
|
||||
|
||||
inputs.push(input);
|
||||
outputs.push(quote! {
|
||||
#ident: #output
|
||||
});
|
||||
});
|
||||
|
||||
(quote! {{ #(#inputs),* }}, quote! {{ #(#outputs),* }})
|
||||
};
|
||||
|
||||
match &variant.fields {
|
||||
syn::Fields::Named(fields) => gen_fields_named(fields),
|
||||
syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()),
|
||||
syn::Fields::Unit => (quote! {}, quote! {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum variant (simplified).
|
||||
pub(crate) struct EnumVariant {
|
||||
pub ident: syn::Ident,
|
||||
pub ty: syn::Type,
|
||||
}
|
||||
pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> syn::Result<Vec<EnumVariant>> {
|
||||
let enum_data = match &ast.data {
|
||||
syn::Data::Enum(data) => data,
|
||||
_ => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
ast,
|
||||
"Module can only be derived for enums.",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut variants = Vec::new();
|
||||
|
||||
for variant in enum_data.variants.iter() {
|
||||
match &variant.fields {
|
||||
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
|
||||
let field = &fields.unnamed[0];
|
||||
|
||||
variants.push(EnumVariant {
|
||||
ident: variant.ident.clone(),
|
||||
ty: field.ty.clone(),
|
||||
});
|
||||
}
|
||||
syn::Fields::Unnamed(_) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
variant,
|
||||
"Module derive only supports tuple enum variants with exactly one field.",
|
||||
));
|
||||
}
|
||||
syn::Fields::Named(_) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
variant,
|
||||
"Module derive does not support struct enum variants.",
|
||||
));
|
||||
}
|
||||
syn::Fields::Unit => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
variant,
|
||||
"Module derive does not support unit enum variants.",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(variants)
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use super::attribute::AttributeAnalyzer;
|
||||
use proc_macro2::Ident;
|
||||
use syn::{Field, Type, TypePath};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FieldTypeAnalyzer {
|
||||
pub field: Field,
|
||||
}
|
||||
|
||||
impl FieldTypeAnalyzer {
|
||||
pub fn new(field: Field) -> Self {
|
||||
FieldTypeAnalyzer { field }
|
||||
}
|
||||
|
||||
pub fn ident(&self) -> Ident {
|
||||
self.field.ident.clone().unwrap()
|
||||
}
|
||||
|
||||
pub fn is_of_type(&self, paths: &[&str]) -> bool {
|
||||
match &self.field.ty {
|
||||
syn::Type::Path(path) => {
|
||||
let name = Self::path_name(path);
|
||||
paths.contains(&name.as_str())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn first_generic_field(&self) -> TypePath {
|
||||
let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap());
|
||||
match &self.field.ty {
|
||||
syn::Type::Path(path) => Self::path_generic_argument(path),
|
||||
_ => err(),
|
||||
}
|
||||
}
|
||||
pub fn path_generic_argument(path: &TypePath) -> TypePath {
|
||||
let segment = path.path.segments.last().unwrap();
|
||||
let err = || panic!("Path segment {} has no generic", segment.ident.clone(),);
|
||||
match &segment.arguments {
|
||||
syn::PathArguments::None => err(),
|
||||
syn::PathArguments::AngleBracketed(param) => {
|
||||
let first_param = param.args.first().unwrap();
|
||||
|
||||
if let syn::GenericArgument::Type(Type::Path(path)) = first_param {
|
||||
path.clone()
|
||||
} else {
|
||||
err()
|
||||
}
|
||||
}
|
||||
syn::PathArguments::Parenthesized(_) => err(),
|
||||
}
|
||||
}
|
||||
|
||||
fn path_name(path: &TypePath) -> String {
|
||||
let length = path.path.segments.len();
|
||||
let mut name = String::new();
|
||||
for (i, segment) in path.path.segments.iter().enumerate() {
|
||||
if i == length - 1 {
|
||||
name += segment.ident.to_string().as_str();
|
||||
} else {
|
||||
let tmp = segment.ident.to_string() + "::";
|
||||
name += tmp.as_str();
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
|
||||
/// Returns the docs of the field.
|
||||
pub fn docs(&self) -> impl Iterator<Item = &syn::Attribute> {
|
||||
self.field
|
||||
.attrs
|
||||
.iter()
|
||||
.filter(|attr| attr.path().is_ident("doc"))
|
||||
}
|
||||
|
||||
pub fn attributes(&self) -> impl Iterator<Item = AttributeAnalyzer> {
|
||||
self.field
|
||||
.attrs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(AttributeAnalyzer::new)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
|
||||
let mut fields = Vec::new();
|
||||
|
||||
match &ast.data {
|
||||
syn::Data::Struct(struct_data) => {
|
||||
for field in struct_data.fields.iter() {
|
||||
fields.push(field.clone());
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(_) => panic!("Only struct can be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct can be derived"),
|
||||
};
|
||||
fields
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
use proc_macro2::Ident;
|
||||
use quote::quote;
|
||||
use syn::{Generics, WhereClause, WherePredicate, parse_quote};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GenericsHelper {
|
||||
pub(crate) generics: Generics,
|
||||
}
|
||||
|
||||
impl GenericsHelper {
|
||||
pub fn add_predicate(&mut self, predicate: WherePredicate) {
|
||||
let where_clause: WhereClause = match &self.generics.where_clause {
|
||||
Some(val) => parse_quote! {
|
||||
#val
|
||||
#predicate,
|
||||
},
|
||||
None => parse_quote! {
|
||||
where
|
||||
#predicate,
|
||||
},
|
||||
};
|
||||
self.generics.where_clause = Some(where_clause);
|
||||
}
|
||||
|
||||
pub fn consts(&self) -> Vec<Ident> {
|
||||
self.generics
|
||||
.const_params()
|
||||
.map(|c| c.ident.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn types(&self) -> Vec<Ident> {
|
||||
self.generics
|
||||
.type_params()
|
||||
.map(|tp| tp.ident.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fetch_backend_trait(&self) -> proc_macro2::TokenStream {
|
||||
static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str =
|
||||
"Modules should be generic over a backend.
|
||||
- The generic argument named `B` should have its first trait bound being a backend trait.
|
||||
- The default backend trait is `burn::tensor::backend::Backend`.
|
||||
- Any backend trait is supported.";
|
||||
|
||||
for param in self.generics.params.iter() {
|
||||
if let syn::GenericParam::Type(ty) = ¶m
|
||||
&& ty.ident == "B"
|
||||
{
|
||||
let bound = ty
|
||||
.bounds
|
||||
.first()
|
||||
.expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG);
|
||||
|
||||
return quote! {
|
||||
#bound
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub(crate) mod attribute;
|
||||
pub(crate) mod enum_variant;
|
||||
pub(crate) mod field;
|
||||
pub(crate) mod generics;
|
||||
Reference in New Issue
Block a user