feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,23 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Derive crate for the Burn framework"
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-derive"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-derive"
version.workspace = true
[lints]
workspace = true
[lib]
proc-macro = true
[dependencies]
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }
derive-new = { workspace = true }

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,6 @@
# Burn Derive
This crate should only be used with [burn](https://github.com/tracel-ai/burn).
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-derive.svg)](https://crates.io/crates/burn-derive)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-derive/blob/master/README.md)

View File

@@ -0,0 +1,87 @@
use super::ConfigEnumAnalyzer;
use crate::config::ConfigStructAnalyzer;
use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Field, Ident};
pub struct ConfigAnalyzerFactory {}
pub trait ConfigAnalyzer {
fn gen_new_fn(&self) -> TokenStream {
quote! {}
}
fn gen_builder_fns(&self) -> TokenStream {
quote! {}
}
fn gen_serde_impl(&self) -> TokenStream;
fn gen_clone_impl(&self) -> TokenStream;
fn gen_display_impl(&self) -> TokenStream;
fn gen_config_impl(&self) -> TokenStream;
}
impl ConfigAnalyzerFactory {
pub fn new() -> Self {
Self {}
}
pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box<dyn ConfigAnalyzer> {
let name = item.ident.clone();
let config_type = parse_asm(item);
match config_type {
ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)),
ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)),
}
}
fn create_struct_analyzer(&self, name: Ident, fields: Vec<Field>) -> ConfigStructAnalyzer {
let fields = fields.into_iter().map(FieldTypeAnalyzer::new);
let mut fields_required = Vec::new();
let mut fields_option = Vec::new();
let mut fields_default = Vec::new();
for field in fields {
let attributes: Vec<AttributeItem> = field
.attributes()
.filter(|attr| attr.has_name("config"))
.map(|attr| attr.item())
.collect();
if !attributes.is_empty() {
let item = attributes.first().unwrap().clone();
fields_default.push((field.clone(), item));
continue;
}
if field.is_of_type(&["Option"]) {
fields_option.push(field.clone());
continue;
}
fields_required.push(field.clone());
}
ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default)
}
fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer {
ConfigEnumAnalyzer::new(name, data)
}
}
enum ConfigType {
Struct(Vec<Field>),
Enum(syn::DataEnum),
}
fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
match &ast.data {
syn::Data::Struct(struct_data) => {
ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
}
syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
}
}

View File

@@ -0,0 +1,141 @@
use crate::shared::enum_variant::map_enum_variant;
use super::ConfigAnalyzer;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub struct ConfigEnumAnalyzer {
name: Ident,
data: syn::DataEnum,
}
impl ConfigEnumAnalyzer {
pub fn new(name: Ident, data: syn::DataEnum) -> Self {
Self { name, data }
}
fn serde_enum_ident(&self) -> Ident {
Ident::new(&format!("{}Serde", self.name), self.name.span())
}
fn gen_serde_enum(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let data = &self.data.variants;
quote! {
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
enum #enum_name {
#data
}
}
}
fn gen_serialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
quote! { Self::#variant_name #inputs => #enum_name::#variant_name #outputs }
});
let name = &self.name;
quote! {
impl burn::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: burn::serde::Serializer {
let serde_state = match self {
#(#variants),*
};
serde_state.serialize(serializer)
}
}
}
}
fn gen_deserialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
quote! { #enum_name::#variant_name #inputs => Self::#variant_name #outputs }
});
let name = &self.name;
quote! {
impl<'de> burn::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: burn::serde::Deserializer<'de> {
let serde_state = #enum_name::deserialize(deserializer)?;
Ok(match serde_state {
#(#variants),*
})
}
}
}
}
}
impl ConfigAnalyzer for ConfigEnumAnalyzer {
fn gen_serde_impl(&self) -> TokenStream {
let struct_gen = self.gen_serde_enum();
let serialize_gen = self.gen_serialize_fn();
let deserialize_gen = self.gen_deserialize_fn();
quote! {
#struct_gen
#serialize_gen
#deserialize_gen
}
}
fn gen_clone_impl(&self) -> TokenStream {
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });
quote! { Self::#variant_name #inputs => Self::#variant_name #outputs }
});
let name = &self.name;
quote! {
impl Clone for #name {
fn clone(&self) -> Self {
match self {
#(#variants),*
}
}
}
}
}
fn gen_display_impl(&self) -> TokenStream {
let name = &self.name;
quote! {
impl core::fmt::Display for #name {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&burn::config::config_to_json(self))
}
}
}
}
fn gen_config_impl(&self) -> TokenStream {
let name = &self.name;
quote! {
impl burn::config::Config for #name {
}
}
}
}

View File

@@ -0,0 +1,380 @@
use super::ConfigAnalyzer;
use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub struct ConfigStructAnalyzer {
name: Ident,
fields_required: Vec<FieldTypeAnalyzer>,
fields_option: Vec<FieldTypeAnalyzer>,
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
}
impl ConfigStructAnalyzer {
pub fn new(
name: Ident,
fields_required: Vec<FieldTypeAnalyzer>,
fields_option: Vec<FieldTypeAnalyzer>,
fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,
) -> Self {
Self {
name,
fields_required,
fields_option,
fields_default,
}
}
fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream {
let name = &self.name;
quote! {
impl #name {
#tokens
}
}
}
fn names(&self) -> Vec<FieldTypeAnalyzer> {
let mut names = Vec::new();
for field in self.fields_required.iter() {
names.push(field.clone());
}
for field in self.fields_option.iter() {
names.push(field.clone());
}
for (field, _) in self.fields_default.iter() {
names.push(field.clone());
}
names
}
fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec<TokenStream> {
let mut name_types = Vec::new();
for field in names.iter() {
let name = field.ident();
let ty = &field.field.ty;
name_types.push(quote! {
#name: #ty
});
}
name_types
}
fn serde_struct_ident(&self) -> Ident {
Ident::new(&format!("{}Serde", self.name), self.name.span())
}
fn gen_serialize_fn(
&self,
struct_name: &Ident,
struct_gen: &TokenStream,
names: &[FieldTypeAnalyzer],
) -> TokenStream {
let name = &self.name;
let names = names.iter().map(|name| {
let name = name.ident();
quote! { #name: self.#name.clone() }
});
quote! {
impl burn::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: burn::serde::Serializer {
#[derive(burn::serde::Serialize)]
#[serde(crate = "burn::serde")]
#struct_gen
let serde_state = #struct_name {
#(#names),*
};
serde_state.serialize(serializer)
}
}
}
}
fn gen_deserialize_fn(
&self,
struct_name: &Ident,
struct_gen: &TokenStream,
names: &[FieldTypeAnalyzer],
) -> TokenStream {
let name = &self.name;
let names = names.iter().map(|name| {
let name = name.ident();
quote! { #name: serde_state.#name }
});
quote! {
impl<'de> burn::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: burn::serde::Deserializer<'de> {
#[derive(burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#struct_gen
let serde_state = #struct_name::deserialize(deserializer)?;
Ok(#name {
#(#names),*
})
}
}
}
}
fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream {
let struct_name = self.serde_struct_ident();
quote! {
struct #struct_name {
#(#names),*
}
}
}
}
impl ConfigAnalyzer for ConfigStructAnalyzer {
fn gen_new_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut args = Vec::new();
let mut fn_docs = quote! {};
let mut has_field_docs = false;
let mut has_required_docs = false;
let mut has_option_docs = false;
let mut has_default_docs = false;
let mut docs_header = |fn_docs: &mut TokenStream,
required_docs: bool,
option_docs: bool,
default_docs: bool| {
if !has_field_docs {
has_field_docs = true;
fn_docs.extend(quote! {
#[doc = "# Arguments"]
});
}
if !has_required_docs && required_docs {
fn_docs.extend(quote! {
#[doc = "###### Required Arguments"]
});
has_required_docs = true;
}
if !has_option_docs && option_docs {
fn_docs.extend(quote! {
#[doc = "###### Optional Arguments"]
});
has_option_docs = true;
}
if !has_default_docs && default_docs {
fn_docs.extend(quote! {
#[doc = "###### Default Arguments"]
});
has_default_docs = true;
}
};
for field in self.fields_required.iter() {
let name = field.ident();
let ty = &field.field.ty;
let docs = field.docs();
body.extend(quote! {
#name: #name,
});
args.push(quote! {
#name: #ty
});
docs_header(&mut fn_docs, true, false, false);
let doc_str = format!("###### `{}`\n\n", quote!(#name));
fn_docs.extend(quote! {
#[doc = #doc_str]
#(#docs)*
});
}
for field in self.fields_option.iter() {
let name = field.ident();
let docs = field.docs();
body.extend(quote! {
#name: None,
});
docs_header(&mut fn_docs, false, true, false);
let default_doc = "- Defaults to `None`";
let doc_str = format!("###### `{}`\n", quote!(#name));
fn_docs.extend(quote! {
#[doc = #doc_str]
#(#docs)*
#[doc = #default_doc]
});
}
for (field, attribute) in self.fields_default.iter() {
let name = field.ident();
let value = &attribute.value;
let docs = field.docs();
match value {
syn::Lit::Str(value) => {
let stream: proc_macro2::TokenStream = value.value().parse().unwrap();
body.extend(quote! {
#name: #stream,
});
}
_ => {
body.extend(quote! {
#name: #value,
});
}
};
docs_header(&mut fn_docs, false, false, true);
let default_doc = format!("- Defaults to `{}`", quote!(#value));
let doc_str = format!("###### `{}`\n", quote!(#name));
fn_docs.extend(quote! {
#[doc = #doc_str]
#(#docs)*
#[doc = #default_doc]
});
}
let body = quote! {
#[doc = "Create a new instance of the config."]
#fn_docs
#[allow(clippy::too_many_arguments)]
pub fn new(
#(#args),*
) -> Self {
Self { #body }
}
};
self.wrap_impl_block(body)
}
fn gen_builder_fns(&self) -> TokenStream {
let mut body = quote! {};
for (field, attribute) in self.fields_default.iter() {
let name = field.ident();
let ty = &field.field.ty;
let value = &attribute.value;
let docs = field.docs();
let default_doc = format!("- Defaults to `{}`", quote!(#value));
let doc_str = format!(
"Sets the value for the field [`{}`](Self::{0}).\n\n",
quote!(#name)
);
let fn_docs = quote! {
#[doc = #doc_str]
#(#docs)*
#[doc = #default_doc]
};
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
#fn_docs
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self
}
});
}
for field in self.fields_option.iter() {
let name = field.ident();
let ty = &field.field.ty;
let docs = field.docs();
let default_doc = "- Defaults to `None`";
let doc_str = format!(
"Sets the value for the field [`{}`](Self::{0}).\n\n",
quote!(#name)
);
let fn_docs = quote! {
#[doc = #doc_str]
#(#docs)*
#[doc = #default_doc]
};
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
#fn_docs
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self
}
});
}
self.wrap_impl_block(body)
}
fn gen_serde_impl(&self) -> TokenStream {
let names = self.names();
let struct_name = self.serde_struct_ident();
let name_types = self.name_types(&names);
let struct_gen = self.gen_serde_struct(&name_types);
let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names);
let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names);
quote! {
#serialize_gen
#deserialize_gen
}
}
fn gen_clone_impl(&self) -> TokenStream {
let name = &self.name;
let names = self.names().into_iter().map(|name| {
let name = name.ident();
quote! { #name: self.#name.clone() }
});
quote! {
impl Clone for #name {
fn clone(&self) -> Self {
Self {
#(#names),*
}
}
}
}
}
fn gen_display_impl(&self) -> TokenStream {
let name = &self.name;
quote! {
impl core::fmt::Display for #name {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&burn::config::config_to_json(self))
}
}
}
}
fn gen_config_impl(&self) -> TokenStream {
let name = &self.name;
quote! {
impl burn::config::Config for #name {
}
}
}
}

View File

@@ -0,0 +1,24 @@
use super::ConfigAnalyzerFactory;
use quote::quote;
pub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream {
let factory = ConfigAnalyzerFactory::new();
let analyzer = factory.create_analyzer(item);
let constructor = analyzer.gen_new_fn();
let builders = analyzer.gen_builder_fns();
let serde = analyzer.gen_serde_impl();
let clone = analyzer.gen_clone_impl();
let display = analyzer.gen_display_impl();
let config_impl = analyzer.gen_config_impl();
quote! {
#config_impl
#constructor
#builders
#serde
#clone
#display
}
.into()
}

View File

@@ -0,0 +1,9 @@
mod analyzer;
mod analyzer_enum;
mod analyzer_struct;
mod base;
pub(crate) use analyzer::*;
pub(crate) use analyzer_enum::*;
pub(crate) use analyzer_struct::*;
pub(crate) use base::*;

View File

@@ -0,0 +1,34 @@
#![warn(missing_docs)]
//! The derive crate of Burn.
#[macro_use]
extern crate derive_new;
use proc_macro::TokenStream;
pub(crate) mod config;
pub(crate) mod module;
pub(crate) mod record;
pub(crate) mod shared;
/// Derive macro for the module.
#[proc_macro_derive(Module, attributes(module))]
pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
module::derive_impl(&input)
}
/// Derive macro for the record.
#[proc_macro_derive(Record)]
pub fn record_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
record::derive_impl(&input)
}
/// Derive macro for the config.
#[proc_macro_derive(Config, attributes(config))]
pub fn config_derive(input: TokenStream) -> TokenStream {
let item = syn::parse(input).unwrap();
config::derive_impl(&item)
}

View File

@@ -0,0 +1,39 @@
use super::{
codegen::{generate_module_const, generate_module_standard},
codegen_enum::EnumModuleCodegen,
codegen_struct::StructModuleCodegen,
};
use proc_macro::TokenStream;
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
match &ast.data {
syn::Data::Struct(_) => {
if has_backend {
generate_module_standard(ast, StructModuleCodegen::from_ast(ast))
} else {
generate_module_const(ast)
}
}
syn::Data::Enum(_data) => match EnumModuleCodegen::from_ast(ast) {
Ok(enum_codegen) => {
if has_backend {
generate_module_standard(ast, enum_codegen)
} else {
generate_module_const(ast)
}
}
Err(err) => err.to_compile_error(),
},
syn::Data::Union(_) => {
syn::Error::new_spanned(ast, "Union modules aren't supported").to_compile_error()
}
}
.into()
}

View File

@@ -0,0 +1,319 @@
use super::{display, record::ModuleRecordCodegen};
use crate::shared::generics::GenericsHelper;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Attribute, Generics, parse_quote};
/// Basic trait to be implemented for Module generation.
pub(crate) trait ModuleCodegen {
type RecordCodegen: ModuleRecordCodegen;
fn gen_num_params(&self) -> TokenStream;
fn gen_visit(&self) -> TokenStream;
fn gen_collect_devices(&self) -> TokenStream;
fn gen_to_device(&self) -> TokenStream;
fn gen_fork(&self) -> TokenStream;
fn gen_map(&self) -> TokenStream;
fn gen_valid(&self) -> TokenStream;
fn gen_from_inner(&self) -> TokenStream;
fn gen_into_record(&self) -> TokenStream;
fn gen_load_record(&self) -> TokenStream;
fn gen_clone(&self) -> TokenStream;
fn record_codegen(self) -> Self::RecordCodegen;
}
pub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(
ast: &syn::DeriveInput,
codegen: Codegen,
) -> TokenStream {
let name = &ast.ident;
let generics = GenericsParser::from_ast(&ast.generics);
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let num_params_fn = codegen.gen_num_params();
let visit = codegen.gen_visit();
let map_mut = codegen.gen_map();
let collect_devices = codegen.gen_collect_devices();
let to_device = codegen.gen_to_device();
let fork = codegen.gen_fork();
let valid_fn = codegen.gen_valid();
let from_inner_fn = codegen.gen_from_inner();
let into_record_fn = codegen.gen_into_record();
let load_record_fn = codegen.gen_load_record();
let clone_fn = codegen.gen_clone();
let record = codegen.record_codegen();
let record_name = Ident::new(format!("{name}Record").as_str(), name.span());
let record_type = record.gen_record_type(&record_name, &generics.module);
let (generics_module, generics_ty_module, generics_where_module) =
generics.module.split_for_impl();
let (generics_module_autodiff, generics_ty_module_autodiff, generics_where_module_autodiff) =
generics.module_autodiff.split_for_impl();
let (generics_module_has_autodiff, _generics_ty, generics_where_module_has_autodiff) =
generics.module_has_autodiff.split_for_impl();
let generics_ty_inner_module = generics.inner_module_ty;
let generics_ty_train_module = generics.train_module_ty;
let generics_ty_train_inner_module = generics.train_inner_ty;
let mut codegen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {
type Record = #record_name #generics_ty_module;
#load_record_fn
#into_record_fn
#num_params_fn
#visit
#map_mut
#collect_devices
#to_device
#fork
}
impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff
{
type InnerModule=#name<B::InnerBackend, #generics_ty_inner_module>;
#valid_fn
#from_inner_fn
}
impl #generics_module_has_autodiff burn::module::HasAutodiffModule<B> for #name<B::InnerBackend, #generics_ty_train_module> #generics_where_module_has_autodiff
{
type TrainModule=#name<B, #generics_ty_train_inner_module>;
}
impl #generics_module core::fmt::Display for #name #generics_ty_module #generics_where_module {
#display_fn
}
impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {
#attributes_fn
fn num_params(&self) -> usize {
burn::module::Module::num_params(self)
}
}
impl #generics_module Clone for #name #generics_ty_module #generics_where_module {
#clone_fn
}
#record_type
};
if !has_custom_display(&ast.attrs) {
codegen.extend(quote! {
impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {
}
});
}
codegen
}
// When there is no backend in the generic parameter, the type is considered as a constant.
pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};
let mut generics_module = ast.generics.clone();
let mut generics_module_autodiff = ast.generics.clone();
for param in backend.params.into_iter() {
generics_module.params.push(param);
}
for param in backend_ad.params.into_iter() {
generics_module_autodiff.params.push(param);
}
let (generics_module, _, _) = generics_module.split_for_impl();
let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();
let display_fn = display::display_fn(ast);
let attributes_fn = display::attributes_fn(ast);
let mut codegen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
burn::constant!(module);
}
impl #generics_module_ad burn::module::AutodiffModule<B>
for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
impl #generics core::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {
#attributes_fn
}
};
if !has_custom_display(&ast.attrs) {
codegen.extend(quote! {
impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {
}
});
}
codegen
}
struct GenericsParser {
module: Generics,
module_autodiff: Generics,
module_has_autodiff: Generics,
inner_module_ty: TokenStream,
train_module_ty: TokenStream,
train_inner_ty: TokenStream,
}
impl GenericsParser {
fn from_ast(generics: &Generics) -> Self {
let mut module = GenericsHelper::new(generics.clone());
let mut module_autodiff = GenericsHelper::new(generics.clone());
let mut module_has_autodiff = GenericsHelper::new(generics.clone());
let backend_trait = module.fetch_backend_trait();
module_autodiff.add_predicate(parse_quote! {
B: burn::tensor::backend::AutodiffBackend
});
module_autodiff.add_predicate(parse_quote! {
<B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait
});
module_has_autodiff.add_predicate(parse_quote! {
B: burn::tensor::backend::AutodiffBackend
});
module_has_autodiff.add_predicate(parse_quote! {
<B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait
});
let mut generics_names_except_backend = quote! {};
let mut train_generics_names_except_backend = quote! {};
let mut train_inner_generics_names_except_backend = quote! {};
module
.types()
.into_iter()
.filter(|ident| ident != "B")
.for_each(|ident| {
module.add_predicate(
parse_quote! {
#ident: burn::module::Module<B>
}
);
module.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::AutodiffModule<B>
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>
}
);
module_autodiff.add_predicate(
parse_quote! {
<#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay
}
);
generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });
module_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
module_has_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::Module<B::InnerBackend>
}
);
module_has_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::ModuleDisplay
}
);
module_has_autodiff.add_predicate(
parse_quote! {
#ident: burn::module::HasAutodiffModule<B>
}
);
module_has_autodiff.add_predicate(
parse_quote! {
#ident::TrainModule: burn::module::ModuleDisplay
}
);
train_generics_names_except_backend.extend(quote! { #ident, });
train_inner_generics_names_except_backend.extend(quote! { #ident::TrainModule, });
});
module.consts().into_iter().for_each(|ident| {
generics_names_except_backend.extend(quote! { #ident, });
train_generics_names_except_backend.extend(quote! { #ident, });
train_inner_generics_names_except_backend.extend(quote! { #ident, });
});
Self {
module: module.generics,
module_autodiff: module_autodiff.generics,
module_has_autodiff: module_has_autodiff.generics,
inner_module_ty: generics_names_except_backend,
train_module_ty: train_generics_names_except_backend,
train_inner_ty: train_inner_generics_names_except_backend,
}
}
}
fn has_custom_display(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().is_ident("module")
&& attr
.parse_nested_meta(|meta| {
if meta.path.is_ident("custom_display") {
Ok(())
} else {
Err(meta.error("unsupported attribute"))
}
})
.is_ok()
})
}

View File

@@ -0,0 +1,236 @@
use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
use crate::shared::enum_variant::{EnumVariant, parse_variants};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::Visibility;
pub(crate) struct EnumModuleCodegen {
pub name: Ident,
pub variants: Vec<EnumVariant>,
pub vis: Visibility,
}
impl ModuleCodegen for EnumModuleCodegen {
type RecordCodegen = EnumModuleRecordCodegen;
fn gen_num_params(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::num_params(module)
}
});
quote! {
fn num_params(&self) -> usize {
#match_body
}
}
}
fn gen_visit(&self) -> TokenStream {
let enum_name = self.name.to_string();
let container_type = format!("Enum:{}", enum_name);
let match_body = self.gen_variants_match_fn(|variant_name| {
let variant_str = variant_name.to_string();
quote! {
{
visitor.enter_module(#variant_str, #container_type);
burn::module::Module::visit(module, visitor);
visitor.exit_module(#variant_str, #container_type);
}
}
});
quote! {
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
#match_body
}
}
}
fn gen_collect_devices(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|_| {
quote! {
burn::module::Module::<B>::collect_devices(module, devices)
}
});
quote! {
fn collect_devices(
&self,
devices: burn::module::Devices<B>
) -> burn::module::Devices<B> {
#match_body
}
}
}
fn gen_to_device(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::to_device(module, device))
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#match_body
}
}
}
fn gen_fork(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(burn::module::Module::<B>::fork(module, device))
}
});
quote! {
fn fork(self, device: &B::Device) -> Self {
#match_body
}
}
}
fn gen_map(&self) -> TokenStream {
let enum_name = self.name.to_string();
let container_type = format!("Enum:{}", enum_name);
let match_body = self.gen_variants_match_fn(|variant| {
let variant_str = variant.to_string();
quote! {
{
mapper.enter_module(#variant_str, #container_type);
let result = burn::module::Module::<B>::map(module, mapper);
mapper.exit_module(#variant_str, #container_type);
Self::#variant(result)
}
}
});
quote! {
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
#match_body
}
}
}
fn gen_valid(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))
}
});
quote! {
fn valid(&self) -> Self::InnerModule {
#match_body
}
}
}
fn gen_from_inner(&self) -> TokenStream {
let match_body =
self.gen_variants_match_fn_param("module", "Self::InnerModule::", |variant| {
quote! {
Self::#variant(burn::module::AutodiffModule::<B>::from_inner(module))
}
});
quote! {
fn from_inner(module: Self::InnerModule) -> Self {
#match_body
}
}
}
fn gen_into_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::Record::#variant(burn::module::Module::<B>::into_record(module))
}
});
quote! {
fn into_record(self) -> Self::Record {
#match_body
}
}
}
fn gen_load_record(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
{
let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");};
Self::#variant(burn::module::Module::<B>::load_record(module, r))
}
}
});
quote! {
fn load_record(self, record: Self::Record) -> Self {
#match_body
}
}
}
fn gen_clone(&self) -> TokenStream {
let match_body = self.gen_variants_match_fn(|variant| {
quote! {
Self::#variant(module.clone())
}
});
quote! {
fn clone(&self) -> Self {
#match_body
}
}
}
fn record_codegen(self) -> Self::RecordCodegen {
EnumModuleRecordCodegen::new(self.variants, self.vis)
}
}
impl EnumModuleCodegen {
pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
Ok(Self {
name: ast.ident.clone(),
variants: parse_variants(ast)?,
vis: ast.vis.clone(),
})
}
/// Generate the enum variants' match arms with the provided function
fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
self.gen_variants_match_fn_param("self", "Self::", func)
}
/// Generate a match expression over the given argument (e.g., `self`)
/// and using the provided prefix for variants (e.g., `Self::` or `Self::InnerModule::`)
fn gen_variants_match_fn_param<F>(&self, arg: &str, prefix: &str, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
let match_arms = self.variants.iter().map(|variant| {
let name = &variant.ident;
let full_variant = syn::parse_str::<syn::Path>(&format!("{prefix}{name}")).unwrap();
let arm_pattern = quote! { #full_variant(module) };
let arm_code = func(name.clone());
quote! { #arm_pattern => #arm_code, }
});
let arg = Ident::new(arg, Span::call_site());
quote! {
match #arg {
#(#match_arms)*
}
}
}
}

View File

@@ -0,0 +1,267 @@
use super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen};
use crate::shared::field::{FieldTypeAnalyzer, parse_fields};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Visibility;
pub(crate) struct StructModuleCodegen {
pub name: Ident,
pub fields: Vec<FieldTypeAnalyzer>,
pub vis: Visibility,
}
impl ModuleCodegen for StructModuleCodegen {
type RecordCodegen = StructModuleRecordCodegen;
fn gen_num_params(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
num_params += burn::module::Module::<B>::num_params(&self.#name);
}
});
quote! {
fn num_params(&self) -> usize {
let mut num_params = 0;
#body
num_params
}
}
}
fn gen_visit(&self) -> TokenStream {
let struct_name = self.name.to_string();
let container_type = format!("Struct:{}", struct_name);
let body = self.gen_fields_fn(|name| {
let name_str = name.to_string();
quote! {
visitor.enter_module(#name_str, #container_type);
burn::module::Module::visit(&self.#name, visitor);
visitor.exit_module(#name_str, #container_type);
}
});
quote! {
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
#body
}
}
}
fn gen_collect_devices(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);
}
});
quote! {
fn collect_devices(
&self,
devices: burn::module::Devices<B>
) -> burn::module::Devices<B> {
#body
devices
}
}
}
fn gen_to_device(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::to_device(self.#name, device);
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn gen_fork(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::fork(self.#name, device);
}
});
quote! {
fn fork(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn gen_map(&self) -> TokenStream {
let struct_name = self.name.to_string();
let container_type = format!("Struct:{}", struct_name);
let (names, body) = self.gen_fields_fn_names(|name| {
let name_str = name.to_string();
quote! {
mapper.enter_module(#name_str, #container_type);
let #name = burn::module::Module::<B>::map(self.#name, mapper);
mapper.exit_module(#name_str, #container_type);
}
});
quote! {
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn gen_valid(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::AutodiffModule::<B>::valid(&self.#name);
}
});
quote! {
fn valid(&self) -> Self::InnerModule {
#body
Self::InnerModule {
#(#names),*
}
}
}
}
fn gen_from_inner(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::AutodiffModule::<B>::from_inner(#name);
}
});
// Destructure inner module to move all fields
let destructure = quote! {
let Self::InnerModule { #(#names),* } = module;
};
quote! {
fn from_inner(module: Self::InnerModule) -> Self {
#destructure
#body
Self {
#(#names),*
}
}
}
}
fn gen_into_record(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
#name: burn::module::Module::<B>::into_record(self.#name),
}
});
quote! {
fn into_record(self) -> Self::Record {
Self::Record {
#body
}
}
}
}
fn gen_load_record(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
#name: burn::module::Module::<B>::load_record(self.#name, record.#name),
}
});
quote! {
fn load_record(self, record: Self::Record) -> Self {
Self {
#body
}
}
}
}
fn gen_clone(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = self.#name.clone();
}
});
quote! {
fn clone(&self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
fn record_codegen(self) -> Self::RecordCodegen {
StructModuleRecordCodegen::new(self.fields, self.vis)
}
}
impl StructModuleCodegen {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
name: ast.ident.clone(),
fields: parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
vis: ast.vis.clone(),
}
}
fn gen_fields_fn_names<F>(&self, func: F) -> (Vec<Ident>, TokenStream)
where
F: Fn(Ident) -> TokenStream,
{
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields.iter() {
let name = field.ident();
names.push(name.clone());
body.extend(func(field.ident()));
}
(names, body)
}
fn gen_fields_fn<F>(&self, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
let mut body = quote! {};
for field in self.fields.iter() {
body.extend(func(field.ident()));
}
body
}
}

View File

@@ -0,0 +1,94 @@
use quote::quote;
pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
match &ast.data {
syn::Data::Struct(data_struct) => {
let fields = match &data_struct.fields {
syn::Fields::Named(named_fields) => named_fields.named.iter().collect::<Vec<_>>(),
syn::Fields::Unit => Vec::new(),
_ => panic!("attributes_fn only supports structs with named or unit fields"),
};
let field_prints = fields.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let struct_name = &ast.ident;
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
content
.set_top_level_type(&stringify!(#struct_name))
#(#field_prints)*
.optional()
}
}
}
syn::Data::Enum(data_enum) => {
let variant_prints = data_enum.variants.iter().map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
syn::Fields::Unit => {
quote! {
Self::#variant_name => {
content.add_formatted(&stringify!(#variant_name).to_string())
.optional()
}
}
}
syn::Fields::Named(named_fields) => {
let field_prints = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { .add(stringify!(#field_name), &self.#field_name) }
});
let field_names = named_fields.named.iter().map(|field| {
let field_name = &field.ident;
quote! { #field_name }
});
quote! {
Self::#variant_name { #(#field_names),* } => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
syn::Fields::Unnamed(unnamed_fields) => {
let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {
syn::Ident::new(&format!("_{i}"), proc_macro2::Span::call_site())
});
let field_prints = field_names.clone().map(|field_name| {
quote! { .add(stringify!(#field_name), #field_name) }
});
quote! {
Self::#variant_name(#(#field_names),*) => {
content.set_top_level_type(&stringify!(#variant_name))
#(#field_prints)*
.optional()
}
}
}
}
});
quote! {
fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {
match self {
#(#variant_prints)*
}
}
}
}
_ => panic!("attributes_fn only supports structs and enums"),
}
}
pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
quote! {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let formatted = burn::module::ModuleDisplay::format(self, Default::default());
write!(f, "{}", formatted)
}
}
}

View File

@@ -0,0 +1,11 @@
pub(crate) mod codegen;
pub(crate) mod codegen_enum;
pub(crate) mod codegen_struct;
pub(crate) mod display;
pub(crate) mod record;
pub(crate) mod record_enum;
pub(crate) mod record_struct;
mod base;
pub(crate) use base::*;

View File

@@ -0,0 +1,8 @@
use proc_macro2::{Ident, TokenStream};
use syn::Generics;
/// Basic trait to generate a record type based on the Module struct.
pub(crate) trait ModuleRecordCodegen {
/// Generate the record type (i.e a struct)
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream;
}

View File

@@ -0,0 +1,41 @@
use crate::shared::enum_variant::EnumVariant;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Generics, Visibility};
use super::record::ModuleRecordCodegen;
#[derive(new)]
pub(crate) struct EnumModuleRecordCodegen {
variants: Vec<EnumVariant>,
vis: Visibility,
}
impl ModuleRecordCodegen for EnumModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut variants = quote! {};
let vis = &self.vis;
// Capture the Record enum variant types
for variant in self.variants.iter() {
let ty = &variant.ty;
let name = &variant.ident;
variants.extend(quote! {
/// The module record associative type.
#name(<#ty as burn::module::Module<B>>::Record),
});
}
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
quote! {
/// The record type for the module.
#[derive(burn::record::Record)]
#vis enum #record_name #generics #generics_where {
#variants
}
}
}
}

View File

@@ -0,0 +1,40 @@
use crate::shared::field::FieldTypeAnalyzer;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Generics, Visibility};
use super::record::ModuleRecordCodegen;
#[derive(new)]
pub(crate) struct StructModuleRecordCodegen {
fields: Vec<FieldTypeAnalyzer>,
vis: Visibility,
}
impl ModuleRecordCodegen for StructModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut fields = quote! {};
let vis = &self.vis;
for field in self.fields.iter() {
let ty = &field.field.ty;
let name = &field.field.ident;
fields.extend(quote! {
/// The module record associative type.
#vis #name: <#ty as burn::module::Module<B>>::Record,
});
}
let (generics, _generics_ty, generics_where) = generics.split_for_impl();
quote! {
/// The record type for the module.
#[derive(burn::record::Record)]
#vis struct #record_name #generics #generics_where {
#fields
}
}
}
}

View File

@@ -0,0 +1,13 @@
use super::{
codegen::generate_record,
item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen},
};
pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream {
match &ast.data {
syn::Data::Struct(_) => generate_record::<StructRecordItemCodegen>(ast),
syn::Data::Enum(_) => generate_record::<EnumRecordItemCodegen>(ast),
syn::Data::Union(_) => panic!("Union modules aren't supported yet."),
}
.into()
}

View File

@@ -0,0 +1,145 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Generics, parse_quote};
use crate::record::item::codegen::RecordItemCodegen;
pub(crate) fn generate_record<G: RecordItemCodegen>(ast: &syn::DeriveInput) -> TokenStream {
let record_gen: syn::Result<RecordCodegen<G>> = RecordCodegen::from_ast(ast);
match record_gen {
Ok(record_gen) => {
let item_type = record_gen.gen_record_type();
let record_impl = record_gen.gen_impl_record();
quote! {
#item_type
#record_impl
}
}
Err(err) => err.to_compile_error(),
}
}
pub(crate) struct RecordCodegen<G: RecordItemCodegen> {
/// Record type info.
ty: RecordType,
/// Record item code gen.
codegen: G,
}
impl<G: RecordItemCodegen> RecordCodegen<G> {
/// Generate the record type with the correct generics.
pub(crate) fn gen_record_type(&self) -> TokenStream {
// Add precision settings type bound
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.ty.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
// Generate the record item definition
self.codegen
.gen_item_type(&self.ty.item, &generics, self.ty.has_backend)
}
/// Generate the implementation for the Record trait.
pub(crate) fn gen_impl_record(&self) -> TokenStream {
// Capture the record type's generics and bounds in where clauses
let item_generics = self.record_item_generics();
let (_, ty_generics_item, _) = item_generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl();
let impl_generics = if let Some(impl_generic) = self.impl_generics() {
impl_generic
} else {
quote! { #impl_generics }
};
let name_item = &self.ty.item;
let into_item_fn = self.codegen.gen_into_item(name_item);
let from_item_fn = self.codegen.gen_from_item();
// Return the generated stream of token trees (i.e., code to be generated)
let name = &self.ty.name;
quote! {
impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {
type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;
#into_item_fn
#from_item_fn
}
}
}
/// Add backend generic type to the implementation block.
fn impl_generics(&self) -> Option<TokenStream> {
if self.ty.has_backend {
return None;
}
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
let mut generics = self.ty.generics.clone();
generics.params.push(syn::GenericParam::Type(param));
let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
Some(quote! {#impl_generics})
}
/// Get the generics attached to the record item type.
fn record_item_generics(&self) -> Generics {
let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};
let mut generics = self.ty.generics.clone();
for param in param.params.into_iter() {
generics.params.push(param);
}
if !self.ty.has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
generics
}
pub(crate) fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
Ok(Self {
ty: RecordType::from_ast(ast),
codegen: G::from_ast(ast)?,
})
}
}
/// Information about a record type.
struct RecordType {
/// Record type name.
name: Ident,
/// Record item type name.
item: Ident,
/// Lifetimes and type parameters attached to the record type declaration.
generics: Generics,
/// Whether or not the record type should specify a backend generic.
has_backend: bool,
}
impl RecordType {
fn from_ast(ast: &syn::DeriveInput) -> Self {
let name = ast.ident.clone();
let item = Ident::new(format!("{name}Item").as_str(), name.span());
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
Self {
name,
item,
generics: ast.generics.clone(),
has_backend,
}
}
}

View File

@@ -0,0 +1,21 @@
use proc_macro2::{Ident, TokenStream};
use syn::Generics;
/// Basic trait to be implemented for record generation.
pub(crate) trait RecordItemCodegen {
/// Initialize the record item.
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self>
where
Self: Sized;
/// Generate the record item type.
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream;
/// Generate the into_item function.
fn gen_into_item(&self, item_name: &Ident) -> TokenStream;
/// Generate the from item function.
fn gen_from_item(&self) -> TokenStream;
}

View File

@@ -0,0 +1,137 @@
use crate::shared::enum_variant::{EnumVariant, parse_variants};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Generics, Visibility, parse_quote};
use super::codegen::RecordItemCodegen;
pub(crate) struct EnumRecordItemCodegen {
/// Enum variants.
variants: Vec<EnumVariant>,
vis: Visibility,
}
impl RecordItemCodegen for EnumRecordItemCodegen {
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
Ok(Self {
variants: parse_variants(ast)?,
vis: ast.vis.clone(),
})
}
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream {
let mut variants = quote! {};
let mut serde_bounds = quote! {};
let mut clone_bounds = vec![];
let mut clone_match_arms = quote! {};
let vis = &self.vis;
// Capture the Record enum variant types and names to transpose them in RecordItem
for variant in self.variants.iter() {
let ty = &variant.ty;
let name = &variant.ident;
variants.extend(quote! {
/// Variant to be serialized.
#name(<#ty as burn::record::Record<B>>::Item<S>),
});
// Item types must implement serialization/deserialization
serde_bounds.extend(quote! {
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});
clone_bounds.push(parse_quote! {
<#ty as burn::record::Record<B>>::Item<S>: Clone
});
clone_match_arms.extend(quote! {
Self::#name(inner) => Self::#name(inner.clone()),
});
}
let serde_bound = serde_bounds.to_string();
// Capture the type's generics and bounds in where clauses
let mut generics = generics.clone();
if !has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
let (generics, type_generics, generics_where) = generics.split_for_impl();
let clone_bounds = generics_where.cloned().map(|mut where_clause| {
for predicate in clone_bounds {
where_clause.predicates.push(predicate);
}
where_clause
});
let clone_impl = quote! {
impl #generics Clone for #item_name #type_generics #clone_bounds {
fn clone(&self) -> Self {
match self {
#clone_match_arms
}
}
}
};
// Return the generated stream of token trees (i.e., code to be generated)
quote! {
/// The record item type for the module.
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #serde_bound)]
#vis enum #item_name #generics #generics_where {
#variants
}
#clone_impl
}
}
fn gen_into_item(&self, _item_name: &Ident) -> TokenStream {
let mut into_item_match_arms = quote! {};
for variant in self.variants.iter() {
let name = &variant.ident;
into_item_match_arms.extend(quote! {
Self::#name(record) => Self::Item::#name(burn::record::Record::<B>::into_item::<S>(record)),
});
}
quote! {
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
match self {
#into_item_match_arms
}
}
}
}
fn gen_from_item(&self) -> TokenStream {
let mut from_item_match_arms = quote! {};
for variant in self.variants.iter() {
let name = &variant.ident;
from_item_match_arms.extend(quote! {
Self::Item::#name(item) => Self::#name(burn::record::Record::<B>::from_item::<S>(item, device)),
});
}
quote! {
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
#from_item_match_arms
}
}
}
}
}

View File

@@ -0,0 +1,136 @@
use crate::shared::field::{FieldTypeAnalyzer, parse_fields};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Generics, Visibility, parse_quote};
use super::codegen::RecordItemCodegen;
pub(crate) struct StructRecordItemCodegen {
fields: Vec<FieldTypeAnalyzer>,
vis: Visibility,
}
impl RecordItemCodegen for StructRecordItemCodegen {
fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {
Ok(Self {
fields: parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
vis: ast.vis.clone(),
})
}
fn gen_item_type(
&self,
item_name: &Ident,
generics: &Generics,
has_backend: bool,
) -> TokenStream {
let mut fields = quote! {};
let mut serde_bounds = quote! {};
let mut clone_bounds = vec![];
let mut clone_delegate = quote! {};
let vis = &self.vis;
for field in self.fields.iter() {
let ty = &field.field.ty;
let name = &field.field.ident;
fields.extend(quote! {
/// Field to be serialized.
pub #name: <#ty as burn::record::Record<B>>::Item<S>,
});
serde_bounds.extend(quote! {
<#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,
});
clone_bounds.push(parse_quote! {
<#ty as burn::record::Record<B>>::Item<S>: Clone
});
clone_delegate.extend(quote! {
#name: self.#name.clone(),
});
}
let serde_bound = serde_bounds.to_string();
let mut generics = generics.clone();
if !has_backend {
let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };
generics.params.push(syn::GenericParam::Type(param));
}
let (generics, type_generics, generics_where) = generics.split_for_impl();
let clone_bounds = generics_where.cloned().map(|mut where_clause| {
for predicate in clone_bounds {
where_clause.predicates.push(predicate);
}
where_clause
});
let clone_impl = quote! {
impl #generics Clone for #item_name #type_generics #clone_bounds {
fn clone(&self) -> Self {
Self {
#clone_delegate
}
}
}
};
quote! {
/// The record item type for the module.
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #serde_bound)]
#vis struct #item_name #generics #generics_where {
#fields
}
#clone_impl
}
}
fn gen_into_item(&self, item_name: &Ident) -> TokenStream {
let mut body_into_item = quote! {};
for field in self.fields.iter() {
let name = &field.field.ident;
body_into_item.extend(quote! {
#name: burn::record::Record::<B>::into_item::<S>(self.#name),
});
}
quote! {
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
#item_name {
#body_into_item
}
}
}
}
fn gen_from_item(&self) -> TokenStream {
let mut body_from_item = quote! {};
for field in self.fields.iter() {
let name = &field.field.ident;
body_from_item.extend(quote! {
#name: burn::record::Record::<B>::from_item::<S>(item.#name, device),
});
}
quote! {
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
Self {
#body_from_item
}
}
}
}
}

View File

@@ -0,0 +1,3 @@
pub(crate) mod codegen;
pub(crate) mod codegen_enum;
pub(crate) mod codegen_struct;

View File

@@ -0,0 +1,5 @@
pub(crate) mod codegen;
pub(crate) mod item;
mod base;
pub(crate) use base::*;

View File

@@ -0,0 +1,49 @@
use syn::{Attribute, Meta};
pub struct AttributeAnalyzer {
attr: Attribute,
}
#[derive(Clone)]
pub struct AttributeItem {
pub value: syn::Lit,
}
impl AttributeAnalyzer {
pub fn new(attr: Attribute) -> Self {
Self { attr }
}
pub fn item(&self) -> AttributeItem {
let value = match &self.attr.meta {
Meta::List(val) => val.parse_args::<syn::MetaNameValue>().unwrap(),
Meta::NameValue(meta) => meta.clone(),
Meta::Path(_) => panic!("Path meta unsupported"),
};
let lit = match value.value {
syn::Expr::Lit(lit) => lit.lit,
_ => panic!("Only literal is supported"),
};
AttributeItem { value: lit }
}
pub fn has_name(&self, name: &str) -> bool {
Self::path_syn_name(self.attr.path()) == name
}
fn path_syn_name(path: &syn::Path) -> String {
let length = path.segments.len();
let mut name = String::new();
for (i, segment) in path.segments.iter().enumerate() {
if i == length - 1 {
name += segment.ident.to_string().as_str();
} else {
let tmp = segment.ident.to_string() + "::";
name += tmp.as_str();
}
}
name
}
}

View File

@@ -0,0 +1,103 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{FieldsNamed, Variant};
/// Process a variant of an enum where the output is the result of the given mapper.
pub(crate) fn map_enum_variant<Mapper>(
variant: &Variant,
mapper: Mapper,
) -> (TokenStream, TokenStream)
where
Mapper: Fn(&Ident) -> TokenStream,
{
let gen_fields_unnamed = |num: usize| {
let mut inputs = Vec::new();
let mut outputs = Vec::new();
for i in 0..num {
let arg_name = Ident::new(&format!("arg_{i}"), Span::call_site());
let input = quote! { #arg_name };
let output = mapper(&arg_name);
inputs.push(input);
outputs.push(output);
}
(quote! (( #(#inputs),* )), quote! (( #(#outputs),* )))
};
let gen_fields_named = |fields: &FieldsNamed| {
let mut inputs = Vec::new();
let mut outputs = Vec::new();
fields.named.iter().for_each(|field| {
let ident = field.ident.as_ref().expect("Named field to have a name.");
let input = quote! { #ident };
let output = mapper(ident);
inputs.push(input);
outputs.push(quote! {
#ident: #output
});
});
(quote! {{ #(#inputs),* }}, quote! {{ #(#outputs),* }})
};
match &variant.fields {
syn::Fields::Named(fields) => gen_fields_named(fields),
syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()),
syn::Fields::Unit => (quote! {}, quote! {}),
}
}
/// An enum variant (simplified).
pub(crate) struct EnumVariant {
pub ident: syn::Ident,
pub ty: syn::Type,
}
pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> syn::Result<Vec<EnumVariant>> {
let enum_data = match &ast.data {
syn::Data::Enum(data) => data,
_ => {
return Err(syn::Error::new_spanned(
ast,
"Module can only be derived for enums.",
));
}
};
let mut variants = Vec::new();
for variant in enum_data.variants.iter() {
match &variant.fields {
syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
let field = &fields.unnamed[0];
variants.push(EnumVariant {
ident: variant.ident.clone(),
ty: field.ty.clone(),
});
}
syn::Fields::Unnamed(_) => {
return Err(syn::Error::new_spanned(
variant,
"Module derive only supports tuple enum variants with exactly one field.",
));
}
syn::Fields::Named(_) => {
return Err(syn::Error::new_spanned(
variant,
"Module derive does not support struct enum variants.",
));
}
syn::Fields::Unit => {
return Err(syn::Error::new_spanned(
variant,
"Module derive does not support unit enum variants.",
));
}
}
}
Ok(variants)
}

View File

@@ -0,0 +1,99 @@
use super::attribute::AttributeAnalyzer;
use proc_macro2::Ident;
use syn::{Field, Type, TypePath};
#[derive(Clone)]
pub struct FieldTypeAnalyzer {
pub field: Field,
}
impl FieldTypeAnalyzer {
pub fn new(field: Field) -> Self {
FieldTypeAnalyzer { field }
}
pub fn ident(&self) -> Ident {
self.field.ident.clone().unwrap()
}
pub fn is_of_type(&self, paths: &[&str]) -> bool {
match &self.field.ty {
syn::Type::Path(path) => {
let name = Self::path_name(path);
paths.contains(&name.as_str())
}
_ => false,
}
}
#[allow(dead_code)]
pub fn first_generic_field(&self) -> TypePath {
let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap());
match &self.field.ty {
syn::Type::Path(path) => Self::path_generic_argument(path),
_ => err(),
}
}
pub fn path_generic_argument(path: &TypePath) -> TypePath {
let segment = path.path.segments.last().unwrap();
let err = || panic!("Path segment {} has no generic", segment.ident.clone(),);
match &segment.arguments {
syn::PathArguments::None => err(),
syn::PathArguments::AngleBracketed(param) => {
let first_param = param.args.first().unwrap();
if let syn::GenericArgument::Type(Type::Path(path)) = first_param {
path.clone()
} else {
err()
}
}
syn::PathArguments::Parenthesized(_) => err(),
}
}
fn path_name(path: &TypePath) -> String {
let length = path.path.segments.len();
let mut name = String::new();
for (i, segment) in path.path.segments.iter().enumerate() {
if i == length - 1 {
name += segment.ident.to_string().as_str();
} else {
let tmp = segment.ident.to_string() + "::";
name += tmp.as_str();
}
}
name
}
/// Returns the docs of the field.
pub fn docs(&self) -> impl Iterator<Item = &syn::Attribute> {
self.field
.attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
}
pub fn attributes(&self) -> impl Iterator<Item = AttributeAnalyzer> {
self.field
.attrs
.clone()
.into_iter()
.map(AttributeAnalyzer::new)
}
}
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
let mut fields = Vec::new();
match &ast.data {
syn::Data::Struct(struct_data) => {
for field in struct_data.fields.iter() {
fields.push(field.clone());
}
}
syn::Data::Enum(_) => panic!("Only struct can be derived"),
syn::Data::Union(_) => panic!("Only struct can be derived"),
};
fields
}

View File

@@ -0,0 +1,63 @@
use proc_macro2::Ident;
use quote::quote;
use syn::{Generics, WhereClause, WherePredicate, parse_quote};
#[derive(new)]
pub struct GenericsHelper {
pub(crate) generics: Generics,
}
impl GenericsHelper {
pub fn add_predicate(&mut self, predicate: WherePredicate) {
let where_clause: WhereClause = match &self.generics.where_clause {
Some(val) => parse_quote! {
#val
#predicate,
},
None => parse_quote! {
where
#predicate,
},
};
self.generics.where_clause = Some(where_clause);
}
pub fn consts(&self) -> Vec<Ident> {
self.generics
.const_params()
.map(|c| c.ident.clone())
.collect()
}
pub fn types(&self) -> Vec<Ident> {
self.generics
.type_params()
.map(|tp| tp.ident.clone())
.collect()
}
pub fn fetch_backend_trait(&self) -> proc_macro2::TokenStream {
static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str =
"Modules should be generic over a backend.
- The generic argument named `B` should have its first trait bound being a backend trait.
- The default backend trait is `burn::tensor::backend::Backend`.
- Any backend trait is supported.";
for param in self.generics.params.iter() {
if let syn::GenericParam::Type(ty) = &param
&& ty.ident == "B"
{
let bound = ty
.bounds
.first()
.expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG);
return quote! {
#bound
};
}
}
panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}");
}
}

View File

@@ -0,0 +1,4 @@
pub(crate) mod attribute;
pub(crate) mod enum_variant;
pub(crate) mod field;
pub(crate) mod generics;