//! Module adapters for transforming tensors between different formats //! //! This module provides adapters that handle differences between PyTorch and Burn: //! - Linear layer weight transposition //! - Normalization parameter naming (weight/bias vs gamma/beta) use crate::TensorSnapshot; use alloc::boxed::Box; use alloc::rc::Rc; use alloc::string::String; use alloc::string::ToString; use alloc::vec; use burn_tensor::TensorData; // Module type names as they appear in the container_type field // These come from the Module derive macro which uses stringify! on the struct name // Format: "Struct:TypeName" for user-defined structs mod module_names { // The actual string constants that match what the Module derive macro produces pub const LINEAR: &str = "Struct:Linear"; pub const BATCH_NORM: &str = "Struct:BatchNorm"; pub const LAYER_NORM: &str = "Struct:LayerNorm"; pub const GROUP_NORM: &str = "Struct:GroupNorm"; } /// Trait for adapting tensor snapshots between different module formats pub trait ModuleAdapter: Send + Sync { /// Adapt a tensor snapshot based on its container type and parameter name fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot; /// Get alternative parameter name to try during matching /// /// When looking for a parameter in a module, this method provides an alternative /// name to try if the direct name doesn't match. This enables matching parameters /// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma"). /// /// # Arguments /// * `param_name` - The parameter name we're looking for /// * `container_type` - The type of container module (e.g., "BatchNorm") /// /// # Returns /// Alternative parameter name to try, or None if no alternative exists fn get_alternative_param_name( &self, _param_name: &str, _container_type: &str, ) -> Option { None } /// Clone the adapter into a boxed trait object fn clone_box(&self) -> Box; /// Chain adapters together, applying `self` first and then `next`. /// /// This is useful when multiple transformations are required when importing model weights /// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping). /// /// The semantics follow a simple pipeline: /// - `adapt`: `next.adapt(&self.adapt(snapshot))` /// - `get_alternative_param_name`: try `self` first; if it returns an alternative name, /// try `next` with that name, otherwise return the first alternative name. fn chain(self, next: A) -> ChainAdapter where Self: Sized + 'static, A: ModuleAdapter + 'static, { ChainAdapter::new(self, next) } } impl Clone for Box { fn clone(&self) -> Self { self.clone_box() } } /// Adapter that applies two adapters in sequence. /// /// This allows composing smaller adapters instead of creating one large monolithic adapter. #[derive(Clone)] pub struct ChainAdapter { first: Box, second: Box, } impl ChainAdapter { /// Create a new adapter chain. pub fn new(first: A, second: B) -> Self where A: ModuleAdapter + 'static, B: ModuleAdapter + 'static, { Self { first: Box::new(first), second: Box::new(second), } } } impl ModuleAdapter for ChainAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { let snapshot = self.first.adapt(snapshot); self.second.adapt(&snapshot) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { if let Some(name) = self .first .get_alternative_param_name(param_name, container_type) { self.second .get_alternative_param_name(&name, container_type) .or(Some(name)) } else { self.second .get_alternative_param_name(param_name, container_type) } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Identity adapter that passes tensors through unchanged #[derive(Debug, Clone, Default)] pub struct IdentityAdapter; impl ModuleAdapter for IdentityAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { snapshot.clone() } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Adapter for converting from PyTorch format to Burn format /// /// Handles: /// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out]) /// - Normalization parameter renaming (weight → gamma, bias → beta) #[derive(Debug, Clone, Default)] pub struct PyTorchToBurnAdapter; impl ModuleAdapter for PyTorchToBurnAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { // For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias) if is_normalization_layer(container_type) { burn_norm_param_to_pytorch(param_name).map(|s| s.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Adapter for converting from Burn format to PyTorch format /// /// Handles: /// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in]) /// - Normalization parameter renaming (gamma → weight, beta → bias) #[derive(Debug, Clone, Default)] pub struct BurnToPyTorchAdapter; impl ModuleAdapter for BurnToPyTorchAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { // For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta) if is_normalization_layer(container_type) { pytorch_norm_param_to_burn(param_name).map(|s| s.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Direction of PyTorch conversion for parameter naming #[derive(Debug, Clone, Copy)] enum PyTorchConversionDirection { PyTorchToBurn, BurnToPyTorch, } /// Check if container type is a normalization layer fn is_normalization_layer(container_type: &str) -> bool { matches!( container_type, module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM ) } /// Map PyTorch normalization parameter name to Burn fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> { match param_name { "weight" => Some("gamma"), "bias" => Some("beta"), _ => None, } } /// Map Burn normalization parameter name to PyTorch fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> { match param_name { "gamma" => Some("weight"), "beta" => Some("bias"), _ => None, } } /// Core tensor adaptation logic for PyTorch format conversions fn adapt_pytorch_tensor( snapshot: &TensorSnapshot, direction: PyTorchConversionDirection, ) -> TensorSnapshot { // Extract path and parameter name let (path_stack, param_name) = match get_path_and_param(snapshot) { Some(result) => result, None => return snapshot.clone(), }; // Get module type for matching (ignores Vec/Array wrappers) let module_type = match snapshot.module_type() { Some(mt) => mt, None => return snapshot.clone(), // No user-defined module found }; // Linear: transpose weight (bidirectional - same operation both ways) if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 { return transpose_2d_tensor(snapshot); } // Normalization layers: rename parameters based on direction if is_normalization_layer(&module_type) { let new_name = match direction { PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name), PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name), }; if let Some(new_name) = new_name { return rename_parameter(snapshot, path_stack, new_name); } } snapshot.clone() } /// Extract path stack and parameter name from snapshot fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> { let path_stack = snapshot.path_stack.as_ref()?; let param_name = path_stack.last()?.as_str(); Some((path_stack.as_slice(), param_name)) } /// Rename a parameter in the snapshot fn rename_parameter( snapshot: &TensorSnapshot, path_stack: &[String], new_name: &str, ) -> TensorSnapshot { let mut new_path = path_stack.to_vec(); *new_path.last_mut().unwrap() = new_name.to_string(); TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), new_path, snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } /// Transpose a 2D tensor fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot { if snapshot.shape.len() != 2 { return snapshot.clone(); } let original_data_fn = snapshot.clone_data_fn(); let dtype = snapshot.dtype; let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]]; // Create a lazy closure that transposes when called let transposed_data_fn = Rc::new(move || { let data = original_data_fn()?; Ok(transpose_tensor_data(data)) }); TensorSnapshot::from_closure( transposed_data_fn, dtype, transposed_shape, snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } /// Transpose tensor data (assumes 2D shape is already validated) fn transpose_tensor_data(data: TensorData) -> TensorData { let shape = &data.shape; let rows = shape[0]; let cols = shape[1]; let transposed_shape = vec![cols, rows]; // Get the raw bytes and element size let bytes = data.as_bytes(); let element_size = data.dtype.size(); // Create a new buffer for transposed data let mut transposed_bytes = vec![0u8; bytes.len()]; // Transpose at the byte level - works for any data type for i in 0..rows { for j in 0..cols { let src_idx = (i * cols + j) * element_size; let dst_idx = (j * rows + i) * element_size; // Copy the bytes for this element transposed_bytes[dst_idx..dst_idx + element_size] .copy_from_slice(&bytes[src_idx..src_idx + element_size]); } } // Create new TensorData from transposed bytes TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype) } #[cfg(test)] mod tests { use super::*; use alloc::rc::Rc; use alloc::sync::Arc; use burn_tensor::{DType, TensorData}; use core::sync::atomic::{AtomicUsize, Ordering}; #[test] fn test_module_names_match_burn_nn() { // If these types are renamed or moved in `burn-nn`, this test will fail to compile. // This use statement replicates the previous check/alarm system. #[allow(unused_imports)] use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear}; // These assert statements work as extra checks that should remind maintainers more // clearly that the hardcoded strings needs get updated. assert_eq!(module_names::LINEAR, "Struct:Linear"); assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm"); assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm"); assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm"); } fn create_test_snapshot(path: &str, shape: Vec, container_type: &str) -> TensorSnapshot { let path_parts: Vec = path.split('.').map(|s| s.to_string()).collect(); let values = vec![1.0f32; shape.iter().product()]; let data = TensorData::new(values, shape.clone()); TensorSnapshot::from_closure( Rc::new(move || Ok(data.clone())), DType::F32, shape, path_parts, vec![container_type.to_string()], burn_core::module::ParamId::new(), ) } #[test] fn test_pytorch_to_burn_linear_weight() { let adapter = PyTorchToBurnAdapter; // Linear layer weight should be transposed let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, vec![5, 10]); // Linear layer bias should not be transposed let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, vec![10]); } #[test] fn test_pytorch_to_burn_norm_params() { let adapter = PyTorchToBurnAdapter; // BatchNorm weight -> gamma let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.gamma"); // BatchNorm bias -> beta let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.beta"); } #[test] fn test_burn_to_pytorch_linear_weight() { let adapter = BurnToPyTorchAdapter; // Linear layer weight should be transposed let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, vec![10, 5]); } #[test] fn test_burn_to_pytorch_norm_params() { let adapter = BurnToPyTorchAdapter; // BatchNorm gamma -> weight let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.weight"); // BatchNorm beta -> bias let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.bias"); } #[test] fn test_transpose_different_dtypes() { // Test that transpose works for different data types // Test with F32 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]); let transposed = transpose_tensor_data(f32_data); assert_eq!(transposed.shape, vec![3, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); // Test with I32 let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]); let transposed = transpose_tensor_data(i32_data); assert_eq!(transposed.shape, vec![3, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1, 4, 2, 5, 3, 6]); // Test with F64 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]); let transposed = transpose_tensor_data(f64_data); assert_eq!(transposed.shape, vec![2, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]); } #[test] fn test_no_container_info() { let adapter = PyTorchToBurnAdapter; // Without container info, adapter returns unchanged for non-norm parameters let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR); snapshot.container_stack = None; // Without container info, no transformation occurs for linear layers let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, vec![10, 5]); // No transposition without container info // Test a non-linear, non-norm parameter - should pass through unchanged let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Struct:Other"); snapshot2.container_stack = None; let adapted2 = adapter.adapt(&snapshot2); assert_eq!(adapted2.shape, vec![10, 5]); // No transposition } #[derive(Clone)] struct RenameParamAdapter { from: &'static str, to: &'static str, called: Arc, } impl ModuleAdapter for RenameParamAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { self.called.fetch_add(1, Ordering::Relaxed); let path_stack = match snapshot.path_stack.as_ref() { Some(stack) => stack, None => return snapshot.clone(), }; let param = match path_stack.last() { Some(p) => p.as_str(), None => return snapshot.clone(), }; if param != self.from { return snapshot.clone(); } let mut new_path = path_stack.to_vec(); *new_path.last_mut().unwrap() = self.to.to_string(); TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), new_path, snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } fn get_alternative_param_name( &self, _param_name: &str, _container_type: &str, ) -> Option { None } fn clone_box(&self) -> Box { Box::new(self.clone()) } } #[derive(Clone)] struct AltNameAdapter { from: &'static str, to: &'static str, called: Arc, } impl ModuleAdapter for AltNameAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } fn get_alternative_param_name( &self, param_name: &str, _container_type: &str, ) -> Option { self.called.fetch_add(1, Ordering::Relaxed); if param_name == self.from { Some(self.to.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } #[test] fn test_chain_adapter_pipes_adapt() { let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = RenameParamAdapter { from: "weight", to: "a", called: called1.clone(), }; let b = RenameParamAdapter { from: "a", to: "b", called: called2.clone(), }; let chain = a.chain(b); let snapshot = create_test_snapshot("fc.weight", vec![2, 2], module_names::LINEAR); let adapted = chain.adapt(&snapshot); assert_eq!(adapted.full_path(), "fc.b"); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); } #[test] fn test_chain_adapter_alternative_name_pipes_and_fallbacks() { let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "gamma", to: "weight", called: called1.clone(), }; let b = AltNameAdapter { from: "weight", to: "scale", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("scale")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // If the second adapter doesn't have a mapping for the first alternative, // fall back to the first alternative name. let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "gamma", to: "weight", called: called1.clone(), }; let b = AltNameAdapter { from: "something-else", to: "unused", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // If the first adapter doesn't provide an alternative, try the second with the original name. let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "something-else", to: "unused", called: called1.clone(), }; let b = AltNameAdapter { from: "gamma", to: "weight", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // clone_box must preserve behavior. let boxed = chain.clone_box(); let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); } }