Files
RustyUI/crates/stable-diffusion-burn/burn-crates/burn-store/src/lib.rs
Ben_Kosytorz 3a67c0979c feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
2026-03-05 19:39:14 +01:00

119 lines
4.3 KiB
Rust

#![cfg_attr(not(feature = "std"), no_std)]
//! # Burn Store
//!
//! Advanced model storage and serialization infrastructure for the Burn deep learning framework.
//!
//! This crate provides comprehensive functionality for storing and loading Burn modules
//! and their tensor data, with support for cross-framework interoperability, flexible filtering,
//! and efficient memory management through lazy materialization.
//!
//! ## Key Features
//!
//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support
//! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization
//! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation
//! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance
//! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates
//! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility
//! - **No-std Support**: Core functionality available in embedded and WASM environments
//!
//! ## Quick Start
//!
//! ### Basic Save and Load
//!
//! ```rust,ignore
//! use burn_store::{ModuleSnapshot, SafetensorsStore};
//!
//! // Save a model
//! let mut store = SafetensorsStore::from_file("model.safetensors");
//! model.save_into(&mut store)?;
//!
//! // Load a model
//! let mut store = SafetensorsStore::from_file("model.safetensors");
//! model.load_from(&mut store)?;
//! ```
//!
//! ### Loading PyTorch Models
//!
//! ```rust,ignore
//! use burn_store::PytorchStore;
//!
//! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter)
//! let mut store = PytorchStore::from_file("pytorch_model.pth")
//! .with_top_level_key("state_dict") // Access nested state dict if needed
//! .allow_partial(true); // Skip unknown tensors
//!
//! model.load_from(&mut store)?;
//! ```
//!
//! ### Filtering and Remapping
//!
//! ```rust,no_run
//! # use burn_store::SafetensorsStore;
//! // Save only specific layers with renaming
//! let mut store = SafetensorsStore::from_file("encoder.safetensors")
//! .with_regex(r"^encoder\..*") // Filter: only encoder layers
//! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X
//! .metadata("subset", "encoder_only");
//!
//! // Use store with model.save_into(&mut store)?;
//! ```
//!
//! ## Core Components
//!
//! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods
//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows
//! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format
//! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files
//! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving
//! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns
//! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility
//!
//! ## Feature Flags
//!
//! - `std`: Enables file I/O and other std-only features (default)
//! - `safetensors`: Enables SafeTensors format support (default)
extern crate alloc;
mod adapter;
mod applier;
mod apply_result;
mod collector;
mod filter;
mod tensor_snapshot;
mod traits;
pub use adapter::{
BurnToPyTorchAdapter, ChainAdapter, IdentityAdapter, ModuleAdapter, PyTorchToBurnAdapter,
};
pub use applier::Applier;
pub use apply_result::{ApplyError, ApplyResult};
pub use collector::Collector;
pub use filter::PathFilter;
pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError};
pub use traits::{ModuleSnapshot, ModuleStore};
#[cfg(feature = "std")]
mod keyremapper;
#[cfg(feature = "std")]
pub use keyremapper::{KeyRemapper, map_indices_contiguous};
#[cfg(feature = "pytorch")]
pub mod pytorch;
#[cfg(feature = "pytorch")]
pub use pytorch::{PytorchStore, PytorchStoreError};
#[cfg(feature = "safetensors")]
mod safetensors;
#[cfg(feature = "safetensors")]
pub use safetensors::{SafetensorsStore, SafetensorsStoreError};
#[cfg(feature = "burnpack")]
mod burnpack;
#[cfg(feature = "burnpack")]
pub use burnpack::writer::BurnpackWriter;
#[cfg(feature = "burnpack")]
pub use burnpack::{base::BurnpackError, store::BurnpackStore};