feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,106 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Storage and serialization infrastructure for Burn"
documentation = "https://docs.rs/burn-store"
edition.workspace = true
keywords = [
"deep-learning",
"machine-learning",
"tensor",
"storage",
"serialization",
]
license.workspace = true
name = "burn-store"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-store"
version.workspace = true
[lints]
workspace = true
[features]
default = ["std", "pytorch", "safetensors", "burnpack", "memmap"]
memmap = ["std", "dep:memmap2"]
std = [
"dep:memmap2",
"safetensors/std",
"burn-core/std",
"burn-tensor/std",
"dep:regex",
"byteorder/std",
]
tracing = [
"burn-core/tracing",
"burn-cuda?/tracing",
"burn-nn/tracing",
"burn-tch?/tracing",
"burn-tensor/tracing",
"burn-wgpu?/tracing",
]
burnpack = ["serde", "ciborium"]
cuda = ["burn-cuda"]
metal = ["wgpu", "burn-wgpu/metal"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
safetensors = ["dep:safetensors"]
pytorch = ["burn-core/record-item-custom-serde", "zip", "serde", "tar"]
[dependencies]
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
# External dependencies
byteorder = { workspace = true, default-features = false }
bytes = { workspace = true }
ciborium = { workspace = true, optional = true }
half = { workspace = true }
hashbrown = { workspace = true, features = ["serde"] }
memmap2 = { workspace = true, optional = true }
regex = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
textdistance = { workspace = true }
zip = { workspace = true, optional = true }
tar = { workspace = true, optional = true }
# Workaround to force broken minor version to update
lzma-rust2 = { workspace = true, optional = true }
safetensors = { workspace = true, optional = true }
# Optional backend dependencies for benchmarks
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true }
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true }
[dev-dependencies]
# burn-import = { path = "../burn-import", version = "=0.21.0-pre.2" } # disabled (circular dep in publish, only for bench)
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", default-features = false }
divan = "0.1"
tempfile = { workspace = true }
[[bench]]
harness = false
name = "resnet18_loading"
[[bench]]
harness = false
name = "unified_loading"
[[bench]]
harness = false
name = "unified_saving"
[[bench]]
harness = false
name = "zero_copy_loading"
# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
bytes = { workspace = true, features = ["extra-platforms"] }

View File

@@ -0,0 +1,325 @@
# Migration Guide: burn-import to burn-store
This guide helps you migrate from the deprecated `burn-import` recorders (`PyTorchFileRecorder`,
`SafetensorsFileRecorder`) to the new `burn-store` API (`PytorchStore`, `SafetensorsStore`).
## Overview
The new `burn-store` API provides:
- **Simpler API**: Load directly into models instead of records
- **Fluent builder pattern**: Chain configuration methods
- **Better error handling**: Detailed load results with applied/missing/errors info
- **Bidirectional support**: Both load and save operations
- **More features**: Filtering, partial loading, metadata, zero-copy loading
## Quick Migration
### PyTorch Files (.pt/.pth)
**Before (burn-import):**
```rust
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
// Load into a record, then create model from record
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("model.pt".into(), &device)
.expect("Failed to load");
let model = Model::init(&device).load_record(record);
```
**After (burn-store):**
```rust
use burn_store::{ModuleSnapshot, PytorchStore};
// Initialize model, then load weights directly
let mut model = Model::init(&device);
let mut store = PytorchStore::from_file("model.pt");
model.load_from(&mut store).expect("Failed to load");
```
### SafeTensors Files (.safetensors)
**Before (burn-import):**
```rust
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load("model.safetensors".into(), &device)
.expect("Failed to load");
let model = Model::init(&device).load_record(record);
```
**After (burn-store):**
```rust
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
let mut model = Model::init(&device);
// For SafeTensors exported from PyTorch, use the adapter
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store).expect("Failed to load");
// For native Burn SafeTensors, no adapter needed
let mut store = SafetensorsStore::from_file("model.safetensors");
model.load_from(&mut store).expect("Failed to load");
```
## API Mapping
### PyTorchFileRecorder Options
| burn-import | burn-store |
| ---------------------------------------------- | ------------------------------------------- |
| `LoadArgs::new(path)` | `PytorchStore::from_file(path)` |
| `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` |
| `.with_top_level_key(key)` | `.with_top_level_key(key)` |
| `.with_debug_print()` | _(use tracing/logging instead)_ |
| `PyTorchFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_ |
### SafetensorsFileRecorder Options
| burn-import | burn-store |
| -------------------------------------------------- | ------------------------------------------- |
| `LoadArgs::new(path)` | `SafetensorsStore::from_file(path)` |
| `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` |
| `.with_adapter_type(AdapterType::PyTorch)` | `.with_from_adapter(PyTorchToBurnAdapter)` |
| `.with_adapter_type(AdapterType::NoAdapter)` | _(default, no adapter)_ |
| `.with_debug_print()` | _(use tracing/logging instead)_ |
| `SafetensorsFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_ |
## Detailed Examples
### Key Remapping
**Before:**
```rust
let args = LoadArgs::new("model.pt".into())
.with_key_remap("conv\\.(.*)", "$1")
.with_key_remap("^old_prefix\\.", "new_prefix.");
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
```
**After:**
```rust
let mut store = PytorchStore::from_file("model.pt")
.with_key_remapping("conv\\.(.*)", "$1")
.with_key_remapping("^old_prefix\\.", "new_prefix.");
model.load_from(&mut store)?;
```
### Top-Level Key Access
**Before:**
```rust
let args = LoadArgs::new("checkpoint.pt".into())
.with_top_level_key("state_dict");
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
```
**After:**
```rust
let mut store = PytorchStore::from_file("checkpoint.pt")
.with_top_level_key("state_dict");
model.load_from(&mut store)?;
```
### PyTorch Adapter for SafeTensors
**Before:**
```rust
use burn_import::safetensors::{AdapterType, LoadArgs};
let args = LoadArgs::new("pytorch_model.safetensors".into())
.with_adapter_type(AdapterType::PyTorch);
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(args, &device)?;
```
**After:**
```rust
use burn_store::{PyTorchToBurnAdapter, SafetensorsStore};
let mut store = SafetensorsStore::from_file("pytorch_model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store)?;
```
## New Features in burn-store
### Partial Loading
Handle missing tensors gracefully:
```rust
let mut store = PytorchStore::from_file("model.pt")
.allow_partial(true);
let result = model.load_from(&mut store)?;
println!("Loaded: {:?}", result.applied);
println!("Missing: {:?}", result.missing);
```
### Filtering
Load only specific tensors:
```rust
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*") // Only encoder layers
.allow_partial(true);
model.load_from(&mut store)?;
```
### Saving Models
Save models (not supported by old recorders):
```rust
// Save to SafeTensors
let mut store = SafetensorsStore::from_file("output.safetensors")
.metadata("version", "1.0");
model.save_into(&mut store)?;
// Save to Burnpack (native format)
let mut store = BurnpackStore::from_file("output.bpk");
model.save_into(&mut store)?;
```
### Load Results
Get detailed information about loading:
```rust
let result = model.load_from(&mut store)?;
// Print the full result for debugging - shows applied, skipped, missing, and errors
println!("{}", result);
// Or access individual fields
println!("Applied: {} tensors", result.applied.len());
println!("Skipped: {} tensors", result.skipped.len());
println!("Missing: {:?}", result.missing);
println!("Errors: {:?}", result.errors);
// Check if fully successful
if result.is_success() {
println!("All tensors loaded successfully");
}
```
The `LoadResult` implements `Display`, so printing it shows a formatted summary with suggestions for
common issues (e.g., using `allow_partial(true)` for missing tensors).
## Updating Cargo.toml
**Before:**
```toml
[dependencies]
burn-import = { version = "0.x", features = ["pytorch", "safetensors"] }
```
**After:**
```toml
[dependencies]
burn-store = { version = "0.x", features = ["pytorch", "safetensors"] }
```
## Common Migration Issues
### 1. Model vs Record
The new API loads directly into models, not records. Update your model initialization:
```rust
// Before: Create record, then model from record
let record = recorder.load(...)?;
let model = Model::init(&device).load_record(record);
// After: Create model, then load into it
let mut model = Model::init(&device);
model.load_from(&mut store)?;
```
### 2. Inference Functions
If you had functions that took `ModelRecord`, update them to take `Model`:
```rust
// Before
fn infer(record: ModelRecord<B>) {
let model = Model::init(&device).load_record(record);
// ...
}
// After
fn infer(model: Model<B>) {
// Model already has weights loaded
// ...
}
```
### 3. Precision Settings
The old API required explicit precision settings. The new API handles this automatically:
```rust
// Before: Had to specify FullPrecisionSettings or HalfPrecisionSettings
PyTorchFileRecorder::<FullPrecisionSettings>::default()
// After: Precision handled automatically based on tensor dtype
PytorchStore::from_file("model.pt")
```
### 4. Error Handling
The new API provides richer error information:
```rust
// Before: Simple Result
let record = recorder.load(args, &device)?;
// After: LoadResult with detailed info
let result = model.load_from(&mut store)?;
// Print the result to see a helpful summary with suggestions
println!("{}", result);
// Or handle specific issues programmatically
if !result.errors.is_empty() {
for (path, error) in &result.errors {
eprintln!("Error loading {}: {}", path, error);
}
}
```
## See Also
- [burn-store README](README.md) - Full documentation
- [import-model-weights example](../../examples/import-model-weights/) - Working example

View File

@@ -0,0 +1,77 @@
# Burn Store
> Advanced model storage and serialization for the Burn deep learning framework
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-store.svg)](https://crates.io/crates/burn-store)
[![Documentation](https://docs.rs/burn-store/badge.svg)](https://docs.rs/burn-store)
A comprehensive storage library for Burn that enables efficient model serialization, cross-framework
interoperability, and advanced tensor management.
> **Migrating from burn-import?** See the [Migration Guide](MIGRATION.md) for help moving from
> `PyTorchFileRecorder`/`SafetensorsFileRecorder` to the new Store API.
## Features
- **Burnpack Format** - Native Burn format with CBOR metadata, memory-mapped loading, ParamId
persistence for stateful training, and no-std support
- **SafeTensors Format** - Industry-standard format for secure and efficient tensor serialization
- **PyTorch Support** - Direct loading of PyTorch .pth/.pt files with automatic weight
transformation
- **Zero-Copy Loading** - Memory-mapped files and lazy tensor materialization for optimal
performance
- **Flexible Filtering** - Load/save specific model subsets with regex, exact paths, or custom
predicates
- **Tensor Remapping** - Rename tensors during load/save for framework compatibility
- **No-std Support** - Burnpack and SafeTensors formats available in embedded and WASM environments
## Quick Start
```rust
use burn_store::{ModuleSnapshot, PytorchStore, SafetensorsStore, BurnpackStore};
// Load from PyTorch
let mut store = PytorchStore::from_file("model.pt");
model.load_from(&mut store)?;
// Load from SafeTensors (with PyTorch adapter)
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store)?;
// Save to Burnpack
let mut store = BurnpackStore::from_file("model.bpk");
model.save_into(&mut store)?;
```
## Documentation
For comprehensive documentation including:
- Exporting weights from PyTorch
- Loading weights into Burn models
- Saving models to various formats
- Advanced features (filtering, remapping, partial loading, zero-copy)
- API reference and troubleshooting
See the **[Burn Book - Model Weights](https://burn.dev/book/import/model-weights.html)** chapter.
## Running Benchmarks
```bash
# Generate model files (one-time setup)
uv run benches/generate_unified_models.py
# Run loading benchmarks
cargo bench --bench unified_loading
# Run saving benchmarks
cargo bench --bench unified_saving
# With specific backend
cargo bench --bench unified_loading --features metal
```
## License
This project is dual-licensed under MIT and Apache-2.0.

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "torch",
# "torchvision",
# ]
# ///
"""
Download ResNet18 PyTorch model for benchmarking.
This script downloads a pre-trained ResNet18 model from PyTorch Hub
and saves it in a format suitable for benchmarking.
"""
import os
import sys
import tempfile
from pathlib import Path
import torch
import torchvision.models as models
def download_resnet18():
"""Download ResNet18 model and save to temp directory."""
# Create a temporary directory for the model
temp_dir = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark"
temp_dir.mkdir(parents=True, exist_ok=True)
output_path = temp_dir / "resnet18.pth"
# Check if already downloaded
if output_path.exists():
file_size_mb = output_path.stat().st_size / (1024 * 1024)
print(f"✅ ResNet18 already exists at: {output_path}")
print(f" Size: {file_size_mb:.1f} MB")
return str(output_path)
print("📥 Downloading ResNet18 model...")
try:
# Download pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
# Save the model state dict (this is what burn-store reads)
# Using the legacy format for compatibility
torch.save(model.state_dict(), output_path, _use_new_zipfile_serialization=False)
file_size_mb = output_path.stat().st_size / (1024 * 1024)
print(f"✅ Successfully downloaded ResNet18 to: {output_path}")
print(f" Size: {file_size_mb:.1f} MB")
print(f" Format: PyTorch legacy format")
# Verify it's readable
state_dict = torch.load(output_path, map_location='cpu')
print(f" Tensors: {len(state_dict)} tensors")
# Print a few tensor names and shapes for verification
print("\n Sample tensors:")
for i, (name, tensor) in enumerate(state_dict.items()):
if i < 3:
print(f" - {name}: {list(tensor.shape)}")
return str(output_path)
except Exception as e:
print(f"❌ Failed to download ResNet18: {e}")
sys.exit(1)
def main():
"""Main entry point."""
path = download_resnet18()
# Write the path to a file that the benchmark can read
bench_config = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark" / "path.txt"
bench_config.write_text(path)
print(f"\n💡 Model ready for benchmarking")
print(f" Run: cargo bench --bench resnet18_loading")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "torch",
# "safetensors",
# "packaging",
# "numpy",
# ]
# ///
"""
Generate a large model (~312MB) in both PyTorch and SafeTensors formats for unified benchmarking.
Usage:
uv run benches/generate_unified_models.py
The script will create model files in /tmp/simple_bench_models/ directory.
"""
import torch
import torch.nn as nn
import os
from pathlib import Path
import tempfile
from safetensors.torch import save_file
def get_temp_dir():
"""Get the appropriate temp directory."""
temp_dir = Path(tempfile.gettempdir()) / "simple_bench_models"
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
class LargeModel(nn.Module):
"""Large model with 20 layers to match Rust benchmark."""
def __init__(self):
super().__init__()
self.layers = nn.ModuleList()
# Create a model with 20 layers matching the Rust LargeModel
for i in range(20):
in_size = 1024 if i == 0 else 2048
out_size = 2048
self.layers.append(nn.Linear(in_size, out_size))
print(f"Created model with {len(self.layers)} layers")
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def calculate_model_size(model):
"""Calculate the size of the model in MB."""
total_params = sum(p.numel() for p in model.parameters())
size_mb = (total_params * 4) / (1024 * 1024) # 4 bytes per float32
return total_params, size_mb
def initialize_weights(model):
"""Initialize model weights with random values."""
for param in model.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
else:
nn.init.zeros_(param)
def save_pytorch_format(model, output_dir):
"""Save model in PyTorch format."""
pt_path = output_dir / "large_model.pt"
# Save as checkpoint with model_state_dict (common format)
checkpoint = {
'model_state_dict': model.state_dict(),
'metadata': {
'model_type': 'large_benchmark_model',
'num_layers': len(model.layers),
}
}
torch.save(checkpoint, pt_path)
return pt_path
def save_safetensors_format(model, output_dir):
"""Save model in SafeTensors format."""
st_path = output_dir / "large_model.safetensors"
# Convert state dict to safetensors format
state_dict = model.state_dict()
# Ensure all tensors are contiguous and on CPU
state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()}
# Save with metadata
metadata = {
'model_type': 'large_benchmark_model',
'num_layers': str(len(model.layers)),
}
save_file(state_dict, st_path, metadata=metadata)
return st_path
def verify_files(pt_path, st_path):
"""Verify the saved files can be loaded."""
# Verify PyTorch file
checkpoint = torch.load(pt_path, map_location='cpu')
pt_keys = set(checkpoint['model_state_dict'].keys())
print(f" PyTorch file: {len(pt_keys)} tensors")
# Verify SafeTensors file
from safetensors import safe_open
with safe_open(st_path, framework="pt", device="cpu") as f:
st_keys = set(f.keys())
print(f" SafeTensors file: {len(st_keys)} tensors")
# Check keys match
if pt_keys != st_keys:
print(" ⚠️ Warning: Keys don't match between formats!")
else:
print(" ✓ Keys match between formats")
def main():
print("🔧 Generating unified benchmark model files...")
print("")
output_dir = get_temp_dir()
print(f"📁 Output directory: {output_dir}")
print("")
# Set random seed for reproducibility
torch.manual_seed(42)
# Create the large model
print("📝 Creating large model...")
model = LargeModel()
# Calculate and display model size
total_params, size_mb = calculate_model_size(model)
print(f" Total parameters: {total_params:,}")
print(f" Model size: {size_mb:.2f} MB")
print("")
# Initialize weights
print("🎲 Initializing weights...")
initialize_weights(model)
# Save in PyTorch format
print("💾 Saving PyTorch format...")
pt_path = save_pytorch_format(model, output_dir)
pt_size_mb = pt_path.stat().st_size / (1024 * 1024)
print(f" Saved: {pt_path}")
print(f" File size: {pt_size_mb:.2f} MB")
print("")
# Save in SafeTensors format
print("💾 Saving SafeTensors format...")
st_path = save_safetensors_format(model, output_dir)
st_size_mb = st_path.stat().st_size / (1024 * 1024)
print(f" Saved: {st_path}")
print(f" File size: {st_size_mb:.2f} MB")
print("")
# Verify files
print("🔍 Verifying saved files...")
verify_files(pt_path, st_path)
print("")
print(f"✅ Model files generated successfully!")
print("")
print("📊 Summary:")
print(f" PyTorch file: {pt_path.name} ({pt_size_mb:.2f} MB)")
print(f" SafeTensors file: {st_path.name} ({st_size_mb:.2f} MB)")
print("")
print("💡 To run the unified benchmark:")
print(" cargo bench --bench unified_loading")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,213 @@
//! Benchmark for ResNet18 loading to verify lazy loading memory usage.
//!
//! resnet18.pth is pytorch's legacy file format.
//!
//! This benchmark loads a ResNet18 model and materializes all tensors
//! to ensure memory usage stays reasonable with lazy loading.
//!
//! Run the benchmark:
//! ```bash
//! cargo bench --bench resnet18_loading
//! ```
use burn_store::pytorch::PytorchReader;
use divan::{AllocProfiler, Bencher};
use std::path::PathBuf;
#[global_allocator]
static ALLOC: AllocProfiler = AllocProfiler::system();
#[allow(clippy::manual_range_contains)]
fn main() {
// Check if ResNet18 file exists
let path = resnet18_path();
if !path.exists() {
eprintln!("❌ ResNet18 model not found!");
eprintln!();
eprintln!("Please download it first by running:");
eprintln!(" python benches/download_resnet18.py");
eprintln!();
eprintln!("Or if you don't have Python/PyTorch installed:");
eprintln!(" uv run benches/download_resnet18.py");
eprintln!();
eprintln!("Expected location: {}", path.display());
std::process::exit(1);
}
// Verify file size is reasonable
let metadata = std::fs::metadata(&path).expect("Failed to read file metadata");
let size_mb = metadata.len() as f64 / 1_048_576.0;
if size_mb < 40.0 || size_mb > 50.0 {
eprintln!(
"⚠️ Warning: ResNet18 file size ({:.1} MB) seems unusual",
size_mb
);
eprintln!("Expected size is around 45 MB");
}
println!("✅ Found ResNet18 model at: {}", path.display());
println!("📦 File size: {:.1} MB", size_mb);
println!("📊 Running ResNet18 loading benchmarks...\n");
// Run divan benchmarks
divan::main();
}
/// Get the path to ResNet18 model file
fn resnet18_path() -> PathBuf {
// First try to read from the path file created by download script
let temp_dir = std::env::temp_dir();
let config_file = temp_dir.join("burn_resnet18_benchmark").join("path.txt");
if config_file.exists()
&& let Ok(path_str) = std::fs::read_to_string(&config_file)
{
let path = PathBuf::from(path_str.trim());
if path.exists() {
return path;
}
}
// Fallback to default location
temp_dir
.join("burn_resnet18_benchmark")
.join("resnet18.pth")
}
#[divan::bench(sample_count = 10)]
fn load_resnet18_metadata(bencher: Bencher) {
let path = resnet18_path();
bencher.bench_local(|| {
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
let metadata = reader.metadata();
// Just access metadata without materializing tensors
assert_eq!(metadata.tensor_count, 122);
});
}
#[divan::bench(sample_count = 5)]
fn load_resnet18_materialize_all(bencher: Bencher) {
let path = resnet18_path();
bencher.bench_local(|| {
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
let keys = reader.keys();
let mut total_bytes = 0usize;
// Materialize all tensors one by one
for key in &keys {
let tensor = reader.get(key).expect("Failed to get tensor");
// Materialize the tensor data
let _data = tensor.to_data().expect("Failed to materialize tensor data");
total_bytes += tensor.data_len();
}
// Verify we processed all the data
assert!(total_bytes > 40_000_000); // Should be ~45MB
});
}
#[divan::bench(sample_count = 5)]
fn load_resnet18_materialize_sequential(bencher: Bencher) {
let path = resnet18_path();
bencher.bench_local(|| {
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
let keys = reader.keys();
// Materialize tensors one at a time, letting previous ones be dropped
// This simulates processing tensors sequentially without keeping all in memory
for key in &keys {
let tensor = reader.get(key).expect("Failed to get tensor");
let data = tensor.to_data().expect("Failed to materialize tensor data");
// Do minimal work with the data to prevent optimization
let sum = match data.dtype {
burn_tensor::DType::F32 => data
.as_slice::<f32>()
.map(|s| s.iter().sum::<f32>())
.unwrap_or(0.0) as f64,
burn_tensor::DType::F64 => data
.as_slice::<f64>()
.map(|s| s.iter().sum::<f64>())
.unwrap_or(0.0),
_ => 0.0,
};
// Use the sum to prevent dead code elimination
std::hint::black_box(sum);
}
});
}
#[divan::bench(sample_count = 10)]
fn load_resnet18_largest_tensor(bencher: Bencher) {
let path = resnet18_path();
bencher.bench_local(|| {
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
// Find and materialize only the largest tensor
// This tests peak memory for a single tensor operation
let keys = reader.keys();
let mut largest_key = String::new();
let mut largest_size = 0usize;
for key in &keys {
let tensor = reader.get(key).expect("Failed to get tensor");
let size = tensor.data_len();
if size > largest_size {
largest_size = size;
largest_key = key.clone();
}
}
// Materialize the largest tensor
let tensor = reader
.get(&largest_key)
.expect("Failed to get largest tensor");
let _data = tensor.to_data().expect("Failed to materialize tensor data");
assert!(largest_size > 9_000_000); // Should be ~9MB for layer4.0.conv2.weight
});
}
#[divan::bench(sample_count = 10)]
fn load_resnet18_memory_profile(bencher: Bencher) {
let path = resnet18_path();
bencher
.with_inputs(|| path.clone())
.bench_local_values(|path| {
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
let keys = reader.keys();
let mut peak_single_tensor = 0usize;
let mut total_data = 0usize;
// Process each tensor and track memory
for key in &keys {
let tensor = reader.get(key).expect("Failed to get tensor");
let tensor_size = tensor.data_len();
// Track largest single tensor
if tensor_size > peak_single_tensor {
peak_single_tensor = tensor_size;
}
// Materialize the tensor
let data = tensor.to_data().expect("Failed to materialize tensor data");
total_data += tensor_size;
// Drop data immediately to test lazy loading memory efficiency
drop(data);
}
// Return stats for verification
(peak_single_tensor, total_data)
});
}

View File

@@ -0,0 +1,332 @@
#![recursion_limit = "256"]
//! Unified benchmark comparing all loading methods:
//! - BurnpackStore (new native format)
//! - NamedMpkFileRecorder (old native format)
//! - SafetensorsStore (new)
//! - SafetensorsFileRecorder (old)
//! - PytorchStore (new)
//! - PyTorchFileRecorder (old)
//!
//! Before running this benchmark, generate the model files:
//! ```bash
//! cd crates/burn-store
//! uv run benches/generate_unified_models.py
//! ```
//!
//! Then run the benchmark:
//! ```bash
//! cargo bench --bench unified_loading
//! ```
use burn_core as burn;
use burn_core::module::Module;
use burn_core::prelude::*;
use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
// use burn_import::safetensors::SafetensorsFileRecorder;
use burn_nn as nn;
use burn_store::{
BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore,
};
use divan::{AllocProfiler, Bencher};
use std::fs;
use std::path::{Path, PathBuf};
#[global_allocator]
static ALLOC: AllocProfiler = AllocProfiler::system();
// Backend type aliases
type NdArrayBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "wgpu")]
type WgpuBackend = burn_wgpu::Wgpu;
#[cfg(feature = "cuda")]
type CudaBackend = burn_cuda::Cuda<f32, i32>;
#[cfg(feature = "tch")]
type TchBackend = burn_tch::LibTorch<f32>;
#[cfg(feature = "metal")]
type MetalBackend = burn_wgpu::Metal;
// Use the same LargeModel as other benchmarks for fair comparison
#[derive(Module, Debug)]
struct LargeModel<B: Backend> {
layers: Vec<nn::Linear<B>>,
}
impl<B: Backend> LargeModel<B> {
fn new(device: &B::Device) -> Self {
let mut layers = Vec::new();
// Create a model with 20 layers - same as safetensor_loading benchmark
for i in 0..20 {
let in_size = if i == 0 { 1024 } else { 2048 };
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
}
Self { layers }
}
}
/// Get the path to the model files
fn get_model_dir() -> PathBuf {
std::env::temp_dir().join("simple_bench_models")
}
/// Generate Burnpack and NamedMpk files from existing SafeTensors file
fn generate_burn_formats(st_path: &Path, bp_path: &Path, mpk_path: &Path) {
type TestBackend = NdArrayBackend;
let device = Default::default();
// Load the model from SafeTensors
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = SafetensorsStore::from_file(st_path).with_from_adapter(PyTorchToBurnAdapter);
model
.load_from(&mut store)
.expect("Failed to load from SafeTensors");
// Save as Burnpack
if !bp_path.exists() {
println!(" Creating Burnpack file...");
let mut burnpack_store = BurnpackStore::from_file(bp_path);
model
.save_into(&mut burnpack_store)
.expect("Failed to save as Burnpack");
}
// Save as NamedMpk
if !mpk_path.exists() {
println!(" Creating NamedMpk file...");
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
model
.save_file(mpk_path, &recorder)
.expect("Failed to save as NamedMpk");
}
}
/// Get paths to the model files
fn get_model_paths() -> (PathBuf, PathBuf, PathBuf, PathBuf) {
let dir = get_model_dir();
(
dir.join("large_model.bpk"),
dir.join("large_model.mpk"),
dir.join("large_model.safetensors"),
dir.join("large_model.pt"),
)
}
/// Check if model files exist
fn check_model_files() -> Result<(), String> {
let (_, _, st_path, pt_path) = get_model_paths();
// For now, only check safetensors and pytorch files (will generate burnpack/mpk later)
if !st_path.exists() || !pt_path.exists() {
return Err(format!(
"\n❌ Model files not found!\n\
\n\
Please generate the model files first by running:\n\
\n\
cd crates/burn-store\n\
uv run benches/generate_unified_models.py\n\
\n\
Expected files:\n\
- {}\n\
- {}\n",
st_path.display(),
pt_path.display()
));
}
Ok(())
}
fn main() {
// Check if model files exist before running benchmarks
match check_model_files() {
Ok(()) => {
let (bp_path, mpk_path, st_path, pt_path) = get_model_paths();
// First, generate Burnpack and MPK files if they don't exist
if !bp_path.exists() || !mpk_path.exists() {
println!("⏳ Generating Burnpack and NamedMpk files from SafeTensors...");
generate_burn_formats(&st_path, &bp_path, &mpk_path);
}
let bp_size = fs::metadata(&bp_path)
.ok()
.map(|m| m.len() as f64 / 1_048_576.0);
let mpk_size = fs::metadata(&mpk_path)
.ok()
.map(|m| m.len() as f64 / 1_048_576.0);
let st_size = fs::metadata(&st_path).unwrap().len() as f64 / 1_048_576.0;
let pt_size = fs::metadata(&pt_path).unwrap().len() as f64 / 1_048_576.0;
println!("✅ Found model files:");
if let Some(size) = bp_size {
println!(" Burnpack: {} ({:.1} MB)", bp_path.display(), size);
}
if let Some(size) = mpk_size {
println!(" NamedMpk: {} ({:.1} MB)", mpk_path.display(), size);
}
println!(" SafeTensors: {} ({:.1} MB)", st_path.display(), st_size);
println!(" PyTorch: {} ({:.1} MB)", pt_path.display(), pt_size);
println!();
println!("🚀 Running unified loading benchmarks...");
println!();
println!("Comparing 6 loading methods:");
println!(" 1. BurnpackStore (new native format - lazy loading)");
println!(" 2. NamedMpkFileRecorder (old native format - loads all to memory)");
println!(" 3. SafetensorsStore (new)");
println!(" 4. SafetensorsFileRecorder (old)");
println!(" 5. PytorchStore (new)");
println!(" 6. PyTorchFileRecorder (old)");
println!();
println!("Available backends:");
println!(" - NdArray (CPU)");
#[cfg(feature = "wgpu")]
println!(" - WGPU (GPU)");
#[cfg(feature = "cuda")]
println!(" - CUDA (NVIDIA GPU)");
#[cfg(feature = "tch")]
println!(" - LibTorch");
#[cfg(feature = "metal")]
println!(" - Metal (Apple GPU)");
println!();
divan::main();
}
Err(msg) => {
eprintln!("{}", msg);
std::process::exit(1);
}
}
}
// Macro to generate benchmarks for each backend
macro_rules! bench_backend {
($backend:ty, $mod_name:ident, $backend_name:literal) => {
#[divan::bench_group(name = $backend_name, sample_count = 10)]
mod $mod_name {
use super::*;
type TestBackend = $backend;
type TestDevice = <TestBackend as Backend>::Device;
#[divan::bench]
fn burnpack_store(bencher: Bencher) {
let (bp_path, _, _, _) = get_model_paths();
let file_size = fs::metadata(&bp_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = BurnpackStore::from_file(bp_path.clone());
model.load_from(&mut store).expect("Failed to load");
});
}
#[divan::bench]
fn namedmpk_recorder(bencher: Bencher) {
let (_, mpk_path, _, _) = get_model_paths();
let file_size = fs::metadata(&mpk_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
let record = recorder
.load(mpk_path.clone().into(), &device)
.expect("Failed to load");
let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
});
}
#[divan::bench]
fn safetensors_store(bencher: Bencher) {
let (_, _, st_path, _) = get_model_paths();
let file_size = fs::metadata(&st_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = SafetensorsStore::from_file(st_path.clone())
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store).expect("Failed to load");
});
}
// #[divan::bench]
// fn safetensors_recorder(bencher: Bencher) {
// let (_, _, st_path, _) = get_model_paths();
// let file_size = fs::metadata(&st_path).unwrap().len();
// bencher
// .counter(divan::counter::BytesCount::new(file_size))
// .bench(|| {
// let device: TestDevice = Default::default();
// let recorder = SafetensorsFileRecorder::<FullPrecisionSettings>::default();
// let record = recorder
// .load(st_path.clone().into(), &device)
// .expect("Failed to load");
// let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
// });
// }
#[divan::bench]
fn pytorch_store(bencher: Bencher) {
let (_, _, _, pt_path) = get_model_paths();
let file_size = fs::metadata(&pt_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file(pt_path.clone())
.with_top_level_key("model_state_dict")
.allow_partial(true);
model.load_from(&mut store).expect("Failed to load");
});
}
// #[divan::bench]
// fn pytorch_recorder(bencher: Bencher) {
// let (_, _, _, pt_path) = get_model_paths();
// let file_size = fs::metadata(&pt_path).unwrap().len();
// bencher
// .counter(divan::counter::BytesCount::new(file_size))
// .bench(|| {
// let device: TestDevice = Default::default();
// let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();
// let load_args =
// LoadArgs::new(pt_path.clone()).with_top_level_key("model_state_dict");
// let record = recorder.load(load_args, &device).expect("Failed to load");
// let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
// });
// }
}
};
}
// Generate benchmarks for each backend
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
#[cfg(feature = "wgpu")]
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
#[cfg(feature = "cuda")]
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
#[cfg(feature = "tch")]
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
#[cfg(feature = "metal")]
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");

View File

@@ -0,0 +1,183 @@
#![recursion_limit = "256"]
//! Unified benchmark comparing all saving methods:
//! - BurnpackStore (new native format)
//! - NamedMpkFileRecorder (old native format)
//! - SafetensorsStore (new)
//!
//! Before running this benchmark, ensure the directory exists:
//! ```bash
//! mkdir -p /tmp/simple_bench_models
//! ```
//!
//! Then run the benchmark:
//! ```bash
//! cargo bench --bench unified_saving
//! ```
use burn_core as burn;
use burn_core::module::Module;
use burn_core::prelude::*;
use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder};
use burn_nn as nn;
use burn_store::{BurnpackStore, ModuleSnapshot, SafetensorsStore};
use divan::{AllocProfiler, Bencher};
use std::fs;
use std::path::PathBuf;
#[global_allocator]
static ALLOC: AllocProfiler = AllocProfiler::system();
// Backend type aliases
type NdArrayBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "wgpu")]
type WgpuBackend = burn_wgpu::Wgpu;
#[cfg(feature = "cuda")]
type CudaBackend = burn_cuda::Cuda<f32, i32>;
#[cfg(feature = "tch")]
type TchBackend = burn_tch::LibTorch<f32>;
#[cfg(feature = "metal")]
type MetalBackend = burn_wgpu::Metal;
// Use the same LargeModel as other benchmarks for fair comparison
#[derive(Module, Debug)]
struct LargeModel<B: Backend> {
layers: Vec<nn::Linear<B>>,
}
impl<B: Backend> LargeModel<B> {
fn new(device: &B::Device) -> Self {
let mut layers = Vec::new();
// Create a model with 20 layers - same as loading benchmarks
for i in 0..20 {
let in_size = if i == 0 { 1024 } else { 2048 };
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
}
Self { layers }
}
}
/// Get the path to the output directory
fn get_output_dir() -> PathBuf {
std::env::temp_dir().join("simple_bench_models_saving")
}
/// Ensure output directory exists
fn ensure_output_dir() -> Result<(), String> {
let dir = get_output_dir();
if !dir.exists() {
fs::create_dir_all(&dir)
.map_err(|e| format!("Failed to create output directory: {}", e))?;
}
Ok(())
}
fn main() {
match ensure_output_dir() {
Ok(()) => {
println!("✅ Output directory ready: {}", get_output_dir().display());
println!();
println!("🚀 Running unified saving benchmarks...");
println!();
println!("Comparing 3 saving methods:");
println!(" 1. BurnpackStore (new native format)");
println!(" 2. NamedMpkFileRecorder (old native format)");
println!(" 3. SafetensorsStore (new)");
println!();
println!("Available backends:");
println!(" - NdArray (CPU)");
#[cfg(feature = "wgpu")]
println!(" - WGPU (GPU)");
#[cfg(feature = "cuda")]
println!(" - CUDA (NVIDIA GPU)");
#[cfg(feature = "tch")]
println!(" - LibTorch");
#[cfg(feature = "metal")]
println!(" - Metal (Apple GPU)");
println!();
divan::main();
}
Err(msg) => {
eprintln!("{}", msg);
std::process::exit(1);
}
}
}
// Macro to generate benchmarks for each backend
macro_rules! bench_backend {
($backend:ty, $mod_name:ident, $backend_name:literal) => {
#[divan::bench_group(name = $backend_name, sample_count = 10)]
mod $mod_name {
use super::*;
type TestBackend = $backend;
type TestDevice = <TestBackend as Backend>::Device;
#[divan::bench]
fn burnpack_store(bencher: Bencher) {
bencher.bench(|| {
let device: TestDevice = Default::default();
let model = LargeModel::<TestBackend>::new(&device);
let output_path = get_output_dir().join("test_burnpack.bpk");
let mut store = BurnpackStore::from_file(output_path.clone()).overwrite(true);
model
.save_into(&mut store)
.expect("Failed to save with BurnpackStore");
// Clean up
let _ = fs::remove_file(output_path);
});
}
#[divan::bench]
fn namedmpk_recorder(bencher: Bencher) {
bencher.bench(|| {
let device: TestDevice = Default::default();
let model = LargeModel::<TestBackend>::new(&device);
let output_path = get_output_dir().join("test_namedmpk.mpk");
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
model
.save_file(output_path.clone(), &recorder)
.expect("Failed to save with NamedMpkFileRecorder");
// Clean up
let _ = fs::remove_file(output_path);
});
}
#[divan::bench]
fn safetensors_store(bencher: Bencher) {
bencher.bench(|| {
let device: TestDevice = Default::default();
let model = LargeModel::<TestBackend>::new(&device);
let output_path = get_output_dir().join("test_safetensors_store.safetensors");
let mut store = SafetensorsStore::from_file(output_path.clone());
model
.save_into(&mut store)
.expect("Failed to save with SafetensorsStore");
// Clean up
let _ = fs::remove_file(output_path);
});
}
}
};
}
// Generate benchmarks for each backend
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
#[cfg(feature = "wgpu")]
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
#[cfg(feature = "cuda")]
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
#[cfg(feature = "tch")]
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
#[cfg(feature = "metal")]
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");

View File

@@ -0,0 +1,596 @@
#![recursion_limit = "256"]
//! Benchmark comparing zero-copy vs copy loading modes for BurnpackStore.
//!
//! This benchmark measures the performance difference between:
//! - `zero_copy(false)` - Default mode, copies tensor data into new allocations
//! - `zero_copy(true)` - Zero-copy mode, slices tensor data without copying
//!
//! ## Understanding the Results
//!
//! **IMPORTANT**: For NdArray backend, you'll see similar allocation numbers because:
//! - NdArray uses `ndarray::ArrayD` which MUST own data as `Vec<T>`
//! - Even with zero-copy, the backend eventually copies data into its own format
//!
//! The zero-copy benefit is:
//! - **Without zero-copy**: File → Copy to heap (Bytes) → Copy to Vec (backend)
//! - **With zero-copy**: File → Zero-copy slice → Copy to Vec (backend)
//!
//! So zero-copy saves ONE memory copy at the store level. The `store_only_*` benchmarks
//! show the raw store performance without backend allocation overhead.
//!
//! GPU backends that can consume `Bytes` directly will show larger benefits.
//!
//! ## Running the benchmark
//!
//! Before running this benchmark, generate the model files:
//! ```bash
//! cd crates/burn-store
//! uv run benches/generate_unified_models.py
//! ```
//!
//! Then run the benchmark:
//! ```bash
//! cargo bench --bench zero_copy_loading
//! ```
use burn_core as burn;
use burn_core::module::Module;
use burn_core::prelude::*;
use burn_nn as nn;
use burn_store::{
BurnpackStore, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore,
};
use burn_tensor::{AllocationProperty, Bytes};
use divan::{AllocProfiler, Bencher};
use std::fs;
use std::path::PathBuf;
use std::sync::OnceLock;
#[global_allocator]
static ALLOC: AllocProfiler = AllocProfiler::system();
// Static storage for embedded model bytes (simulating include_bytes!)
static STATIC_MODEL_BYTES: OnceLock<&'static [u8]> = OnceLock::new();
// Backend type aliases
type NdArrayBackend = burn_ndarray::NdArray<f32>;
#[cfg(feature = "wgpu")]
type WgpuBackend = burn_wgpu::Wgpu;
#[cfg(feature = "cuda")]
type CudaBackend = burn_cuda::Cuda<f32, i32>;
#[cfg(feature = "tch")]
type TchBackend = burn_tch::LibTorch<f32>;
#[cfg(feature = "metal")]
type MetalBackend = burn_wgpu::Metal;
// Use the same LargeModel as other benchmarks for fair comparison
#[derive(Module, Debug)]
struct LargeModel<B: Backend> {
layers: Vec<nn::Linear<B>>,
}
impl<B: Backend> LargeModel<B> {
fn new(device: &B::Device) -> Self {
let mut layers = Vec::new();
// Create a model with 20 layers - same as unified_loading benchmark
for i in 0..20 {
let in_size = if i == 0 { 1024 } else { 2048 };
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
}
Self { layers }
}
}
/// Get the path to the model files
fn get_model_dir() -> PathBuf {
std::env::temp_dir().join("simple_bench_models")
}
/// Get path to Burnpack model file
fn get_burnpack_path() -> PathBuf {
get_model_dir().join("large_model.bpk")
}
/// Generate Burnpack file from existing SafeTensors file if needed
fn ensure_burnpack_file() {
let bp_path = get_burnpack_path();
let st_path = get_model_dir().join("large_model.safetensors");
if bp_path.exists() {
return;
}
if !st_path.exists() {
panic!(
"\n❌ SafeTensors model file not found!\n\
\n\
Please generate the model files first by running:\n\
\n\
cd crates/burn-store\n\
uv run benches/generate_unified_models.py\n\
\n\
Expected file: {}\n",
st_path.display()
);
}
println!("⏳ Generating Burnpack file from SafeTensors...");
type TestBackend = NdArrayBackend;
let device = Default::default();
// Load from SafeTensors
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = SafetensorsStore::from_file(&st_path).with_from_adapter(PyTorchToBurnAdapter);
model
.load_from(&mut store)
.expect("Failed to load from SafeTensors");
// Save as Burnpack
let mut burnpack_store = BurnpackStore::from_file(&bp_path);
model
.save_into(&mut burnpack_store)
.expect("Failed to save as Burnpack");
println!("✅ Created Burnpack file: {}", bp_path.display());
}
/// Initialize static model bytes (simulating include_bytes! at runtime for benchmarks)
fn get_static_model_bytes() -> &'static [u8] {
STATIC_MODEL_BYTES.get_or_init(|| {
let bp_path = get_burnpack_path();
let bytes = fs::read(&bp_path).expect("Failed to read Burnpack file");
// Leak the bytes to get a 'static lifetime (acceptable for benchmarks)
Box::leak(bytes.into_boxed_slice())
})
}
fn main() {
// Ensure Burnpack file exists
ensure_burnpack_file();
let bp_path = get_burnpack_path();
let file_size = fs::metadata(&bp_path).unwrap().len() as f64 / 1_048_576.0;
println!("✅ Found Burnpack model file:");
println!(" Path: {}", bp_path.display());
println!(" Size: {:.1} MB", file_size);
println!();
println!("🚀 Running zero-copy loading benchmarks...");
println!();
println!("Comparing loading modes:");
println!(" 1. file_copy - from_file().zero_copy(false) - copies tensor data");
println!(" 2. file_zero_copy - from_file().zero_copy(true) - zero-copy via mmap");
println!(" 3. static_copy - from_bytes() with Vec copy - copies from static");
println!(" 4. static_zero_copy - from_static() - zero-copy from static");
println!();
println!("Available backends:");
println!(" - NdArray (CPU)");
#[cfg(feature = "wgpu")]
println!(" - WGPU (GPU)");
#[cfg(feature = "cuda")]
println!(" - CUDA (NVIDIA GPU)");
#[cfg(feature = "tch")]
println!(" - LibTorch");
#[cfg(feature = "metal")]
println!(" - Metal (Apple GPU)");
println!();
// Pre-initialize static bytes before benchmarks
let _ = get_static_model_bytes();
divan::main();
}
// Macro to generate benchmarks for each backend
macro_rules! bench_backend {
($backend:ty, $mod_name:ident, $backend_name:literal) => {
#[divan::bench_group(name = $backend_name, sample_count = 10)]
mod $mod_name {
use super::*;
type TestBackend = $backend;
type TestDevice = <TestBackend as Backend>::Device;
/// File-based loading with copy mode (default)
#[divan::bench]
fn file_copy(bencher: Bencher) {
let bp_path = get_burnpack_path();
let file_size = fs::metadata(&bp_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);
model.load_from(&mut store).expect("Failed to load");
});
}
/// File-based loading with zero-copy mode (mmap + bytes::Bytes)
#[divan::bench]
fn file_zero_copy(bencher: Bencher) {
let bp_path = get_burnpack_path();
let file_size = fs::metadata(&bp_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);
model.load_from(&mut store).expect("Failed to load");
});
}
/// Static bytes with copy mode (simulating old behavior)
#[divan::bench]
fn static_copy(bencher: Bencher) {
let static_bytes = get_static_model_bytes();
let file_size = static_bytes.len() as u64;
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
// Simulate old behavior: copy static bytes to Vec, then load
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
model.load_from(&mut store).expect("Failed to load");
});
}
/// Static bytes with zero-copy mode (new from_static)
#[divan::bench]
fn static_zero_copy(bencher: Bencher) {
let static_bytes = get_static_model_bytes();
let file_size = static_bytes.len() as u64;
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
// Zero-copy: use from_static which keeps data in .rodata
let mut store = BurnpackStore::from_static(static_bytes);
model.load_from(&mut store).expect("Failed to load");
});
}
/// In-memory shared bytes with zero-copy
#[divan::bench]
fn memory_shared_zero_copy(bencher: Bencher) {
let static_bytes = get_static_model_bytes();
let file_size = static_bytes.len() as u64;
// Pre-create shared bytes outside the benchmark loop
let shared = bytes::Bytes::from_static(static_bytes);
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let device: TestDevice = Default::default();
let mut model = LargeModel::<TestBackend>::new(&device);
// Create Bytes from shared (cheap clone of Arc)
let bytes = Bytes::from_shared(shared.clone(), AllocationProperty::Other);
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(true);
model.load_from(&mut store).expect("Failed to load");
});
}
}
};
}
// =============================================================================
// Zero-copy verification (proves operations use static region data)
// =============================================================================
/// Verify that zero-copy loading actually uses data from the static region.
/// This runs once at startup to prove correctness before benchmarking.
#[divan::bench_group(name = "Zero-Copy Verification", sample_count = 1)]
mod verification {
use super::*;
use burn_ndarray::NdArray;
type B = NdArray<f32>;
/// Verify zero-copy: tensor storage is borrowed (not owned)
#[divan::bench]
fn verify_storage_is_borrowed() {
let static_bytes = get_static_model_bytes();
// Load model with zero-copy from static bytes
let device = Default::default();
let mut model = LargeModel::<B>::new(&device);
let mut store = BurnpackStore::from_static(static_bytes);
model.load_from(&mut store).expect("Failed to load");
// Get the first layer's weight tensor and verify it uses borrowed storage
let weight = model.layers[0].weight.val();
// .into_primitive() returns TensorPrimitive<B>, .tensor() extracts B::FloatTensorPrimitive
let ndarray_tensor = weight.into_primitive().tensor();
// Verify the storage is borrowed (zero-copy from static region)
assert!(
ndarray_tensor.is_borrowed(),
"ZERO-COPY FAILURE: Tensor storage is NOT borrowed. \
Data was copied instead of being zero-copy!"
);
println!("✅ Verified: Tensor storage is borrowed (zero-copy from static region)");
}
/// Verify ALL layers use borrowed (zero-copy) storage.
/// This is the key proof that loaded weights point to static memory.
#[divan::bench]
fn verify_all_layers_borrowed() {
let static_bytes = get_static_model_bytes();
// Load model with zero-copy
let device = Default::default();
let mut model = LargeModel::<B>::new(&device);
let mut store = BurnpackStore::from_static(static_bytes);
model.load_from(&mut store).expect("Failed to load");
// Check ALL layers have borrowed storage
let mut total_elements = 0usize;
for (i, layer) in model.layers.iter().enumerate() {
let weight = layer.weight.val();
total_elements += weight.shape().num_elements();
assert!(
weight.into_primitive().tensor().is_borrowed(),
"Layer {} weight should be borrowed (zero-copy)",
i
);
}
let total_mb = (total_elements * 4) as f64 / 1_048_576.0;
println!(
"✅ Verified: All {} layers use borrowed storage",
model.layers.len()
);
println!(
" - Model size: {:.2} MB - all pointing to static region",
total_mb
);
}
/// Verify data is readable and correct using sum().into_scalar().
/// Note: sum() triggers COW copy, so this shows ops work correctly on zero-copy data.
#[divan::bench]
fn verify_ops_produce_correct_results() {
let static_bytes = get_static_model_bytes();
let device = Default::default();
let mut model = LargeModel::<B>::new(&device);
let mut store = BurnpackStore::from_static(static_bytes);
model.load_from(&mut store).expect("Failed to load");
// Compute sum of first layer weight - proves data is valid
let weight = model.layers[0].weight.val();
let sum: f32 = weight.sum().into_scalar();
assert!(sum.is_finite(), "Sum should be finite");
println!("✅ Verified: Operations on zero-copy data produce valid results");
println!(" - First layer sum: {:.4}", sum);
}
/// Verify operations produce correct results on zero-copy data
#[divan::bench]
fn verify_operations_on_static_data() {
let static_bytes = get_static_model_bytes();
// Load model with zero-copy
let device = Default::default();
let mut model = LargeModel::<B>::new(&device);
let mut store = BurnpackStore::from_static(static_bytes);
model.load_from(&mut store).expect("Failed to load");
// Perform operations on the loaded weights
let weight = model.layers[0].weight.val();
let shape = weight.shape();
// Test 1: Sum should be finite (not NaN or Inf)
let sum: f32 = weight.clone().sum().to_data().to_vec().unwrap()[0];
assert!(
sum.is_finite(),
"Operation failed: sum is not finite ({})",
sum
);
// Test 2: Matrix multiply with itself transposed (W @ W.T)
let transposed = weight.clone().transpose();
let matmul_result = weight.clone().matmul(transposed);
let matmul_sum: f32 = matmul_result.sum().to_data().to_vec().unwrap()[0];
assert!(
matmul_sum.is_finite(),
"Matmul failed: result sum is not finite ({})",
matmul_sum
);
// Test 3: Element-wise operations
let doubled = weight.clone() * 2.0;
let doubled_sum: f32 = doubled.sum().to_data().to_vec().unwrap()[0];
assert!(
(doubled_sum - sum * 2.0).abs() < 1e-3,
"Element-wise op failed: doubled_sum ({}) != sum*2 ({})",
doubled_sum,
sum * 2.0
);
println!("✅ Verified: Operations on zero-copy data produce correct results");
println!(" - Weight shape: {:?}", shape.as_slice());
println!(" - Sum: {:.4}", sum);
println!(" - Matmul result sum: {:.4}", matmul_sum);
}
/// Compare zero-copy vs copy: verify both produce identical results
#[divan::bench]
fn verify_copy_vs_zero_copy_equality() {
let static_bytes = get_static_model_bytes();
let device: <B as Backend>::Device = Default::default();
// Load with zero-copy
let mut model_zc = LargeModel::<B>::new(&device);
let mut store_zc = BurnpackStore::from_static(static_bytes);
model_zc
.load_from(&mut store_zc)
.expect("Failed to load zero-copy");
// Load with copy (simulate old behavior)
let mut model_copy = LargeModel::<B>::new(&device);
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
let mut store_copy = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
model_copy
.load_from(&mut store_copy)
.expect("Failed to load copy");
// Compare weights from both models
for (i, (layer_zc, layer_copy)) in model_zc
.layers
.iter()
.zip(model_copy.layers.iter())
.enumerate()
{
let weight_zc = layer_zc.weight.val();
let weight_copy = layer_copy.weight.val();
// Check shapes match
assert_eq!(
weight_zc.shape(),
weight_copy.shape(),
"Layer {} weight shapes don't match",
i
);
// Check values match (using sum as a proxy)
let sum_zc: f32 = weight_zc.clone().sum().to_data().to_vec().unwrap()[0];
let sum_copy: f32 = weight_copy.clone().sum().to_data().to_vec().unwrap()[0];
assert!(
(sum_zc - sum_copy).abs() < 1e-6,
"Layer {} weight sums don't match: zero-copy={}, copy={}",
i,
sum_zc,
sum_copy
);
}
println!(
"✅ Verified: Zero-copy and copy loading produce identical results for all {} layers",
model_zc.layers.len()
);
}
}
// =============================================================================
// Store-only benchmarks (no backend allocation overhead)
// These show the TRUE zero-copy benefit at the store level
// =============================================================================
#[divan::bench_group(name = "Store Only (no backend)", sample_count = 10)]
mod store_only {
use super::*;
/// File-based store with copy mode - measures store overhead only
#[divan::bench]
fn file_copy(bencher: Bencher) {
let bp_path = get_burnpack_path();
let file_size = fs::metadata(&bp_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);
// Just iterate through all tensor snapshots, calling to_data() on each
// This forces the store to read and materialize all tensor data
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
for snapshot in snapshots.values() {
let _data = snapshot.to_data().expect("Failed to get tensor data");
}
});
}
/// File-based store with zero-copy mode - measures store overhead only
#[divan::bench]
fn file_zero_copy(bencher: Bencher) {
let bp_path = get_burnpack_path();
let file_size = fs::metadata(&bp_path).unwrap().len();
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
for snapshot in snapshots.values() {
let _data = snapshot.to_data().expect("Failed to get tensor data");
}
});
}
/// Static bytes with copy mode - measures store overhead only
#[divan::bench]
fn static_copy(bencher: Bencher) {
let static_bytes = get_static_model_bytes();
let file_size = static_bytes.len() as u64;
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
// Simulate old behavior: copy static bytes to Vec
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
for snapshot in snapshots.values() {
let _data = snapshot.to_data().expect("Failed to get tensor data");
}
});
}
/// Static bytes with zero-copy mode - measures store overhead only
#[divan::bench]
fn static_zero_copy(bencher: Bencher) {
let static_bytes = get_static_model_bytes();
let file_size = static_bytes.len() as u64;
bencher
.counter(divan::counter::BytesCount::new(file_size))
.bench(|| {
let mut store = BurnpackStore::from_static(static_bytes);
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
for snapshot in snapshots.values() {
let _data = snapshot.to_data().expect("Failed to get tensor data");
}
});
}
}
// =============================================================================
// Full model loading benchmarks (includes backend allocation)
// =============================================================================
// Generate benchmarks for each backend
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
#[cfg(feature = "wgpu")]
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
#[cfg(feature = "cuda")]
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
#[cfg(feature = "tch")]
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
#[cfg(feature = "metal")]
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");

View File

@@ -0,0 +1,148 @@
//! Example: Generate a Burnpack file for inspection
//!
//! This example creates a simple Burnpack file that you can examine to understand the format.
//!
//! Usage:
//! cargo run --example burnpack-inspect [output_path]
//!
//! Example:
//! cargo run --example burnpack-inspect sample.bpk
//! cargo run --example burnpack-inspect /tmp/test.bpk
//!
//! After generating the file, examine it with:
//! hexdump -C sample.bpk | head -100
//! xxd sample.bpk | head -100
//! hexyl sample.bpk
use burn_core as burn;
use burn_core::module::Module;
use burn_ndarray::NdArray;
use burn_nn::{Linear, LinearConfig};
use burn_store::{BurnpackStore, ModuleSnapshot};
use burn_tensor::backend::Backend;
use std::env;
// Simple model with a few layers
#[derive(Module, Debug)]
struct SampleModel<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
linear3: Linear<B>,
}
impl<B: Backend> SampleModel<B> {
fn new(device: &B::Device) -> Self {
Self {
linear1: LinearConfig::new(128, 64).init(device),
linear2: LinearConfig::new(64, 32).init(device),
linear3: LinearConfig::new(32, 10).init(device),
}
}
}
fn main() {
type Backend = NdArray<f32>;
// Get output path from command line or use default
let output_path = env::args()
.nth(1)
.unwrap_or_else(|| "sample.bpk".to_string());
println!("Creating sample Burnpack file: {}", output_path);
println!();
// Create a simple model
let device = Default::default();
let model = SampleModel::<Backend>::new(&device);
// Save to Burnpack format with metadata
let mut store = BurnpackStore::from_file(&output_path)
.overwrite(true)
.metadata("format", "burnpack")
.metadata("description", "Sample file for examining Burnpack format")
.metadata("version", env!("CARGO_PKG_VERSION"))
.metadata("author", "Burn Example");
model.save_into(&mut store).expect("Failed to save model");
println!("✅ Successfully created: {}", output_path);
println!();
println!("📋 File Structure:");
println!(" ┌─────────────────────────────────────┐");
println!(" │ Header (10 bytes) │");
println!(" ├─────────────────────────────────────┤");
println!(" │ - Magic: 0x4E525542 (BURN in LE) │");
println!(" │ - Version: 0x0001 (2 bytes) │");
println!(" │ - Metadata size: (4 bytes, u32 LE) │");
println!(" ├─────────────────────────────────────┤");
println!(" │ Metadata (CBOR format) │");
println!(" ├─────────────────────────────────────┤");
println!(" │ - Tensor descriptors │");
println!(" │ * name, dtype, shape, offsets │");
println!(" │ - User metadata │");
println!(" ├─────────────────────────────────────┤");
println!(" │ Tensor Data (raw bytes, LE) │");
println!(" ├─────────────────────────────────────┤");
println!(" │ - linear1.weight [64, 128] │");
println!(" │ - linear1.bias [64] │");
println!(" │ - linear2.weight [32, 64] │");
println!(" │ - linear2.bias [32] │");
println!(" │ - linear3.weight [10, 32] │");
println!(" │ - linear3.bias [10] │");
println!(" └─────────────────────────────────────┘");
println!();
println!("📊 Model Contents:");
println!(" - linear1.weight: [64, 128] = 8,192 params → 32,768 bytes");
println!(" - linear1.bias: [64] = 64 params → 256 bytes");
println!(" - linear2.weight: [32, 64] = 2,048 params → 8,192 bytes");
println!(" - linear2.bias: [32] = 32 params → 128 bytes");
println!(" - linear3.weight: [10, 32] = 320 params → 1,280 bytes");
println!(" - linear3.bias: [10] = 10 params → 40 bytes");
println!(" ───────────────────────────────────────────────────────");
let total_params = 8192 + 64 + 2048 + 32 + 320 + 10;
let total_bytes = total_params * 4;
println!(
" Total: {} parameters = {} KB",
total_params,
total_bytes / 1024
);
println!();
// Get actual file size
if let Ok(metadata) = std::fs::metadata(&output_path) {
let file_size = metadata.len();
println!(
"📦 File size: {} bytes ({:.2} KB)",
file_size,
file_size as f64 / 1024.0
);
}
println!();
println!("🔍 Inspection Commands:");
println!();
println!(" # View first 100 bytes in hex:");
println!(" hexdump -C {} | head -20", output_path);
println!();
println!(" # View header only (10 bytes):");
println!(" head -c 10 {} | hexdump -C", output_path);
println!();
println!(" # View with prettier hex viewer (if installed):");
println!(" hexyl {} | head -50", output_path);
println!();
println!(" # View in binary format:");
println!(" xxd -b {} | head -20", output_path);
println!();
println!(" # Extract and examine header:");
println!(" # Magic (bytes 0-3): Should be 42 55 52 4E (BURN)");
println!(" # Version (bytes 4-5): Should be 01 00");
println!(" # Metadata size (bytes 6-9): u32 little-endian");
println!();
println!(" # Load back the model:");
println!(
" # let mut store = BurnpackStore::from_file(\"{}\");",
output_path
);
println!(" # model.load_from(&mut store)?;");
}

View File

@@ -0,0 +1,13 @@
[package]
name = "pytorch-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-store = { path = "../", features = ["std", "pytorch"] }
serde = { workspace = true }
float-cmp = { workspace = true }

View File

@@ -0,0 +1 @@
pub type TestBackend = burn_ndarray::NdArray<f32>;

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.BatchNorm2d(5)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
# Condition batch norm (each forward will affect the running stats)
x1 = torch.ones(1, 5, 2, 2) - 0.5
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model after the first forward
torch.save(model.state_dict(), "batch_norm2d.pt")
x2 = torch.ones(1, 5, 2, 2) - 0.3
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,62 @@
use burn::{
module::Module,
nn::{BatchNorm, BatchNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: BatchNorm<B>,
}
impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
norm1: BatchNormConfig::new(5).init(device), // Python model uses BatchNorm2d(5)
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
#[test]
fn batch_norm2d() {
let device = Default::default();
let mut model = Net::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file("tests/batch_norm/batch_norm2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 5, 2, 2], &device) - 0.3;
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([True, False, True])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "boolean.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,58 @@
use burn::{
module::{Module, Param, ParamId},
tensor::{Bool, Tensor, TensorData, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Bool>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::initialized(
ParamId::new(),
Tensor::from_bool(TensorData::from([false, false, false]), device),
),
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::TensorData;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
use crate::backend::TestBackend;
#[test]
fn boolean() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/boolean/boolean.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 1, Bool>::from_bool(
TensorData::from([true, false, true]),
&device,
);
assert_eq!(output.to_data(), expected.to_data());
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.ones(3, 3)
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer + x
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "buffer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,53 @@
use burn::{
module::{Module, Param},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 2>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::from_tensor(Tensor::zeros([3, 3], device)),
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
self.buffer.val() + x
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
#[test]
fn buffer() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/buffer/buffer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 2>::ones([3, 3], &device) * 2.0;
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.norm = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_blocks = nn.Sequential(
ConvBlock(2, 4, (3, 2)),
ConvBlock(4, 6, (3, 2)),
)
self.norm1 = nn.BatchNorm2d(6)
self.fc1 = nn.Linear(120, 12)
self.fc2 = nn.Linear(12, 10)
def forward(self, x):
x = self.conv_blocks(x)
x = self.norm1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(2)
model = Net().to(torch.device("cpu"))
# Condition the model (batch norm requires a forward pass to compute the mean and variance)
x1 = torch.ones(1, 2, 9, 6) - 0.1
x2 = torch.ones(1, 2, 9, 6) - 0.3
output = model(x1)
output = model(x2)
model.eval() # set to eval mode
torch.save(model.state_dict(), "complex_nested.pt")
# feed test data
x = torch.ones(1, 2, 9, 6) - 0.5
output = model(x)
print("Input shape: {}", x.shape)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,240 @@
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use burn::{
module::Module,
nn::{
BatchNorm, BatchNormConfig, Linear, LinearConfig,
conv::{Conv2d, Conv2dConfig},
},
tensor::{
Tensor,
activation::{log_softmax, relu},
backend::Backend,
},
};
use burn_autodiff::Autodiff;
use burn_store::{ModuleSnapshot, PytorchStore};
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv_blocks: Vec<ConvBlock<B>>,
norm1: BatchNorm<B>,
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let conv_blocks = vec![
ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
},
ConvBlock {
conv: Conv2dConfig::new([4, 6], [3, 2]).init(device),
norm: BatchNormConfig::new(6).init(device), // matches conv output channels
},
];
let norm1 = BatchNormConfig::new(6).init(device);
let fc1 = LinearConfig::new(120, 12).init(device);
let fc2 = LinearConfig::new(12, 10).init(device);
Self {
conv_blocks,
norm1,
fc1,
fc2,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv_blocks[0].forward(x);
let x = self.conv_blocks[1].forward(x);
let x = self.norm1.forward(x);
let x = x.reshape([0, -1]);
let x = self.fc1.forward(x);
let x = relu(x);
let x = self.fc2.forward(x);
log_softmax(x, 1)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(x);
self.norm.forward(x)
}
}
/// Partial model to test loading of partial records.
#[derive(Module, Debug)]
pub struct PartialNet<B: Backend> {
conv1: ConvBlock<B>,
}
impl<B: Backend> PartialNet<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
};
Self { conv1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
/// Model with extra fields to test loading of records (e.g. from a different model).
#[derive(Module, Debug)]
pub struct PartialWithExtraNet<B: Backend> {
conv1: ConvBlock<B>,
extra_field: bool, // This field is not present in the pytorch model
}
impl<B: Backend> PartialWithExtraNet<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
};
Self {
conv1,
extra_field: true,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
type TestBackend = burn_ndarray::NdArray<f32>;
fn model_test(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
let expected = Tensor::<TestBackend, 2>::from_data(
[[
-2.306_613,
-2.058_945_4,
-2.298_372_7,
-2.358_294,
-2.296_395_5,
-2.416_090_5,
-2.107_669,
-2.428_420_8,
-2.526_469,
-2.319_918_6,
]],
&device,
);
output.to_data().assert_approx_eq::<FloatElem<TestBackend>>(
&expected.to_data(),
Tolerance::absolute(precision),
);
}
#[test]
fn full_record() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
model_test(model, 1e-8);
}
#[test]
fn full_record_autodiff() {
let device = Default::default();
let mut model = Net::<Autodiff<TestBackend>>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn half_record() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
model_test(model, 1e-4);
}
#[test]
fn partial_model_loading() {
let device = Default::default();
let mut model = PartialNet::<TestBackend>::init(&device);
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
.allow_partial(true);
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
}
#[test]
fn extra_field_model_loading() {
let device = Default::default();
let mut model = PartialWithExtraNet::<TestBackend>::init(&device);
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
.allow_partial(true);
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
assert!(model.extra_field);
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)
return x
CONFIG = {
"n_head": 2,
"n_layer": 3,
"d_model": 512,
"some_float": 0.1,
"some_int": 1,
"some_bool": True,
"some_str": "hello",
"some_list_int": [1, 2, 3],
"some_list_str": ["hello", "world"],
"some_list_float": [0.1, 0.2, 0.3],
"some_dict": {
"some_key": "some_value"
}
}
class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)
def forward(self, x):
x = self.fc1(x)
return x
def main():
model = Model().to(torch.device("cpu"))
weights_with_config = {
"my_model": model.state_dict(),
"my_config": CONFIG
}
torch.save(weights_with_config, "weights_with_config.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,53 @@
#![allow(clippy::too_many_arguments)] // To mute derive Config warning
use std::collections::HashMap;
use burn::config::Config;
#[allow(clippy::too_many_arguments)]
#[derive(Debug, PartialEq, Config)]
struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec<i32>,
some_list_str: Vec<String>,
some_list_float: Vec<f64>,
some_dict: HashMap<String, String>,
}
#[cfg(test)]
mod tests {
use burn_store::pytorch::PytorchReader;
use super::*;
#[test]
fn test_net_config() {
let config_expected = NetConfig {
n_head: 2,
n_layer: 3,
d_model: 512,
some_float: 0.1,
some_int: 1,
some_bool: true,
some_str: "hello".to_string(),
some_list_int: vec![1, 2, 3],
some_list_str: vec!["hello".to_string(), "world".to_string()],
some_list_float: vec![0.1, 0.2, 0.3],
some_dict: {
let mut map = HashMap::new();
map.insert("some_key".to_string(), "some_value".to_string());
map
},
};
let path = "tests/config/weights_with_config.pt";
let top_level_key = Some("my_config");
let config: NetConfig = PytorchReader::load_config(path, top_level_key).unwrap();
assert_eq!(config, config_expected);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv1d(2, 2, 2)
self.conv2 = nn.Conv1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv1d.pt")
input = torch.rand(1, 2, 6)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,97 @@
use burn::{
module::Module,
nn::conv::{Conv1d, Conv1dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv1d<B>,
conv2: Conv1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv1dConfig::new(2, 2, 2).init(device);
let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn conv1d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 3>::from_data(
[[
[
0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,
],
[
0.33977574,
0.523_940_8,
0.798_063_9,
0.77176833,
0.01122457,
0.80996025,
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[[
[0.02987457, 0.03134188, 0.04234261, -0.02437721],
[-0.03788019, -0.02972012, -0.00806090, -0.01981254],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv1d_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv1d(model, 1e-7);
}
#[test]
fn conv1d_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv1d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv2d.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,134 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv2d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv2d_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv2d(model, 1e-7);
}
#[test]
fn conv2d_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv2d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose1d(2, 2, 2)
self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose1d.pt")
input = torch.rand(1, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose1d<B>,
conv2: ConvTranspose1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device);
let conv2 = ConvTranspose1dConfig::new([2, 2], 2)
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv_transpose1d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 3>::from_data(
[[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[[
[0.02935525, 0.01119324, -0.01356167, -0.00682688],
[0.01644749, -0.01429807, 0.00083987, 0.00279229],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv_transpose1d_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose1d(model, 1e-8);
}
#[test]
fn conv_transpose1d_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose1d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))
self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose2d.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,99 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose2d<B>,
conv2: ConvTranspose2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv_transpose2d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
[[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.04547675, 0.01879685, -0.01636661, 0.00310803],
[0.02090115, 0.01192738, -0.048_240_2, 0.02252235],
[0.03249975, -0.00460748, 0.05003899, 0.04029131],
[0.02185687, -0.10226749, -0.06508022, -0.01267705],
],
[
[0.00277598, -0.00513832, -0.059_048_3, 0.00567626],
[-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],
[-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],
[-0.03174751, 0.02963913, -0.02703723, -0.01860938],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv_transpose2d_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose2d(model, 1e-7);
}
#[test]
fn conv_transpose2d_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose2d(model, 1e-4);
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embed = nn.Embedding(10, 3)
def forward(self, x):
x = self.embed(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "embedding.pt")
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,86 @@
use burn::{
module::Module,
nn::{Embedding, EmbeddingConfig},
tensor::{Int, Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
embed: Embedding<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let embed = EmbeddingConfig::new(10, 3).init(device);
Self { embed }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
self.embed.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn embedding(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[
[
[-1.609_484_9, -0.10016718, -0.609_188_9],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.22227049, 1.687_113_4, -0.32062083],
[-0.29934573, 1.879_345_7, -0.07213178],
],
[
[-0.22227049, 1.687_113_4, -0.32062083],
[0.303_722, -0.777_314_3, -0.25145486],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.02878714, 2.357_111, -1.037_338_7],
],
],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn embedding_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
embedding(model, 1e-3);
}
#[test]
fn embedding_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
embedding(model, 1e-3);
}
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class DwsConv(nn.Module):
"""Depthwise separable convolution."""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
# Depthwise conv
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
# Pointwise conv
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: Tensor) -> Tensor:
x = self.dconv(x)
return self.pconv(x)
class Model(nn.Module):
def __init__(self, depthwise: bool = False) -> None:
super().__init__()
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_false.pt")
input = torch.rand(1, 2, 5, 5)
print("Depthwise is False")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Depthwise is True")
model = Model(depthwise=True).to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_true.pt")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,197 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum Conv<B: Backend> {
DwsConv(DwsConv<B>),
Conv(Conv2d<B>),
}
#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
dconv: Conv2d<B>,
pconv: Conv2d<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv: Conv<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model with DwsConv variant.
pub fn init_dws_conv(device: &B::Device) -> Self {
let dconv = Conv2dConfig::new([2, 2], [3, 3])
.with_groups(2)
.init(device);
let pconv = Conv2dConfig::new([2, 2], [1, 1])
.with_groups(1)
.init(device);
Net {
conv: Conv::DwsConv(DwsConv { dconv, pconv }),
}
}
/// Create a new model with Conv variant.
pub fn init_conv(device: &B::Device) -> Self {
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);
Net {
conv: Conv::Conv(conv2d_config.init(device)),
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
match &self.conv {
Conv::DwsConv(dws_conv) => {
let x = dws_conv.dconv.forward(x);
dws_conv.pconv.forward(x)
}
Conv::Conv(conv) => conv.forward(x),
}
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn depthwise_false() {
let device = Default::default();
let mut model = Net::<TestBackend>::init_conv(&device);
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_false.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.35449377, -0.02832414, 0.490_976_1],
[0.29709217, 0.332_586_3, 0.30594018],
[0.18101373, 0.30932188, 0.30558896],
],
[
[-0.17683622, -0.13244139, -0.05608707],
[0.23467252, -0.07038684, 0.255_044_1],
[-0.241_931_3, -0.20476191, -0.14468731],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn depthwise_true() {
let device = Default::default();
let mut model = Net::<TestBackend>::init_dws_conv(&device);
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_true.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.77874625, 0.859_017_6, 0.834_283_5],
[0.773_056_4, 0.73817325, 0.78292674],
[0.710_775_2, 0.747_187_2, 0.733_264_4],
],
[
[-0.44891885, -0.49027523, -0.394_170_7],
[-0.43836114, -0.33961445, -0.387_311_5],
[-0.581_134_3, -0.34197026, -0.535_035_7],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.GroupNorm(2, 6)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "group_norm.pt")
x2 = torch.rand(1, 6, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,90 @@
use burn::{
module::Module,
nn::{GroupNorm, GroupNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: GroupNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let norm1 = GroupNormConfig::new(2, 6).init(device);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn group_norm(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
[[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],
[[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],
[[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],
[[0.31379688, 0.19801933], [0.41619217, 0.28432965]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],
[[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],
[[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],
[[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],
[[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],
[[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn group_norm_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
group_norm(model, 1e-3);
}
#[test]
fn group_norm_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
group_norm(model, 1e-3);
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([1, 2, 3])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "integer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,72 @@
use burn::{
module::{Module, Param, ParamId},
tensor::{Int, Tensor, TensorData, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Int>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::initialized(
ParamId::new(),
Tensor::<B, 1, Int>::from_data(TensorData::from([0, 0, 0]), device),
),
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::TensorData;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn integer(model: Net<TestBackend>) {
let device = Default::default();
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 2, 3]), &device);
assert_eq!(output.to_data(), expected.to_data());
}
#[test]
fn integer_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
integer(model);
}
#[test]
fn integer_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
integer(model);
}
}

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvModule(nn.Module):
def __init__(self):
super(ConvModule, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = ConvModule()
def forward(self, x):
x = self.conv(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "key_remap.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,118 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn key_remap() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/key_remap/key_remap.pt")
.with_key_remapping("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, 3, bias=False)
self.bn = nn.BatchNorm2d(6)
self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6))
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.layer(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(42)
model = Model()
input = torch.rand(1, 3, 4, 4)
model(input) # condition batch norm
model.eval()
with torch.no_grad():
print(f"Input shape: {input.shape}")
print("Input type: {}", input.dtype)
print(f"Input: {input}")
output = model(input)
print(f"Output: {output}")
print(f"Output Shape: {output.shape}")
torch.save(model.state_dict(), "key_remap.pt")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,179 @@
use std::marker::PhantomData;
use burn::{
module::Module,
nn::{
BatchNorm, BatchNormConfig,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Device, Tensor, backend::Backend},
};
/// Some module that implements a specific method so it can be used in a sequential block.
pub trait ForwardModule<B: Backend> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
}
/// Conv2d + BatchNorm block.
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
bn: BatchNorm<B>,
}
impl<B: Backend> ForwardModule<B> for ConvBlock<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
self.bn.forward(out)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn new(in_channels: usize, out_channels: usize, device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self { conv, bn }
}
}
/// Collection of sequential blocks.
#[derive(Module, Debug)]
pub struct ModuleBlock<B: Backend, M> {
blocks: Vec<M>,
_backend: PhantomData<B>,
}
impl<B: Backend, M: ForwardModule<B>> ModuleBlock<B, M> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let mut out = input;
for block in &self.blocks {
out = block.forward(out);
}
out
}
}
impl<B: Backend> ModuleBlock<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)];
Self {
blocks,
_backend: PhantomData,
}
}
}
#[derive(Module, Debug)]
pub struct Model<B: Backend, M> {
conv: Conv2d<B>,
bn: BatchNorm<B>,
layer: ModuleBlock<B, M>,
}
impl<B: Backend> Model<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([3, 6], [3, 3])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(6).init(device);
let layer = ModuleBlock::new(device);
Self { conv, bn, layer }
}
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
let out = self.bn.forward(out);
self.layer.forward(out)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
#[should_panic]
fn key_remap_chained_missing_pattern() {
// Loading record should fail due to missing pattern to map the layer.blocks
let device = Default::default();
let mut model: Model<TestBackend, _> = Model::new(&device);
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
// Map *.block.0.* -> *.conv.*
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn key_remap_chained() {
let device = Default::default();
let mut model: Model<TestBackend, _> = Model::new(&device);
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
// Map *.block.0.* -> *.conv.*
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2")
// Map layer.[i].* -> layer.blocks.[i].*
.with_key_remapping("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.76193494, 0.626_546_1, 0.49510366, 0.11974698],
[0.07161391, 0.03232569, 0.704_681, 0.254_516],
[0.399_373_7, 0.21224737, 0.40888822, 0.14808255],
[0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6],
],
[
[0.33959562, 0.13321638, 0.41178054, 0.257_626_3],
[0.347_029_2, 0.02400219, 0.77974546, 0.15189773],
[0.75130886, 0.726_892_1, 0.85721636, 0.11647397],
[0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734],
],
[
[0.42948407, 0.49613327, 0.38488472, 0.08250773],
[0.73995143, 0.00364107, 0.81039995, 0.87411255],
[0.972_853_2, 0.38206023, 0.08917904, 0.61241513],
[0.77621365, 0.00234562, 0.38650817, 0.20027226],
],
]],
&device,
);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]],
[[0.17582723, 0.11344293], [0.05444185, 0.13307181]],
[[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]],
[[0.00230906, -0.02177845], [0.01129148, 0.00925517]],
[[0.14751078, 0.14433631], [0.05498439, 0.29049855]],
[[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]],
]],
&device,
);
let output = model.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.LayerNorm(2)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "layer_norm.pt")
x2 = torch.rand(1, 2, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,82 @@
use burn::{
module::Module,
nn::{LayerNorm, LayerNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: LayerNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let norm1 = LayerNormConfig::new(2).init(device);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn layer_norm(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],
[[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn layer_norm_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
layer_norm(model, 1e-3);
}
#[test]
fn layer_norm_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
layer_norm(model, 1e-3);
}
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)
return x
class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)
def forward(self, x):
x = self.fc1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
model_with_bias = ModelWithBias().to(torch.device("cpu"))
torch.save(model.state_dict(), "linear.pt")
torch.save(model_with_bias.state_dict(), "linear_with_bias.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Model with bias")
output = model_with_bias(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,154 @@
use burn::{
module::Module,
nn::{Linear, LinearConfig, Relu},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
relu: Relu,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let fc1 = LinearConfig::new(2, 3).init(device);
let fc2 = LinearConfig::new(3, 4).with_bias(false).init(device);
let relu = Relu;
Self { fc1, fc2, relu }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.fc1.forward(x);
let x = self.relu.forward(x);
self.fc2.forward(x)
}
}
#[derive(Module, Debug)]
struct NetWithBias<B: Backend> {
fc1: Linear<B>,
}
impl<B: Backend> NetWithBias<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let fc1 = LinearConfig::new(2, 3).init(device);
Self { fc1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.fc1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn linear_test(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.09778349, -0.13756673, 0.04962806, 0.08856435],
[0.03163241, -0.02848549, 0.01437942, 0.11905234],
],
[
[0.07628226, -0.10757702, 0.03656857, 0.03824598],
[0.05443089, -0.06904714, 0.02744314, 0.09997337],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn linear_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
linear_test(model, 1e-7);
}
#[test]
fn linear_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
linear_test(model, 1e-4);
}
#[test]
fn linear_with_bias() {
let device = Default::default();
let mut model = NetWithBias::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear_with_bias.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.00432095, -1.107_101_2, 0.870_691_4],
[0.024_595_5, -0.954_462_9, 0.48518157],
],
[
[0.34315687, -0.757_384_2, 0.548_288],
[-0.06608963, -1.072_072_7, 0.645_800_5],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
def forward(self, x):
x = self.conv1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "missing_module_field.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,37 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
do_not_exist_in_pt: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::nn::conv::Conv2dConfig;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
Self {
do_not_exist_in_pt: Conv2dConfig::new([2, 2], [2, 2]).init(device),
}
}
}
#[test]
#[should_panic(expected = "do_not_exist_in_pt")]
fn should_fail_if_struct_field_is_missing() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store =
PytorchStore::from_file("tests/missing_module_field/missing_module_field.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
}

View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
num_layers = 5 # Number of repeated convolutional layers
# Create a list to store the layers
layers = []
for _ in range(num_layers):
layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True))
layers.append(nn.ReLU(inplace=True))
# Use nn.Sequential to create a single module from the layers
self.fc = nn.Sequential(*layers)
def forward(self, x):
x = self.fc(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "non_contiguous_indexes.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,110 @@
use burn::{
module::Module,
nn::{
PaddingConfig2d,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Tensor, activation::relu, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
fc: Vec<Conv2d<B>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same);
// The PyTorch file has 5 Conv2d layers at non-contiguous indices (0, 2, 4, 6, 8)
// in the Sequential (alternating with ReLU layers)
let fc = vec![
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
];
Net { fc }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i)))
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn non_contiguous_indexes() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store =
PytorchStore::from_file("tests/non_contiguous_indexes/non_contiguous_indexes.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.67890584,
0.307_537_2,
0.265_156_2,
0.528_318_8,
0.86194897,
],
[0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455],
[0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637],
[0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804],
[0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932],
],
[
[0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266],
[0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439],
[0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786],
[0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244],
[0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3],
],
[
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(1e-7));
}
}

View File

@@ -0,0 +1,22 @@
mod backend;
mod batch_norm;
mod boolean;
mod buffer;
mod complex_nested;
mod config;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod enum_module;
mod group_norm;
mod integer;
mod key_remap;
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
mod non_contiguous_indexes;
mod top_level_key;

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
def forward(self, x):
x = self.conv1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,48 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::nn::conv::Conv2dConfig;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([2, 2], [2, 2]).init(device),
}
}
}
#[test]
#[should_panic]
fn should_fail_if_not_found() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn should_load() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt")
.with_top_level_key("my_state_dict");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
}

View File

@@ -0,0 +1,13 @@
[package]
name = "safetensors-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-store = { path = "../", features = ["std", "safetensors"] }
serde = { workspace = true }
float-cmp = { workspace = true }

View File

@@ -0,0 +1 @@
pub type TestBackend = burn_ndarray::NdArray<f32>;

View File

@@ -0,0 +1,92 @@
use burn::{
module::Module,
nn::{
BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
norm1: BatchNorm<B>,
fc1: Linear<B>,
relu: Relu,
}
impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([3, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
norm1: BatchNormConfig::new(4).init(device),
fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),
relu: Relu::new(),
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv1.forward(x);
let x = self.norm1.forward(x);
let x = self.relu.forward(x);
// Flatten all dimensions except the batch dimension
let x = x.flatten(1, 3);
self.fc1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
use super::*;
#[test]
fn multi_layer_model() {
let device = Default::default();
let mut model = Net::<TestBackend>::new(&device);
let mut store = SafetensorsStore::from_file("tests/multi_layer/multi_layer.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
let output = model.forward(input);
// Note: Expected values should be updated based on the actual output from the PyTorch model
let expected = Tensor::<TestBackend, 2>::from_data(
[[
0.04971555,
-0.16849735,
0.05182848,
-0.18032673,
0.23138367,
0.05041867,
0.13005908,
-0.32202929,
-0.07915690,
-0.03232457,
-0.19790289,
-0.17476529,
-0.19627589,
-0.21757686,
-0.31376451,
0.08377837,
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 4, kernel_size=3, padding=1)
self.norm1 = nn.BatchNorm2d(4)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(4 * 8 * 8, 16) # Changed for smaller input size
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = F.relu(x)
x = self.flatten(x)
x = self.fc1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
# Use a smaller input size
# 1 batch, 3 channels (RGB), 8x8 image (small input)
x1 = torch.ones(1, 3, 8, 8)
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model to safetensors after the first forward
save_file(model.state_dict(), "multi_layer.safetensors")
x2 = torch.ones(1, 3, 8, 8)
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
mod backend;
mod multi_layer;

View File

@@ -0,0 +1,663 @@
//! Module adapters for transforming tensors between different formats
//!
//! This module provides adapters that handle differences between PyTorch and Burn:
//! - Linear layer weight transposition
//! - Normalization parameter naming (weight/bias vs gamma/beta)
use crate::TensorSnapshot;
use alloc::boxed::Box;
use alloc::rc::Rc;
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec;
use burn_tensor::TensorData;
// Module type names as they appear in the container_type field
// These come from the Module derive macro which uses stringify! on the struct name
// Format: "Struct:TypeName" for user-defined structs
mod module_names {
// The actual string constants that match what the Module derive macro produces
pub const LINEAR: &str = "Struct:Linear";
pub const BATCH_NORM: &str = "Struct:BatchNorm";
pub const LAYER_NORM: &str = "Struct:LayerNorm";
pub const GROUP_NORM: &str = "Struct:GroupNorm";
}
/// Trait for adapting tensor snapshots between different module formats
pub trait ModuleAdapter: Send + Sync {
/// Adapt a tensor snapshot based on its container type and parameter name
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
/// Get alternative parameter name to try during matching
///
/// When looking for a parameter in a module, this method provides an alternative
/// name to try if the direct name doesn't match. This enables matching parameters
/// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma").
///
/// # Arguments
/// * `param_name` - The parameter name we're looking for
/// * `container_type` - The type of container module (e.g., "BatchNorm")
///
/// # Returns
/// Alternative parameter name to try, or None if no alternative exists
fn get_alternative_param_name(
&self,
_param_name: &str,
_container_type: &str,
) -> Option<String> {
None
}
/// Clone the adapter into a boxed trait object
fn clone_box(&self) -> Box<dyn ModuleAdapter>;
/// Chain adapters together, applying `self` first and then `next`.
///
/// This is useful when multiple transformations are required when importing model weights
/// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping).
///
/// The semantics follow a simple pipeline:
/// - `adapt`: `next.adapt(&self.adapt(snapshot))`
/// - `get_alternative_param_name`: try `self` first; if it returns an alternative name,
/// try `next` with that name, otherwise return the first alternative name.
fn chain<A>(self, next: A) -> ChainAdapter
where
Self: Sized + 'static,
A: ModuleAdapter + 'static,
{
ChainAdapter::new(self, next)
}
}
impl Clone for Box<dyn ModuleAdapter> {
fn clone(&self) -> Self {
self.clone_box()
}
}
/// Adapter that applies two adapters in sequence.
///
/// This allows composing smaller adapters instead of creating one large monolithic adapter.
#[derive(Clone)]
pub struct ChainAdapter {
first: Box<dyn ModuleAdapter>,
second: Box<dyn ModuleAdapter>,
}
impl ChainAdapter {
/// Create a new adapter chain.
pub fn new<A, B>(first: A, second: B) -> Self
where
A: ModuleAdapter + 'static,
B: ModuleAdapter + 'static,
{
Self {
first: Box::new(first),
second: Box::new(second),
}
}
}
impl ModuleAdapter for ChainAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
let snapshot = self.first.adapt(snapshot);
self.second.adapt(&snapshot)
}
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
if let Some(name) = self
.first
.get_alternative_param_name(param_name, container_type)
{
self.second
.get_alternative_param_name(&name, container_type)
.or(Some(name))
} else {
self.second
.get_alternative_param_name(param_name, container_type)
}
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
/// Identity adapter that passes tensors through unchanged
#[derive(Debug, Clone, Default)]
pub struct IdentityAdapter;
impl ModuleAdapter for IdentityAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
snapshot.clone()
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
/// Adapter for converting from PyTorch format to Burn format
///
/// Handles:
/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])
/// - Normalization parameter renaming (weight → gamma, bias → beta)
#[derive(Debug, Clone, Default)]
pub struct PyTorchToBurnAdapter;
impl ModuleAdapter for PyTorchToBurnAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
}
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
// For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias)
if is_normalization_layer(container_type) {
burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
} else {
None
}
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
/// Adapter for converting from Burn format to PyTorch format
///
/// Handles:
/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])
/// - Normalization parameter renaming (gamma → weight, beta → bias)
#[derive(Debug, Clone, Default)]
pub struct BurnToPyTorchAdapter;
impl ModuleAdapter for BurnToPyTorchAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
}
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
// For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta)
if is_normalization_layer(container_type) {
pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
} else {
None
}
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
/// Direction of PyTorch conversion for parameter naming
#[derive(Debug, Clone, Copy)]
enum PyTorchConversionDirection {
PyTorchToBurn,
BurnToPyTorch,
}
/// Check if container type is a normalization layer
fn is_normalization_layer(container_type: &str) -> bool {
matches!(
container_type,
module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
)
}
/// Map PyTorch normalization parameter name to Burn
fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
match param_name {
"weight" => Some("gamma"),
"bias" => Some("beta"),
_ => None,
}
}
/// Map Burn normalization parameter name to PyTorch
fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
match param_name {
"gamma" => Some("weight"),
"beta" => Some("bias"),
_ => None,
}
}
/// Core tensor adaptation logic for PyTorch format conversions
fn adapt_pytorch_tensor(
snapshot: &TensorSnapshot,
direction: PyTorchConversionDirection,
) -> TensorSnapshot {
// Extract path and parameter name
let (path_stack, param_name) = match get_path_and_param(snapshot) {
Some(result) => result,
None => return snapshot.clone(),
};
// Get module type for matching (ignores Vec/Array wrappers)
let module_type = match snapshot.module_type() {
Some(mt) => mt,
None => return snapshot.clone(), // No user-defined module found
};
// Linear: transpose weight (bidirectional - same operation both ways)
if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
return transpose_2d_tensor(snapshot);
}
// Normalization layers: rename parameters based on direction
if is_normalization_layer(&module_type) {
let new_name = match direction {
PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
};
if let Some(new_name) = new_name {
return rename_parameter(snapshot, path_stack, new_name);
}
}
snapshot.clone()
}
/// Extract path stack and parameter name from snapshot
fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
let path_stack = snapshot.path_stack.as_ref()?;
let param_name = path_stack.last()?.as_str();
Some((path_stack.as_slice(), param_name))
}
/// Rename a parameter in the snapshot
fn rename_parameter(
snapshot: &TensorSnapshot,
path_stack: &[String],
new_name: &str,
) -> TensorSnapshot {
let mut new_path = path_stack.to_vec();
*new_path.last_mut().unwrap() = new_name.to_string();
TensorSnapshot::from_closure(
snapshot.clone_data_fn(),
snapshot.dtype,
snapshot.shape.clone(),
new_path,
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_default(),
)
}
/// Transpose a 2D tensor
fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
if snapshot.shape.len() != 2 {
return snapshot.clone();
}
let original_data_fn = snapshot.clone_data_fn();
let dtype = snapshot.dtype;
let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
// Create a lazy closure that transposes when called
let transposed_data_fn = Rc::new(move || {
let data = original_data_fn()?;
Ok(transpose_tensor_data(data))
});
TensorSnapshot::from_closure(
transposed_data_fn,
dtype,
transposed_shape,
snapshot.path_stack.clone().unwrap_or_default(),
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_default(),
)
}
/// Transpose tensor data (assumes 2D shape is already validated)
fn transpose_tensor_data(data: TensorData) -> TensorData {
let shape = &data.shape;
let rows = shape[0];
let cols = shape[1];
let transposed_shape = vec![cols, rows];
// Get the raw bytes and element size
let bytes = data.as_bytes();
let element_size = data.dtype.size();
// Create a new buffer for transposed data
let mut transposed_bytes = vec![0u8; bytes.len()];
// Transpose at the byte level - works for any data type
for i in 0..rows {
for j in 0..cols {
let src_idx = (i * cols + j) * element_size;
let dst_idx = (j * rows + i) * element_size;
// Copy the bytes for this element
transposed_bytes[dst_idx..dst_idx + element_size]
.copy_from_slice(&bytes[src_idx..src_idx + element_size]);
}
}
// Create new TensorData from transposed bytes
TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::rc::Rc;
use alloc::sync::Arc;
use burn_tensor::{DType, TensorData};
use core::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_module_names_match_burn_nn() {
// If these types are renamed or moved in `burn-nn`, this test will fail to compile.
// This use statement replicates the previous check/alarm system.
#[allow(unused_imports)]
use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
// These assert statements work as extra checks that should remind maintainers more
// clearly that the hardcoded strings needs get updated.
assert_eq!(module_names::LINEAR, "Struct:Linear");
assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm");
assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm");
assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm");
}
fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
let values = vec![1.0f32; shape.iter().product()];
let data = TensorData::new(values, shape.clone());
TensorSnapshot::from_closure(
Rc::new(move || Ok(data.clone())),
DType::F32,
shape,
path_parts,
vec![container_type.to_string()],
burn_core::module::ParamId::new(),
)
}
#[test]
fn test_pytorch_to_burn_linear_weight() {
let adapter = PyTorchToBurnAdapter;
// Linear layer weight should be transposed
let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.shape, vec![5, 10]);
// Linear layer bias should not be transposed
let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.shape, vec![10]);
}
#[test]
fn test_pytorch_to_burn_norm_params() {
let adapter = PyTorchToBurnAdapter;
// BatchNorm weight -> gamma
let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.full_path(), "norm.gamma");
// BatchNorm bias -> beta
let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.full_path(), "norm.beta");
}
#[test]
fn test_burn_to_pytorch_linear_weight() {
let adapter = BurnToPyTorchAdapter;
// Linear layer weight should be transposed
let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.shape, vec![10, 5]);
}
#[test]
fn test_burn_to_pytorch_norm_params() {
let adapter = BurnToPyTorchAdapter;
// BatchNorm gamma -> weight
let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.full_path(), "norm.weight");
// BatchNorm beta -> bias
let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.full_path(), "norm.bias");
}
#[test]
fn test_transpose_different_dtypes() {
// Test that transpose works for different data types
// Test with F32
let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let transposed = transpose_tensor_data(f32_data);
assert_eq!(transposed.shape, vec![3, 2]);
let values = transposed.to_vec::<f32>().unwrap();
assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
// Test with I32
let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
let transposed = transpose_tensor_data(i32_data);
assert_eq!(transposed.shape, vec![3, 2]);
let values = transposed.to_vec::<i32>().unwrap();
assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
// Test with F64
let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
let transposed = transpose_tensor_data(f64_data);
assert_eq!(transposed.shape, vec![2, 2]);
let values = transposed.to_vec::<f64>().unwrap();
assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
}
#[test]
fn test_no_container_info() {
let adapter = PyTorchToBurnAdapter;
// Without container info, adapter returns unchanged for non-norm parameters
let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
snapshot.container_stack = None;
// Without container info, no transformation occurs for linear layers
let adapted = adapter.adapt(&snapshot);
assert_eq!(adapted.shape, vec![10, 5]); // No transposition without container info
// Test a non-linear, non-norm parameter - should pass through unchanged
let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Struct:Other");
snapshot2.container_stack = None;
let adapted2 = adapter.adapt(&snapshot2);
assert_eq!(adapted2.shape, vec![10, 5]); // No transposition
}
#[derive(Clone)]
struct RenameParamAdapter {
from: &'static str,
to: &'static str,
called: Arc<AtomicUsize>,
}
impl ModuleAdapter for RenameParamAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
self.called.fetch_add(1, Ordering::Relaxed);
let path_stack = match snapshot.path_stack.as_ref() {
Some(stack) => stack,
None => return snapshot.clone(),
};
let param = match path_stack.last() {
Some(p) => p.as_str(),
None => return snapshot.clone(),
};
if param != self.from {
return snapshot.clone();
}
let mut new_path = path_stack.to_vec();
*new_path.last_mut().unwrap() = self.to.to_string();
TensorSnapshot::from_closure(
snapshot.clone_data_fn(),
snapshot.dtype,
snapshot.shape.clone(),
new_path,
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_default(),
)
}
fn get_alternative_param_name(
&self,
_param_name: &str,
_container_type: &str,
) -> Option<String> {
None
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
#[derive(Clone)]
struct AltNameAdapter {
from: &'static str,
to: &'static str,
called: Arc<AtomicUsize>,
}
impl ModuleAdapter for AltNameAdapter {
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
TensorSnapshot::from_closure(
snapshot.clone_data_fn(),
snapshot.dtype,
snapshot.shape.clone(),
snapshot.path_stack.clone().unwrap_or_default(),
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_default(),
)
}
fn get_alternative_param_name(
&self,
param_name: &str,
_container_type: &str,
) -> Option<String> {
self.called.fetch_add(1, Ordering::Relaxed);
if param_name == self.from {
Some(self.to.to_string())
} else {
None
}
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(self.clone())
}
}
#[test]
fn test_chain_adapter_pipes_adapt() {
let called1 = Arc::new(AtomicUsize::new(0));
let called2 = Arc::new(AtomicUsize::new(0));
let a = RenameParamAdapter {
from: "weight",
to: "a",
called: called1.clone(),
};
let b = RenameParamAdapter {
from: "a",
to: "b",
called: called2.clone(),
};
let chain = a.chain(b);
let snapshot = create_test_snapshot("fc.weight", vec![2, 2], module_names::LINEAR);
let adapted = chain.adapt(&snapshot);
assert_eq!(adapted.full_path(), "fc.b");
assert_eq!(called1.load(Ordering::Relaxed), 1);
assert_eq!(called2.load(Ordering::Relaxed), 1);
}
#[test]
fn test_chain_adapter_alternative_name_pipes_and_fallbacks() {
let called1 = Arc::new(AtomicUsize::new(0));
let called2 = Arc::new(AtomicUsize::new(0));
let a = AltNameAdapter {
from: "gamma",
to: "weight",
called: called1.clone(),
};
let b = AltNameAdapter {
from: "weight",
to: "scale",
called: called2.clone(),
};
let chain = a.chain(b);
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
assert_eq!(alt.as_deref(), Some("scale"));
assert_eq!(called1.load(Ordering::Relaxed), 1);
assert_eq!(called2.load(Ordering::Relaxed), 1);
// If the second adapter doesn't have a mapping for the first alternative,
// fall back to the first alternative name.
let called1 = Arc::new(AtomicUsize::new(0));
let called2 = Arc::new(AtomicUsize::new(0));
let a = AltNameAdapter {
from: "gamma",
to: "weight",
called: called1.clone(),
};
let b = AltNameAdapter {
from: "something-else",
to: "unused",
called: called2.clone(),
};
let chain = a.chain(b);
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
assert_eq!(alt.as_deref(), Some("weight"));
assert_eq!(called1.load(Ordering::Relaxed), 1);
assert_eq!(called2.load(Ordering::Relaxed), 1);
// If the first adapter doesn't provide an alternative, try the second with the original name.
let called1 = Arc::new(AtomicUsize::new(0));
let called2 = Arc::new(AtomicUsize::new(0));
let a = AltNameAdapter {
from: "something-else",
to: "unused",
called: called1.clone(),
};
let b = AltNameAdapter {
from: "gamma",
to: "weight",
called: called2.clone(),
};
let chain = a.chain(b);
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
assert_eq!(alt.as_deref(), Some("weight"));
assert_eq!(called1.load(Ordering::Relaxed), 1);
assert_eq!(called2.load(Ordering::Relaxed), 1);
// clone_box must preserve behavior.
let boxed = chain.clone_box();
let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM);
assert_eq!(alt.as_deref(), Some("weight"));
}
}

View File

@@ -0,0 +1,608 @@
//! Applier that correctly applies tensor snapshots with adapter support
use alloc::boxed::Box;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use hashbrown::{HashMap, HashSet};
use burn_core::module::{ModuleMapper, Param};
use burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend};
use crate::apply_result::{ApplyError, ApplyResult};
use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
/// Applier that applies tensor snapshots to module parameters
/// with proper adapter support using container type information
pub struct Applier<B: Backend> {
/// Map of tensor paths to their snapshots
snapshots: HashMap<String, TensorSnapshot>,
/// Current path in the module hierarchy
path_stack: Vec<String>,
/// Current container type stack in the module hierarchy
container_stack: Vec<String>,
/// Optional filter for selective application
filter: Option<PathFilter>,
/// Optional adapter to transform tensors based on container types
adapter: Option<Box<dyn ModuleAdapter>>,
/// Successfully applied tensor paths
applied: Vec<String>,
/// Skipped tensor paths
skipped: HashSet<String>,
/// Errors encountered during application
errors: Vec<ApplyError>,
/// Track visited paths with their container stacks (in dot notation) to find missing tensors
visited_paths: HashMap<String, String>,
/// Skip enum variant names when matching paths
/// When true, "feature.BaseConv.weight" will also try to match "feature.weight"
skip_enum_variants: bool,
/// Phantom data for backend type
_backend: core::marker::PhantomData<B>,
}
impl<B: Backend> Applier<B> {
/// Create a new applier with snapshots, optional filter, and optional adapter
///
/// # Arguments
///
/// * `views` - A vector of TensorSnapshot objects to apply
/// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
/// When `None`, all available tensors are applied.
/// * `adapter` - Optional adapter to transform tensors based on container types
/// * `skip_enum_variants` - Skip enum variant names when matching paths
pub fn new(
views: Vec<TensorSnapshot>,
filter: Option<PathFilter>,
adapter: Option<Box<dyn ModuleAdapter>>,
skip_enum_variants: bool,
) -> Self {
let views_map: HashMap<String, TensorSnapshot> = views
.into_iter()
.map(|view| (view.full_path(), view))
.collect();
Self {
snapshots: views_map,
path_stack: Vec::new(),
container_stack: Vec::new(),
filter,
adapter,
applied: Vec::new(),
skipped: HashSet::new(),
errors: Vec::new(),
visited_paths: HashMap::new(),
skip_enum_variants,
_backend: core::marker::PhantomData,
}
}
/// Get the current path in the module hierarchy
fn current_path(&self) -> String {
self.path_stack.join(".")
}
/// Get the current module type (last Struct/Enum in container stack)
fn current_module_type(&self) -> Option<&str> {
self.container_stack
.iter()
.rev()
.find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
.map(|s| s.as_str())
}
/// Check if a tensor should be applied based on filter
fn should_apply(&self) -> bool {
match &self.filter {
None => true,
Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
}
}
/// Convert the applier into a result
pub fn into_result(self) -> ApplyResult {
let mut unused: Vec<String> = self
.snapshots
.keys()
.filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path))
.cloned()
.collect();
// Sort for stable output order
unused.sort();
// Create a set of successfully applied paths for efficient lookup
let applied_set: HashSet<String> = self.applied.iter().cloned().collect();
// Extract paths that have errors - these are not "missing", they were found but had issues
let errored_paths: HashSet<String> = self
.errors
.iter()
.map(|e| match e {
ApplyError::ShapeMismatch { path, .. } => path.clone(),
ApplyError::DTypeMismatch { path, .. } => path.clone(),
ApplyError::AdapterError { path, .. } => path.clone(),
ApplyError::LoadError { path, .. } => path.clone(),
})
.collect();
// A path is missing if it was visited but not successfully applied, not skipped, and didn't have an error
// Store both the path and its container stack (in dot notation)
let mut missing: Vec<(String, String)> = self
.visited_paths
.into_iter()
.filter(|(p, _)| {
!applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p)
})
.collect();
// Sort for stable output order (by path)
missing.sort_by(|a, b| a.0.cmp(&b.0));
// Convert skipped HashSet to sorted Vec for stable output
let mut skipped: Vec<String> = self.skipped.into_iter().collect();
skipped.sort();
ApplyResult {
applied: self.applied,
skipped,
missing,
unused,
errors: self.errors,
}
}
/// Apply a tensor snapshot with shape validation and optional adapter transformation
/// Returns None if snapshot not found, filtered, or validation fails
fn apply_tensor<const D: usize, K>(
&mut self,
target_device: &B::Device,
target_shape: Shape,
) -> Option<Tensor<B, D, K>>
where
K: burn_tensor::TensorKind<B>,
K: burn_tensor::BasicOps<B>,
{
let path = self.current_path();
let container_stack_str = self.container_stack.join(".");
self.visited_paths.insert(path.clone(), container_stack_str);
// Try to get snapshot with original path first
let mut snapshot = self.snapshots.get(&path).cloned();
// If not found and we have an adapter, try alternative parameter names
if snapshot.is_none()
&& let Some(ref adapter) = self.adapter
&& let Some(module_type) = self.current_module_type()
{
// Get alternative name based on current module type (user-defined module only)
let param_name = self.path_stack.last()?;
if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) {
// Build alternative path with parameter name substitution
let mut alt_path_stack = self.path_stack.clone();
*alt_path_stack.last_mut().unwrap() = alt_name.clone();
let alt_path = alt_path_stack.join(".");
// Try to get snapshot with alternative name
snapshot = self.snapshots.get(&alt_path).cloned();
// Don't mark the alternative path as visited - only the original Burn path
// should be tracked. The alternative path is just for lookup.
}
}
let mut snapshot = snapshot?;
// Apply adapter transformation using current container_stack context (for data transformation like transpose)
if let Some(ref adapter) = self.adapter {
// Create a temporary snapshot with current context for adaptation
let snapshot_with_context = TensorSnapshot::from_closure(
snapshot.clone_data_fn(),
snapshot.dtype,
snapshot.shape.clone(),
self.path_stack.clone(),
self.container_stack.clone(),
snapshot.tensor_id.unwrap_or_default(),
);
// Transform using adapter (handles transpose)
snapshot = adapter.adapt(&snapshot_with_context);
}
// Check if we should apply based on filter
if !self.should_apply() {
self.skipped.insert(path.clone());
return None;
}
// Load tensor data
let data = match snapshot.to_data() {
Ok(data) => data,
Err(e) => {
self.errors.push(ApplyError::LoadError {
path: path.clone(),
message: format!("Failed to load tensor data: {:?}", e),
});
return None; // Signal caller to fall back to initialization
}
};
// Validate shape
if data.shape != *target_shape {
self.errors.push(ApplyError::ShapeMismatch {
path: path.clone(),
expected: target_shape.to_vec(),
found: data.shape.clone(),
});
return None; // Signal caller to fall back to initialization
}
self.applied.push(path);
Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype))
}
}
impl<B: Backend> ModuleMapper<B> for Applier<B> {
fn enter_module(&mut self, name: &str, container_type: &str) {
// Always track the container type for proper module type detection
self.container_stack.push(container_type.to_string());
// Only add to path if it's not an enum variant (when skip_enum_variants is enabled)
// This ensures paths are built without enum variant names from the start
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
self.path_stack.push(name.to_string());
}
}
fn exit_module(&mut self, _name: &str, container_type: &str) {
self.container_stack.pop();
// Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
self.path_stack.pop();
}
}
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let param_id = param.id;
let target_device = param.lazy_device();
let target_shape = param.lazy_shape();
// Try to apply snapshot with shape validation
match self.apply_tensor(&target_device, target_shape) {
Some(tensor) => {
// We have a tensor to apply - load it
param.transform_for_load(tensor, param_id)
}
None => {
// No snapshot, filtered, or validation failed - return param unchanged
param
}
}
}
fn map_int<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Int>>,
) -> Param<Tensor<B, D, Int>> {
let param_id = param.id;
let target_device = param.lazy_device();
let target_shape = param.lazy_shape();
// Try to apply snapshot with shape validation
match self.apply_tensor(&target_device, target_shape) {
Some(tensor) => {
// We have a tensor to apply - load it
param.transform_for_load(tensor, param_id)
}
None => {
// No snapshot, filtered, or validation failed - return param unchanged
param
}
}
}
fn map_bool<const D: usize>(
&mut self,
param: Param<Tensor<B, D, Bool>>,
) -> Param<Tensor<B, D, Bool>> {
let param_id = param.id;
let target_device = param.lazy_device();
let target_shape = param.lazy_shape();
// Try to apply snapshot with shape validation
match self.apply_tensor(&target_device, target_shape) {
Some(tensor) => {
// We have a tensor to apply - load it
param.transform_for_load(tensor, param_id)
}
None => {
// No snapshot, filtered, or validation failed - return param unchanged
param
}
}
}
}
#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
mod tests {
use super::*;
use burn_core::module::{ModuleMapper, Param, ParamId};
use burn_tensor::{DType, Tensor, TensorData};
type TestBackend = burn_ndarray::NdArray;
#[test]
fn root_level_parameters() {
let device = Default::default();
// Create root-level parameters (not inside any module)
let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
// Create snapshots with root-level paths (single-element path, no nested modules)
let weight_snapshot = crate::TensorSnapshot::from_data(
weight.val().to_data(),
vec!["weight".to_string()], // root-level parameter name
vec![], // no container
ParamId::new(),
);
let bias_snapshot = crate::TensorSnapshot::from_data(
bias.val().to_data(),
vec!["bias".to_string()], // root-level parameter name
vec![], // no container
ParamId::new(),
);
// Create applier with root-level snapshots
let mut applier =
Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None, false);
// Create new params to load into
let weight_target = Param::initialized(
ParamId::new(),
Tensor::<TestBackend, 2>::zeros([2, 2], &device),
);
let bias_target = Param::initialized(
ParamId::new(),
Tensor::<TestBackend, 1>::zeros([2], &device),
);
// Apply using the ModuleMapper interface - simulate module traversal
// Enter "weight" path (as if we're visiting a field named "weight")
applier.enter_module("weight", "");
let weight_loaded = applier.map_float(weight_target);
applier.exit_module("weight", "");
// Enter "bias" path (as if we're visiting a field named "bias")
applier.enter_module("bias", "");
let bias_loaded = applier.map_float(bias_target);
applier.exit_module("bias", "");
// Verify values were loaded
let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(bias_data, vec![5.0, 6.0]);
// Verify applier result
let result = applier.into_result();
assert_eq!(result.applied.len(), 2);
assert_eq!(result.errors.len(), 0);
}
/// Test that the applier preserves dtype when loading tensor data.
/// This is a regression test for the bug where F16 tensors were being
/// loaded as F32 because `Tensor::from_data` was used instead of
/// `Tensor::from_data_dtype`.
#[test]
fn dtype_preservation_f64() {
// Use NdArray<f64> backend to properly test F64 dtype preservation
type TestBackendF64 = burn_ndarray::NdArray<f64>;
let device = Default::default();
// Create TensorData with F64 dtype explicitly
let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
assert_eq!(f64_data.dtype, DType::F64, "Test setup: data should be F64");
// Create a snapshot with F64 data
let snapshot = crate::TensorSnapshot::from_data(
f64_data.clone(),
vec!["weight".to_string()],
vec![],
ParamId::new(),
);
assert_eq!(
snapshot.dtype,
DType::F64,
"Snapshot should preserve F64 dtype"
);
// Create applier with the F64 snapshot
let mut applier = Applier::<TestBackendF64>::new(vec![snapshot], None, None, false);
// Create target parameter
let target = Param::initialized(
ParamId::new(),
Tensor::<TestBackendF64, 2>::zeros([2, 2], &device),
);
// Apply the snapshot
applier.enter_module("weight", "");
let loaded = applier.map_float(target);
applier.exit_module("weight", "");
// Verify dtype is preserved - this would fail before the fix
// because the data would be converted to the backend's default FloatElem
assert_eq!(
loaded.val().dtype(),
DType::F64,
"Loaded tensor should have F64 dtype"
);
// Verify data values are correct
let loaded_data = loaded.val().to_data().to_vec::<f64>().unwrap();
assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
// Verify applier result
let result = applier.into_result();
assert_eq!(result.applied.len(), 1);
assert_eq!(result.errors.len(), 0);
}
/// Test that F32 dtype is preserved when loading (verifies we didn't break F32 handling)
#[test]
fn dtype_preservation_f32() {
let device = Default::default();
// Create TensorData with F32 dtype
let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]);
assert_eq!(f32_data.dtype, DType::F32);
// Create a snapshot with F32 data
let snapshot = crate::TensorSnapshot::from_data(
f32_data.clone(),
vec!["weight".to_string()],
vec![],
ParamId::new(),
);
assert_eq!(snapshot.dtype, DType::F32);
// Create applier with the F32 snapshot
let mut applier = Applier::<TestBackend>::new(vec![snapshot], None, None, false);
// Create target parameter
let target = Param::initialized(
ParamId::new(),
Tensor::<TestBackend, 2>::zeros([2, 2], &device),
);
// Apply the snapshot
applier.enter_module("weight", "");
let loaded = applier.map_float(target);
applier.exit_module("weight", "");
// Verify dtype is F32
assert_eq!(loaded.val().dtype(), DType::F32);
// Verify data values
let loaded_data = loaded.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
}
/// Test that F16 dtype is correctly preserved in TensorSnapshot.
///
/// Note: Full F16 tensor loading requires a backend that supports F16
/// (e.g., CUDA, WebGPU). The NdArray backend does not support F16.
/// This test verifies that the snapshot correctly preserves F16 dtype,
/// which is the key part of the dtype preservation fix.
#[test]
fn dtype_preservation_f16_snapshot() {
use half::f16;
// Create TensorData with F16 dtype using the half crate
let f16_values: Vec<f16> = vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0),
];
let f16_data = TensorData::new(f16_values.clone(), [2, 2]);
assert_eq!(
f16_data.dtype,
DType::F16,
"TensorData should have F16 dtype"
);
// Create a snapshot with F16 data
let snapshot = crate::TensorSnapshot::from_data(
f16_data.clone(),
vec!["weight".to_string()],
vec![],
ParamId::new(),
);
// Verify snapshot preserves F16 dtype
assert_eq!(
snapshot.dtype,
DType::F16,
"TensorSnapshot should preserve F16 dtype"
);
// Verify the data can be retrieved with correct dtype
let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
assert_eq!(
retrieved_data.dtype,
DType::F16,
"Retrieved data should have F16 dtype"
);
// Verify the actual values are preserved
let retrieved_values: Vec<f16> = retrieved_data
.to_vec()
.expect("Should be able to convert to f16 vec");
assert_eq!(
retrieved_values, f16_values,
"F16 values should be preserved"
);
// Note: To fully test F16 tensor creation, you would need a backend
// that supports F16 (like CUDA or WebGPU). The applier fix ensures
// that `Tensor::from_data_dtype(data, device, snapshot.dtype)` is
// called with DType::F16, which will correctly create an F16 tensor
// on backends that support it.
}
/// Test that BF16 dtype is correctly preserved in TensorSnapshot.
#[test]
fn dtype_preservation_bf16_snapshot() {
use half::bf16;
// Create TensorData with BF16 dtype
let bf16_values: Vec<bf16> = vec![
bf16::from_f32(1.0),
bf16::from_f32(2.0),
bf16::from_f32(3.0),
bf16::from_f32(4.0),
];
let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]);
assert_eq!(
bf16_data.dtype,
DType::BF16,
"TensorData should have BF16 dtype"
);
// Create a snapshot with BF16 data
let snapshot = crate::TensorSnapshot::from_data(
bf16_data.clone(),
vec!["weight".to_string()],
vec![],
ParamId::new(),
);
// Verify snapshot preserves BF16 dtype
assert_eq!(
snapshot.dtype,
DType::BF16,
"TensorSnapshot should preserve BF16 dtype"
);
// Verify the data can be retrieved with correct dtype
let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
assert_eq!(
retrieved_data.dtype,
DType::BF16,
"Retrieved data should have BF16 dtype"
);
// Verify the actual values are preserved
let retrieved_values: Vec<bf16> = retrieved_data
.to_vec()
.expect("Should be able to convert to bf16 vec");
assert_eq!(
retrieved_values, bf16_values,
"BF16 values should be preserved"
);
}
}

View File

@@ -0,0 +1,300 @@
//! Result types and diagnostics for tensor application operations
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::DType;
/// Error types that can occur during tensor application
#[derive(Debug, Clone)]
pub enum ApplyError {
/// Shape mismatch between expected and actual tensor
ShapeMismatch {
/// Path of the tensor
path: String,
/// Expected shape
expected: Vec<usize>,
/// Found shape
found: Vec<usize>,
},
/// Data type mismatch between expected and actual tensor
DTypeMismatch {
/// Path of the tensor
path: String,
/// Expected data type
expected: DType,
/// Found data type
found: DType,
},
/// Error from adapter transformation
AdapterError {
/// Path of the tensor
path: String,
/// Error message
message: String,
},
/// Error loading tensor data
LoadError {
/// Path of the tensor
path: String,
/// Error message
message: String,
},
}
impl core::fmt::Display for ApplyError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ShapeMismatch {
path,
expected,
found,
} => {
write!(
f,
"Shape mismatch for '{}': expected {:?}, found {:?}",
path, expected, found
)
}
Self::DTypeMismatch {
path,
expected,
found,
} => {
write!(
f,
"DType mismatch for '{}': expected {:?}, found {:?}",
path, expected, found
)
}
Self::AdapterError { path, message } => {
write!(f, "Adapter error for '{}': {}", path, message)
}
Self::LoadError { path, message } => {
write!(f, "Load error for '{}': {}", path, message)
}
}
}
}
impl core::error::Error for ApplyError {}
/// Result of applying tensor snapshots to a module
#[derive(Clone)]
pub struct ApplyResult {
/// Successfully applied tensor paths
pub applied: Vec<String>,
/// Skipped tensor paths (due to filter)
pub skipped: Vec<String>,
/// Missing tensor paths with their container stacks in dot notation (path, container_stack)
/// Container stack shows the hierarchy: "Struct:Model.Struct:Linear" or "Struct:Model.Enum:ConvType.Struct:Linear"
pub missing: Vec<(String, String)>,
/// Unused tensor paths (in snapshots but not in module)
pub unused: Vec<String>,
/// Errors encountered during application
pub errors: Vec<ApplyError>,
}
impl ApplyResult {
/// Try to strip enum variant from a path
/// e.g., "field.BaseConv.weight" -> "field.weight"
fn strip_enum_variant(path: &str) -> Option<String> {
let segments: Vec<&str> = path.split('.').collect();
// Find segments that look like enum variants (CamelCase in middle of path)
let variant_indices: Vec<usize> = segments
.iter()
.enumerate()
.filter(|(i, segment)| {
*i > 0 && *i < segments.len() - 1 // Not first or last
&& !segment.is_empty()
&& segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
&& segment.len() > 1
&& segment.chars().skip(1).any(|c| c.is_lowercase())
})
.map(|(i, _)| i)
.collect();
if variant_indices.is_empty() {
return None;
}
// Remove the first found variant and return the modified path
let mut result_segments = segments.clone();
result_segments.remove(variant_indices[0]);
Some(result_segments.join("."))
}
/// Find similar paths for a given missing path (for "Did you mean?" suggestions)
fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec<String> {
// First, try exact match with enum variant stripped
if let Some(stripped) = Self::strip_enum_variant(missing_path)
&& self.unused.contains(&stripped)
{
return vec![stripped];
}
// Fall back to Jaro similarity (used by Elixir for "did you mean?" suggestions)
// Jaro gives higher weight to matching prefixes, ideal for hierarchical tensor paths
let mut similarities: Vec<(String, f64)> = self
.unused
.iter()
.map(|available| {
let similarity = textdistance::nstr::jaro(missing_path, available);
(available.clone(), similarity)
})
.collect();
// Sort by similarity (higher = more similar)
similarities
.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal));
// Only suggest paths with >= 70% similarity
const SIMILARITY_THRESHOLD: f64 = 0.7;
similarities
.into_iter()
.filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD)
.take(max_suggestions)
.map(|(path, _)| path)
.collect()
}
}
impl ApplyResult {
/// Check if the apply operation was successful (no errors)
/// Note: Missing tensors are not considered errors when allow_partial is true
pub fn is_success(&self) -> bool {
self.errors.is_empty()
}
}
impl core::fmt::Debug for ApplyResult {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// Delegate to Display for comprehensive output
core::fmt::Display::fmt(self, f)
}
}
impl core::fmt::Display for ApplyResult {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "┌─ Tensor Loading Summary ─────────────────────────")?;
writeln!(f, "")?;
writeln!(
f,
"│ ✓ Successfully applied: {} tensors",
self.applied.len()
)?;
writeln!(f, "│ ⊘ Skipped (filtered): {} tensors", self.skipped.len())?;
writeln!(
f,
"│ ✗ Missing in source: {} tensors",
self.missing.len()
)?;
writeln!(f, "│ ? Unused in target: {} tensors", self.unused.len())?;
writeln!(f, "│ ! Errors: {} errors", self.errors.len())?;
if !self.missing.is_empty() {
writeln!(f, "")?;
writeln!(
f,
"├─ Missing Tensors (requested by model but not found in source)"
)?;
writeln!(f, "")?;
// Use actual container stack data to detect enum variants
// Count how many missing paths have "Enum:" in their container stack
let enum_variant_missing: Vec<_> = self
.missing
.iter()
.filter(|(_, stack)| stack.contains("Enum:"))
.collect();
if !enum_variant_missing.is_empty() {
writeln!(
f,
"│ ⚠️ {} paths contain enum variants (detected from container stack)",
enum_variant_missing.len()
)?;
writeln!(
f,
"│ Burn includes enum variant names in paths, but PyTorch doesn't."
)?;
writeln!(
f,
"│ Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'"
)?;
writeln!(f, "")?;
writeln!(
f,
"│ 💡 Solution 1: Enable skip_enum_variants flag (simplest):"
)?;
writeln!(f, "")?;
writeln!(
f,
"│ let mut store = PytorchStore::from_file(\"model.pth\")"
)?;
writeln!(f, "│ .skip_enum_variants(true); // ← Add this")?;
writeln!(f, "")?;
writeln!(
f,
"│ 💡 Solution 2: Remap enum keys in source (most precise):"
)?;
writeln!(f, "")?;
writeln!(
f,
"│ let mut store = SafetensorsStore::from_file(\"model.safetensors\")"
)?;
writeln!(
f,
"│ .with_key_remapping(r\"field\\.(\\w+)\", \"field.BaseConv.$1\");"
)?;
writeln!(f, "")?;
}
writeln!(f, "│ First 10 missing tensors:")?;
for (path, _) in self.missing.iter().take(10) {
writeln!(f, "│ • {}", path)?;
// Show "Did you mean?" suggestions for this path
let suggestions = self.find_similar_paths(path, 1);
if !suggestions.is_empty() {
writeln!(f, "│ Did you mean: '{}'?", suggestions[0])?;
}
}
if self.missing.len() > 10 {
writeln!(f, "│ ... and {} more", self.missing.len() - 10)?;
}
}
if !self.unused.is_empty() {
writeln!(f, "")?;
writeln!(f, "├─ Unused Tensors (in source but not used by model)")?;
writeln!(f, "")?;
writeln!(f, "│ First 10 unused tensors:")?;
for path in self.unused.iter().take(10) {
writeln!(f, "│ • {}", path)?;
}
if self.unused.len() > 10 {
writeln!(f, "│ ... and {} more", self.unused.len() - 10)?;
}
}
if !self.errors.is_empty() {
writeln!(f, "")?;
writeln!(f, "├─ Errors")?;
writeln!(f, "")?;
for error in self.errors.iter().take(10) {
writeln!(f, "│ ⚠️ {}", error)?;
}
if self.errors.len() > 10 {
writeln!(f, "│ ... and {} more", self.errors.len() - 10)?;
}
}
writeln!(f, "")?;
write!(f, "└───────────────────────────────────────────────────")?;
Ok(())
}
}

View File

@@ -0,0 +1,231 @@
//! Core types and constants for the Burnpack file format.
//!
//! See the [parent module](crate::burnpack) for the complete file format specification.
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use burn_tensor::DType;
use byteorder::{ByteOrder, LittleEndian};
use serde::{Deserialize, Serialize};
/// Magic number identifying a Burnpack file: "BURN" in ASCII (0x4255524E)
/// When written to file in little-endian format, appears as "NRUB" bytes
pub const MAGIC_NUMBER: u32 = 0x4255524E;
/// Current format version
pub const FORMAT_VERSION: u16 = 0x0001;
/// Size of the magic number in bytes
pub const MAGIC_SIZE: usize = 4;
/// Size of the format version in bytes
pub const VERSION_SIZE: usize = 2;
/// Size of the metadata size field in bytes
pub const METADATA_SIZE_FIELD_SIZE: usize = 4;
/// Total header size (computed from components)
pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;
/// Alignment for tensor data in bytes.
///
/// All tensor data is aligned to 256-byte boundaries to enable efficient
/// memory-mapped (mmap) zero-copy loading. This alignment ensures:
/// - Proper pointer alignment for all tensor element types (f64 requires 8-byte alignment)
/// - Cache-line friendly access (most CPUs use 64-byte cache lines)
/// - GPU memory alignment (CUDA prefers 256-byte for coalesced access)
/// - Future-proofing for wider SIMD (AVX-512 = 64 bytes, future AVX-1024 = 128 bytes)
///
/// Industry alignment choices:
/// - 256-byte: GGUF, MLX, ncnn, MNN, TNN, vLLM-AWQ, Marlin (15+ formats)
/// - 64-byte: SafeTensors (minimum for AVX-512)
/// - 4096-byte: Core ML
///
/// 256-byte alignment has negligible overhead for typical tensor sizes while
/// providing maximum compatibility with current and future hardware.
pub const TENSOR_ALIGNMENT: u64 = 256;
/// Calculate the byte offset where the tensor data section starts.
///
/// The data section is padded to start at a 256-byte aligned position
/// so that all tensor offsets (which are relative to data section) result
/// in properly aligned absolute file positions for mmap zero-copy access.
///
/// This function must be used consistently by both writer and reader.
#[inline]
pub fn aligned_data_section_start(metadata_size: usize) -> usize {
let unaligned_start = (HEADER_SIZE + metadata_size) as u64;
// Keep multiplication in u64 space to avoid overflow on 32-bit systems
(unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize
}
// Security limits to prevent DoS attacks via resource exhaustion
// These can be adjusted based on your use case
/// Maximum allowed metadata size (100 MB)
/// Prevents memory exhaustion attacks via oversized metadata claims
pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;
/// Maximum allowed tensor size per tensor
/// Prevents memory exhaustion attacks via oversized tensor claims
/// 32-bit platforms: 2 GB limit (to fit within usize range)
/// 64-bit platforms: 10 GB limit
#[cfg(target_pointer_width = "32")]
pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;
#[cfg(not(target_pointer_width = "32"))]
pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;
/// Maximum allowed number of tensors (100,000)
/// Prevents resource exhaustion via excessive tensor counts
pub const MAX_TENSOR_COUNT: usize = 100_000;
/// Maximum CBOR deserialization recursion depth (128 levels)
/// Prevents stack overflow attacks via deeply nested CBOR structures
pub const MAX_CBOR_RECURSION_DEPTH: usize = 128;
/// Maximum allowed file size (100 GB)
/// Prevents resource exhaustion from extremely large files
/// This limit applies to file-based loading (mmap and buffered)
#[cfg(feature = "std")]
pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;
/// Byte range for magic number in header
pub const fn magic_range() -> core::ops::Range<usize> {
let start = 0;
let end = start + MAGIC_SIZE;
start..end
}
/// Byte range for format version in header
pub const fn version_range() -> core::ops::Range<usize> {
let start = MAGIC_SIZE;
let end = start + VERSION_SIZE;
start..end
}
/// Byte range for metadata size field in header
pub const fn metadata_size_range() -> core::ops::Range<usize> {
let start = MAGIC_SIZE + VERSION_SIZE;
let end = start + METADATA_SIZE_FIELD_SIZE;
start..end
}
// Compile-time validation that ranges are correct
const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);
/// Header structure for Burnpack files
#[derive(Debug, Clone, Copy)]
pub struct BurnpackHeader {
/// Magic number (4 bytes): 0x4255524E ("BURN")
pub magic: u32,
/// Format version (2 bytes)
pub version: u16,
/// Size of CBOR metadata in bytes (4 bytes)
pub metadata_size: u32,
}
impl BurnpackHeader {
/// Create a new header with the given metadata size
#[allow(dead_code)]
pub fn new(metadata_size: u32) -> Self {
Self {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
}
}
/// Serialize header into bytes
pub fn into_bytes(self) -> [u8; HEADER_SIZE] {
let mut bytes = [0u8; HEADER_SIZE];
LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);
LittleEndian::write_u16(&mut bytes[version_range()], self.version);
LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);
bytes
}
/// Deserialize header from bytes
pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {
if bytes.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
let magic = LittleEndian::read_u32(&bytes[magic_range()]);
if magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
let version = LittleEndian::read_u16(&bytes[version_range()]);
let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);
Ok(Self {
magic,
version,
metadata_size,
})
}
}
/// Metadata structure serialized with CBOR
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BurnpackMetadata {
/// Tensor descriptors mapped by name for efficient lookup
pub tensors: BTreeMap<String, TensorDescriptor>,
/// Optional additional metadata
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub metadata: BTreeMap<String, String>,
}
/// Individual tensor descriptor
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorDescriptor {
/// Data type of the tensor
pub dtype: DType,
/// Tensor shape dimensions
pub shape: Vec<u64>,
/// Byte offsets in data section (start, end)
pub data_offsets: (u64, u64),
/// Parameter ID for training state persistence matching.
/// Generated automatically if not present during loading.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub param_id: Option<u64>,
}
/// Error types for Burnpack operations
#[derive(Debug)]
pub enum BurnpackError {
InvalidHeader,
InvalidMagicNumber,
InvalidVersion,
MetadataSerializationError(String),
MetadataDeserializationError(String),
IoError(String),
TensorNotFound(String),
TensorBytesSizeMismatch(String),
ValidationError(String),
}
impl core::fmt::Display for BurnpackError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"),
BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"),
BurnpackError::InvalidVersion => write!(f, "Unsupported version"),
BurnpackError::MetadataSerializationError(e) => {
write!(f, "Metadata serialization error: {}", e)
}
BurnpackError::MetadataDeserializationError(e) => {
write!(f, "Metadata deserialization error: {}", e)
}
BurnpackError::IoError(e) => write!(f, "I/O error: {}", e),
BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
BurnpackError::TensorBytesSizeMismatch(e) => {
write!(f, "Tensor bytes size mismatch: {}", e)
}
BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e),
}
}
}
impl core::error::Error for BurnpackError {}

View File

@@ -0,0 +1,62 @@
//! # Burnpack - Native Burn Model Storage Format
//!
//! Burnpack is the native binary storage format for Burn models, designed for efficient
//! serialization, fast loading, and cross-platform compatibility.
//!
//! ## Key Features
//!
//! - **CBOR Metadata**: Structured metadata with efficient binary encoding
//! - **Memory-Mapped Loading**: Zero-copy loading for optimal performance
//! - **256-byte Tensor Alignment**: Enables efficient mmap zero-copy access
//! - **No-std Support**: Works in embedded and WASM environments
//! - **ParamId Persistence**: Preserves parameter identities for stateful training
//! - **Lazy Tensor Loading**: Deferred data materialization for efficient memory usage
//!
//! ## File Format Structure
//!
//! ```text
//! ┌──────────────────────────────────┐
//! │ Header (10 bytes) │
//! ├──────────────────────────────────┤
//! │ - Magic number (4 bytes) │ 0x4E525542 ("NRUB" in LE)
//! │ - Version (2 bytes) │ Format version (0x0001)
//! │ - Metadata size (4 bytes) │ Size of CBOR metadata (u32)
//! ├──────────────────────────────────┤
//! │ Metadata (CBOR) │
//! ├──────────────────────────────────┤
//! │ - Tensor descriptors (BTreeMap) │ Order-preserving map of tensor metadata
//! │ Key: tensor name (string) │ e.g., "model.layer1.weight"
//! │ Value: TensorDescriptor │
//! │ - dtype: DType │ Data type (F32, F64, I32, etc.)
//! │ - shape: Vec<u64> │ Tensor dimensions
//! │ - data_offsets: (u64, u64) │ (start, end) byte offsets (256-byte aligned)
//! │ - param_id: Option<u64> │ Parameter ID (for training state)
//! │ - Additional metadata(BTreeMap) │ User-defined key-value pairs
//! ├──────────────────────────────────┤
//! │ Tensor Data Section │
//! ├──────────────────────────────────┤
//! │ [padding][tensor1][padding]... │ Each tensor aligned to 256-byte boundary
//! │ Raw tensor bytes (little-endian)│ Enables mmap zero-copy loading
//! └──────────────────────────────────┘
//! ```
//!
//! ## Tensor Alignment
//!
//! All tensor data is aligned to 256-byte boundaries to enable efficient memory-mapped
//! (mmap) zero-copy loading. This alignment ensures:
//!
//! - Proper pointer alignment for all tensor element types (f64 requires 8 bytes)
//! - Cache-line friendly access (most CPUs use 64-byte cache lines)
//! - GPU memory alignment (CUDA prefers 256-byte for coalesced access)
//! - Future-proofing for wider SIMD instructions (AVX-512, future AVX-1024)
//!
//! The 256-byte alignment matches industry standards used by GGUF, MLX, ncnn, MNN,
//! and other major model formats.
pub mod base;
pub mod reader;
pub mod store;
pub mod writer;
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,761 @@
#[cfg(feature = "std")]
use super::base::MAX_FILE_SIZE;
use super::base::{
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE,
aligned_data_section_start,
};
use crate::TensorSnapshot;
use alloc::format;
use alloc::rc::Rc;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use burn_core::module::ParamId;
use burn_tensor::{Bytes, TensorData};
#[cfg(feature = "std")]
use std::cell::RefCell;
#[cfg(feature = "std")]
use std::fs::File;
#[cfg(feature = "std")]
use std::io::{Read, Seek};
#[cfg(feature = "std")]
use std::path::Path;
/// Storage backend for BurnpackReader
pub(crate) enum StorageBackend {
/// Memory-based storage (also used for memory-mapped files converted to bytes::Bytes)
Memory(Rc<Bytes>),
/// File-based storage with buffered reading
#[cfg(feature = "std")]
#[allow(dead_code)]
FileBuffered { file: Rc<RefCell<File>> },
}
impl StorageBackend {
/// Read data from storage into the provided buffer at the given offset.
///
/// # Arguments
/// * `bytes` - The buffer to read into (caller-allocated)
/// * `offset` - Absolute file/data position to start reading from
///
/// # Errors
///
/// Returns an error if:
/// - The requested data range is out of bounds
/// - Less data is available than requested (indicates corruption or incorrect offset)
/// - File I/O fails
///
/// # Notes
///
/// The caller allocates the buffer, which allows for buffer reuse and future optimizations
/// like memory pools and pinned memory.
///
/// This method ensures all backends have consistent behavior: if the exact number of
/// requested bytes cannot be read, an error is returned to prevent data corruption.
pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> {
match self {
StorageBackend::Memory(data) => {
let data_bytes = data.as_ref();
let end = offset.checked_add(bytes.len()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Offset overflow: offset {} + length {} exceeds maximum",
offset,
bytes.len()
))
})?;
if end > data_bytes.len() {
return Err(BurnpackError::IoError(format!(
"Read out of bounds: requested {}..{} but data length is {}",
offset,
end,
data_bytes.len()
)));
}
bytes.copy_from_slice(&data_bytes[offset..end]);
Ok(())
}
#[cfg(feature = "std")]
StorageBackend::FileBuffered { file } => {
use std::io::SeekFrom;
let mut file = file.borrow_mut();
file.seek(SeekFrom::Start(offset as u64)).map_err(|e| {
BurnpackError::IoError(format!("Failed to seek in file: {}", e))
})?;
file.read_exact(bytes).map_err(|e| {
BurnpackError::IoError(format!("Failed to read from file: {}", e))
})?;
Ok(())
}
}
}
/// Get full data reference for raw access
#[allow(dead_code)]
pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> {
match self {
StorageBackend::Memory(data) => Ok(data.as_ref()),
#[cfg(feature = "std")]
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
"Cannot get full bytes reference for FileBuffered backend".into(),
)),
}
}
/// Attempt to slice bytes without copying (zero-copy).
///
/// This uses `Bytes::clone()` + `split()` which is zero-copy when the underlying
/// `Bytes` was created via `Bytes::from_shared()` (backed by `bytes::Bytes`).
///
/// # Returns
/// - `Ok(bytes)` - Successfully created a zero-copy slice
/// - `Err(_)` - Backend doesn't support zero-copy or split failed
pub(crate) fn slice_bytes(&self, start: usize, end: usize) -> Result<Bytes, BurnpackError> {
if end < start {
return Err(BurnpackError::IoError(format!(
"Invalid slice range: end ({}) < start ({})",
end, start
)));
}
match self {
StorageBackend::Memory(data) => {
// Clone the Bytes - cheap if backed by SharedBytesAllocationController
let cloned = (**data).clone();
// Split at start offset to get (_, right)
let (_, right) = cloned.split(start).map_err(|(_, e)| {
BurnpackError::IoError(format!("Failed to split at start {}: {:?}", start, e))
})?;
// Split right at (end - start) to get (middle, _)
let slice_len = end - start;
let (middle, _) = right.split(slice_len).map_err(|(_, e)| {
BurnpackError::IoError(format!(
"Failed to split at length {}: {:?}",
slice_len, e
))
})?;
Ok(middle)
}
#[cfg(feature = "std")]
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
"Zero-copy not supported for buffered file reading. Use from_file() with memmap feature for zero-copy loading.".into(),
)),
}
}
}
/// Reader for loading Burnpack files
pub struct BurnpackReader {
/// Parsed metadata
pub(crate) metadata: BurnpackMetadata,
/// Storage backend
pub(crate) storage: StorageBackend,
/// Offset to the start of tensor data
pub(crate) data_offset: usize,
}
impl BurnpackReader {
/// Load from bytes
pub fn from_bytes(bytes: Bytes) -> Result<Self, BurnpackError> {
// Validate minimum size
if bytes.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
// Parse header
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?;
// Verify magic number
if header.magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
// Verify version compatibility
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
// Validate metadata size against security limit
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
// Parse metadata
let metadata_start = HEADER_SIZE;
let metadata_end = metadata_start
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
metadata_start, header.metadata_size
))
})?;
if bytes.len() < metadata_end {
return Err(BurnpackError::InvalidHeader);
}
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
&bytes[metadata_start..metadata_end],
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
// Validate tensor count against security limit
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
// Validate total file size - ensure file is large enough for all claimed tensor data
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
if bytes.len() < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size,
bytes.len()
)));
}
}
Ok(Self {
metadata,
storage: StorageBackend::Memory(Rc::new(bytes)),
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
/// Load from file with memory mapping (most efficient for large files)
#[cfg(all(feature = "std", feature = "memmap"))]
pub(crate) fn from_file_mmap<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
// Validate maximum file size to prevent resource exhaustion
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len();
if file_size > MAX_FILE_SIZE {
return Err(BurnpackError::ValidationError(format!(
"File size {} bytes exceeds maximum allowed size of {} bytes",
file_size, MAX_FILE_SIZE
)));
}
// Memory map the file
let mmap = unsafe {
memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| BurnpackError::IoError(e.to_string()))?
};
// Parse header
if mmap.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?;
// Verify magic number and version
if header.magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
// Validate metadata size against security limit
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
// Parse metadata
let metadata_start = HEADER_SIZE;
let metadata_end = metadata_start
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
metadata_start, header.metadata_size
))
})?;
if mmap.len() < metadata_end {
return Err(BurnpackError::InvalidHeader);
}
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
&mmap[metadata_start..metadata_end],
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
// Validate tensor count against security limit
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
// Validate total file size - ensure file is large enough for all claimed tensor data
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
if mmap.len() < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size,
mmap.len()
)));
}
}
// Convert mmap to bytes::Bytes for zero-copy slicing support
// bytes::Bytes::from_owner takes ownership and enables efficient slicing
let shared_bytes = bytes::Bytes::from_owner(mmap);
let bytes = Bytes::from_shared(shared_bytes, burn_tensor::AllocationProperty::File);
Ok(Self {
metadata,
storage: StorageBackend::Memory(Rc::new(bytes)),
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
/// Load from file - automatically uses memory mapping if available, otherwise uses buffered reading
#[cfg(feature = "std")]
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
#[cfg(feature = "memmap")]
{
// Use memory mapping for efficient access
Self::from_file_mmap(path)
}
#[cfg(not(feature = "memmap"))]
{
// Fall back to buffered reading for memory efficiency
Self::from_file_buffered(path)
}
}
/// Load from file with buffered reading (memory efficient but slower)
/// This is less efficient than memory mapping but works everywhere
#[cfg(feature = "std")]
#[allow(dead_code)]
pub(crate) fn from_file_buffered<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
// Validate maximum file size to prevent resource exhaustion
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len();
if file_size > MAX_FILE_SIZE {
return Err(BurnpackError::ValidationError(format!(
"File size {} bytes exceeds maximum allowed size of {} bytes",
file_size, MAX_FILE_SIZE
)));
}
// Read header
let mut header_bytes = [0u8; HEADER_SIZE];
file.read_exact(&mut header_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
let header = BurnpackHeader::from_bytes(&header_bytes)?;
// Verify version
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
// Validate metadata size against security limit
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
// Read metadata
let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
file.read_exact(&mut metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
metadata_bytes.as_slice(),
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
// Validate tensor count against security limit
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
// Calculate metadata end offset
let metadata_end = HEADER_SIZE
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
HEADER_SIZE, header.metadata_size
))
})?;
// Validate total file size - ensure file is large enough for all claimed tensor data
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
// Get actual file size
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len() as usize;
if file_size < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size, file_size
)));
}
}
Ok(Self {
metadata,
storage: StorageBackend::FileBuffered {
file: Rc::new(RefCell::new(file)),
},
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
/// Get all tensor snapshots at once for efficient loading (always copies data)
pub fn get_snapshots(&self) -> Result<Vec<TensorSnapshot>, BurnpackError> {
self.get_snapshots_internal(false)
}
/// Get all tensor snapshots with optional zero-copy loading.
///
/// When `zero_copy` is true and the backend supports it (Memory backend with
/// `Bytes::from_shared()`), tensor data is sliced without copying. This keeps
/// the original data alive as long as any tensor holds a reference.
///
/// When `zero_copy` is false or the backend doesn't support it, data is copied
/// into newly allocated buffers (default behavior).
pub fn get_snapshots_zero_copy(
&self,
zero_copy: bool,
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
self.get_snapshots_internal(zero_copy)
}
/// Internal implementation with optional zero-copy support
fn get_snapshots_internal(
&self,
zero_copy: bool,
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
let mut snapshots = Vec::new();
for (name, descriptor) in &self.metadata.tensors {
// Clone metadata for use in closure
// Convert shape dimensions with overflow checking
let shape: Vec<usize> = descriptor
.shape
.iter()
.map(|&s| {
s.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum",
name, s
))
})
})
.collect::<Result<Vec<usize>, BurnpackError>>()?;
let dtype = descriptor.dtype;
// Clone storage reference for the closure
let storage = match &self.storage {
StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()),
#[cfg(feature = "std")]
StorageBackend::FileBuffered { file } => {
StorageBackend::FileBuffered { file: file.clone() }
}
};
// Always use absolute positions for all backends
// Convert offsets with overflow checking
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
name, descriptor.data_offsets.0
))
})?;
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
name, descriptor.data_offsets.1
))
})?;
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
name, self.data_offset, offset_start
))
})?;
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
name, self.data_offset, offset_end
))
})?;
// Clone shape for the closure (TensorSnapshot::from_closure will also need it)
let shape_for_closure = shape.clone();
// Validate offset range
if end < start {
return Err(BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
name, end, start
)));
}
// Validate tensor size against security limit
let tensor_size = end - start;
if tensor_size > MAX_TENSOR_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
name, tensor_size, MAX_TENSOR_SIZE
)));
}
// Restore param_id if it was saved, otherwise generate
let tensor_id = descriptor
.param_id
.map(ParamId::from)
.unwrap_or_else(ParamId::new);
// Create the data-loading closure based on zero_copy flag
let data_fn: Rc<dyn Fn() -> Result<TensorData, crate::TensorSnapshotError>> =
if zero_copy {
// Zero-copy closure: slice without copying, error if not supported
Rc::new(move || {
let bytes = storage.slice_bytes(start, end).map_err(|e| {
crate::TensorSnapshotError::IoError(format!(
"Zero-copy slice failed: {}",
e
))
})?;
Ok(TensorData::from_bytes(
bytes,
shape_for_closure.clone(),
dtype,
))
})
} else {
// Copying closure: always allocate and copy
Rc::new(move || {
let len = end - start;
// TODO Should be allocated by the backend in the future
// See https://github.com/tracel-ai/burn/pull/3792#discussion_r2416812091
let mut data_bytes = vec![0u8; len];
storage.read_into(&mut data_bytes, start).map_err(|e| {
crate::TensorSnapshotError::IoError(format!(
"Failed to read tensor data: {}",
e
))
})?;
Ok(TensorData::from_bytes_vec(
data_bytes,
shape_for_closure.clone(),
dtype,
))
})
};
// Create lazy TensorSnapshot
let snapshot = TensorSnapshot::from_closure(
data_fn,
dtype,
shape,
name.split('.').map(|s| s.to_string()).collect(),
vec![], // empty container_stack
tensor_id, // restored or newly generated param id
);
snapshots.push(snapshot);
}
Ok(snapshots)
}
// Legacy methods for test compatibility - will be removed
/// Get tensor as TensorSnapshot with lazy loading
#[allow(dead_code)]
pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result<TensorSnapshot, BurnpackError> {
let snapshots = self.get_snapshots()?;
snapshots
.into_iter()
.find(|s| s.full_path() == name)
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))
}
/// Get list of tensor names
#[allow(dead_code)]
pub(crate) fn tensor_names(&self) -> Vec<&str> {
self.metadata
.tensors
.keys()
.map(|name| name.as_str())
.collect()
}
/// Get metadata
#[allow(dead_code)]
pub(crate) fn metadata(&self) -> &BurnpackMetadata {
&self.metadata
}
/// Get tensor data as raw bytes
#[allow(dead_code)]
pub(crate) fn get_tensor_data(&self, name: &str) -> Result<Vec<u8>, BurnpackError> {
let descriptor = self
.metadata
.tensors
.get(name)
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?;
// Always use absolute positions for all backends
// Convert offsets with overflow checking
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
name, descriptor.data_offsets.0
))
})?;
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
name, descriptor.data_offsets.1
))
})?;
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
name, self.data_offset, offset_start
))
})?;
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
name, self.data_offset, offset_end
))
})?;
// Validate offset range
if end < start {
return Err(BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
name, end, start
)));
}
let len = end - start;
let mut buffer = vec![0u8; len];
self.storage.read_into(&mut buffer, start)?;
Ok(buffer)
}
}

View File

@@ -0,0 +1,507 @@
#[cfg(feature = "std")]
use std::path::PathBuf;
use super::reader::BurnpackReader;
use super::writer::BurnpackWriter;
#[cfg(feature = "std")]
use crate::KeyRemapper;
use crate::burnpack::base::BurnpackError;
use crate::{ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot};
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use burn_core::prelude::Backend;
use burn_tensor::Bytes;
/// Store mode for BurnpackStore
enum StoreMode {
#[cfg(feature = "std")]
File(PathBuf),
Bytes(Option<Bytes>),
}
/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
pub struct BurnpackStore {
/// Store mode - either file path or bytes
mode: StoreMode,
/// Optional filter for selective loading/saving
filter: Option<PathFilter>,
/// Additional metadata
metadata: BTreeMap<String, String>,
/// Allow partial loading (ignore missing tensors)
allow_partial: bool,
/// Validate tensors during loading (check shapes and dtypes)
validate: bool,
/// Allow overwriting existing files (default: false)
overwrite: bool,
/// Enable zero-copy tensor loading (default: false)
///
/// When enabled and the backend supports it, tensor data is sliced from
/// the source without copying. This requires keeping the source data alive.
zero_copy: bool,
/// Automatically append .bpk extension if not present (default: true)
#[cfg(feature = "std")]
auto_extension: bool,
/// Key remapper for tensor name transformations
#[cfg(feature = "std")]
remapper: KeyRemapper,
/// Writer for saving
writer: Option<BurnpackWriter>,
/// Reader for loading
reader: Option<BurnpackReader>,
/// Cached tensor snapshots (parsed once, reused)
snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
}
impl BurnpackStore {
/// Get the default metadata that includes Burn framework information.
///
/// This includes:
/// - `format`: "burnpack"
/// - `producer`: "burn"
/// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
///
/// These metadata fields are automatically added to all saved models.
pub fn default_metadata() -> BTreeMap<String, String> {
let mut metadata = BTreeMap::new();
metadata.insert("format".into(), "burnpack".into());
metadata.insert("producer".into(), "burn".into());
metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
metadata
}
/// Create a new store from a file path
///
/// By default, automatically appends `.bpk` extension if the path doesn't have one.
/// Use `.auto_extension(false)` to disable this behavior.
///
/// # Examples
///
/// ```no_run
/// # use burn_store::BurnpackStore;
/// // Automatically appends .bpk
/// let store = BurnpackStore::from_file("model"); // creates "model.bpk"
///
/// // Already has extension, no append
/// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk"
/// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext"
///
/// // Disable auto-extension
/// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model"
/// ```
#[cfg(feature = "std")]
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
Self {
mode: StoreMode::File(path.as_ref().to_path_buf()),
filter: None,
metadata: Self::default_metadata(),
allow_partial: false,
validate: true,
overwrite: false,
zero_copy: false,
#[cfg(feature = "std")]
auto_extension: true,
#[cfg(feature = "std")]
remapper: KeyRemapper::new(),
writer: None,
reader: None,
snapshots_cache: None,
}
}
/// Create a new store from bytes (for reading) or empty (for writing)
pub fn from_bytes(bytes: Option<Bytes>) -> Self {
Self {
mode: StoreMode::Bytes(bytes),
filter: None,
metadata: Self::default_metadata(),
allow_partial: false,
validate: true,
overwrite: false,
zero_copy: false,
#[cfg(feature = "std")]
auto_extension: false, // Not used for bytes mode
#[cfg(feature = "std")]
remapper: KeyRemapper::new(),
writer: None,
reader: None,
snapshots_cache: None,
}
}
/// Create a new store from static bytes with zero-copy loading enabled.
///
/// This is optimized for embedded model weights where the data lives in the
/// binary's `.rodata` section. Tensor data is sliced without copying, keeping
/// the static reference alive.
///
/// # Example
///
/// ```ignore
/// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
/// let store = BurnpackStore::from_static(MODEL_DATA);
/// ```
pub fn from_static(data: &'static [u8]) -> Self {
use burn_tensor::AllocationProperty;
// Create bytes::Bytes from static data (zero-copy, stays in .rodata)
let shared = bytes::Bytes::from_static(data);
// Wrap in cubecl Bytes with shared-bytes allocation controller
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
Self {
mode: StoreMode::Bytes(Some(bytes)),
filter: None,
metadata: Self::default_metadata(),
allow_partial: false,
validate: true,
overwrite: false,
zero_copy: true, // Enable zero-copy by default for static data
#[cfg(feature = "std")]
auto_extension: false,
#[cfg(feature = "std")]
remapper: KeyRemapper::new(),
writer: None,
reader: None,
snapshots_cache: None,
}
}
/// Add metadata key-value pair
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
/// Clear all metadata (including defaults)
///
/// This removes all metadata including the default format, producer, and version fields.
/// Use with caution as some tools may expect these fields to be present.
pub fn clear_metadata(mut self) -> Self {
self.metadata.clear();
self
}
/// Allow partial loading (ignore missing tensors)
///
/// When set to `true`, the store will not fail if some tensors are missing
/// during loading. This is useful when loading a subset of a model's parameters.
///
/// Default: `false`
pub fn allow_partial(mut self, allow: bool) -> Self {
self.allow_partial = allow;
self
}
/// Enable or disable validation during loading
///
/// When validation is enabled, the store will check that loaded tensors
/// match the expected shapes and data types. Disabling validation can
/// improve performance but may lead to runtime errors if data is corrupted.
///
/// Default: `true`
pub fn validate(mut self, validate: bool) -> Self {
self.validate = validate;
self
}
/// Allow overwriting existing files when saving
///
/// When set to `false`, attempting to save to an existing file will result in an error.
/// When set to `true`, existing files will be overwritten without warning.
///
/// Default: `false`
pub fn overwrite(mut self, overwrite: bool) -> Self {
self.overwrite = overwrite;
self
}
/// Enable or disable zero-copy tensor loading.
///
/// When enabled and the backend supports it (memory-backed with shared bytes),
/// tensor data is sliced from the source without copying. This keeps the source
/// data alive as long as any tensor holds a reference.
///
/// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
/// Use this method to enable it for other memory-backed stores created with
/// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
///
/// Default: `false` (except for `from_static` which defaults to `true`)
pub fn zero_copy(mut self, enable: bool) -> Self {
self.zero_copy = enable;
self
}
/// Enable or disable automatic .bpk extension appending
///
/// When enabled (default), automatically appends `.bpk` to the file path
/// if no extension is detected. If an extension is already present, it is preserved.
///
/// When disabled, uses the exact path provided without modification.
///
/// Default: `true`
///
/// # Examples
///
/// ```no_run
/// # use burn_store::BurnpackStore;
/// // With auto_extension enabled (default)
/// let store = BurnpackStore::from_file("model"); // -> "model.bpk"
///
/// // With auto_extension disabled
/// let store = BurnpackStore::from_file("model")
/// .auto_extension(false); // -> "model"
/// ```
#[cfg(feature = "std")]
pub fn auto_extension(mut self, enable: bool) -> Self {
self.auto_extension = enable;
self
}
/// Set path filter for selective loading/saving
pub fn with_filter(mut self, filter: PathFilter) -> Self {
self.filter = Some(filter);
self
}
/// Add regex pattern to filter
#[cfg(feature = "std")]
pub fn with_regex(mut self, pattern: &str) -> Self {
let filter = self.filter.unwrap_or_default();
self.filter = Some(filter.with_regex(pattern));
self
}
/// Add exact path to filter
pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
let filter = self.filter.unwrap_or_default();
self.filter = Some(filter.with_full_path(path));
self
}
/// Match all tensors (no filtering)
pub fn match_all(mut self) -> Self {
self.filter = Some(PathFilter::new().match_all());
self
}
/// Set key remapper for tensor name transformations during loading
#[cfg(feature = "std")]
pub fn remap(mut self, remapper: KeyRemapper) -> Self {
self.remapper = remapper;
self
}
/// Add a single regex pattern for key remapping
#[cfg(feature = "std")]
pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
where
S1: AsRef<str>,
S2: Into<String>,
{
self.remapper = self
.remapper
.add_pattern(from.as_ref(), to.into())
.expect("Invalid regex pattern");
self
}
/// Set the path filter
pub fn filter(mut self, filter: PathFilter) -> Self {
self.filter = Some(filter);
self
}
/// Get the bytes after writing (only valid for bytes mode after collecting)
pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
if let Some(writer) = &self.writer {
return writer.to_bytes();
}
match &self.mode {
StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
_ => Err(BurnpackError::IoError("No bytes available".into())),
}
}
/// Process the file path with auto-extension logic
#[cfg(feature = "std")]
fn process_path(&self, path: &std::path::Path) -> PathBuf {
if !self.auto_extension {
return path.to_path_buf();
}
// Check if path already has an extension
if path.extension().is_some() {
// Has extension, use as-is
return path.to_path_buf();
}
// No extension, append .bpk
let mut new_path = path.to_path_buf();
new_path.set_extension("bpk");
new_path
}
/// Ensure the reader is initialized, loading from storage if needed
fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
if self.reader.is_none() {
let reader = match &self.mode {
#[cfg(feature = "std")]
StoreMode::File(path) => {
let final_path = self.process_path(path);
BurnpackReader::from_file(&final_path)?
}
StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
StoreMode::Bytes(None) => {
return Err(BurnpackError::IoError("No bytes to read from".into()));
}
};
self.reader = Some(reader);
}
self.reader
.as_ref()
.ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
}
}
impl ModuleStore for BurnpackStore {
type Error = BurnpackError;
fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
module: &M,
) -> Result<(), Self::Error> {
// Invalidate cache since we're writing new data
self.snapshots_cache = None;
self.reader = None;
// Collect snapshots from module
let snapshots = module.collect(self.filter.clone(), None, false);
// Initialize writer with snapshots
let mut writer = BurnpackWriter::new(snapshots);
// Add metadata using builder pattern
for (key, value) in &self.metadata {
writer = writer.with_metadata(key.as_str(), value.as_str());
}
// Store the writer for finalization
self.writer = Some(writer);
// Write to storage based on mode
if let Some(writer) = &self.writer {
match &self.mode {
#[cfg(feature = "std")]
StoreMode::File(path) => {
// Process path with auto-extension logic
let final_path = self.process_path(path);
// Check if file exists and overwrite is disabled
if final_path.exists() && !self.overwrite {
return Err(BurnpackError::IoError(format!(
"File already exists: {}. Use .overwrite(true) to overwrite.",
final_path.display()
)));
}
writer.write_to_file(&final_path)?;
}
StoreMode::Bytes(_) => {
// Generate and store the bytes
let bytes_data = writer.to_bytes()?;
// Update mode with bytes - this pattern is irrefutable in no-std mode
#[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
unreachable!("We just matched Bytes variant");
};
*bytes_ref = Some(bytes_data);
}
}
}
Ok(())
}
fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
module: &mut M,
) -> Result<crate::ApplyResult, Self::Error> {
// Get all snapshots using the cached method
let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
// Apply all snapshots at once to the module
// Burnpack is Burn's native format, so no enum variant skipping needed
// Filter is applied here during apply, not during cache population
let result = module.apply(snapshots, self.filter.clone(), None, false);
// Validate if needed
if self.validate && !result.errors.is_empty() {
return Err(BurnpackError::ValidationError(format!(
"Import errors: {:?}",
result.errors
)));
}
// Check for missing tensors if partial loading is not allowed
if !self.allow_partial && !result.missing.is_empty() {
return Err(BurnpackError::ValidationError(format!(
"Missing tensors: {:?}",
result.missing
)));
}
Ok(result)
}
fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
// Ensure cache is populated
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap().get(name))
}
fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
// Ensure cache is populated
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap())
}
fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
// Always use the cache to ensure remapping is applied consistently
Ok(self.get_all_snapshots()?.keys().cloned().collect())
}
}
impl BurnpackStore {
/// Ensure the snapshots cache is populated
fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
if self.snapshots_cache.is_some() {
return Ok(());
}
// Ensure reader is loaded
self.ensure_reader()?;
// Get snapshots from reader with zero-copy if enabled
let reader = self.reader.as_ref().unwrap();
let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
// Apply remapping if configured (but NOT filtering - that's done at apply time)
#[cfg(feature = "std")]
let snapshots = if !self.remapper.patterns.is_empty() {
let (remapped, _remapped_names) = self.remapper.remap(snapshots);
remapped
} else {
snapshots
};
// Build the cache as BTreeMap
let cache: BTreeMap<String, TensorSnapshot> =
snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
self.snapshots_cache = Some(cache);
Ok(())
}
}

View File

@@ -0,0 +1,434 @@
//! Tests for tensor data alignment in burnpack format.
//!
//! These tests verify that tensor data is properly aligned for mmap zero-copy access.
use crate::TensorSnapshot;
use crate::burnpack::{
base::{
BurnpackHeader, BurnpackMetadata, HEADER_SIZE, TENSOR_ALIGNMENT, aligned_data_section_start,
},
reader::BurnpackReader,
writer::BurnpackWriter,
};
use burn_core::module::ParamId;
use burn_tensor::{DType, TensorData};
/// Verify that aligned_data_section_start always returns 256-byte aligned values
#[test]
fn test_aligned_data_section_start_is_always_aligned() {
// Test various metadata sizes
for metadata_size in 0..1024 {
let result = aligned_data_section_start(metadata_size);
assert_eq!(
result % TENSOR_ALIGNMENT as usize,
0,
"aligned_data_section_start({}) = {} is not aligned to {}",
metadata_size,
result,
TENSOR_ALIGNMENT
);
}
}
/// Verify data section starts at correct aligned position
#[test]
fn test_data_section_alignment() {
// Create a tensor
let data = [1.0f32, 2.0, 3.0, 4.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![4], DType::F32),
vec!["tensor".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
// Parse header to get metadata size
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
// Verify data section starts at 256-byte aligned position
assert_eq!(
data_section_start % TENSOR_ALIGNMENT as usize,
0,
"Data section start {} is not 256-byte aligned",
data_section_start
);
// Verify the file is large enough
assert!(
file_bytes.len() >= data_section_start,
"File too small: {} < {}",
file_bytes.len(),
data_section_start
);
}
/// Verify that first tensor's absolute file position is 256-byte aligned
#[test]
fn test_first_tensor_absolute_position_aligned() {
let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![8], DType::U8),
vec!["first".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
let tensor_desc = metadata.tensors.get("first").unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
// Absolute file position of first tensor
let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;
assert_eq!(
absolute_pos % TENSOR_ALIGNMENT as usize,
0,
"First tensor absolute position {} is not 256-byte aligned",
absolute_pos
);
}
/// Verify that all tensors in a multi-tensor file have 256-byte aligned absolute positions
#[test]
fn test_all_tensors_absolute_positions_aligned() {
// Create multiple tensors of different sizes (all U8 to simplify shape calculation)
let tensors = vec![
("tensor_a", vec![1u8, 2, 3]), // 3 bytes
("tensor_b", vec![0u8; 16]), // 16 bytes
("tensor_c", vec![0u8; 64]), // 64 bytes
("tensor_d", vec![42u8]), // 1 byte
("tensor_e", vec![0u8; 400]), // 400 bytes
];
let snapshots: Vec<TensorSnapshot> = tensors
.into_iter()
.map(|(name, data)| {
let len = data.len();
TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![len], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
)
})
.collect();
let writer = BurnpackWriter::new(snapshots);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
// Check every tensor has aligned absolute position
for (name, desc) in &metadata.tensors {
let absolute_pos = data_section_start + desc.data_offsets.0 as usize;
assert_eq!(
absolute_pos % TENSOR_ALIGNMENT as usize,
0,
"Tensor '{}' at absolute position {} is not 256-byte aligned (offset in data section: {})",
name,
absolute_pos,
desc.data_offsets.0
);
}
}
/// Test edge case: metadata size that results in no padding needed
#[test]
fn test_alignment_with_minimal_padding() {
// We can't control metadata size directly, but we can verify the math works
// When HEADER_SIZE + metadata_size is already a multiple of 256, no padding needed
let aligned_metadata_size = TENSOR_ALIGNMENT as usize - HEADER_SIZE; // 256 - 10 = 246
let result = aligned_data_section_start(aligned_metadata_size);
assert_eq!(result, TENSOR_ALIGNMENT as usize); // Should be exactly 256
// One byte more should still round up to 256
let result_plus_one = aligned_data_section_start(aligned_metadata_size + 1);
assert_eq!(result_plus_one, 2 * TENSOR_ALIGNMENT as usize); // Should be 512
}
/// Verify padding bytes in the file are zeros
#[test]
fn test_padding_bytes_are_zeros() {
let data: Vec<u8> = vec![0xAA; 16]; // Distinctive pattern
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![16], DType::U8),
vec!["tensor".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
// Check padding between metadata and data section
if data_section_start > metadata_end {
let padding = &file_bytes[metadata_end..data_section_start];
assert!(
padding.iter().all(|&b| b == 0),
"Padding bytes between metadata and data section contain non-zero values"
);
}
}
/// Verify alignment is sufficient for all primitive types
/// 256-byte alignment is a multiple of all primitive type alignments:
/// - f64/i64/u64: 8 bytes
/// - f32/i32/u32: 4 bytes
/// - f16/bf16/i16/u16: 2 bytes
/// - i8/u8/bool: 1 byte
#[test]
#[allow(clippy::modulo_one)]
fn test_alignment_covers_all_primitive_types() {
// 256 must be divisible by all common alignments
assert_eq!(
TENSOR_ALIGNMENT % 8,
0,
"256 not divisible by 8 (f64 alignment)"
);
assert_eq!(
TENSOR_ALIGNMENT % 4,
0,
"256 not divisible by 4 (f32 alignment)"
);
assert_eq!(
TENSOR_ALIGNMENT % 2,
0,
"256 not divisible by 2 (f16 alignment)"
);
assert_eq!(
TENSOR_ALIGNMENT % 1,
0,
"256 not divisible by 1 (u8 alignment)"
);
}
/// Verify that tensor data can be read correctly after alignment
#[test]
fn test_aligned_tensor_data_readable() {
// Create f32 tensor
let f32_data = vec![1.0f32, 2.0, 3.0, 4.0];
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes.clone(), vec![4], DType::F32),
vec!["floats".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
let tensor_desc = metadata.tensors.get("floats").unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
let tensor_bytes = &file_bytes[start..end];
// Verify the bytes match what we wrote
assert_eq!(tensor_bytes, f32_bytes.as_slice());
// Verify we can interpret them as floats
let mut floats = Vec::new();
for chunk in tensor_bytes.chunks_exact(4) {
floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
assert_eq!(floats, f32_data);
}
/// Verify alignment works with f64 data
#[test]
fn test_aligned_f64_tensor_data_readable() {
let f64_data = vec![1.0f64, 2.0, 3.0, 4.0];
let f64_bytes: Vec<u8> = f64_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f64_bytes.clone(), vec![4], DType::F64),
vec!["doubles".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
let tensor_desc = metadata.tensors.get("doubles").unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
let tensor_bytes = &file_bytes[start..end];
// Verify the bytes match
assert_eq!(tensor_bytes, f64_bytes.as_slice());
// Verify we can interpret them as doubles
let mut doubles = Vec::new();
for chunk in tensor_bytes.chunks_exact(8) {
doubles.push(f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]));
}
assert_eq!(doubles, f64_data);
}
/// Test round-trip preserves alignment (write then read)
#[test]
fn test_round_trip_maintains_alignment() {
let f32_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes, vec![2, 4], DType::F32),
vec!["matrix".to_string()],
vec![],
ParamId::new(),
);
// Write
let writer = BurnpackWriter::new(vec![snapshot]);
let file_bytes = writer.to_bytes().unwrap();
// Read back
let reader = BurnpackReader::from_bytes(file_bytes.clone()).unwrap();
let snapshots = reader.get_snapshots().unwrap();
assert_eq!(snapshots.len(), 1);
let loaded = &snapshots[0];
assert_eq!(loaded.full_path(), "matrix");
// Verify the loaded data is correct
let tensor_data = loaded.to_data().unwrap();
let mut loaded_floats = Vec::new();
for chunk in tensor_data.bytes.chunks_exact(4) {
loaded_floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
assert_eq!(loaded_floats, f32_data);
}
/// Test that tensor offsets within data section are also aligned
#[test]
fn test_tensor_relative_offsets_are_aligned() {
// Create several small tensors to force multiple alignment padding
let tensors: Vec<_> = (0..5)
.map(|i| {
let data = vec![i as u8; 7]; // 7 bytes each - not aligned
TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![7], DType::U8),
vec![format!("tensor_{}", i)],
vec![],
ParamId::new(),
)
})
.collect();
let writer = BurnpackWriter::new(tensors);
let file_bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
// All tensor start offsets within data section should be multiples of 256
for (name, desc) in &metadata.tensors {
assert_eq!(
desc.data_offsets.0 % TENSOR_ALIGNMENT,
0,
"Tensor '{}' relative offset {} is not 256-byte aligned",
name,
desc.data_offsets.0
);
}
}
#[cfg(feature = "std")]
mod file_tests {
use super::*;
use std::fs;
use tempfile::tempdir;
/// Test alignment is preserved when writing to and reading from file
#[test]
fn test_file_io_preserves_alignment() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("aligned.bpk");
let f32_data = [1.0f32, 2.0, 3.0, 4.0];
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes, vec![4], DType::F32),
vec!["floats".to_string()],
vec![],
ParamId::new(),
);
// Write to file
let writer = BurnpackWriter::new(vec![snapshot]);
writer.write_to_file(&file_path).unwrap();
// Read file bytes directly
let file_bytes = fs::read(&file_path).unwrap();
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
let tensor_desc = metadata.tensors.get("floats").unwrap();
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;
assert_eq!(
absolute_pos % TENSOR_ALIGNMENT as usize,
0,
"Tensor absolute position in file {} is not 256-byte aligned",
absolute_pos
);
// Verify data is correct
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
let tensor_bytes = &file_bytes[start..end];
let mut floats = Vec::new();
for chunk in tensor_bytes.chunks_exact(4) {
floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
assert_eq!(floats, vec![1.0f32, 2.0, 3.0, 4.0]);
}
}

View File

@@ -0,0 +1,365 @@
use crate::TensorSnapshot;
use crate::burnpack::{
base::{BurnpackHeader, HEADER_SIZE},
reader::BurnpackReader,
writer::BurnpackWriter,
};
use burn_core::module::ParamId;
use burn_tensor::{DType, TensorData};
#[test]
fn test_maximum_metadata_size() {
// Create metadata that approaches u32::MAX (4GB limit)
// In practice, we'll test with a reasonably large metadata
let large_key = "x".repeat(1000);
let large_value = "y".repeat(10000);
let mut writer = BurnpackWriter::new(vec![]);
for i in 0..100 {
writer = writer.with_metadata(&format!("{}_{}", large_key, i), &large_value);
}
let result = writer.to_bytes();
assert!(result.is_ok());
let bytes = result.unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
// Metadata size should be large but within u32 bounds
assert!(header.metadata_size > 1000000); // At least 1MB of metadata
assert!(header.metadata_size < u32::MAX);
}
#[test]
fn test_zero_size_tensor_shapes() {
// Test various zero-dimensional shapes
let test_cases = [
(vec![0], vec![]), // Empty 1D
(vec![0, 10], vec![]), // Zero rows
(vec![10, 0], vec![]), // Zero columns
(vec![0, 0], vec![]), // Zero both dimensions
(vec![5, 0, 10], vec![]), // Zero in middle dimension
];
let mut snapshots = vec![];
for (i, (shape, data)) in test_cases.iter().enumerate() {
let name = format!("zero_tensor_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32),
vec![name.clone()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
// Read back and verify
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
assert_eq!(names.len(), 5);
}
#[test]
fn test_extremely_long_tensor_names() {
// Create a tensor with an extremely long name
let long_name = "a".repeat(10000);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![long_name.clone()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
assert_eq!(names[0].len(), 10000);
}
#[test]
fn test_unicode_in_names_and_metadata() {
// Test various Unicode characters in tensor names and metadata
let unicode_names = vec![
"测试_tensor", // Chinese
"тест_tensor", // Cyrillic
"テスト_tensor", // Japanese
"🔥_burn_tensor", // Emoji
"αβγδ_tensor", // Greek
"한글_tensor", // Korean
];
let mut snapshots = vec![];
for name in &unicode_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots)
.with_metadata("模型名称", "测试模型")
.with_metadata("מודל", "בדיקה")
.with_metadata("🔥", "fire");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Verify all Unicode names are preserved
let names = reader.tensor_names();
assert_eq!(names.len(), unicode_names.len());
// Verify metadata
assert_eq!(
reader.metadata().metadata.get("模型名称"),
Some(&"测试模型".to_string())
);
assert_eq!(
reader.metadata().metadata.get("🔥"),
Some(&"fire".to_string())
);
}
#[test]
fn test_all_supported_dtypes() {
// Test all DTypes with their boundary values
let dtypes_with_data = [
(
DType::F32,
[
f32::MIN.to_le_bytes().to_vec(),
f32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::F64,
[
f64::MIN.to_le_bytes().to_vec(),
f64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::I32,
[
i32::MIN.to_le_bytes().to_vec(),
i32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::I64,
[
i64::MIN.to_le_bytes().to_vec(),
i64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::U32,
[
u32::MIN.to_le_bytes().to_vec(),
u32::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(
DType::U64,
[
u64::MIN.to_le_bytes().to_vec(),
u64::MAX.to_le_bytes().to_vec(),
]
.concat(),
),
(DType::U8, vec![u8::MIN, u8::MAX]),
(DType::Bool, vec![0, 1]),
];
let mut snapshots = vec![];
for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() {
let name = format!("dtype_test_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![2], *dtype),
vec![name],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(reader.tensor_names().len(), dtypes_with_data.len());
// Verify dtypes are preserved
for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() {
let name = format!("dtype_test_{}", i);
let snapshot = reader.get_tensor_snapshot(&name).unwrap();
assert_eq!(snapshot.dtype, *expected_dtype);
}
}
#[test]
fn test_special_float_values() {
// Test special floating-point values (NaN, Inf, -Inf)
let special_values = [
f32::NAN,
f32::INFINITY,
f32::NEG_INFINITY,
0.0_f32,
-0.0_f32,
];
let data: Vec<u8> = special_values
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32),
vec!["special_floats".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let tensor_data = reader.get_tensor_data("special_floats").unwrap();
// Check data is preserved exactly (bit-for-bit)
assert_eq!(tensor_data, data);
}
#[test]
fn test_metadata_with_empty_values() {
let writer = BurnpackWriter::new(vec![])
.with_metadata("empty_value", "")
.with_metadata("", "empty_key")
.with_metadata("normal", "value");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let metadata = &reader.metadata().metadata;
assert_eq!(metadata.get("empty_value"), Some(&"".to_string()));
assert_eq!(metadata.get(""), Some(&"empty_key".to_string()));
assert_eq!(metadata.get("normal"), Some(&"value".to_string()));
}
#[test]
fn test_single_byte_tensor() {
// Test the smallest possible tensor (1 byte)
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![42], vec![1], DType::U8),
vec!["single_byte".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let data = reader.get_tensor_data("single_byte").unwrap();
assert_eq!(data, vec![42]);
}
#[test]
fn test_high_dimensional_tensor() {
// Test a tensor with many dimensions (10D)
let shape = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; // 10 dimensions, 1024 elements total
let data = vec![1u8; 1024];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8),
vec!["high_dim".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let loaded_snapshot = reader.get_tensor_snapshot("high_dim").unwrap();
assert_eq!(loaded_snapshot.shape, shape);
}
#[test]
fn test_metadata_key_collision() {
// Test that later values override earlier ones for the same key
let writer = BurnpackWriter::new(vec![])
.with_metadata("key", "value1")
.with_metadata("key", "value2")
.with_metadata("key", "value3");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(
reader.metadata().metadata.get("key"),
Some(&"value3".to_string())
);
}
#[test]
fn test_tensor_name_with_path_separators() {
// Test tensor names that look like file paths
let path_like_names = vec![
"model/encoder/layer1/weights",
"model\\decoder\\layer1\\bias",
"model::module::param",
"model.submodule.weight",
];
let mut snapshots = vec![];
for name in &path_like_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let names = reader.tensor_names();
// All names should be preserved exactly
for expected_name in &path_like_names {
assert!(names.contains(expected_name));
}
}
// The following tests are commented out as they test error conditions
// that might be handled differently in the new API
// #[test]
// fn test_data_overflow_protection() {
// // Test that we handle potential integer overflows in offset calculations
// ...
// }
// #[test]
// fn test_reading_corrupted_header() {
// // Test reading files with corrupted headers
// ...
// }

View File

@@ -0,0 +1,61 @@
use crate::burnpack::base::*;
#[test]
fn test_header_serialization() {
let header = BurnpackHeader::new(12345);
// Check fields
assert_eq!(header.magic, MAGIC_NUMBER);
assert_eq!(header.version, FORMAT_VERSION);
assert_eq!(header.metadata_size, 12345);
// Serialize to bytes
let bytes = header.into_bytes();
assert_eq!(bytes.len(), HEADER_SIZE);
// Deserialize back
let header2 = BurnpackHeader::from_bytes(&bytes).unwrap();
assert_eq!(header2.magic, header.magic);
assert_eq!(header2.version, header.version);
assert_eq!(header2.metadata_size, header.metadata_size);
}
#[test]
fn test_header_invalid_magic() {
let mut bytes = [0u8; HEADER_SIZE];
// Write wrong magic number
bytes[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
let result = BurnpackHeader::from_bytes(&bytes);
match result {
Err(BurnpackError::InvalidMagicNumber) => {}
_ => panic!("Expected InvalidMagicNumber error"),
}
}
#[test]
fn test_header_insufficient_bytes() {
let bytes = [0u8; 5]; // Too short
let result = BurnpackHeader::from_bytes(&bytes);
match result {
Err(BurnpackError::InvalidHeader) => {}
_ => panic!("Expected InvalidHeader error"),
}
}
#[test]
fn test_version_compatibility() {
// Create a header with current version
let header = BurnpackHeader::new(100);
let bytes = header.into_bytes();
// Should succeed with current version
let result = BurnpackHeader::from_bytes(&bytes);
assert!(result.is_ok());
// Test with future version (should fail in real implementation)
// For now, we just verify the version field is correctly set
let header = result.unwrap();
assert_eq!(header.version, FORMAT_VERSION);
}

View File

@@ -0,0 +1,19 @@
use crate::TensorSnapshot;
use burn_core::module::ParamId;
use burn_tensor::{DType, TensorData};
/// Helper to create a test TensorSnapshot
#[allow(dead_code)]
pub fn create_test_snapshot(
name: String,
data: Vec<u8>,
shape: Vec<usize>,
dtype: DType,
) -> TensorSnapshot {
TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, shape, dtype),
vec![name],
vec![],
ParamId::new(),
)
}

View File

@@ -0,0 +1,11 @@
use crate::TensorSnapshot;
mod alignment;
mod edge_cases;
mod header;
mod helpers;
mod reader;
mod round_trip;
mod store;
mod writer;
mod zero_copy;

View File

@@ -0,0 +1,775 @@
use crate::burnpack::{
base::{
BurnpackError, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, metadata_size_range,
version_range,
},
reader::BurnpackReader,
writer::BurnpackWriter,
};
use super::*;
use burn_tensor::{Bytes, DType, TensorData};
#[test]
fn test_reader_from_bytes_empty() {
// Create empty burnpack data
let writer = BurnpackWriter::new(Vec::new());
let bytes = writer.to_bytes().unwrap();
// Read it back
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(reader.metadata().tensors.len(), 0);
assert!(reader.metadata().metadata.is_empty());
}
#[test]
fn test_reader_from_bytes_with_data() {
// Create test data
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
let bytes = writer.to_bytes().unwrap();
// Read it back
let reader = BurnpackReader::from_bytes(bytes).unwrap();
assert_eq!(reader.metadata().tensors.len(), 1);
assert_eq!(
reader.metadata().metadata.get("test"),
Some(&"value".to_string())
);
// Get tensor data
let tensor_data = reader.get_tensor_data("test_tensor").unwrap();
assert_eq!(tensor_data, &[1, 2, 3, 4]);
}
#[test]
fn test_reader_invalid_magic_number() {
let mut bytes = vec![0u8; 100];
// Write invalid magic number
bytes[magic_range()].copy_from_slice(b"NOPE");
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
assert!(matches!(result, Err(BurnpackError::InvalidMagicNumber)));
}
#[test]
fn test_reader_invalid_version() {
let mut bytes = vec![0u8; 100];
// Write correct magic but invalid version
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
bytes[version_range()].copy_from_slice(&999u16.to_le_bytes()); // Invalid version
bytes[metadata_size_range()].copy_from_slice(&10u32.to_le_bytes()); // Metadata size
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
assert!(matches!(result, Err(BurnpackError::InvalidVersion)));
}
#[test]
fn test_reader_header_too_short() {
let bytes = vec![0u8; 5]; // Less than HEADER_SIZE
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
assert!(matches!(result, Err(BurnpackError::InvalidHeader)));
}
#[test]
fn test_reader_metadata_truncated() {
let mut bytes = vec![0u8; HEADER_SIZE + 10];
// Write valid header
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());
bytes[metadata_size_range()].copy_from_slice(&100u32.to_le_bytes()); // Claims 100 bytes of metadata
// But only provide 10 bytes after header
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
assert!(matches!(result, Err(BurnpackError::InvalidHeader)));
}
#[test]
fn test_reader_get_tensor_not_found() {
let writer = BurnpackWriter::new(Vec::new());
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let result = reader.get_tensor_data("non_existent");
assert!(matches!(result, Err(BurnpackError::TensorNotFound(_))));
}
#[test]
fn test_reader_get_tensor_snapshot() {
let data = [1.0f32, 2.0, 3.0, 4.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
vec!["weights".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let writer_bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(writer_bytes).unwrap();
// Get tensor as snapshot
let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap();
// Verify snapshot metadata
assert_eq!(loaded_snapshot.full_path(), "weights");
assert_eq!(loaded_snapshot.dtype, DType::F32);
assert_eq!(loaded_snapshot.shape, vec![2, 2]);
// Verify data through closure
let tensor_data = loaded_snapshot.to_data().unwrap();
assert_eq!(tensor_data.shape, vec![2, 2]);
}
#[test]
fn test_reader_multiple_tensors() {
// Add multiple tensors
let mut snapshots = Vec::new();
for i in 0..10 {
let name = format!("tensor_{}", i);
let data = vec![i as u8; 100];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![100], DType::U8),
vec![name.clone()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Verify all tensors can be read
for i in 0..10 {
let name = format!("tensor_{}", i);
let data = reader.get_tensor_data(&name).unwrap();
assert_eq!(data.len(), 100);
assert!(data.iter().all(|&b| b == i as u8));
}
}
#[test]
fn test_reader_lazy_loading() {
// Create large tensor
let size = 1024 * 1024; // 1MB
let data = vec![42u8; size];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![size], DType::U8),
vec!["large".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Get snapshot (should be lazy)
let snapshot = reader.get_tensor_snapshot("large").unwrap();
// Data should only be accessed when to_data is called
let tensor_data = snapshot.to_data().unwrap();
assert_eq!(tensor_data.bytes.len(), size);
assert!(tensor_data.bytes.iter().all(|&b| b == 42));
}
#[test]
fn test_reader_all_dtypes() {
// Test all data types
let test_data = [
(DType::F32, [1.0f32.to_le_bytes().to_vec()].concat()),
(DType::F64, [2.0f64.to_le_bytes().to_vec()].concat()),
(DType::I32, [3i32.to_le_bytes().to_vec()].concat()),
(DType::I64, [4i64.to_le_bytes().to_vec()].concat()),
(DType::U32, [5u32.to_le_bytes().to_vec()].concat()),
(DType::U64, [6u64.to_le_bytes().to_vec()].concat()),
(DType::U8, vec![7u8]),
(DType::Bool, vec![1u8]),
];
let mut snapshots = Vec::new();
for (i, (dtype, data)) in test_data.iter().enumerate() {
let name = format!("tensor_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![1], *dtype),
vec![name.clone()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Verify all dtypes are preserved
for (i, (expected_dtype, expected_data)) in test_data.iter().enumerate() {
let name = format!("tensor_{}", i);
let snapshot = reader.get_tensor_snapshot(&name).unwrap();
assert_eq!(snapshot.dtype, *expected_dtype);
let data = reader.get_tensor_data(&name).unwrap();
assert_eq!(data, expected_data.as_slice());
}
}
#[test]
fn test_reader_empty_tensor() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
vec!["empty".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let data = reader.get_tensor_data("empty").unwrap();
assert_eq!(data.len(), 0);
let snapshot = reader.get_tensor_snapshot("empty").unwrap();
assert_eq!(snapshot.shape, vec![0]);
}
#[cfg(feature = "std")]
#[test]
fn test_reader_from_file() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.bpk");
// Create test file
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![10, 20, 30], vec![3], DType::U8),
vec!["file_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("from_file_test", "true");
writer.write_to_file(&file_path).unwrap();
// Read from file
let reader = BurnpackReader::from_file(&file_path).unwrap();
assert_eq!(
reader.metadata().metadata.get("from_file_test"),
Some(&"true".to_string())
);
let data = reader.get_tensor_data("file_tensor").unwrap();
assert_eq!(data, &[10, 20, 30]);
}
#[cfg(all(feature = "std", feature = "memmap"))]
#[test]
fn test_reader_from_file_mmap() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test_mmap.bpk");
// Create large test file
let size = 1024 * 1024; // 1MB
let data = vec![99u8; size];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![size], DType::U8),
vec!["large_mmap".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
writer.write_to_file(&file_path).unwrap();
// Read using mmap
let reader = BurnpackReader::from_file_mmap(&file_path).unwrap();
let data = reader.get_tensor_data("large_mmap").unwrap();
assert_eq!(data.len(), size);
assert!(data.iter().all(|&b| b == 99));
}
#[cfg(feature = "std")]
#[test]
fn test_reader_from_file_buffered() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test_buffered.bpk");
// Create test file
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![5, 10, 15], vec![3], DType::U8),
vec!["buffered_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
writer.write_to_file(&file_path).unwrap();
// Read using buffered reader
let reader = BurnpackReader::from_file_buffered(&file_path).unwrap();
let data = reader.get_tensor_data("buffered_tensor").unwrap();
assert_eq!(data, &[5, 10, 15]);
}
#[test]
fn test_reader_metadata_access() {
// Add various metadata using builder pattern
let writer = BurnpackWriter::new(Vec::new())
.with_metadata("model_name", "test_model")
.with_metadata("version", "1.2.3")
.with_metadata("author", "test_author")
.with_metadata("description", "A test model");
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let metadata = reader.metadata();
assert_eq!(metadata.metadata.len(), 4);
assert_eq!(
metadata.metadata.get("model_name"),
Some(&"test_model".to_string())
);
assert_eq!(metadata.metadata.get("version"), Some(&"1.2.3".to_string()));
assert_eq!(
metadata.metadata.get("author"),
Some(&"test_author".to_string())
);
assert_eq!(
metadata.metadata.get("description"),
Some(&"A test model".to_string())
);
}
#[test]
fn test_reader_tensor_iteration() {
// Add tensors
let tensor_names = vec!["weights", "bias", "running_mean", "running_var"];
let mut snapshots = Vec::new();
for name in &tensor_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![name.to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Iterate through all tensors
let metadata = reader.metadata();
assert_eq!(metadata.tensors.len(), 4);
// Check that all expected tensor names are present
for name in &tensor_names {
let tensor_desc = metadata.tensors.get(*name).unwrap();
assert_eq!(tensor_desc.shape, vec![4u64]);
assert_eq!(tensor_desc.dtype, DType::U8);
}
// Verify the keys match the expected names
let mut actual_names: Vec<_> = metadata.tensors.keys().cloned().collect();
actual_names.sort();
let mut expected_names = tensor_names
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
expected_names.sort();
assert_eq!(actual_names, expected_names);
}
#[test]
fn test_reader_corrupt_metadata() {
let mut bytes = vec![0u8; 100];
// Write valid header
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());
bytes[metadata_size_range()].copy_from_slice(&50u32.to_le_bytes()); // 50 bytes of metadata
// Write garbage as metadata
#[allow(clippy::needless_range_loop)]
for i in HEADER_SIZE..HEADER_SIZE + 50 {
bytes[i] = 0xFF;
}
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
assert!(result.is_err());
}
#[test]
fn test_reader_data_offsets_validation() {
// Add two tensors
let snapshot1 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec!["tensor1".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let snapshot2 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),
vec!["tensor2".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Verify offsets don't overlap and are properly aligned
let metadata = reader.metadata();
let tensor1_desc = metadata.tensors.get("tensor1").unwrap();
let tensor2_desc = metadata.tensors.get("tensor2").unwrap();
// First tensor starts at offset 0 (already aligned to 256 bytes)
assert_eq!(tensor1_desc.data_offsets, (0, 4));
// Second tensor starts at next 256-byte aligned offset
assert_eq!(tensor2_desc.data_offsets, (256, 260));
}
#[test]
fn test_reader_out_of_bounds_error() {
use crate::burnpack::reader::StorageBackend;
use alloc::rc::Rc;
// Create a small data buffer
let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);
let backend = StorageBackend::Memory(Rc::new(data));
// Try to read beyond the available data
let mut buffer = vec![0u8; 10];
let result = backend.read_into(&mut buffer, 0);
// Should return an error
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("out of bounds"));
}
#[test]
fn test_reader_offset_overflow_error() {
use crate::burnpack::reader::StorageBackend;
use alloc::rc::Rc;
let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);
let backend = StorageBackend::Memory(Rc::new(data));
// Try to read with an offset that would overflow
let mut buffer = vec![0u8; 10];
let result = backend.read_into(&mut buffer, usize::MAX - 5);
// Should return an error about overflow
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("overflow"));
}
#[test]
fn test_reader_corrupted_shape_returns_error() {
// Only test this on platforms where usize is smaller than u64
// On 64-bit platforms, u64 values can fit in usize
#[cfg(target_pointer_width = "32")]
{
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use burn_tensor::DType;
// Create metadata with a shape dimension that exceeds usize::MAX on 32-bit platforms
let mut tensors = BTreeMap::new();
tensors.insert(
"corrupted_tensor".to_string(),
TensorDescriptor {
dtype: DType::F32,
shape: vec![u64::MAX, 2, 3], // First dimension exceeds usize::MAX on 32-bit
data_offsets: (0, 100),
param_id: None,
},
);
let metadata = BurnpackMetadata {
tensors,
metadata: BTreeMap::new(),
};
// Create a small data buffer
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
let reader = BurnpackReader {
metadata,
storage: backend,
data_offset: 0,
};
// This should return an error, not panic
let result = reader.get_snapshots();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(
err.to_string().contains("corrupted shape data")
|| err.to_string().contains("exceeds platform maximum")
);
}
#[cfg(not(target_pointer_width = "32"))]
{
// On 64-bit platforms, just pass the test
// The conversion logic is still correct, but u64 fits in usize
}
}
#[test]
fn test_reader_corrupted_offsets_returns_error() {
// Only test this on platforms where usize is smaller than u64
#[cfg(target_pointer_width = "32")]
{
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use burn_tensor::DType;
// Create metadata with offsets that would overflow
let mut tensors = BTreeMap::new();
tensors.insert(
"tensor_bad_offset".to_string(),
TensorDescriptor {
dtype: DType::F32,
shape: vec![2, 2],
data_offsets: (u64::MAX - 10, u64::MAX), // Offsets that exceed usize::MAX on 32-bit
param_id: None,
},
);
let metadata = BurnpackMetadata {
tensors,
metadata: BTreeMap::new(),
};
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
let reader = BurnpackReader {
metadata,
storage: backend,
data_offset: 0,
};
// This should return an error, not panic
let result = reader.get_snapshots();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(
err.to_string().contains("corrupted offset data")
|| err.to_string().contains("exceeds platform maximum")
);
}
#[cfg(not(target_pointer_width = "32"))]
{
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use burn_tensor::DType;
// On 64-bit platforms, test offset overflow during addition
let mut tensors = BTreeMap::new();
tensors.insert(
"tensor_overflow".to_string(),
TensorDescriptor {
dtype: DType::F32,
shape: vec![2, 2],
data_offsets: (0, 100),
param_id: None,
},
);
let metadata = BurnpackMetadata {
tensors,
metadata: BTreeMap::new(),
};
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
// Use a data_offset that will overflow when added to the tensor offset
let reader = BurnpackReader {
metadata,
storage: backend,
data_offset: usize::MAX - 50, // Will overflow when added to 100
};
// This should return an error, not panic
let result = reader.get_snapshots();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(err.to_string().contains("overflow"));
}
}
#[test]
fn test_reader_inverted_offsets_returns_error() {
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use burn_tensor::DType;
// Create metadata with end offset < start offset (corrupted)
let mut tensors = BTreeMap::new();
tensors.insert(
"inverted_tensor".to_string(),
TensorDescriptor {
dtype: DType::F32,
shape: vec![2, 2],
data_offsets: (100, 50), // End offset < start offset
param_id: None,
},
);
let metadata = BurnpackMetadata {
tensors,
metadata: BTreeMap::new(),
};
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
let reader = BurnpackReader {
metadata,
storage: backend,
data_offset: 0,
};
// This should return an error, not panic
let result = reader.get_snapshots();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(err.to_string().contains("end offset") && err.to_string().contains("start offset"));
}
#[test]
fn test_reader_truncated_file_from_bytes() {
// Create a valid burnpack with tensor data
let tensor_size = 1024; // 1KB of data
let data = vec![42u8; tensor_size];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
vec!["large_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let full_bytes = writer.to_bytes().unwrap();
// Truncate the bytes - remove the last 512 bytes of tensor data
let truncated_len = full_bytes.len() - 512;
let truncated_bytes = Bytes::from_bytes_vec(full_bytes.to_vec()[..truncated_len].to_vec());
// This should fail with a validation error indicating file truncation
let result = BurnpackReader::from_bytes(truncated_bytes);
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(err.to_string().contains("File truncated"));
assert!(err.to_string().contains("expected at least"));
}
}
#[cfg(feature = "std")]
#[test]
fn test_reader_truncated_file_from_file() {
use std::fs::OpenOptions;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("truncated.bpk");
// Create a valid burnpack file with tensor data
let tensor_size = 2048; // 2KB of data
let data = vec![99u8; tensor_size];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
vec!["data_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
writer.write_to_file(&file_path).unwrap();
// Read the full file to get its size
let full_size = std::fs::metadata(&file_path).unwrap().len();
// Truncate the file - remove the last 1KB
let truncated_size = full_size - 1024;
let truncated_file = OpenOptions::new().write(true).open(&file_path).unwrap();
truncated_file.set_len(truncated_size).unwrap();
drop(truncated_file);
// Try to read the truncated file - should fail with validation error
let result = BurnpackReader::from_file(&file_path);
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(err, BurnpackError::ValidationError(_)));
assert!(err.to_string().contains("File truncated"));
assert!(err.to_string().contains("expected at least"));
}
}
#[test]
fn test_reader_file_size_exactly_correct() {
// Test that a file with exactly the right size passes validation
let tensor_size = 100;
let data = vec![77u8; tensor_size];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
vec!["exact_size".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
// This should succeed - file is exactly the right size
let reader = BurnpackReader::from_bytes(bytes);
assert!(reader.is_ok());
// Verify we can read the data
let reader = reader.unwrap();
let tensor_data = reader.get_tensor_data("exact_size").unwrap();
assert_eq!(tensor_data.len(), tensor_size);
assert!(tensor_data.iter().all(|&b| b == 77));
}

View File

@@ -0,0 +1,606 @@
use crate::burnpack::{reader::BurnpackReader, writer::BurnpackWriter};
use super::*;
use alloc::collections::BTreeMap;
use alloc::string::String;
use burn_tensor::{DType, TensorData};
/// Helper function to perform round-trip test
fn round_trip_test<F>(setup: F)
where
F: FnOnce(&mut Vec<TensorSnapshot>, &mut BTreeMap<String, String>),
{
// Collect snapshots and metadata
let mut snapshots = Vec::new();
let mut metadata = BTreeMap::new();
setup(&mut snapshots, &mut metadata);
// Sort snapshots by name to ensure consistent ordering
// This is necessary because BTreeMap will store them sorted
snapshots.sort_by_key(|a| a.full_path());
// Create writer with snapshots and metadata
let mut writer = BurnpackWriter::new(snapshots);
for (key, value) in &metadata {
writer = writer.with_metadata(key, value);
}
let bytes = writer.to_bytes().unwrap();
let reader = BurnpackReader::from_bytes(bytes.clone()).unwrap();
// Write to bytes again from reader data
let mut snapshots2 = Vec::new();
// Copy tensors (metadata.tensors is now BTreeMap<String, TensorDescriptor>)
// They will come out in sorted order from tensor_names()
for tensor_name in reader.tensor_names() {
let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();
snapshots2.push(snapshot);
}
// Create writer2 with collected snapshots and metadata
let mut writer2 = BurnpackWriter::new(snapshots2);
for (key, value) in &reader.metadata().metadata {
writer2 = writer2.with_metadata(key, value);
}
let bytes2 = writer2.to_bytes().unwrap();
// Both byte representations should be identical
assert_eq!(bytes, bytes2, "Round-trip produced different bytes");
}
#[test]
fn test_round_trip_empty() {
round_trip_test(|_snapshots, _metadata| {
// Empty writer
});
}
#[test]
fn test_round_trip_metadata_only() {
round_trip_test(|_snapshots, metadata| {
metadata.insert("key1".to_string(), "value1".to_string());
metadata.insert("key2".to_string(), "value2".to_string());
metadata.insert("key3".to_string(), "value3".to_string());
});
}
#[test]
fn test_round_trip_f32() {
round_trip_test(|snapshots, _metadata| {
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![5], DType::F32),
vec!["f32_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_f64() {
round_trip_test(|snapshots, _metadata| {
let data = [1.0f64, 2.0, 3.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![3], DType::F64),
vec!["f64_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_i32() {
round_trip_test(|snapshots, _metadata| {
let data = [-10i32, 0, 10, 20];
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![4], DType::I32),
vec!["i32_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_i64() {
round_trip_test(|snapshots, _metadata| {
let data = [i64::MIN, 0, i64::MAX];
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![3], DType::I64),
vec!["i64_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_u32() {
round_trip_test(|snapshots, _metadata| {
let data = [0u32, 100, 1000, u32::MAX];
let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![4], DType::U32),
vec!["u32_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_u64() {
round_trip_test(|snapshots, _metadata| {
let data = [0u64, u64::MAX / 2, u64::MAX];
let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![3], DType::U64),
vec!["u64_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_u8() {
round_trip_test(|snapshots, _metadata| {
let data = vec![0u8, 127, 255];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![3], DType::U8),
vec!["u8_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_bool() {
round_trip_test(|snapshots, _metadata| {
let data = vec![0u8, 1, 0, 1, 1];
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data, vec![5], DType::Bool),
vec!["bool_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[test]
fn test_round_trip_mixed_dtypes() {
round_trip_test(|snapshots, _metadata| {
// F32
let f32_data = [1.0f32, 2.0];
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let f32_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes, vec![2], DType::F32),
vec!["f32".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(f32_snapshot);
// I64
let i64_data = [100i64, 200];
let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();
let i64_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(i64_bytes, vec![2], DType::I64),
vec!["i64".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(i64_snapshot);
// Bool
let bool_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 0, 1], vec![3], DType::Bool),
vec!["bool".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(bool_snapshot);
});
}
#[test]
fn test_round_trip_multidimensional() {
round_trip_test(|snapshots, _metadata| {
// 2D tensor
let data_2d = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let bytes_2d: Vec<u8> = data_2d.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot_2d = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes_2d, vec![2, 3], DType::F32),
vec!["tensor_2d".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot_2d);
// 3D tensor
let data_3d = [1.0f32; 24];
let bytes_3d: Vec<u8> = data_3d.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot_3d = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes_3d, vec![2, 3, 4], DType::F32),
vec!["tensor_3d".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot_3d);
// 4D tensor (common for CNNs)
let data_4d = vec![1.0f32; 120];
let bytes_4d: Vec<u8> = data_4d.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot_4d = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes_4d, vec![2, 3, 4, 5], DType::F32),
vec!["tensor_4d".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot_4d);
});
}
#[test]
fn test_round_trip_with_metadata_and_tensors() {
round_trip_test(|snapshots, metadata| {
// Add metadata
metadata.insert("model_name".to_string(), "test_model".to_string());
metadata.insert("version".to_string(), "1.0.0".to_string());
metadata.insert(
"description".to_string(),
"A test model for round-trip testing".to_string(),
);
// Add tensors
let weights = [0.1f32, 0.2, 0.3, 0.4];
let weights_bytes: Vec<u8> = weights.iter().flat_map(|f| f.to_le_bytes()).collect();
let weights_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(weights_bytes, vec![2, 2], DType::F32),
vec!["layer1".to_string(), "weights".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(weights_snapshot);
let bias = [0.5f32, 0.6];
let bias_bytes: Vec<u8> = bias.iter().flat_map(|f| f.to_le_bytes()).collect();
let bias_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bias_bytes, vec![2], DType::F32),
vec!["layer1".to_string(), "bias".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(bias_snapshot);
});
}
#[test]
fn test_round_trip_special_values() {
round_trip_test(|snapshots, _metadata| {
// Test special float values
let special_f32 = [
0.0f32,
-0.0,
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
f32::MIN,
f32::MAX,
f32::EPSILON,
];
let f32_bytes: Vec<u8> = special_f32.iter().flat_map(|f| f.to_le_bytes()).collect();
let f32_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes, vec![8], DType::F32),
vec!["special_f32".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(f32_snapshot);
// Test special f64 values
let special_f64 = [
0.0f64,
-0.0,
f64::INFINITY,
f64::NEG_INFINITY,
f64::NAN,
f64::MIN,
f64::MAX,
f64::EPSILON,
];
let f64_bytes: Vec<u8> = special_f64.iter().flat_map(|f| f.to_le_bytes()).collect();
let f64_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f64_bytes, vec![8], DType::F64),
vec!["special_f64".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(f64_snapshot);
});
}
#[test]
fn test_round_trip_large_tensors() {
round_trip_test(|snapshots, _metadata| {
// Large tensor (100KB)
let size = 25600; // 100KB / 4 bytes per f32
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![size], DType::F32),
vec!["large_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
});
}
#[cfg(feature = "std")]
#[test]
fn test_round_trip_file_io() {
use std::fs;
use tempfile::tempdir;
use crate::burnpack::writer::BurnpackWriter;
let dir = tempdir().unwrap();
let file_path = dir.path().join("round_trip.bpk");
// Create original data
let data = [1.0f32, 2.0, 3.0, 4.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
vec!["weights".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "round_trip");
// Write to file
writer.write_to_file(&file_path).unwrap();
// Read from file
let reader = BurnpackReader::from_file(&file_path).unwrap();
// Write to another file
let file_path2 = dir.path().join("round_trip2.bpk");
// Collect snapshots from reader
let mut snapshots2 = Vec::new();
for tensor_name in reader.tensor_names() {
let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();
snapshots2.push(snapshot);
}
// Create writer2 with snapshots and metadata
let mut writer2 = BurnpackWriter::new(snapshots2);
for (key, value) in &reader.metadata().metadata {
writer2 = writer2.with_metadata(key, value);
}
writer2.write_to_file(&file_path2).unwrap();
// Compare files
let bytes1 = fs::read(&file_path).unwrap();
let bytes2 = fs::read(&file_path2).unwrap();
assert_eq!(
bytes1, bytes2,
"Round-trip through files produced different content"
);
}
#[test]
fn test_round_trip_empty_shapes() {
round_trip_test(|snapshots, _metadata| {
// Scalar (0-dimensional)
let scalar = [42.0f32];
let scalar_bytes: Vec<u8> = scalar.iter().flat_map(|f| f.to_le_bytes()).collect();
let scalar_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(scalar_bytes, vec![], DType::F32),
vec!["scalar".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(scalar_snapshot);
// Empty tensor
let empty_snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
vec!["empty".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(empty_snapshot);
});
}
#[test]
fn test_param_id_persistence() {
use burn_core::module::ParamId;
// Create a specific ParamId with a known value
let original_param_id = ParamId::from(123456789u64);
let data = [1.0f32, 2.0, 3.0, 4.0];
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
vec!["weights".to_string()],
vec![],
original_param_id,
);
// Write to burnpack
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
// Read back from burnpack
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap();
// Verify ParamId was preserved
assert!(
loaded_snapshot.tensor_id.is_some(),
"ParamId should be present"
);
let loaded_param_id = loaded_snapshot.tensor_id.unwrap();
assert_eq!(
loaded_param_id.val(),
original_param_id.val(),
"ParamId value should be preserved: expected {}, got {}",
original_param_id.val(),
loaded_param_id.val()
);
}
#[test]
fn test_param_id_backward_compatibility() {
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
use alloc::collections::BTreeMap;
// Create metadata without param_id (simulating old burnpack format)
let mut tensors = BTreeMap::new();
tensors.insert(
"old_tensor".to_string(),
TensorDescriptor {
dtype: DType::F32,
shape: vec![2, 2],
data_offsets: (0, 16),
param_id: None, // No param_id stored (old format)
},
);
let metadata = BurnpackMetadata {
tensors,
metadata: BTreeMap::new(),
};
// Serialize metadata
let mut metadata_bytes = Vec::new();
ciborium::ser::into_writer(&metadata, &mut metadata_bytes).unwrap();
// Create a complete burnpack with header and data
use crate::burnpack::base::{BurnpackHeader, FORMAT_VERSION, MAGIC_NUMBER};
let metadata_size = metadata_bytes.len() as u32;
let header = BurnpackHeader {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
};
let mut full_bytes = Vec::new();
full_bytes.extend_from_slice(&header.into_bytes());
full_bytes.extend_from_slice(&metadata_bytes);
// Add tensor data (4 f32 values = 16 bytes)
let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
for value in tensor_data {
full_bytes.extend_from_slice(&value.to_le_bytes());
}
// Read the old format burnpack
let reader =
BurnpackReader::from_bytes(burn_tensor::Bytes::from_bytes_vec(full_bytes)).unwrap();
let loaded_snapshot = reader.get_tensor_snapshot("old_tensor").unwrap();
// Verify that a new ParamId was generated (backward compatibility)
assert!(
loaded_snapshot.tensor_id.is_some(),
"ParamId should be generated for old format"
);
// The generated ParamId should be different each time (it's new), but we can't test the exact value
// We just verify it exists and has a valid u64 value
let generated_param_id = loaded_snapshot.tensor_id.unwrap();
assert!(
generated_param_id.val() > 0,
"Generated ParamId should have a valid value"
);
}
#[test]
fn test_multiple_tensors_preserve_distinct_param_ids() {
use burn_core::module::ParamId;
// Create multiple tensors with distinct ParamIds
let param_id_1 = ParamId::from(111111u64);
let param_id_2 = ParamId::from(222222u64);
let param_id_3 = ParamId::from(333333u64);
let mut snapshots = Vec::new();
let data1 = [1.0f32, 2.0];
let bytes1: Vec<u8> = data1.iter().flat_map(|f| f.to_le_bytes()).collect();
snapshots.push(TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes1, vec![2], DType::F32),
vec!["tensor1".to_string()],
vec![],
param_id_1,
));
let data2 = [3.0f32, 4.0];
let bytes2: Vec<u8> = data2.iter().flat_map(|f| f.to_le_bytes()).collect();
snapshots.push(TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes2, vec![2], DType::F32),
vec!["tensor2".to_string()],
vec![],
param_id_2,
));
let data3 = [5.0f32, 6.0];
let bytes3: Vec<u8> = data3.iter().flat_map(|f| f.to_le_bytes()).collect();
snapshots.push(TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes3, vec![2], DType::F32),
vec!["tensor3".to_string()],
vec![],
param_id_3,
));
// Write to burnpack
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
// Read back
let reader = BurnpackReader::from_bytes(bytes).unwrap();
let snapshot1 = reader.get_tensor_snapshot("tensor1").unwrap();
let snapshot2 = reader.get_tensor_snapshot("tensor2").unwrap();
let snapshot3 = reader.get_tensor_snapshot("tensor3").unwrap();
// Verify each ParamId was preserved correctly
assert_eq!(snapshot1.tensor_id.unwrap().val(), param_id_1.val());
assert_eq!(snapshot2.tensor_id.unwrap().val(), param_id_2.val());
assert_eq!(snapshot3.tensor_id.unwrap().val(), param_id_3.val());
// Verify they are distinct
let id1 = snapshot1.tensor_id.unwrap().val();
let id2 = snapshot2.tensor_id.unwrap().val();
let id3 = snapshot3.tensor_id.unwrap().val();
assert_ne!(id1, id2, "ParamIds should be distinct");
assert_ne!(id2, id3, "ParamIds should be distinct");
assert_ne!(id1, id3, "ParamIds should be distinct");
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,744 @@
use crate::burnpack::{
base::{
BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
aligned_data_section_start, magic_range,
},
writer::BurnpackWriter,
};
use super::*;
use burn_core::module::ParamId;
use burn_tensor::{DType, TensorData};
use std::rc::Rc;
#[test]
fn test_writer_new() {
let writer = BurnpackWriter::new(vec![]);
assert_eq!(writer.snapshots.len(), 0);
assert!(writer.metadata.is_empty());
}
#[test]
fn test_writer_add_metadata() {
let writer = BurnpackWriter::new(vec![])
.with_metadata("model_name", "test_model")
.with_metadata("version", "1.0.0")
.with_metadata("author", "test_author");
assert_eq!(writer.metadata.len(), 3);
assert_eq!(
writer.metadata.get("model_name"),
Some(&"test_model".to_string())
);
assert_eq!(writer.metadata.get("version"), Some(&"1.0.0".to_string()));
assert_eq!(
writer.metadata.get("author"),
Some(&"test_author".to_string())
);
}
#[test]
fn test_writer_add_tensor_snapshot() {
// Create test tensor snapshots
let snapshot1 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["layer1".to_string(), "weights".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let snapshot2 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),
vec!["layer1".to_string(), "bias".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);
assert_eq!(writer.snapshots.len(), 2);
assert_eq!(writer.snapshots[0].full_path(), "layer1.weights");
assert_eq!(writer.snapshots[1].full_path(), "layer1.bias");
}
#[test]
fn test_writer_to_bytes_empty() {
let writer = BurnpackWriter::new(vec![]);
let bytes = writer.to_bytes().unwrap();
// Verify header
assert!(bytes.len() >= HEADER_SIZE);
assert_eq!(&bytes[magic_range()], &MAGIC_NUMBER.to_le_bytes());
// Parse header
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
assert_eq!(header.magic, MAGIC_NUMBER);
assert_eq!(header.version, FORMAT_VERSION);
// Verify metadata
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata_bytes = &bytes[HEADER_SIZE..metadata_end];
let metadata: BurnpackMetadata = ciborium::de::from_reader(metadata_bytes).unwrap();
assert_eq!(metadata.tensors.len(), 0);
assert!(metadata.metadata.is_empty());
}
#[test]
fn test_writer_to_bytes_with_tensors() {
// Add tensors with different data types
let f32_data = [1.0f32, 2.0, 3.0, 4.0];
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot_f32 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(f32_bytes.clone(), vec![2, 2], DType::F32),
vec!["weights".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let i64_data = [10i64, 20, 30];
let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();
let snapshot_i64 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(i64_bytes.clone(), vec![3], DType::I64),
vec!["bias".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot_f32, snapshot_i64])
.with_metadata("test_key", "test_value");
let bytes = writer.to_bytes().unwrap();
// Parse and verify
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();
// Verify metadata
assert_eq!(
metadata.metadata.get("test_key"),
Some(&"test_value".to_string())
);
// Verify tensors
assert_eq!(metadata.tensors.len(), 2);
let weights = metadata.tensors.get("weights").unwrap();
assert_eq!(weights.dtype, DType::F32);
assert_eq!(weights.shape, vec![2, 2]);
assert_eq!(weights.data_offsets.1 - weights.data_offsets.0, 16); // 4 * 4 bytes
let bias = metadata.tensors.get("bias").unwrap();
assert_eq!(bias.dtype, DType::I64);
assert_eq!(bias.shape, vec![3]);
assert_eq!(bias.data_offsets.1 - bias.data_offsets.0, 24); // 3 * 8 bytes
// Verify actual tensor data
// Data section starts at aligned position after metadata
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
let weights = metadata.tensors.get("weights").unwrap();
let bias = metadata.tensors.get("bias").unwrap();
let weights_data = &bytes[data_section_start + weights.data_offsets.0 as usize
..data_section_start + weights.data_offsets.1 as usize];
assert_eq!(weights_data, f32_bytes);
let bias_data = &bytes[data_section_start + bias.data_offsets.0 as usize
..data_section_start + bias.data_offsets.1 as usize];
assert_eq!(bias_data, i64_bytes);
}
#[test]
fn test_writer_all_dtypes() {
use half::{bf16, f16};
// Test all supported data types (excluding QFloat which is tested separately)
// Format: (DType, expected_size_per_element, sample_data_bytes)
let test_cases = vec![
// Floating point types
(DType::F64, 8, 1.0f64.to_le_bytes().to_vec()),
(DType::F32, 4, 1.0f32.to_le_bytes().to_vec()),
(DType::F16, 2, f16::from_f32(1.0).to_le_bytes().to_vec()),
(DType::BF16, 2, bf16::from_f32(1.0).to_le_bytes().to_vec()),
// Signed integers
(DType::I64, 8, 1i64.to_le_bytes().to_vec()),
(DType::I32, 4, 1i32.to_le_bytes().to_vec()),
(DType::I16, 2, 1i16.to_le_bytes().to_vec()),
(DType::I8, 1, 1i8.to_le_bytes().to_vec()),
// Unsigned integers
(DType::U64, 8, 255u64.to_le_bytes().to_vec()),
(DType::U32, 4, 255u32.to_le_bytes().to_vec()),
(DType::U16, 2, 255u16.to_le_bytes().to_vec()),
(DType::U8, 1, vec![255u8]),
// Boolean
(DType::Bool, 1, vec![1u8]),
];
let mut snapshots = vec![];
let mut expected_data = vec![];
for (i, (dtype, expected_size, data)) in test_cases.into_iter().enumerate() {
let name = format!("tensor_{}", i);
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), vec![1], dtype),
vec![name.clone()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
expected_data.push((name, dtype, expected_size, data));
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
// Parse and verify metadata
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
.unwrap();
assert_eq!(
metadata.tensors.len(),
13,
"Expected 13 dtypes to be tested"
);
// Verify each tensor's metadata and data
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
for (name, expected_dtype, expected_size, expected_bytes) in expected_data {
let tensor = metadata
.tensors
.get(&name)
.unwrap_or_else(|| panic!("Missing tensor: {}", name));
assert_eq!(tensor.dtype, expected_dtype, "DType mismatch for {}", name);
assert_eq!(tensor.shape, vec![1], "Shape mismatch for {}", name);
// Verify data size matches expected
let data_size = (tensor.data_offsets.1 - tensor.data_offsets.0) as usize;
assert_eq!(
data_size, expected_size,
"Data size mismatch for {} ({:?})",
name, expected_dtype
);
// Verify actual data bytes match
let actual_bytes = &bytes[data_section_start + tensor.data_offsets.0 as usize
..data_section_start + tensor.data_offsets.1 as usize];
assert_eq!(
actual_bytes,
expected_bytes.as_slice(),
"Data mismatch for {} ({:?})",
name,
expected_dtype
);
}
}
#[test]
fn test_writer_all_dtypes_round_trip() {
use crate::burnpack::reader::BurnpackReader;
use half::{bf16, f16};
// Test all dtypes can be written and read back correctly
let test_cases = vec![
// Floating point types - use multiple elements to better test
(
"f64_tensor",
DType::F64,
[1.0f64, 2.0, 3.0, 4.0]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![4],
),
(
"f32_tensor",
DType::F32,
[1.0f32, 2.0, 3.0, 4.0]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2, 2],
),
(
"f16_tensor",
DType::F16,
[f16::from_f32(1.0), f16::from_f32(2.0)]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2],
),
(
"bf16_tensor",
DType::BF16,
[bf16::from_f32(1.0), bf16::from_f32(2.0)]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2],
),
// Signed integers
(
"i64_tensor",
DType::I64,
[1i64, -2, 3, -4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![4],
),
(
"i32_tensor",
DType::I32,
[1i32, -2, 3, -4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2, 2],
),
(
"i16_tensor",
DType::I16,
[1i16, -2, 3, -4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![4],
),
(
"i8_tensor",
DType::I8,
[1i8, -2, 3, -4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2, 2],
),
// Unsigned integers
(
"u64_tensor",
DType::U64,
[1u64, 2, 3, 4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![4],
),
(
"u32_tensor",
DType::U32,
[1u32, 2, 3, 4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![2, 2],
),
(
"u16_tensor",
DType::U16,
[1u16, 2, 3, 4]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect::<Vec<u8>>(),
vec![4],
),
("u8_tensor", DType::U8, vec![1u8, 2, 3, 4], vec![2, 2]),
// Boolean
("bool_tensor", DType::Bool, vec![1u8, 0, 1, 0], vec![4]),
];
let mut snapshots = vec![];
let mut expected_results: Vec<(&str, DType, Vec<u8>, Vec<usize>)> = vec![];
for (name, dtype, data, shape) in test_cases.into_iter() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(data.clone(), shape.clone(), dtype),
vec![name.to_string()],
vec![],
burn_core::module::ParamId::new(),
);
snapshots.push(snapshot);
expected_results.push((name, dtype, data, shape));
}
// Write to bytes
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
// Read back using BurnpackReader
let reader = BurnpackReader::from_bytes(bytes).unwrap();
// Verify each tensor can be read back with correct data
for (name, expected_dtype, expected_data, expected_shape) in expected_results {
let snapshot = reader
.get_tensor_snapshot(name)
.unwrap_or_else(|e| panic!("Failed to get tensor snapshot {}: {}", name, e));
let tensor_data = snapshot
.to_data()
.unwrap_or_else(|e| panic!("Failed to read tensor data {}: {}", name, e));
assert_eq!(
tensor_data.dtype, expected_dtype,
"DType mismatch for {}",
name
);
assert_eq!(
tensor_data.shape, expected_shape,
"Shape mismatch for {}",
name
);
assert_eq!(
&tensor_data.bytes[..],
expected_data.as_slice(),
"Data mismatch for {}",
name
);
}
}
#[test]
fn test_writer_large_tensor() {
// Create a large tensor (1MB)
let size = 256 * 1024; // 256K floats = 1MB
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(bytes.clone(), vec![size], DType::F32),
vec!["large_tensor".to_string()],
vec![],
burn_core::module::ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let result = writer.to_bytes().unwrap();
// Verify the large tensor is correctly stored
let header = BurnpackHeader::from_bytes(&result[..HEADER_SIZE]).unwrap();
let metadata: BurnpackMetadata = ciborium::de::from_reader(
&result[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize],
)
.unwrap();
assert_eq!(metadata.tensors.len(), 1);
let tensor = metadata.tensors.get("large_tensor").unwrap();
assert_eq!(tensor.shape, vec![size as u64]);
assert_eq!(
tensor.data_offsets.1 - tensor.data_offsets.0,
(size * 4) as u64
);
}
#[test]
fn test_writer_empty_tensors() {
// Add tensor with empty data
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
vec!["empty".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
let bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
.unwrap();
assert_eq!(metadata.tensors.len(), 1);
let tensor = metadata.tensors.get("empty").unwrap();
assert_eq!(tensor.shape, vec![0]);
assert_eq!(tensor.data_offsets.1 - tensor.data_offsets.0, 0);
}
#[test]
fn test_writer_special_characters_in_names() {
// Test various special characters in tensor names
let special_names = vec![
"layer.0.weight",
"model/encoder/layer1",
"model::layer::weight",
"layer[0].bias",
"layer_1_weight",
"layer-1-bias",
"layer@1#weight",
"emoji_😀_tensor",
"unicode_测试_tensor",
"spaces in name",
];
let mut snapshots = vec![];
for name in &special_names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
.unwrap();
assert_eq!(metadata.tensors.len(), 10);
for (tensor_name, _tensor) in metadata.tensors.iter() {
assert!(!tensor_name.is_empty());
// Names should be preserved exactly
assert!(
tensor_name.contains("layer")
|| tensor_name.contains("model")
|| tensor_name.contains("emoji")
|| tensor_name.contains("unicode")
|| tensor_name.contains("spaces")
);
}
}
#[test]
fn test_writer_metadata_overwrite() {
let writer = BurnpackWriter::new(vec![])
.with_metadata("key", "value1")
.with_metadata("key", "value2");
assert_eq!(writer.metadata.get("key"), Some(&"value2".to_string()));
assert_eq!(writer.metadata.len(), 1);
}
#[test]
fn test_writer_tensor_order_preserved() {
// Add tensors in specific order
let names = vec!["z_tensor", "a_tensor", "m_tensor", "b_tensor"];
let mut snapshots = vec![];
for name in &names {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),
vec![name.to_string()],
vec![],
ParamId::new(),
);
snapshots.push(snapshot);
}
let writer = BurnpackWriter::new(snapshots);
let bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
.unwrap();
// Verify all tensors are present (BTreeMap stores in sorted order by key)
let expected_sorted = vec!["a_tensor", "b_tensor", "m_tensor", "z_tensor"];
let actual_names: Vec<_> = metadata.tensors.keys().collect();
assert_eq!(actual_names, expected_sorted);
}
#[test]
fn test_writer_lazy_snapshot_evaluation() {
// Create a lazy snapshot using closure
let data = Rc::new(vec![1.0f32, 2.0, 3.0, 4.0]);
let data_clone = data.clone();
let snapshot = TensorSnapshot::from_closure(
Rc::new(move || {
let bytes: Vec<u8> = data_clone.iter().flat_map(|f| f.to_le_bytes()).collect();
Ok(TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32))
}),
DType::F32,
vec![2, 2],
vec!["lazy".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
// The closure should only be evaluated when to_bytes is called
let bytes = writer.to_bytes().unwrap();
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
let metadata: BurnpackMetadata =
ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();
assert_eq!(metadata.tensors.len(), 1);
let tensor = metadata.tensors.get("lazy").unwrap();
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.shape, vec![2, 2]);
// Verify the data was correctly written
// Data section starts at aligned position after metadata
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
let tensor_data = &bytes[data_section_start..data_section_start + 16];
let expected: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
assert_eq!(tensor_data, expected.as_slice());
}
#[cfg(feature = "std")]
#[test]
fn test_writer_write_to_file() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.bpk");
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("file_test", "true");
writer.write_to_file(&file_path).unwrap();
// Verify file exists and has correct content
assert!(file_path.exists());
let file_bytes = fs::read(&file_path).unwrap();
let memory_bytes = writer.to_bytes().unwrap();
assert_eq!(file_bytes.as_slice(), &*memory_bytes);
}
#[test]
fn test_writer_size() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
let size = writer.size().unwrap();
let bytes = writer.to_bytes().unwrap();
// Size should match actual bytes length
assert_eq!(size, bytes.len());
}
#[test]
fn test_writer_write_into() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
// Get size and allocate buffer
let size = writer.size().unwrap();
let mut buffer = vec![0u8; size];
// Write into buffer
writer.write_into(&mut buffer).unwrap();
// Compare with to_bytes()
let bytes = writer.to_bytes().unwrap();
assert_eq!(buffer.as_slice(), &*bytes);
}
#[test]
fn test_writer_write_into_buffer_too_small() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
// Allocate a buffer that's too small
let mut buffer = vec![0u8; 10];
// Should fail with buffer too small error
let result = writer.write_into(&mut buffer);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Buffer too small"));
}
#[test]
fn test_writer_write_into_buffer_larger_than_needed() {
let snapshot = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["test".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot]);
// Allocate a larger buffer
let size = writer.size().unwrap();
let mut buffer = vec![0u8; size + 100]; // Extra 100 bytes
// Should succeed and only write the necessary bytes
writer.write_into(&mut buffer).unwrap();
// Compare the written portion with to_bytes()
let bytes = writer.to_bytes().unwrap();
assert_eq!(&buffer[..size], &*bytes);
}
#[test]
fn test_writer_write_into_multiple_tensors() {
let snapshot1 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
vec!["tensor1".to_string()],
vec![],
ParamId::new(),
);
let snapshot2 = TensorSnapshot::from_data(
TensorData::from_bytes_vec(vec![5, 6, 7, 8, 9, 10], vec![2, 3], DType::U8),
vec!["tensor2".to_string()],
vec![],
ParamId::new(),
);
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]).with_metadata("test", "multiple");
let size = writer.size().unwrap();
let mut buffer = vec![0u8; size];
writer.write_into(&mut buffer).unwrap();
let bytes = writer.to_bytes().unwrap();
assert_eq!(buffer.as_slice(), &*bytes);
}
#[test]
fn test_writer_write_into_empty() {
let writer = BurnpackWriter::new(vec![]);
let size = writer.size().unwrap();
let mut buffer = vec![0u8; size];
writer.write_into(&mut buffer).unwrap();
let bytes = writer.to_bytes().unwrap();
assert_eq!(buffer.as_slice(), &*bytes);
}

View File

@@ -0,0 +1,211 @@
//! Tests for zero-copy tensor loading functionality.
use crate::ModuleStore;
use crate::burnpack::store::BurnpackStore;
use burn_core as burn;
use burn_core::module::{Module, Param};
use burn_tensor::{AllocationProperty, Bytes, Tensor, backend::Backend};
type TestBackend = burn_ndarray::NdArray;
#[derive(Module, Debug)]
struct SimpleModule<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Tensor<B, 1>>,
}
impl<B: Backend> SimpleModule<B> {
fn new(device: &B::Device) -> Self {
Self {
weight: Param::from_data([[1.0f32, 2.0], [3.0, 4.0]], device),
bias: Param::from_data([0.5f32, 1.5], device),
}
}
fn new_zeros(device: &B::Device) -> Self {
Self {
weight: Param::from_tensor(Tensor::zeros([2, 2], device)),
bias: Param::from_tensor(Tensor::zeros([2], device)),
}
}
}
/// Test that from_static creates a store with zero_copy enabled by default.
#[test]
fn test_from_static_enables_zero_copy() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to bytes first
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
// Convert to Vec<u8> and then leak to get &'static [u8]
let bytes_vec: Vec<u8> = bytes.to_vec();
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
// Create store from static - zero_copy should be enabled
let mut load_store = BurnpackStore::from_static(static_bytes);
// Load into a new module
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
// Verify data is correct
let loaded_weight = loaded_module.weight.val().to_data();
let loaded_bias = loaded_module.bias.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
assert_eq!(loaded_bias.to_vec::<f32>().unwrap(), vec![0.5, 1.5]);
}
/// Test that zero_copy builder method works.
#[test]
fn test_zero_copy_builder_method() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to bytes first
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
// Create shared bytes for zero-copy
let shared = bytes::Bytes::from(bytes.to_vec());
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
// Create store with zero_copy enabled
let mut load_store = BurnpackStore::from_bytes(Some(cubecl_bytes)).zero_copy(true);
// Load into a new module
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
// Verify data is correct
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
/// Test that zero_copy(false) uses copying even with shared bytes.
#[test]
fn test_zero_copy_disabled_uses_copy() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to bytes first
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
// Convert to Vec<u8> and then leak to get &'static [u8]
let bytes_vec: Vec<u8> = bytes.to_vec();
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
// Create store from static but disable zero_copy
let mut load_store = BurnpackStore::from_static(static_bytes).zero_copy(false);
// Load into a new module
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
// Verify data is correct (copied, not zero-copy)
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
/// Test that from_bytes with regular Bytes uses copying by default.
#[test]
fn test_from_bytes_uses_copy_by_default() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to bytes
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
// Load from bytes (default: zero_copy = false)
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
// Verify data is correct
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
/// Test that slice_bytes works correctly on StorageBackend.
#[test]
fn test_storage_backend_slice_bytes() {
use crate::burnpack::reader::BurnpackReader;
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to bytes first
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
// Create shared bytes
let shared = bytes::Bytes::from(bytes.to_vec());
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
// Create reader and get snapshots with zero-copy
let reader = BurnpackReader::from_bytes(cubecl_bytes).unwrap();
let snapshots = reader.get_snapshots_zero_copy(true).unwrap();
// Verify we got the expected number of tensors
assert_eq!(snapshots.len(), 2);
// Load the tensor data
for snapshot in &snapshots {
let data = snapshot.to_data().unwrap();
// Just verify we can access the data - the actual content depends on tensor order
assert!(!data.bytes.is_empty());
}
}
/// Test that zero_copy=true with file-based loading works (via mmap + bytes::Bytes).
#[test]
fn test_zero_copy_file_based_works() {
use tempfile::NamedTempFile;
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
// Save to a temporary file
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
let mut save_store = BurnpackStore::from_file(path).overwrite(true);
save_store.collect_from(&module).unwrap();
// Load with zero_copy=true - should work because mmap is converted to bytes::Bytes
let mut load_store = BurnpackStore::from_file(path).zero_copy(true);
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
// The apply should succeed - mmap now supports zero-copy via bytes::Bytes::from_owner()
load_store.apply_to(&mut loaded_module).unwrap();
// Verify data is correct
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}

View File

@@ -0,0 +1,331 @@
use super::base::{
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,
};
use crate::TensorSnapshot;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::Bytes;
#[cfg(feature = "std")]
use std::fs::File;
#[cfg(feature = "std")]
use std::io::Write;
#[cfg(feature = "std")]
use std::path::Path;
/// Align an offset to the specified alignment boundary.
///
/// Returns the smallest value >= `offset` that is a multiple of `alignment`.
#[inline]
const fn align_offset(offset: u64, alignment: u64) -> u64 {
offset.div_ceil(alignment) * alignment
}
/// Writer for creating Burnpack files
pub struct BurnpackWriter {
/// Tensors to write
pub(crate) snapshots: Vec<TensorSnapshot>,
/// Metadata key-value pairs
pub(crate) metadata: BTreeMap<String, String>,
}
impl BurnpackWriter {
/// Create a new writer
pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {
Self {
snapshots,
metadata: BTreeMap::new(),
}
}
/// Builder pattern: add metadata and return self
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata.insert(key.to_string(), value.to_string());
self
}
/// Build tensor descriptors and metadata
fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {
// Build tensor descriptors and calculate offsets with alignment
let mut tensors = BTreeMap::new();
let mut current_offset = 0u64;
for snapshot in &self.snapshots {
let data_len = snapshot.data_len() as u64;
// Align the start offset for mmap zero-copy support
let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);
let end = aligned_start.checked_add(data_len).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor offset overflow: {} + {} exceeds maximum",
aligned_start, data_len
))
})?;
tensors.insert(
snapshot.full_path(),
TensorDescriptor {
dtype: snapshot.dtype,
shape: snapshot.shape.iter().map(|&s| s as u64).collect(),
data_offsets: (aligned_start, end),
param_id: snapshot.tensor_id.map(|id| id.val()),
},
);
current_offset = end;
}
// Create metadata structure
let metadata = BurnpackMetadata {
tensors,
metadata: self.metadata.clone(),
};
// Serialize metadata with CBOR
let mut metadata_bytes = Vec::new();
ciborium::ser::into_writer(&metadata, &mut metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
Ok((metadata, metadata_bytes))
}
/// Calculate the total size needed for the burnpack data
///
/// This is useful when you want to pre-allocate a buffer for `write_into()`.
/// The size includes padding bytes for both metadata alignment and tensor alignment.
pub fn size(&self) -> Result<usize, BurnpackError> {
let (metadata, metadata_bytes) = self.build_metadata()?;
// Data section starts at aligned position after header + metadata
let data_section_start = aligned_data_section_start(metadata_bytes.len());
// Calculate total data section size from aligned offsets
// The last tensor's end offset gives us the total data section size
let data_size = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0) as usize;
Ok(data_section_start + data_size)
}
/// Write burnpack data into a caller-provided buffer
///
/// The buffer must be large enough to hold all data. Use `size()` to determine
/// the required buffer size. If the buffer is too small, this will return an error.
///
/// This allows the caller to control buffer allocation, enabling optimizations like:
/// - Buffer reuse across multiple writes
/// - Custom allocators
/// - Pinned memory for GPU transfers
///
/// # Arguments
///
/// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes.
pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {
let (metadata, metadata_bytes) = self.build_metadata()?;
// Check metadata size fits in u32
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Metadata size {} exceeds maximum of {} bytes",
metadata_bytes.len(),
u32::MAX
))
})?;
// Create header
let header = BurnpackHeader {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
};
// Data section starts at aligned position after header + metadata
let data_section_start = aligned_data_section_start(metadata_bytes.len());
// Calculate required size from aligned offsets
let data_size = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0) as usize;
let total_size = data_section_start + data_size;
// Check buffer size
if buffer.len() < total_size {
return Err(BurnpackError::IoError(format!(
"Buffer too small: need {} bytes, got {} bytes",
total_size,
buffer.len()
)));
}
let mut offset = 0;
// Write header
let header_bytes = header.into_bytes();
buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);
offset += HEADER_SIZE;
// Write metadata
buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);
offset += metadata_bytes.len();
// Write padding to align data section start
if data_section_start > offset {
buffer[offset..data_section_start].fill(0);
offset = data_section_start;
}
// Write tensor data with alignment padding
for snapshot in &self.snapshots {
// Get the aligned offset from metadata
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Internal error: tensor '{}' not found in metadata",
snapshot.full_path()
))
})?;
let aligned_offset = descriptor.data_offsets.0 as usize;
let target_offset = data_section_start + aligned_offset;
// Write padding zeros if needed
if target_offset > offset {
buffer[offset..target_offset].fill(0);
offset = target_offset;
}
let expected_len = snapshot.data_len();
let data = snapshot.to_data().map_err(|e| {
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
})?;
let actual_len = data.bytes.len();
// Validate data length consistency
if actual_len != expected_len {
return Err(BurnpackError::IoError(format!(
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
snapshot.full_path(),
expected_len,
actual_len
)));
}
buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);
offset += actual_len;
}
Ok(())
}
/// Write to a byte buffer (convenience method)
///
/// This allocates a buffer internally and writes the burnpack data.
/// For more control over buffer allocation, use `size()` + `write_into()`.
pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {
let size = self.size()?;
let mut buffer = vec![0u8; size];
self.write_into(&mut buffer)?;
Ok(Bytes::from_bytes_vec(buffer))
}
/// Write directly to a file (more memory efficient for large models)
#[cfg(feature = "std")]
pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {
let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
let (metadata, metadata_bytes) = self.build_metadata()?;
// Check metadata size fits in u32
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Metadata size {} exceeds maximum of {} bytes",
metadata_bytes.len(),
u32::MAX
))
})?;
// Create and write header
let header = BurnpackHeader {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
};
file.write_all(&header.into_bytes())
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
// Write metadata
file.write_all(&metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
// Data section starts at aligned position after header + metadata
let data_section_start = aligned_data_section_start(metadata_bytes.len());
let current_file_pos = HEADER_SIZE + metadata_bytes.len();
// Write padding to align data section start
if data_section_start > current_file_pos {
let padding_size = data_section_start - current_file_pos;
let padding = vec![0u8; padding_size];
file.write_all(&padding)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
}
// Track current position within data section (relative to data_section_start)
let mut data_offset = 0usize;
// Stream tensor data directly to file with alignment padding
for snapshot in &self.snapshots {
// Get the aligned offset from metadata
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Internal error: tensor '{}' not found in metadata",
snapshot.full_path()
))
})?;
let aligned_offset = descriptor.data_offsets.0 as usize;
// Write padding zeros if needed
if aligned_offset > data_offset {
let padding_size = aligned_offset - data_offset;
let padding = vec![0u8; padding_size];
file.write_all(&padding)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
data_offset = aligned_offset;
}
let expected_len = snapshot.data_len();
let data = snapshot.to_data().map_err(|e| {
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
})?;
let actual_len = data.bytes.len();
// Validate data length consistency
if actual_len != expected_len {
return Err(BurnpackError::IoError(format!(
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
snapshot.full_path(),
expected_len,
actual_len
)));
}
file.write_all(&data.bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
data_offset += actual_len;
}
file.flush()
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,625 @@
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
#[cfg(feature = "std")]
use regex::Regex;
/// A sophisticated path filter that supports multiple matching strategies.
///
/// The filter uses an OR logic - a path is included if it matches ANY of the configured criteria.
/// This allows for flexible and powerful filtering configurations.
///
/// # Examples
///
/// ```rust,no_run
/// # use burn_store::PathFilter;
/// // Create a filter that matches encoder paths or any weight path
/// let filter = PathFilter::new()
/// .with_regex(r"^encoder\..*")
/// .with_regex(r".*\.weight$")
/// .with_full_path("special_tensor");
///
/// // Check if a path should be included
/// if filter.matches("encoder.layer1.weight") {
/// // This will match due to both regex patterns
/// }
/// ```
#[derive(Debug, Clone, Default)]
pub struct PathFilter {
/// Compiled regex patterns for matching paths
#[cfg(feature = "std")]
regex_patterns: Vec<Regex>,
/// Exact full paths to match
exact_paths: Vec<String>,
/// Predicate functions for custom matching logic based on path and container path
/// Note: These cannot be cloned, so we store them separately
predicates: Vec<fn(&str, &str) -> bool>,
/// If true, matches all paths (overrides other filters)
match_all: bool,
}
impl PathFilter {
/// Create a new empty filter (matches nothing by default)
pub fn new() -> Self {
Self::default()
}
/// Create a filter that matches all paths
pub fn all() -> Self {
Self {
match_all: true,
..Default::default()
}
}
/// Create a filter that matches nothing
pub fn none() -> Self {
Self::default()
}
/// Add a regex pattern for matching paths
#[cfg(feature = "std")]
pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
if let Ok(regex) = Regex::new(pattern.as_ref()) {
self.regex_patterns.push(regex);
}
// TODO: Consider returning Result to handle regex compilation errors
self
}
/// Add multiple regex patterns
#[cfg(feature = "std")]
pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for pattern in patterns {
if let Ok(regex) = Regex::new(pattern.as_ref()) {
self.regex_patterns.push(regex);
}
}
self
}
/// Add an exact full path to match
pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
self.exact_paths.push(path.into());
self
}
/// Add multiple exact full paths
pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.exact_paths.extend(paths.into_iter().map(|p| p.into()));
self
}
/// Add a predicate function for custom matching based on path and container path
pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
self.predicates.push(predicate);
self
}
/// Add multiple predicates
pub fn with_predicates<I>(mut self, predicates: I) -> Self
where
I: IntoIterator<Item = fn(&str, &str) -> bool>,
{
self.predicates.extend(predicates);
self
}
/// Set to match all paths
pub fn match_all(mut self) -> Self {
self.match_all = true;
self
}
/// Check if a path matches this filter (assumes empty container path for backward compatibility)
pub fn matches(&self, path: &str) -> bool {
self.matches_with_container_path_str(path, "")
}
/// Check if a path and container type match this filter (for backward compatibility)
pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool {
// For backward compatibility, treat single container type as the full path
self.matches_with_container_path_str(path, container_type)
}
/// Check if a path and container path match this filter
pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool {
let path_str = path.join(".");
let container_path = container_stack.join(".");
self.matches_with_container_path_str(&path_str, &container_path)
}
/// Check if a path and container path (dot-notated strings) match this filter
pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool {
// If match_all is set, always return true
if self.match_all {
return true;
}
// If no filters are configured, match nothing
if self.is_empty() {
return false;
}
// Check exact path matches
if self.exact_paths.iter().any(|p| p == path) {
return true;
}
// Check regex patterns (on the path)
#[cfg(feature = "std")]
{
for regex in &self.regex_patterns {
if regex.is_match(path) {
return true;
}
}
}
// Check predicates with container path
if self
.predicates
.iter()
.any(|pred| pred(path, container_path))
{
return true;
}
false
}
/// Check if the filter is empty (matches nothing)
pub fn is_empty(&self) -> bool {
if self.match_all {
return false;
}
#[cfg(feature = "std")]
let regex_empty = self.regex_patterns.is_empty();
#[cfg(not(feature = "std"))]
let regex_empty = true;
self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty
}
/// Get the number of filter criteria configured
pub fn criteria_count(&self) -> usize {
if self.match_all {
return 1;
}
#[allow(unused_mut)]
let mut count = self.exact_paths.len() + self.predicates.len();
#[cfg(feature = "std")]
{
count += self.regex_patterns.len();
}
count
}
/// Clear all regex patterns
#[cfg(feature = "std")]
pub fn clear_regex(&mut self) -> &mut Self {
self.regex_patterns.clear();
self
}
/// Clear all exact paths
pub fn clear_paths(&mut self) -> &mut Self {
self.exact_paths.clear();
self
}
/// Clear all predicates
pub fn clear_predicates(&mut self) -> &mut Self {
self.predicates.clear();
self
}
/// Clear all filters
pub fn clear(&mut self) -> &mut Self {
#[cfg(feature = "std")]
self.clear_regex();
self.clear_paths().clear_predicates();
self.match_all = false;
self
}
/// Create a filter from regex patterns only
#[cfg(feature = "std")]
pub fn from_regex_patterns<I, S>(patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self::new().with_regexes(patterns)
}
/// Create a filter from exact paths only
pub fn from_paths<I, S>(paths: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self::new().with_full_paths(paths)
}
/// Create a filter from a single predicate
pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self {
Self::new().with_predicate(predicate)
}
/// Combine with another filter using OR logic
pub fn or(mut self, other: Self) -> Self {
if self.match_all || other.match_all {
return Self::all();
}
#[cfg(feature = "std")]
{
self.regex_patterns.extend(other.regex_patterns);
}
self.exact_paths.extend(other.exact_paths);
self.predicates.extend(other.predicates);
self
}
}
impl fmt::Display for PathFilter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.match_all {
return write!(f, "PathFilter::all()");
}
if self.is_empty() {
return write!(f, "PathFilter::none()");
}
write!(f, "PathFilter[")?;
let mut parts = Vec::new();
#[cfg(feature = "std")]
if !self.regex_patterns.is_empty() {
parts.push(format!("regex: {:?}", self.regex_patterns));
}
if !self.exact_paths.is_empty() {
parts.push(format!("paths: {:?}", self.exact_paths));
}
if !self.predicates.is_empty() {
parts.push(format!("predicates: {}", self.predicates.len()));
}
write!(f, "{}]", parts.join(", "))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_filter() {
let filter = PathFilter::new();
assert!(filter.is_empty());
assert!(!filter.matches("encoder.weight"));
assert!(!filter.matches("decoder.bias"));
}
#[test]
fn match_all() {
let filter = PathFilter::all();
assert!(!filter.is_empty());
assert!(filter.matches("encoder.weight"));
assert!(filter.matches("decoder.bias"));
assert!(filter.matches("anything"));
}
#[test]
fn exact_paths() {
let filter = PathFilter::new()
.with_full_path("encoder.weight")
.with_full_path("decoder.bias");
assert!(filter.matches("encoder.weight"));
assert!(filter.matches("decoder.bias"));
assert!(!filter.matches("encoder.bias"));
assert!(!filter.matches("decoder.weight"));
}
#[test]
#[cfg(feature = "std")]
fn regex_patterns() {
let filter = PathFilter::new()
.with_regex(r"^encoder\..*")
.with_regex(r".*\.weight$");
assert!(filter.matches("encoder.layer1.bias"));
assert!(filter.matches("decoder.weight"));
assert!(filter.matches("encoder.weight"));
assert!(!filter.matches("decoder.bias"));
}
#[test]
fn predicates() {
fn contains_norm(path: &str, _container_path: &str) -> bool {
path.contains("norm")
}
fn is_short(path: &str, _container_path: &str) -> bool {
path.len() < 10
}
let filter = PathFilter::new()
.with_predicate(contains_norm)
.with_predicate(is_short);
assert!(filter.matches("norm.weight"));
assert!(filter.matches("layer.norm.bias"));
assert!(filter.matches("bias"));
assert!(!filter.matches("encoder.decoder.weight.long.name"));
}
#[test]
fn combined_filters() {
let filter = PathFilter::new()
.with_full_path("special.tensor")
.with_predicate(|path, _container_path| path.contains("attention"));
#[cfg(feature = "std")]
let filter = filter.with_regex(r"^encoder\..*");
assert!(filter.matches("special.tensor"));
assert!(filter.matches("self_attention.query"));
#[cfg(feature = "std")]
assert!(filter.matches("encoder.anything"));
assert!(!filter.matches("decoder.weight"));
}
#[test]
fn or_combination() {
let encoder_filter = PathFilter::new().with_full_path("encoder.weight");
let decoder_filter = PathFilter::new().with_full_path("decoder.bias");
let combined = encoder_filter.or(decoder_filter);
assert!(combined.matches("encoder.weight"));
assert!(combined.matches("decoder.bias"));
assert!(!combined.matches("model.head.weight"));
}
#[test]
#[cfg(feature = "std")]
fn common_patterns() {
// Test encoder pattern
let encoder = PathFilter::new().with_regex(r"^encoder\..*");
assert!(encoder.matches("encoder.weight"));
assert!(!encoder.matches("decoder.weight"));
// Test weights-only pattern
let weights = PathFilter::new().with_regex(r".*\.weight$");
assert!(weights.matches("encoder.weight"));
assert!(weights.matches("decoder.weight"));
assert!(!weights.matches("encoder.bias"));
// Test layer-specific patterns
let layers = PathFilter::new()
.with_regex(r"(^|.*\.)layers\.0\.")
.with_regex(r"(^|.*\.)layers\.2\.")
.with_regex(r"(^|.*\.)layers\.4\.");
assert!(layers.matches("model.layers.0.weight"));
assert!(layers.matches("layers.2.bias"));
assert!(!layers.matches("layers.1.weight"));
}
#[test]
fn criteria_count() {
let filter = PathFilter::new()
.with_full_path("path1")
.with_full_path("path2")
.with_predicate(|_, _| true);
#[cfg(feature = "std")]
let filter = filter.with_regex(".*");
#[cfg(feature = "std")]
assert_eq!(filter.criteria_count(), 4);
#[cfg(not(feature = "std"))]
assert_eq!(filter.criteria_count(), 3);
}
#[test]
fn clear_operations() {
let mut filter = PathFilter::new().with_full_path("test");
filter.clear_paths();
assert!(!filter.matches("test"));
filter.clear();
assert!(filter.is_empty());
}
#[test]
fn container_predicates() {
// Filter that matches only Linear module weights
let linear_weights = PathFilter::new().with_predicate(|path, container_path| {
container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight")
});
assert!(linear_weights.matches_with_container("layer1.weight", "Linear"));
assert!(!linear_weights.matches_with_container("layer1.weight", "Conv2d"));
assert!(!linear_weights.matches_with_container("layer1.bias", "Linear"));
// Filter for specific container types
let conv_only = PathFilter::new().with_predicate(|_path, container_path| {
let last = container_path.split('.').next_back();
last == Some("Conv2d") || last == Some("ConvTranspose2d")
});
assert!(conv_only.matches_with_container("encoder.weight", "Conv2d"));
assert!(conv_only.matches_with_container("decoder.weight", "ConvTranspose2d"));
assert!(!conv_only.matches_with_container("fc.weight", "Linear"));
// Combine path and container predicates
let combined = PathFilter::new()
.with_predicate(|path, _container_path| path.starts_with("encoder."))
.with_predicate(|_path, container_path| {
container_path.split('.').next_back() == Some("BatchNorm2d")
});
// Should match either condition (OR logic)
assert!(combined.matches_with_container("encoder.layer1", "Linear"));
assert!(combined.matches_with_container("decoder.bn", "BatchNorm2d"));
assert!(!combined.matches_with_container("decoder.layer", "Linear"));
}
#[test]
fn container_predicate_with_regex() {
// Combine regex patterns with container predicates
#[cfg(feature = "std")]
{
let filter = PathFilter::new()
.with_regex(r"^encoder\..*")
.with_predicate(|path, container_path| {
container_path.split('.').next_back() == Some("Linear")
&& path.contains(".bias")
});
// Matches due to regex
assert!(filter.matches_with_container("encoder.layer1.weight", "Conv2d"));
// Matches due to container predicate
assert!(filter.matches_with_container("decoder.fc.bias", "Linear"));
// Doesn't match either
assert!(!filter.matches_with_container("decoder.conv.weight", "Conv2d"));
}
}
#[test]
fn container_stack_predicates() {
// Filter using full container path - only tensors nested in a specific hierarchy
let nested_filter = PathFilter::new().with_predicate(|_path, container_path| {
// Check if tensor is nested within: Model -> TransformerBlock -> Linear
let parts: Vec<&str> = container_path.split('.').collect();
parts.len() >= 3
&& parts[0] == "Model"
&& parts[1] == "TransformerBlock"
&& parts[2] == "Linear"
});
assert!(nested_filter.matches_with_container_path_str(
"encoder.weight",
"Model.TransformerBlock.Linear.Param"
));
assert!(
!nested_filter
.matches_with_container_path_str("decoder.weight", "Model.Decoder.Linear.Param")
);
assert!(!nested_filter.matches_with_container_path_str(
"encoder.weight",
"Model.TransformerBlock.Conv2d.Param"
));
// Filter that checks for specific depth in hierarchy
let depth_filter = PathFilter::new().with_predicate(|_path, container_path| {
let parts: Vec<&str> = container_path.split('.').collect();
parts.len() == 4 && parts.get(2) == Some(&"Linear")
});
assert!(depth_filter.matches_with_container_path_str(
"model.layer.weight",
"Model.TransformerBlock.Linear.Param"
));
assert!(
!depth_filter
.matches_with_container_path_str("model.weight", "Model.TransformerBlock.Conv2d")
); // Too shallow
// Filter that checks any Linear in the path (not just the last)
let any_linear = PathFilter::new()
.with_predicate(|_path, container_path| container_path.contains("Linear"));
assert!(
any_linear.matches_with_container_path_str(
"some.path",
"Model.TransformerBlock.Linear.Param"
)
);
assert!(
any_linear.matches_with_container_path_str("other.path", "Model.Decoder.Linear.Param")
);
assert!(
!any_linear.matches_with_container_path_str(
"conv.path",
"Model.TransformerBlock.Conv2d.Param"
)
);
}
#[test]
fn container_path_dot_notation() {
// Filter using dot-notated container path
let dot_filter = PathFilter::new().with_predicate(|_path, container_path| {
container_path.starts_with("Model.TransformerBlock")
});
// Test with matches_with_container_path
assert!(
dot_filter.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
);
assert!(!dot_filter.matches_with_container_path_str("weight", "Model.Decoder.Linear"));
// Filter that checks for specific patterns in container path
let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| {
// Match any path that has Linear after a block
container_path.contains("Block.Linear") || container_path.contains("Block.Conv")
});
assert!(
pattern_filter
.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
);
assert!(pattern_filter.matches_with_container_path_str("weight", "Model.ResBlock.Conv2d"));
assert!(!pattern_filter.matches_with_container_path_str("weight", "Model.Linear.Param"));
// Filter combining path and container path patterns
let combined = PathFilter::new().with_predicate(|path, container_path| {
// Only weights in Linear layers that are inside blocks
path.ends_with(".weight")
&& container_path.contains("Block")
&& container_path.split('.').next_back() == Some("Linear")
});
assert!(
combined
.matches_with_container_path_str("layer.weight", "Model.TransformerBlock.Linear")
);
assert!(
!combined
.matches_with_container_path_str("layer.bias", "Model.TransformerBlock.Linear")
);
assert!(!combined.matches_with_container_path_str("layer.weight", "Model.Decoder.Linear"));
}
}

View File

@@ -0,0 +1,674 @@
use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use regex::{self, Regex};
use crate::TensorSnapshot;
/// Key remapper for transforming tensor names.
///
/// This allows mapping tensor names from one naming convention to another,
/// which is useful for loading models from different frameworks or versions.
///
/// # Examples
///
/// ```rust
/// # use burn_store::KeyRemapper;
/// // Create a key remapper
/// let remapper = KeyRemapper::new()
/// .add_pattern(r"^pytorch\.(.*)", "burn.$1").expect("valid regex") // pytorch.layer -> burn.layer
/// .add_pattern(r"\.gamma$", ".weight").expect("valid regex"); // layer.gamma -> layer.weight
///
/// // Use remapper with stores
/// // store.remap(remapper)
/// ```
#[derive(Debug, Clone, Default)]
pub struct KeyRemapper {
/// Pattern-based remapping rules (regex pattern, replacement string)
pub patterns: Vec<(Regex, String)>,
}
impl KeyRemapper {
/// Create a new empty key remapper
pub fn new() -> Self {
Self::default()
}
/// Add a remapping pattern (compiles regex)
///
/// # Arguments
///
/// * `from` - Source pattern (regex string)
/// * `to` - Replacement string (can include capture groups like `$1`)
///
/// # Returns
///
/// * `Ok(Self)` - Updated remapping configuration
/// * `Err(regex::Error)` - If regex compilation fails
pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
where
S1: AsRef<str>,
S2: Into<String>,
{
let regex = Regex::new(from.as_ref())?;
self.patterns.push((regex, to.into()));
Ok(self)
}
/// Create from a list of compiled regex patterns
pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
Self { patterns }
}
/// Create from string patterns (will compile to regex)
///
/// # Arguments
///
/// * `patterns` - Vector of (pattern, replacement) tuples
///
/// # Returns
///
/// * `Ok(Self)` - New remapping configuration
/// * `Err(regex::Error)` - If any regex compilation fails
pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
where
S1: AsRef<str>,
S2: Into<String>,
{
let mut compiled_patterns = Vec::new();
for (pattern, replacement) in patterns {
let regex = Regex::new(pattern.as_ref())?;
compiled_patterns.push((regex, replacement.into()));
}
Ok(Self {
patterns: compiled_patterns,
})
}
/// Create from an iterator of patterns
///
/// # Arguments
///
/// * `iter` - Iterator yielding (pattern, replacement) tuples
///
/// # Returns
///
/// * `Ok(Self)` - New remapping configuration
/// * `Err(regex::Error)` - If any regex compilation fails
pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
where
I: IntoIterator<Item = (S1, S2)>,
S1: AsRef<str>,
S2: Into<String>,
{
let patterns: Result<Vec<_>, _> = iter
.into_iter()
.map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
.collect();
Ok(Self {
patterns: patterns?,
})
}
/// Check if the remapping is empty
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
/// Convert to the format expected by remap_tensor_paths_with_patterns
pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
self.patterns.clone()
}
/// Remap tensor paths using the configured patterns.
///
/// # Arguments
///
/// * `tensors` - Vec of TensorSnapshots to remap
///
/// # Returns
///
/// A tuple containing:
/// * The remapped Vec of TensorSnapshots with updated paths
/// * A vector of (new_path, original_path) showing the transformations
pub fn remap(
&self,
mut tensors: Vec<TensorSnapshot>,
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
if self.patterns.is_empty() {
let remapped_names = tensors
.iter()
.map(|v| {
let path = v.full_path();
(path.clone(), path)
})
.collect();
return (tensors, remapped_names);
}
let mut remapped_snapshots = Vec::new();
let mut remapped_names = Vec::new();
for mut snapshot in tensors.drain(..) {
let original_path = snapshot.full_path();
let mut new_path = original_path.clone();
// Apply all patterns to get the new path
for (pattern, replacement) in &self.patterns {
if pattern.is_match(&new_path) {
new_path = pattern
.replace_all(&new_path, replacement.as_str())
.to_string();
}
}
// Update the snapshot's internal path_stack if the path changed
if new_path != original_path
&& let Some(ref mut path_stack) = snapshot.path_stack
{
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
}
remapped_names.push((new_path.clone(), original_path));
remapped_snapshots.push(snapshot);
}
(remapped_snapshots, remapped_names)
}
}
/// Map tensor paths to have contiguous numeric indices.
///
/// This function detects numeric indices in tensor paths and renumbers them
/// to be contiguous (0, 1, 2, ...) while preserving their relative order.
/// It handles nested sequential structures by processing ALL numeric indices
/// in each path independently based on their position context.
///
/// This is useful when loading PyTorch models that have gaps in layer numbering,
/// such as when using `nn.Sequential` with mixed layer types (e.g., Conv2d + ReLU
/// where only Conv2d has parameters).
///
/// # Example
///
/// Simple case - input paths:
/// - `fc.0.weight`, `fc.0.bias`
/// - `fc.2.weight`, `fc.2.bias`
/// - `fc.4.weight`, `fc.4.bias`
///
/// Output paths:
/// - `fc.0.weight`, `fc.0.bias`
/// - `fc.1.weight`, `fc.1.bias`
/// - `fc.2.weight`, `fc.2.bias`
///
/// Nested case - input paths:
/// - `feature.layers.0.conv_block.0.weight`
/// - `feature.layers.0.conv_block.2.weight`
/// - `feature.layers.2.conv_block.0.weight`
/// - `feature.layers.2.conv_block.2.weight`
///
/// Output paths:
/// - `feature.layers.0.conv_block.0.weight`
/// - `feature.layers.0.conv_block.1.weight`
/// - `feature.layers.1.conv_block.0.weight`
/// - `feature.layers.1.conv_block.1.weight`
///
/// # Arguments
///
/// * `tensors` - Vec of TensorSnapshots to map
///
/// # Returns
///
/// A tuple containing:
/// * The mapped Vec of TensorSnapshots with updated paths
/// * A vector of (new_path, original_path) showing the transformations
pub fn map_indices_contiguous(
mut tensors: Vec<TensorSnapshot>,
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
if tensors.is_empty() {
return (tensors, Vec::new());
}
// Step 1: Collect all paths and find all index positions
// For each index position (identified by prefix using ORIGINAL indices),
// collect all indices seen at that position.
//
// Key: prefix using original path (e.g., "feature.layers." or "feature.layers.0.conv_block.")
// Value: BTreeMap of original_index -> new_index
let mut index_maps: BTreeMap<String, BTreeMap<usize, usize>> = BTreeMap::new();
// First pass: collect all indices at each position using original prefixes
for snapshot in &tensors {
let path = snapshot.full_path();
let parts: Vec<&str> = path.split('.').collect();
// Check each part for numeric indices
for (i, part) in parts.iter().enumerate() {
if let Ok(index) = part.parse::<usize>() {
// The prefix is everything before this index (using original path)
let prefix = if i > 0 {
format!("{}.", parts[..i].join("."))
} else {
String::new()
};
index_maps
.entry(prefix)
.or_default()
.entry(index)
.or_insert(usize::MAX); // Placeholder
}
}
}
// Second pass: assign contiguous indices for each position
for indices in index_maps.values_mut() {
let mut sorted_indices: Vec<usize> = indices.keys().cloned().collect();
sorted_indices.sort();
for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() {
indices.insert(old_idx, new_idx);
}
}
// Third pass: apply the remapping to all tensors
// We use original prefixes for lookup since that's how we collected indices
let mut mapped_snapshots = Vec::new();
let mut transformations = Vec::new();
for mut snapshot in tensors.drain(..) {
let original_path = snapshot.full_path();
let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps);
// Update the snapshot's internal path_stack if the path changed
if new_path != original_path
&& let Some(ref mut path_stack) = snapshot.path_stack
{
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
}
transformations.push((new_path, original_path));
mapped_snapshots.push(snapshot);
}
(mapped_snapshots, transformations)
}
/// Remap all numeric indices in a path using the provided index maps.
/// Uses original path prefixes for lookup.
fn remap_all_indices_with_original_prefix(
path: &str,
index_maps: &BTreeMap<String, BTreeMap<usize, usize>>,
) -> String {
let parts: Vec<&str> = path.split('.').collect();
let mut result_parts: Vec<String> = Vec::with_capacity(parts.len());
for (i, part) in parts.iter().enumerate() {
if let Ok(index) = part.parse::<usize>() {
// Build the prefix from ORIGINAL parts (not remapped)
let prefix = if i > 0 {
format!("{}.", parts[..i].join("."))
} else {
String::new()
};
// Look up the new index using original prefix
if let Some(index_map) = index_maps.get(&prefix)
&& let Some(&new_index) = index_map.get(&index)
{
result_parts.push(new_index.to_string());
continue;
}
}
// Not a numeric index or no mapping found, keep as-is
result_parts.push((*part).to_string());
}
result_parts.join(".")
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use burn_core::module::ParamId;
use burn_tensor::TensorData;
fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
let data = TensorData {
bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
shape: vec![2, 2],
dtype: burn_tensor::DType::F32,
};
let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
}
#[test]
fn test_key_remapper_basic() {
let remapper = KeyRemapper::new()
.add_pattern(r"^encoder\.", "transformer.encoder.")
.expect("valid regex");
let tensors = vec![
create_test_tensor_snapshot("encoder.layer1.weight"),
create_test_tensor_snapshot("decoder.layer1.weight"),
];
let (remapped, transformations) = remapper.remap(tensors);
// Check that remapped views exist with correct paths
assert!(
remapped
.iter()
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "decoder.layer1.weight")
);
assert_eq!(remapped.len(), 2);
// Check transformations
let encoder_transform = transformations
.iter()
.find(|(_new, old)| old == "encoder.layer1.weight")
.expect("should find encoder transformation");
assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
}
#[test]
fn test_key_remapper_multiple_patterns() {
let remapper = KeyRemapper::new()
.add_pattern(r"^encoder\.", "transformer.encoder.")
.expect("valid regex")
.add_pattern(r"\.gamma$", ".weight")
.expect("valid regex");
let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
let (remapped, _) = remapper.remap(tensors);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
);
assert_eq!(remapped.len(), 1);
}
#[test]
fn test_key_remapper_from_patterns() {
let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
let (remapped, _) = remapper.remap(tensors);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "burn.linear.bias_param")
);
}
#[test]
fn test_key_remapper_empty() {
let remapper = KeyRemapper::new();
assert!(remapper.is_empty());
let tensors = vec![create_test_tensor_snapshot("test.weight")];
let (remapped, transformations) = remapper.remap(tensors);
assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
assert_eq!(remapped.len(), 1);
assert_eq!(transformations.len(), 1);
assert_eq!(
transformations[0],
("test.weight".to_string(), "test.weight".to_string())
);
}
#[test]
fn test_map_indices_contiguous_basic() {
// Simulate PyTorch nn.Sequential with Conv2d (0, 2, 4) and ReLU (1, 3, 5)
// Only Conv2d layers have parameters
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.0.bias"),
create_test_tensor_snapshot("fc.2.weight"),
create_test_tensor_snapshot("fc.2.bias"),
create_test_tensor_snapshot("fc.4.weight"),
create_test_tensor_snapshot("fc.4.bias"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
// Check that indices are now contiguous
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.bias"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.bias"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.bias"));
assert_eq!(reindexed.len(), 6);
// Check transformations
let transform_2_to_1 = transformations
.iter()
.find(|(_, old)| old == "fc.2.weight")
.expect("should find fc.2.weight transformation");
assert_eq!(transform_2_to_1.0, "fc.1.weight");
let transform_4_to_2 = transformations
.iter()
.find(|(_, old)| old == "fc.4.weight")
.expect("should find fc.4.weight transformation");
assert_eq!(transform_4_to_2.0, "fc.2.weight");
}
#[test]
fn test_map_indices_contiguous_already_contiguous() {
// Already contiguous indices should remain unchanged
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.1.weight"),
create_test_tensor_snapshot("fc.2.weight"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
assert_eq!(reindexed.len(), 3);
// All transformations should have same old and new paths
for (new, old) in &transformations {
assert_eq!(new, old);
}
}
#[test]
fn test_map_indices_contiguous_multiple_prefixes() {
// Different prefixes should be mapped independently
let tensors = vec![
create_test_tensor_snapshot("encoder.0.weight"),
create_test_tensor_snapshot("encoder.2.weight"),
create_test_tensor_snapshot("decoder.1.weight"),
create_test_tensor_snapshot("decoder.5.weight"),
];
let (reindexed, _) = map_indices_contiguous(tensors);
// encoder: 0, 2 -> 0, 1
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "encoder.0.weight")
);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "encoder.1.weight")
);
// decoder: 1, 5 -> 0, 1
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "decoder.0.weight")
);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "decoder.1.weight")
);
}
#[test]
fn test_map_indices_contiguous_no_indices() {
// Paths without indices should remain unchanged
let tensors = vec![
create_test_tensor_snapshot("encoder.weight"),
create_test_tensor_snapshot("decoder.bias"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "encoder.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "decoder.bias"));
for (new, old) in &transformations {
assert_eq!(new, old);
}
}
#[test]
fn test_map_indices_contiguous_empty() {
let tensors: Vec<TensorSnapshot> = vec![];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.is_empty());
assert!(transformations.is_empty());
}
#[test]
fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() {
// Mix of indexed and non-indexed paths
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.2.weight"),
create_test_tensor_snapshot("output.weight"), // no index
];
let (reindexed, _) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); // 2 -> 1
assert!(reindexed.iter().any(|v| v.full_path() == "output.weight")); // unchanged
}
#[test]
fn test_map_indices_contiguous_nested_sequential() {
// Test nested sequential structures like:
// feature = nn.Sequential(ConvBlock, ReLU, ConvBlock, ReLU, ConvBlock)
// where ConvBlock = nn.Sequential(Conv2d, ReLU, Conv2d)
//
// This produces paths like:
// feature.layers.0.conv_block.0.weight (layer 0, conv 0)
// feature.layers.0.conv_block.2.weight (layer 0, conv 2 - skipping ReLU at 1)
// feature.layers.2.conv_block.0.weight (layer 2 - skipping ReLU at 1, conv 0)
// feature.layers.2.conv_block.2.weight (layer 2, conv 2)
let tensors = vec![
create_test_tensor_snapshot("feature.layers.0.conv_block.0.weight"),
create_test_tensor_snapshot("feature.layers.0.conv_block.2.weight"),
create_test_tensor_snapshot("feature.layers.2.conv_block.0.weight"),
create_test_tensor_snapshot("feature.layers.2.conv_block.2.weight"),
];
let (mapped, transformations) = map_indices_contiguous(tensors);
// Expected mapping:
// feature.layers: 0, 2 -> 0, 1
// feature.layers.0.conv_block: 0, 2 -> 0, 1
// feature.layers.2.conv_block: 0, 2 -> 0, 1
//
// Result:
// feature.layers.0.conv_block.0.weight -> feature.layers.0.conv_block.0.weight
// feature.layers.0.conv_block.2.weight -> feature.layers.0.conv_block.1.weight
// feature.layers.2.conv_block.0.weight -> feature.layers.1.conv_block.0.weight
// feature.layers.2.conv_block.2.weight -> feature.layers.1.conv_block.1.weight
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.0.conv_block.0.weight"),
"0.0 should stay as 0.0"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.0.conv_block.1.weight"),
"0.2 should become 0.1"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.1.conv_block.0.weight"),
"2.0 should become 1.0"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.1.conv_block.1.weight"),
"2.2 should become 1.1"
);
// Verify specific transformations
let t1 = transformations
.iter()
.find(|(_, old)| old == "feature.layers.2.conv_block.2.weight");
assert_eq!(
t1.map(|(new, _)| new.as_str()),
Some("feature.layers.1.conv_block.1.weight"),
"2.2 should map to 1.1"
);
}
#[test]
fn test_map_indices_contiguous_deeply_nested() {
// Test with three levels of nesting
let tensors = vec![
create_test_tensor_snapshot("a.0.b.0.c.0.weight"),
create_test_tensor_snapshot("a.0.b.0.c.2.weight"),
create_test_tensor_snapshot("a.0.b.2.c.0.weight"),
create_test_tensor_snapshot("a.2.b.0.c.0.weight"),
];
let (mapped, _) = map_indices_contiguous(tensors);
// a: 0, 2 -> 0, 1
// a.0.b: 0, 2 -> 0, 1
// a.2.b: 0 -> 0
// a.0.b.0.c: 0, 2 -> 0, 1
// a.0.b.2.c: 0 -> 0
// a.2.b.0.c: 0 -> 0
assert!(mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.0.weight"));
assert!(
mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.1.weight"),
"a.0.b.0.c.2 should become a.0.b.0.c.1"
);
assert!(
mapped.iter().any(|v| v.full_path() == "a.0.b.1.c.0.weight"),
"a.0.b.2.c.0 should become a.0.b.1.c.0"
);
assert!(
mapped.iter().any(|v| v.full_path() == "a.1.b.0.c.0.weight"),
"a.2.b.0.c.0 should become a.1.b.0.c.0"
);
}
}

View File

@@ -0,0 +1,118 @@
#![cfg_attr(not(feature = "std"), no_std)]
//! # Burn Store
//!
//! Advanced model storage and serialization infrastructure for the Burn deep learning framework.
//!
//! This crate provides comprehensive functionality for storing and loading Burn modules
//! and their tensor data, with support for cross-framework interoperability, flexible filtering,
//! and efficient memory management through lazy materialization.
//!
//! ## Key Features
//!
//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support
//! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization
//! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation
//! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance
//! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates
//! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility
//! - **No-std Support**: Core functionality available in embedded and WASM environments
//!
//! ## Quick Start
//!
//! ### Basic Save and Load
//!
//! ```rust,ignore
//! use burn_store::{ModuleSnapshot, SafetensorsStore};
//!
//! // Save a model
//! let mut store = SafetensorsStore::from_file("model.safetensors");
//! model.save_into(&mut store)?;
//!
//! // Load a model
//! let mut store = SafetensorsStore::from_file("model.safetensors");
//! model.load_from(&mut store)?;
//! ```
//!
//! ### Loading PyTorch Models
//!
//! ```rust,ignore
//! use burn_store::PytorchStore;
//!
//! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter)
//! let mut store = PytorchStore::from_file("pytorch_model.pth")
//! .with_top_level_key("state_dict") // Access nested state dict if needed
//! .allow_partial(true); // Skip unknown tensors
//!
//! model.load_from(&mut store)?;
//! ```
//!
//! ### Filtering and Remapping
//!
//! ```rust,no_run
//! # use burn_store::SafetensorsStore;
//! // Save only specific layers with renaming
//! let mut store = SafetensorsStore::from_file("encoder.safetensors")
//! .with_regex(r"^encoder\..*") // Filter: only encoder layers
//! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X
//! .metadata("subset", "encoder_only");
//!
//! // Use store with model.save_into(&mut store)?;
//! ```
//!
//! ## Core Components
//!
//! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods
//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows
//! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format
//! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files
//! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving
//! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns
//! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility
//!
//! ## Feature Flags
//!
//! - `std`: Enables file I/O and other std-only features (default)
//! - `safetensors`: Enables SafeTensors format support (default)
extern crate alloc;
mod adapter;
mod applier;
mod apply_result;
mod collector;
mod filter;
mod tensor_snapshot;
mod traits;
pub use adapter::{
BurnToPyTorchAdapter, ChainAdapter, IdentityAdapter, ModuleAdapter, PyTorchToBurnAdapter,
};
pub use applier::Applier;
pub use apply_result::{ApplyError, ApplyResult};
pub use collector::Collector;
pub use filter::PathFilter;
pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError};
pub use traits::{ModuleSnapshot, ModuleStore};
#[cfg(feature = "std")]
mod keyremapper;
#[cfg(feature = "std")]
pub use keyremapper::{KeyRemapper, map_indices_contiguous};
#[cfg(feature = "pytorch")]
pub mod pytorch;
#[cfg(feature = "pytorch")]
pub use pytorch::{PytorchStore, PytorchStoreError};
#[cfg(feature = "safetensors")]
mod safetensors;
#[cfg(feature = "safetensors")]
pub use safetensors::{SafetensorsStore, SafetensorsStoreError};
#[cfg(feature = "burnpack")]
mod burnpack;
#[cfg(feature = "burnpack")]
pub use burnpack::writer::BurnpackWriter;
#[cfg(feature = "burnpack")]
pub use burnpack::{base::BurnpackError, store::BurnpackStore};

View File

@@ -0,0 +1,567 @@
//! Lazy data loading support for PyTorch files.
//!
//! This module provides abstractions for lazy loading of tensor data from PyTorch files,
//! avoiding the need to load all data into memory upfront.
use alloc::string::String;
use alloc::vec::Vec;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Seek};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use zip::ZipArchive;
/// A data source that can lazily load tensor data.
#[derive(Clone)]
pub enum LazyDataSource {
/// ZIP archive with lazy loading
Zip(Arc<Mutex<ZipSource>>),
/// TAR archive format (older torchvision models)
Tar(Arc<Mutex<TarSource>>),
/// Legacy format with multiple storages in single blob
LegacyMultiStorage(Arc<Mutex<LegacyMultiStorageSource>>),
}
/// ZIP archive source for lazy loading
pub struct ZipSource {
path: PathBuf,
// Cache the file list to avoid reopening archive repeatedly
file_list: Vec<(String, u64, u64)>, // (name, offset, compressed_size)
}
/// TAR archive source for lazy loading (older torchvision models like AlexNet, SqueezeNet)
///
/// Older PyTorch/torchvision models (pre-1.6) use TAR format instead of ZIP.
/// The TAR archive contains:
/// - `sys_info`: System info pickle (endianness, type sizes)
/// - `pickle`: OrderedDict mapping tensor names to storage keys
/// - `tensors`: Tensor metadata pickles (unused, metadata is embedded in pickle)
/// - `storages`: Storage count + sequential (metadata pickle, element count, raw data)
pub struct TarSource {
/// Cached storage map: storage_key -> (offset_in_storages, size_bytes)
storage_map: HashMap<String, (usize, usize)>,
/// The raw storages data (kept in memory for TAR format)
storages_data: Vec<u8>,
}
/// Legacy multi-storage source for old PyTorch format (0.1.10 - 1.5)
///
/// Legacy format stores tensor data as concatenated raw binary without explicit
/// storage boundaries. This source tracks storage usage during tensor parsing
/// to build a storage map for lazy loading.
///
/// ## Storage Layout
/// - Pickle metadata with tensor definitions
/// - List of storage keys (determines concatenation order)
/// - Raw binary blob with all storages concatenated
pub struct LegacyMultiStorageSource {
path: PathBuf,
data_offset: u64,
#[allow(dead_code)]
data_size: u64,
// Map of storage_key -> (offset_in_blob, size)
storage_map: RwLock<Option<HashMap<String, (u64, u64)>>>,
// Storage keys in order (for boundary calculation)
storage_keys: RwLock<Option<Vec<String>>>,
// Track storage usage as tensors are accessed
storage_usage: RwLock<HashMap<String, usize>>, // key -> max_bytes_needed
}
impl ZipSource {
/// Create a new ZIP source
pub fn new(path: PathBuf) -> std::io::Result<Self> {
let file = File::open(&path)?;
let reader = BufReader::new(file);
let mut archive = ZipArchive::new(reader)?;
// Cache file metadata
let mut file_list = Vec::new();
for i in 0..archive.len() {
let file = archive.by_index(i)?;
let name = file.name().to_string();
let offset = file.data_start();
let compressed_size = file.compressed_size();
file_list.push((
name,
offset.expect("should have an offset"),
compressed_size,
));
}
Ok(Self { path, file_list })
}
/// Check if a file exists in the archive
pub fn contains(&self, name: &str) -> bool {
self.file_list.iter().any(|(n, _, _)| n == name)
}
/// Get list of data files (excluding pickle files)
pub fn data_files(&self) -> Vec<String> {
self.file_list
.iter()
.filter(|(name, _, _)| name.starts_with("data/") || name.contains("/data/"))
.filter(|(name, _, _)| !name.ends_with(".pkl") && !name.ends_with("/"))
.map(|(name, _, _)| name.clone())
.collect()
}
/// Read a specific file from the archive
pub fn read_file(&self, name: &str) -> std::io::Result<Vec<u8>> {
let file = File::open(&self.path)?;
let reader = BufReader::new(file);
let mut archive = ZipArchive::new(reader)?;
let mut file = archive.by_name(name)?;
let mut contents = Vec::with_capacity(file.size() as usize);
file.read_to_end(&mut contents)?;
Ok(contents)
}
/// Read a portion of a file
pub fn read_file_range(
&self,
name: &str,
offset: usize,
length: usize,
) -> std::io::Result<Vec<u8>> {
let file = File::open(&self.path)?;
let reader = BufReader::new(file);
let mut archive = ZipArchive::new(reader)?;
let mut file = archive.by_name(name)?;
let mut buffer = vec![0u8; length];
// Skip to offset
let mut skip_buffer = vec![0u8; offset.min(8192)];
let mut skipped = 0;
while skipped < offset {
let to_skip = (offset - skipped).min(skip_buffer.len());
file.read_exact(&mut skip_buffer[..to_skip])?;
skipped += to_skip;
}
// Read the requested data
file.read_exact(&mut buffer)?;
Ok(buffer)
}
}
impl LegacyMultiStorageSource {
/// Create a new legacy multi-storage source
pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self {
Self {
path,
data_offset,
data_size,
storage_map: RwLock::new(None),
storage_keys: RwLock::new(None),
storage_usage: RwLock::new(HashMap::new()),
}
}
/// Set the ordered storage keys from the pickle
pub fn set_storage_keys(&self, keys: Vec<String>) {
let mut storage_keys = self
.storage_keys
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*storage_keys = Some(keys);
}
/// Track storage usage from tensor access
/// This is called from within tensor loading closures
pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) {
let mut usage = self
.storage_usage
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let max_extent = offset + size;
usage
.entry(storage_key.to_string())
.and_modify(|current| *current = (*current).max(max_extent))
.or_insert(max_extent);
// Try to build storage map if we have enough information
drop(usage);
self.try_build_storage_map();
}
/// Try to build the storage map from tracked usage
fn try_build_storage_map(&self) {
// Only build if we don't already have a map
if self
.storage_map
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.is_some()
{
return;
}
// Check if we have storage keys
let keys_guard = self
.storage_keys
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if let Some(ref keys) = *keys_guard {
let usage = self
.storage_usage
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
// Only build if we have usage info for all storages
if keys.iter().all(|k| usage.contains_key(k)) {
let mut map = HashMap::new();
let mut current_offset = 0u64;
for key in keys {
if let Some(&size) = usage.get(key) {
map.insert(key.clone(), (current_offset, size as u64));
current_offset += size as u64;
}
}
// Set the storage map
drop(keys_guard);
drop(usage);
let mut storage_map = self
.storage_map
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*storage_map = Some(map);
}
}
}
/// Read data for a specific storage key
/// Only loads the specific storage portion, never the entire blob
pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
// Extract numeric key from paths like "data/0" or just "0"
let storage_key = key.split('/').next_back().unwrap_or(key);
// Get storage map - must be available for lazy loading to work
let storage_map = self
.storage_map
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if let Some(ref map) = *storage_map
&& let Some(&(offset, size)) = map.get(storage_key)
{
// Load only this specific storage
let mut file = File::open(&self.path)?;
file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?;
let mut buffer = vec![0u8; size as usize];
file.read_exact(&mut buffer)?;
return Ok(buffer);
}
// NO FALLBACK! If we don't have storage boundaries, we cannot load data lazily
// The storage map MUST be built from tensor metadata for lazy loading to work
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
storage_key
),
))
}
}
impl TarSource {
/// Create a new TAR source by parsing storages data.
///
/// # Arguments
/// * `storages_data` - Raw storages blob with structure:
/// - Count pickle (number of storages)
/// - For each storage: metadata pickle + u64 num_elements + raw binary data
pub fn new(storages_data: Vec<u8>) -> std::io::Result<Self> {
use super::pickle_reader::{read_pickle, storage_type_to_element_size};
use std::io::Cursor;
let mut storage_map = HashMap::new();
let mut pos = 0usize;
// First, read the count of storages
let mut cursor = Cursor::new(&storages_data[pos..]);
let storage_count =
if let Ok(super::pickle_reader::Object::Int(count)) = read_pickle(&mut cursor) {
pos += cursor.position() as usize;
count as usize
} else {
0
};
// Parse each storage entry
for _i in 0..storage_count {
if pos >= storages_data.len() {
break;
}
// Read the storage metadata pickle: (storage_key, device, storage_type)
let mut cursor = Cursor::new(&storages_data[pos..]);
if let Ok(obj) = read_pickle(&mut cursor) {
let pickle_size = cursor.position() as usize;
pos += pickle_size;
// Extract storage info from pickle tuple
let (storage_key, storage_type) = match obj {
super::pickle_reader::Object::Tuple(tuple) if tuple.len() >= 3 => {
let key = match &tuple[0] {
super::pickle_reader::Object::Int(i) => i.to_string(),
super::pickle_reader::Object::String(s) => s.clone(),
_ => continue,
};
// tuple[1] is device (e.g., "cpu")
// tuple[2] is storage type class
let stype = match &tuple[2] {
super::pickle_reader::Object::Class { name, .. } => name.clone(),
other => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Expected Class for storage type, got {:?}", other),
));
}
};
(key, stype)
}
_ => continue,
};
// Read the number of elements (u64 little-endian)
if pos + 8 > storages_data.len() {
break;
}
let num_elements = u64::from_le_bytes([
storages_data[pos],
storages_data[pos + 1],
storages_data[pos + 2],
storages_data[pos + 3],
storages_data[pos + 4],
storages_data[pos + 5],
storages_data[pos + 6],
storages_data[pos + 7],
]) as usize;
pos += 8;
// Determine element size from storage type
let element_size = storage_type_to_element_size(&storage_type)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let data_size = num_elements * element_size;
// Store the offset to raw data and its size
storage_map.insert(storage_key, (pos, data_size));
// Skip the raw binary data
pos += data_size;
} else {
break;
}
}
Ok(Self {
storage_map,
storages_data,
})
}
/// Read data for a specific storage key
pub fn read_file(&self, key: &str) -> std::io::Result<Vec<u8>> {
// Extract the storage key from paths like "data/0"
let storage_key = key.split('/').next_back().unwrap_or(key);
if let Some(&(offset, size)) = self.storage_map.get(storage_key)
&& offset + size <= self.storages_data.len()
{
return Ok(self.storages_data[offset..offset + size].to_vec());
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Storage key '{}' not found in TAR archive", storage_key),
))
}
/// Read a range of data for a specific storage key (avoids double allocation)
pub fn read_file_range(
&self,
key: &str,
offset: usize,
length: usize,
) -> std::io::Result<Vec<u8>> {
let storage_key = key.split('/').next_back().unwrap_or(key);
if let Some(&(storage_offset, storage_size)) = self.storage_map.get(storage_key)
&& storage_offset + storage_size <= self.storages_data.len()
{
let start = storage_offset + offset;
let end = (storage_offset + offset + length).min(storage_offset + storage_size);
return Ok(self.storages_data[start..end].to_vec());
}
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Storage key '{}' not found in TAR archive", storage_key),
))
}
/// Check if a storage key exists
pub fn contains(&self, key: &str) -> bool {
let storage_key = key.split('/').next_back().unwrap_or(key);
self.storage_map.contains_key(storage_key)
}
/// Get list of storage keys
pub fn keys(&self) -> Vec<String> {
self.storage_map.keys().cloned().collect()
}
}
impl LazyDataSource {
/// Create from a ZIP file
pub fn from_zip(path: impl AsRef<Path>) -> std::io::Result<Self> {
Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new(
path.as_ref().to_path_buf(),
)?))))
}
/// Create from a TAR archive's storages data
pub fn from_tar(storages_data: &[u8]) -> std::io::Result<Self> {
Ok(Self::Tar(Arc::new(Mutex::new(TarSource::new(
storages_data.to_vec(),
)?))))
}
/// Create from a legacy multi-storage file
pub fn from_legacy_multi_storage(
path: impl AsRef<Path>,
data_offset: u64,
data_size: u64,
) -> Self {
Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new(
path.as_ref().to_path_buf(),
data_offset,
data_size,
))))
}
/// Read data for a specific key
pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
match self {
Self::Zip(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.read_file(key)
}
Self::Tar(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.read_file(key)
}
Self::LegacyMultiStorage(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.read(key)
}
}
}
/// Read a portion of data for a specific key
pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result<Vec<u8>> {
match self {
Self::Zip(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.read_file_range(key, offset, length)
}
Self::Tar(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.read_file_range(key, offset, length)
}
Self::LegacyMultiStorage(source) => {
// For legacy format, read only the requested range
let storage_key = key.split('/').next_back().unwrap_or(key);
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
// Get storage boundaries
let storage_map = source
.storage_map
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if let Some(ref map) = *storage_map
&& let Some(&(storage_offset, storage_size)) = map.get(storage_key)
{
// Calculate actual file position
let file_offset = source.data_offset + storage_offset + offset as u64;
let read_length = length.min((storage_size as usize).saturating_sub(offset));
// Read only the requested range
let mut file = File::open(&source.path)?;
file.seek(std::io::SeekFrom::Start(file_offset))?;
let mut buffer = vec![0u8; read_length];
file.read_exact(&mut buffer)?;
Ok(buffer)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
storage_key
),
))
}
}
}
}
/// Check if a key exists
pub fn contains(&self, key: &str) -> bool {
match self {
Self::Zip(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.contains(key)
}
Self::Tar(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.contains(key)
}
Self::LegacyMultiStorage(_) => true, // Legacy format has all data
}
}
/// Get list of available keys (for ZIP sources)
pub fn keys(&self) -> Vec<String> {
match self {
Self::Zip(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.data_files()
}
Self::Tar(source) => {
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.keys()
}
Self::LegacyMultiStorage(_) => vec![], // Legacy format doesn't have distinct keys
}
}
}

View File

@@ -0,0 +1,48 @@
//! PyTorch format support for burn-store.
//!
//! This module provides comprehensive support for loading PyTorch model files (.pth, .pt)
//! into Burn, with automatic weight transformation and flexible configuration options.
//!
//! ## Features
//!
//! - **Direct .pth/.pt file loading**: Load PyTorch checkpoint and state dict files
//! - **Automatic weight transformation**: `PyTorchToBurnAdapter` is applied by default:
//! - Linear layer weights are automatically transposed
//! - Normalization parameters are renamed (gamma → weight, beta → bias)
//! - Conv2d weights maintain their format
//! - **Flexible filtering**: Load only specific layers or parameters
//! - **Key remapping**: Rename tensors during loading to match your model structure
//! - **Partial loading**: Continue even when some tensors are missing
//!
//! ## Example
//!
//! ```rust,ignore
//! use burn_store::PytorchStore;
//!
//! // Load a PyTorch model (PyTorchToBurnAdapter is applied automatically)
//! let mut store = PytorchStore::from_file("model.pth")
//! .with_top_level_key("state_dict") // Access nested state dict
//! .with_regex(r"^encoder\..*") // Only load encoder layers
//! .with_key_remapping(r"^fc\.", "linear.") // Rename fc -> linear
//! .allow_partial(true); // Skip missing tensors
//!
//! let mut model = MyModel::new(&device);
//! let result = model.load_from(&mut store)?;
//!
//! println!("Loaded {} tensors", result.applied.len());
//! if !result.missing.is_empty() {
//! println!("Missing tensors: {:?}", result.missing);
//! }
//! ```
pub mod lazy_data;
pub mod pickle_reader;
pub mod reader;
pub mod store;
#[cfg(test)]
pub mod tests;
// Main public interface
pub use reader::{PytorchError, PytorchReader};
pub use store::{PytorchStore, PytorchStoreError};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,442 @@
//! PyTorch store implementation for saving and loading models in PyTorch format.
use crate::{
ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
TensorSnapshot, map_indices_contiguous,
};
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use core::fmt;
use std::path::PathBuf;
use super::reader::{PytorchError as ReaderError, PytorchReader};
/// Errors that can occur during PyTorch operations.
#[derive(Debug)]
pub enum PytorchStoreError {
/// Reader error.
Reader(ReaderError),
/// I/O error.
Io(std::io::Error),
/// Tensor not found.
TensorNotFound(String),
/// Validation failed.
ValidationFailed(String),
/// Other error.
Other(String),
}
impl fmt::Display for PytorchStoreError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for PytorchStoreError {}
impl From<ReaderError> for PytorchStoreError {
fn from(e: ReaderError) -> Self {
PytorchStoreError::Reader(e)
}
}
impl From<std::io::Error> for PytorchStoreError {
fn from(e: std::io::Error) -> Self {
PytorchStoreError::Io(e)
}
}
/// PyTorch store for file-based storage only.
///
/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)
/// with automatic weight transformation using `PyTorchToBurnAdapter`.
/// Linear weights are automatically transposed and normalization parameters
/// are renamed (gamma -> weight, beta -> bias).
///
/// Note that saving to PyTorch format is not yet supported.
pub struct PytorchStore {
pub(crate) path: PathBuf,
pub(crate) filter: PathFilter,
pub(crate) remapper: KeyRemapper,
pub(crate) validate: bool,
pub(crate) allow_partial: bool,
pub(crate) top_level_key: Option<String>,
pub(crate) skip_enum_variants: bool,
/// Enable contiguous mapping of layer indices (default: true)
pub(crate) map_indices_contiguous: bool,
/// Cached tensor snapshots (parsed once, reused)
snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
}
impl PytorchStore {
/// Create a store for loading from a PyTorch file.
///
/// # Arguments
/// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)
///
/// # Example
/// ```rust,no_run
/// use burn_store::PytorchStore;
///
/// let store = PytorchStore::from_file("model.pth");
/// ```
pub fn from_file(path: impl Into<PathBuf>) -> Self {
Self {
path: path.into(),
filter: PathFilter::new(),
remapper: KeyRemapper::new(),
validate: true,
allow_partial: false,
top_level_key: None,
// PyTorch models never include enum variant names in paths
skip_enum_variants: true,
// Enable contiguous index mapping by default for PyTorch files
// This handles nn.Sequential models with gaps in layer indices
map_indices_contiguous: true,
snapshots_cache: None,
}
}
/// Set a top-level key to extract tensors from.
///
/// PyTorch files often contain nested dictionaries. Use this to extract
/// tensors from a specific top-level key like "state_dict" or "model_state_dict".
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// let store = PytorchStore::from_file("checkpoint.pth")
/// .with_top_level_key("model_state_dict");
/// ```
pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
self.top_level_key = Some(key.into());
self
}
/// Filter which tensors to load.
pub fn filter(mut self, filter: PathFilter) -> Self {
self.filter = filter;
self
}
/// Add a regex pattern to filter tensors.
///
/// Multiple patterns can be added and they work with OR logic.
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// let store = PytorchStore::from_file("model.pth")
/// .with_regex(r"^encoder\..*") // Match all encoder tensors
/// .with_regex(r".*\.weight$"); // OR match any weight tensors
/// ```
pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
self.filter = self.filter.with_regex(pattern);
self
}
/// Add multiple regex patterns to filter tensors.
pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.filter = self.filter.with_regexes(patterns);
self
}
/// Add an exact full path to match.
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// let store = PytorchStore::from_file("model.pth")
/// .with_full_path("encoder.layer1.weight")
/// .with_full_path("decoder.output.bias");
/// ```
pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
self.filter = self.filter.with_full_path(path);
self
}
/// Add multiple exact full paths to match.
pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.filter = self.filter.with_full_paths(paths);
self
}
/// Add a predicate function for custom filtering logic.
///
/// The predicate receives the tensor path and container path.
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// let store = PytorchStore::from_file("model.pth")
/// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
/// ```
pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
self.filter = self.filter.with_predicate(predicate);
self
}
/// Add multiple predicate functions.
pub fn with_predicates<I>(mut self, predicates: I) -> Self
where
I: IntoIterator<Item = fn(&str, &str) -> bool>,
{
self.filter = self.filter.with_predicates(predicates);
self
}
/// Set the filter to match all paths (disables filtering).
pub fn match_all(mut self) -> Self {
self.filter = self.filter.match_all();
self
}
/// Remap tensor names during load.
pub fn remap(mut self, remapper: KeyRemapper) -> Self {
self.remapper = remapper;
self
}
/// Add a regex pattern to remap tensor names during load.
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// let store = PytorchStore::from_file("model.pth")
/// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X
/// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight
/// ```
pub fn with_key_remapping(
mut self,
from_pattern: impl AsRef<str>,
to_pattern: impl Into<String>,
) -> Self {
self.remapper = self
.remapper
.add_pattern(from_pattern, to_pattern)
.expect("Invalid regex pattern");
self
}
/// Set whether to validate tensors during loading (default: true).
pub fn validate(mut self, validate: bool) -> Self {
self.validate = validate;
self
}
/// Allow partial loading of tensors (continue even if some tensors are missing).
pub fn allow_partial(mut self, allow: bool) -> Self {
self.allow_partial = allow;
self
}
/// Skip enum variant names when matching tensor paths (default: true).
///
/// When enabled, tensor paths from PyTorch that don't include enum variants
/// can be matched against Burn module paths that do include them.
/// For example, PyTorch path "feature.weight" can match Burn path "feature.BaseConv.weight".
///
/// This defaults to `true` for PytorchStore since PyTorch models never include
/// enum variant names in their parameter paths.
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// // Disable enum variant skipping (not typical)
/// let store = PytorchStore::from_file("model.pth")
/// .skip_enum_variants(false);
/// ```
pub fn skip_enum_variants(mut self, skip: bool) -> Self {
self.skip_enum_variants = skip;
self
}
/// Enable or disable automatic contiguous mapping of layer indices (default: true).
///
/// When enabled, non-contiguous numeric indices in tensor paths are renumbered
/// to be contiguous. This is useful when loading PyTorch models that have gaps
/// in layer numbering, such as when using `nn.Sequential` with mixed layer types
/// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).
///
/// # Example
///
/// With index mapping enabled (default):
/// - `fc.0.weight` → `fc.0.weight`
/// - `fc.2.weight` → `fc.1.weight` (gap filled)
/// - `fc.4.weight` → `fc.2.weight` (gap filled)
///
/// # Arguments
///
/// * `map` - `true` to enable contiguous index mapping, `false` to disable
///
/// # Example
/// ```rust,no_run
/// # use burn_store::PytorchStore;
/// // Disable contiguous index mapping if your model already has contiguous indices
/// let store = PytorchStore::from_file("model.pth")
/// .map_indices_contiguous(false);
/// ```
pub fn map_indices_contiguous(mut self, map: bool) -> Self {
self.map_indices_contiguous = map;
self
}
/// Apply remapping to tensor snapshots.
fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
if self.remapper.is_empty() {
return snapshots;
}
let (remapped, _) = self.remapper.remap(snapshots);
remapped
}
/// Create a PytorchReader for the configured path and options.
fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {
let reader = if let Some(ref key) = self.top_level_key {
PytorchReader::with_top_level_key(&self.path, key)?
} else {
PytorchReader::new(&self.path)?
};
Ok(reader)
}
}
impl ModuleStore for PytorchStore {
type Error = PytorchStoreError;
fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
_module: &M,
) -> Result<(), Self::Error> {
// Saving to PyTorch format is not yet supported
Err(PytorchStoreError::Other(
"Saving to PyTorch format is not yet supported. Use other formats for saving."
.to_string(),
))
}
fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
module: &mut M,
) -> Result<ApplyResult, Self::Error> {
// Get snapshots from cache
let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
// Get filter (convert to Option for apply)
let filter_opt = if self.filter.is_empty() {
None
} else {
Some(self.filter.clone())
};
// Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)
// This adapter handles:
// - Transposing linear weights from PyTorch format to Burn format
// - Renaming normalization parameters (gamma -> weight, beta -> bias)
// Filter is applied here during apply, not during cache population
let result = module.apply(
snapshots,
filter_opt,
Some(Box::new(PyTorchToBurnAdapter)),
self.skip_enum_variants,
);
// Validate if needed
if self.validate && !result.errors.is_empty() {
return Err(PytorchStoreError::ValidationFailed(format!(
"Import errors:\n{}",
result
)));
}
if !self.allow_partial && !result.missing.is_empty() {
return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result)));
}
Ok(result)
}
fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap().get(name))
}
fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap())
}
fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
// Always use the cache to ensure remapping is applied consistently
Ok(self.get_all_snapshots()?.keys().cloned().collect())
}
}
impl PytorchStore {
/// Ensure the snapshots cache is populated
fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {
if self.snapshots_cache.is_some() {
return Ok(());
}
let reader = self.create_reader()?;
// Convert to tensor snapshots
let mut snapshots: Vec<TensorSnapshot> = reader
.into_tensors()
.into_iter()
.map(|(key, mut snapshot)| {
// Parse the key into path parts (split by '.')
let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
// Set the path stack from the key
snapshot.path_stack = Some(path_parts);
snapshot.container_stack = None;
snapshot.tensor_id = None;
snapshot
})
.collect();
// Apply remapping (but NOT filtering - that's done at apply time)
snapshots = self.apply_remapping(snapshots);
// Apply contiguous index mapping if enabled
// This must be done after remapping so that remapped paths are mapped
if self.map_indices_contiguous {
let (mapped, _) = map_indices_contiguous(snapshots);
snapshots = mapped;
}
// Build cache as BTreeMap
let cache: BTreeMap<String, TensorSnapshot> =
snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
self.snapshots_cache = Some(cache);
Ok(())
}
}

View File

@@ -0,0 +1,2 @@
pub mod reader;
pub mod store;

View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# /// script
# dependencies = ["torch"]
# ///
"""Create a legacy format PyTorch file with specific storage offsets to test offset handling."""
import torch
# Create tensors with known values at specific storage offsets
# This will help us verify we're reading from the correct location
# Create a state dict with tensors that share storage
# This is common in PyTorch models (e.g., weight and transposed weight views)
state_dict = {}
# Create a base tensor with known pattern
base_data = torch.arange(100, dtype=torch.float32)
# tensor1: uses elements 10-19 (offset 10*4 = 40 bytes)
tensor1 = base_data[10:20].clone()
tensor1[:] = torch.arange(1.0, 1.1, 0.01)[:10] # 1.00, 1.01, 1.02, ...
# tensor2: uses elements 30-35 (offset 30*4 = 120 bytes)
tensor2 = base_data[30:35].clone()
tensor2[:] = torch.arange(2.0, 2.5, 0.1)[:5] # 2.0, 2.1, 2.2, 2.3, 2.4
# tensor3: starts at beginning (offset 0)
tensor3 = base_data[:5].clone()
tensor3[:] = torch.arange(3.0, 3.5, 0.1)[:5] # 3.0, 3.1, 3.2, 3.3, 3.4
state_dict['tensor1'] = tensor1
state_dict['tensor2'] = tensor2
state_dict['tensor3'] = tensor3
# Save in legacy format
output_file = 'test_data/legacy_with_offsets.pt'
torch.save(state_dict, output_file, _use_new_zipfile_serialization=False)
print(f"Created {output_file}")
# Verify by loading
loaded = torch.load(output_file, weights_only=False)
print("\nVerification - expected values:")
for key, tensor in loaded.items():
print(f" {key}: {tensor.tolist()}")
print(f" Storage offset: {tensor.storage_offset()}")
print(f" Storage size: {len(tensor.storage())}")
# Also create a test with multiple tensors sharing the same storage
# This is important for proper offset handling
shared_storage = torch.randn(1000)
# Create views into the same storage at different offsets
view1 = shared_storage[100:110] # offset 100
view2 = shared_storage[500:520] # offset 500
view3 = shared_storage[0:10] # offset 0
# Need to save these properly - PyTorch will handle the storage sharing
shared_dict = {
'view1': view1.clone(), # Clone to avoid view issues
'view2': view2.clone(),
'view3': view3.clone(),
}
output_file2 = 'test_data/legacy_shared_storage.pt'
torch.save(shared_dict, output_file2, _use_new_zipfile_serialization=False)
print(f"\nCreated {output_file2}")
# Print exact values for test verification
print("\nExact test values for legacy_with_offsets.pt:")
print("tensor1 (10 elements starting at 1.0):")
print(" First 3 values: [1.00, 1.01, 1.02]")
print("tensor2 (5 elements starting at 2.0):")
print(" All values: [2.0, 2.1, 2.2, 2.3, 2.4]")
print("tensor3 (5 elements starting at 3.0):")
print(" All values: [3.0, 3.1, 3.2, 3.3, 3.4]")

View File

@@ -0,0 +1,361 @@
#!/usr/bin/env python3
"""
Create TAR format test fixtures for burn-store integration tests.
The TAR format was used by very early versions of PyTorch (pre 0.1.10).
Modern torch.save cannot create this format, so we construct it manually.
TAR format structure:
- sys_info: pickle with {protocol_version, little_endian, type_sizes}
- pickle: pickle with OrderedDict containing _rebuild_tensor_v2 REDUCE calls
- storages: count_pickle + for each storage: (key, device, class) pickle + u64 num_elements + raw data
"""
import io
import pickle
import struct
import tarfile
import os
from collections import OrderedDict
def create_sys_info():
"""Create sys_info pickle data."""
sys_info = {
"protocol_version": 1000,
"little_endian": True,
"type_sizes": {
"short": 2,
"int": 4,
"long": 8,
},
}
return pickle.dumps(sys_info, protocol=2)
def encode_tensor_data(values: list, storage_type: str) -> tuple:
"""Encode tensor values to bytes and return (bytes, element_size)."""
fmt_map = {
"FloatStorage": ("<f", 4),
"DoubleStorage": ("<d", 8),
"LongStorage": ("<q", 8),
"IntStorage": ("<i", 4),
"ShortStorage": ("<h", 2),
"ByteStorage": ("<B", 1),
"CharStorage": ("<b", 1),
"BoolStorage": ("<B", 1),
"HalfStorage": ("<e", 2),
}
fmt, size = fmt_map[storage_type]
data = b"".join(struct.pack(fmt, v) for v in values)
return data, size
def write_int(buffer, value):
"""Write an integer using appropriate pickle opcode."""
if 0 <= value < 256:
buffer.write(b'K') # BININT1
buffer.write(bytes([value]))
elif 0 <= value < 65536:
buffer.write(b'M') # BININT2
buffer.write(struct.pack('<H', value))
else:
buffer.write(b'J') # BININT
buffer.write(struct.pack('<i', value))
def write_string(buffer, s):
"""Write a string using appropriate pickle opcode."""
s_bytes = s.encode('utf-8')
if len(s_bytes) < 256:
buffer.write(b'U') # SHORT_BINSTRING
buffer.write(bytes([len(s_bytes)]))
buffer.write(s_bytes)
else:
buffer.write(b'T') # BINSTRING
buffer.write(struct.pack('<I', len(s_bytes)))
buffer.write(s_bytes)
def create_storages_blob_manual(tensors: list) -> bytes:
"""
Create the storages binary blob manually.
Args:
tensors: List of (key, storage_type, element_size, data_bytes) tuples
"""
buffer = io.BytesIO()
# Write storage count as pickle (simple integer)
pickle.dump(len(tensors), buffer, protocol=2)
for key, storage_type, element_size, data_bytes in tensors:
# Manually construct the tuple pickle with GLOBAL class reference
# Format: (key, "cpu", <class 'torch.FloatStorage'>)
tuple_buffer = io.BytesIO()
# Protocol 2 header
tuple_buffer.write(b'\x80\x02')
# Build tuple with MARK + items + TUPLE
tuple_buffer.write(b'(') # MARK
# First item: storage key (string)
write_string(tuple_buffer, key)
# Second item: device "cpu"
tuple_buffer.write(b'U\x03cpu')
# Third item: class reference using GLOBAL
tuple_buffer.write(b'c') # GLOBAL opcode
tuple_buffer.write(b'torch\n') # module
tuple_buffer.write(storage_type.encode('ascii') + b'\n') # name
# End tuple
tuple_buffer.write(b't') # TUPLE
tuple_buffer.write(b'.') # STOP
buffer.write(tuple_buffer.getvalue())
# Write num_elements as u64 little-endian
num_elements = len(data_bytes) // element_size
buffer.write(struct.pack("<Q", num_elements))
# Write raw data
buffer.write(data_bytes)
return buffer.getvalue()
def create_main_pickle_manual(tensors_info: list) -> bytes:
"""
Create the main pickle containing _rebuild_tensor_v2 REDUCE calls.
For each tensor, we need:
- GLOBAL torch._utils _rebuild_tensor_v2
- MARK
- args tuple: (persistent_id, offset, shape, stride, requires_grad, hooks)
- TUPLE
- REDUCE
The persistent_id is a PersistentTuple: ('storage', <class>, key, device, num_elements)
"""
buffer = io.BytesIO()
# Protocol 2 header
buffer.write(b'\x80\x02')
# Build OrderedDict: GLOBAL + EMPTY_LIST + items + TUPLE + REDUCE
# OrderedDict([('name1', tensor1), ('name2', tensor2)])
# GLOBAL collections OrderedDict
buffer.write(b'ccollections\nOrderedDict\n')
# Start list for items
buffer.write(b'(') # MARK
buffer.write(b']') # EMPTY_LIST
# For each tensor, add (name, rebuilt_tensor) to the list
for name, storage_key, storage_type, shape, num_elements in tensors_info:
# Calculate stride for row-major (C) order
stride = []
s = 1
for dim in reversed(shape):
stride.insert(0, s)
s *= dim
# Build inner tuple: (name, tensor_value)
buffer.write(b'(') # MARK for (name, value) tuple
# Write name
write_string(buffer, name)
# Now build the tensor using _rebuild_tensor_v2 REDUCE
# GLOBAL torch._utils _rebuild_tensor_v2
buffer.write(b'ctorch._utils\n_rebuild_tensor_v2\n')
# Build args tuple for _rebuild_tensor_v2
# (persistent_id, offset, shape, stride, requires_grad, backward_hooks)
buffer.write(b'(') # MARK for args tuple
# arg 0: persistent_id tuple: ('storage', class, key, device, num_elements)
# This will be converted to PersistentTuple by the reader
buffer.write(b'(') # MARK for persistent_id
write_string(buffer, 'storage')
# Class reference - GLOBAL torch FloatStorage
buffer.write(b'c')
buffer.write(b'torch\n')
buffer.write(storage_type.encode('ascii') + b'\n')
# Storage key
write_string(buffer, storage_key)
# Device
buffer.write(b'U\x03cpu')
# num_elements
write_int(buffer, num_elements)
buffer.write(b't') # TUPLE - end persistent_id
# arg 1: storage offset (0)
buffer.write(b'K\x00')
# arg 2: shape tuple
buffer.write(b'(')
for dim in shape:
write_int(buffer, dim)
buffer.write(b't')
# arg 3: stride tuple
buffer.write(b'(')
for s_val in stride:
write_int(buffer, s_val)
buffer.write(b't')
# arg 4: requires_grad (False)
buffer.write(b'\x89') # NEWFALSE
# arg 5: backward_hooks (empty OrderedDict)
buffer.write(b'ccollections\nOrderedDict\n')
buffer.write(b'(')
buffer.write(b']')
buffer.write(b't')
buffer.write(b'R') # REDUCE to create empty OrderedDict
buffer.write(b't') # TUPLE - end args tuple
buffer.write(b'R') # REDUCE - call _rebuild_tensor_v2 with args
buffer.write(b't') # TUPLE - end (name, tensor) tuple
buffer.write(b'a') # APPEND to list
buffer.write(b't') # TUPLE - wrap list in tuple for REDUCE
buffer.write(b'R') # REDUCE - call OrderedDict with the list
buffer.write(b'.') # STOP
return buffer.getvalue()
def create_tar_pytorch_file(filename: str, tensors: dict, dtypes: dict):
"""
Create a TAR format PyTorch file.
Args:
filename: Output file path
tensors: Dict of tensor_name -> (values_list, shape)
dtypes: Dict of tensor_name -> storage_type
"""
# Prepare storage data
storage_list = [] # (key, storage_type, element_size, data_bytes)
tensors_info = [] # (name, storage_key, storage_type, shape, num_elements)
for idx, (name, (values, shape)) in enumerate(tensors.items()):
storage_key = str(idx)
storage_type = dtypes[name]
data_bytes, element_size = encode_tensor_data(values, storage_type)
num_elements = len(values)
storage_list.append((storage_key, storage_type, element_size, data_bytes))
tensors_info.append((name, storage_key, storage_type, shape, num_elements))
# Create the three main entries
sys_info_data = create_sys_info()
pickle_data = create_main_pickle_manual(tensors_info)
storages_data = create_storages_blob_manual(storage_list)
# Write TAR archive
os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)
with tarfile.open(filename, "w") as tar:
# Add sys_info
tarinfo = tarfile.TarInfo(name="sys_info")
tarinfo.size = len(sys_info_data)
tar.addfile(tarinfo, io.BytesIO(sys_info_data))
# Add pickle
tarinfo = tarfile.TarInfo(name="pickle")
tarinfo.size = len(pickle_data)
tar.addfile(tarinfo, io.BytesIO(pickle_data))
# Add storages
tarinfo = tarfile.TarInfo(name="storages")
tarinfo.size = len(storages_data)
tar.addfile(tarinfo, io.BytesIO(storages_data))
size = os.path.getsize(filename)
print(f"Created {filename} ({size} bytes)")
print(f" Tensors: {list(tensors.keys())}")
def main():
# Create test_data directory
os.makedirs("test_data", exist_ok=True)
# Test 1: Single float32 tensor
create_tar_pytorch_file(
"test_data/tar_float32.tar",
{"tensor": ([1.0, 2.5, -3.7, 0.0], [4])},
{"tensor": "FloatStorage"},
)
# Test 2: Single float64 tensor
create_tar_pytorch_file(
"test_data/tar_float64.tar",
{"tensor": ([1.1, 2.2, 3.3], [3])},
{"tensor": "DoubleStorage"},
)
# Test 3: Single int64 tensor
create_tar_pytorch_file(
"test_data/tar_int64.tar",
{"tensor": ([100, -200, 300, 0], [4])},
{"tensor": "LongStorage"},
)
# Test 4: Multiple tensors (weight + bias)
create_tar_pytorch_file(
"test_data/tar_weight_bias.tar",
{
"weight": ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [2, 3]),
"bias": ([0.01, 0.02], [2]),
},
{
"weight": "FloatStorage",
"bias": "FloatStorage",
},
)
# Test 5: Different dtypes in one file
create_tar_pytorch_file(
"test_data/tar_multi_dtype.tar",
{
"float_tensor": ([1.5, 2.5, 3.5], [3]),
"double_tensor": ([1.111, 2.222], [2]),
"int_tensor": ([10, 20, 30, 40], [4]),
},
{
"float_tensor": "FloatStorage",
"double_tensor": "DoubleStorage",
"int_tensor": "LongStorage",
},
)
# Test 6: 2D tensor for shape verification
create_tar_pytorch_file(
"test_data/tar_2d_tensor.tar",
{
"matrix": ([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [3, 4]),
},
{"matrix": "FloatStorage"},
)
print("\nAll TAR format test files created!")
print("\nTo run tests: cargo test -p burn-store --features pytorch test_tar")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python3
# /// script
# dependencies = ["torch"]
# ///
"""Create a simple legacy format PyTorch file."""
import torch
# Create a simple state dict
state_dict = {
'weight': torch.randn(2, 3),
'bias': torch.ones(2),
'running_mean': torch.zeros(2),
}
# Save without using zip format (legacy format)
torch.save(state_dict, 'test_data/simple_legacy.pt', _use_new_zipfile_serialization=False)
print("Created simple_legacy.pt")
# Verify
loaded = torch.load('test_data/simple_legacy.pt', weights_only=False)
print(f"Loaded {len(loaded)} tensors")
for key, val in loaded.items():
print(f" {key}: shape {val.shape}, dtype {val.dtype}")

View File

@@ -0,0 +1,227 @@
#!/usr/bin/env python3
# /// script
# dependencies = ["torch", "numpy"]
# ///
"""
Generate test PyTorch .pt files for testing the burn-store PyTorch reader.
Run with: uv run test_files.py
"""
import torch
import numpy as np
import os
from pathlib import Path
# Create test directory
test_dir = Path(__file__).parent / "test_data"
test_dir.mkdir(exist_ok=True)
def save_test_file(filename, data, description):
"""Save a test file and print what was saved."""
filepath = test_dir / filename
torch.save(data, filepath)
print(f"{filename}: {description}")
return filepath
# Test 1: Simple tensors of different types
print("\n=== Generating Basic Tensor Tests ===")
# Float32 tensor (wrap in dict for compatibility)
float32_tensor = torch.tensor([1.0, 2.5, -3.7, 0.0], dtype=torch.float32)
save_test_file("float32.pt", {"tensor": float32_tensor}, "Float32 tensor [1.0, 2.5, -3.7, 0.0]")
# Float64 tensor
float64_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float64)
save_test_file("float64.pt", {"tensor": float64_tensor}, "Float64 tensor [1.1, 2.2, 3.3]")
# Int64 tensor
int64_tensor = torch.tensor([100, -200, 300, 0], dtype=torch.int64)
save_test_file("int64.pt", {"tensor": int64_tensor}, "Int64 tensor [100, -200, 300, 0]")
# Int32 tensor
int32_tensor = torch.tensor([10, 20, -30], dtype=torch.int32)
save_test_file("int32.pt", {"tensor": int32_tensor}, "Int32 tensor [10, 20, -30]")
# Int16 tensor
int16_tensor = torch.tensor([1000, -2000, 3000], dtype=torch.int16)
save_test_file("int16.pt", {"tensor": int16_tensor}, "Int16 tensor [1000, -2000, 3000]")
# Int8 tensor
int8_tensor = torch.tensor([127, -128, 0, 50], dtype=torch.int8)
save_test_file("int8.pt", {"tensor": int8_tensor}, "Int8 tensor [127, -128, 0, 50]")
# Boolean tensor
bool_tensor = torch.tensor([True, False, True, True, False], dtype=torch.bool)
save_test_file("bool.pt", {"tensor": bool_tensor}, "Bool tensor [True, False, True, True, False]")
# Float16 tensor (half precision)
float16_tensor = torch.tensor([1.5, -2.25, 3.125], dtype=torch.float16)
save_test_file("float16.pt", {"tensor": float16_tensor}, "Float16 tensor [1.5, -2.25, 3.125]")
# BFloat16 tensor
bfloat16_tensor = torch.tensor([1.5, -2.5, 3.5], dtype=torch.bfloat16)
save_test_file("bfloat16.pt", {"tensor": bfloat16_tensor}, "BFloat16 tensor [1.5, -2.5, 3.5]")
# UInt8 tensor
uint8_tensor = torch.tensor([0, 128, 255, 42], dtype=torch.uint8)
save_test_file("uint8.pt", {"tensor": uint8_tensor}, "UInt8 tensor [0, 128, 255, 42]")
# Test 2: Multi-dimensional tensors
print("\n=== Generating Multi-dimensional Tensor Tests ===")
# 2D tensor
tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
save_test_file("tensor_2d.pt", {"tensor": tensor_2d}, "2D tensor shape (3, 2)")
# 3D tensor
torch.manual_seed(42)
tensor_3d = torch.randn(2, 3, 4) * 10
save_test_file("tensor_3d.pt", {"tensor": tensor_3d}, "3D tensor shape (2, 3, 4)")
# 4D tensor (common for conv weights)
tensor_4d = torch.randn(2, 3, 2, 2)
save_test_file("tensor_4d.pt", {"tensor": tensor_4d}, "4D tensor shape (2, 3, 2, 2)")
# Test 3: State dict (multiple tensors)
print("\n=== Generating State Dict Tests ===")
state_dict = {
"weight": torch.randn(3, 4),
"bias": torch.randn(3),
"running_mean": torch.zeros(3),
"running_var": torch.ones(3),
}
save_test_file("state_dict.pt", state_dict, "State dict with 4 tensors")
# Nested state dict
nested_dict = {
"layer1": {
"weight": torch.randn(2, 3),
"bias": torch.randn(2)
},
"layer2": {
"weight": torch.randn(4, 2),
"bias": torch.randn(4)
}
}
save_test_file("nested_dict.pt", nested_dict, "Nested state dict")
# Test 4: Model checkpoint format
print("\n=== Generating Model Checkpoint Tests ===")
# Typical checkpoint format (use string keys for compatibility)
checkpoint = {
"model_state_dict": {
"fc1.weight": torch.randn(10, 5),
"fc1.bias": torch.randn(10),
"fc2.weight": torch.randn(3, 10),
"fc2.bias": torch.randn(3),
},
"optimizer_state_dict": {
"state": {
"0": { # Use string key instead of integer
"momentum_buffer": torch.randn(10, 5)
}
}
},
"epoch": 42,
"loss": 0.123
}
save_test_file("checkpoint.pt", checkpoint, "Full checkpoint with model and optimizer state")
# Test 5: Edge cases
print("\n=== Generating Edge Case Tests ===")
# Empty tensor (1D with 0 elements)
empty_tensor = torch.zeros(0)
save_test_file("empty.pt", {"tensor": empty_tensor}, "Empty tensor")
# Scalar tensor (0-dimensional)
scalar_tensor = torch.tensor(42.0)
save_test_file("scalar.pt", {"tensor": scalar_tensor}, "Scalar tensor (0-dim)")
# Large shape but small data (testing shape vs actual data)
sparse_like = torch.zeros(100, 100)
sparse_like[0, 0] = 1.0
sparse_like[50, 50] = 2.0
sparse_like[99, 99] = 3.0
save_test_file("large_shape.pt", {"tensor": sparse_like}, "Large shape (100, 100) mostly zeros")
# Test 6: Mixed types in dict
print("\n=== Generating Mixed Type Tests ===")
mixed_types = {
"float32": torch.tensor([1.0, 2.0], dtype=torch.float32),
"int64": torch.tensor([100, 200], dtype=torch.int64),
"bool": torch.tensor([True, False], dtype=torch.bool),
"float64": torch.tensor([1.1, 2.2], dtype=torch.float64),
}
save_test_file("mixed_types.pt", mixed_types, "Dict with mixed tensor types")
# Test 7: Special values
print("\n=== Generating Special Value Tests ===")
# NaN and Inf values
special_values = torch.tensor([float('nan'), float('inf'), float('-inf'), 0.0, 1.0])
save_test_file("special_values.pt", {"tensor": special_values}, "Tensor with NaN and Inf")
# Very small and very large values
extreme_values = torch.tensor([1e-30, 1e30, -1e-30, -1e30], dtype=torch.float32)
save_test_file("extreme_values.pt", {"tensor": extreme_values}, "Tensor with extreme values")
# Test 8: Parameter wrapper (common in models)
print("\n=== Generating Parameter Tests ===")
import torch.nn as nn
param = nn.Parameter(torch.randn(3, 3))
param_dict = {"param": param}
save_test_file("parameter.pt", param_dict, "nn.Parameter wrapped tensor")
# Test 9: Buffer-style tensors
print("\n=== Generating Buffer Tests ===")
# Simulate model buffers
buffers = {
"buffer1": torch.tensor([1, 2, 3], dtype=torch.int32),
"buffer2": torch.tensor([True, False], dtype=torch.bool),
}
save_test_file("buffers.pt", buffers, "Model buffers")
# Test 10: Complex nested structure
print("\n=== Generating Complex Structure Tests ===")
complex_structure = {
"metadata": {
"version": 1,
"name": "test_model"
},
"state": {
"encoder": {
"layer_0": {
"weight": torch.randn(4, 3),
"bias": torch.randn(4)
},
"layer_1": {
"weight": torch.randn(2, 4),
"bias": torch.randn(2)
}
},
"decoder": {
"weight": torch.randn(3, 2),
"bias": torch.randn(3)
}
},
"config": {
"hidden_size": 4,
"num_layers": 2
}
}
save_test_file("complex_structure.pt", complex_structure, "Complex nested structure")
print(f"\n✅ Generated {len(list(test_dir.glob('*.pt')))} test files in {test_dir}")
print("\nTest files can be used to verify PyTorch reader functionality:")
print("- Different data types (float32, int64, bool, etc.)")
print("- Multi-dimensional tensors")
print("- State dicts and nested structures")
print("- Edge cases (empty, scalar, special values)")
print("- Model checkpoints and parameters")

Some files were not shown because too many files have changed in this diff Show More