feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
106
crates/stable-diffusion-burn/burn-crates/burn-store/Cargo.toml
Normal file
106
crates/stable-diffusion-burn/burn-crates/burn-store/Cargo.toml
Normal file
@@ -0,0 +1,106 @@
|
||||
[package]
|
||||
authors = ["Dilshod Tadjibaev (@antimora)"]
|
||||
categories = ["science", "no-std", "embedded", "wasm"]
|
||||
description = "Storage and serialization infrastructure for Burn"
|
||||
documentation = "https://docs.rs/burn-store"
|
||||
edition.workspace = true
|
||||
keywords = [
|
||||
"deep-learning",
|
||||
"machine-learning",
|
||||
"tensor",
|
||||
"storage",
|
||||
"serialization",
|
||||
]
|
||||
license.workspace = true
|
||||
name = "burn-store"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-store"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std", "pytorch", "safetensors", "burnpack", "memmap"]
|
||||
memmap = ["std", "dep:memmap2"]
|
||||
std = [
|
||||
"dep:memmap2",
|
||||
"safetensors/std",
|
||||
"burn-core/std",
|
||||
"burn-tensor/std",
|
||||
"dep:regex",
|
||||
"byteorder/std",
|
||||
]
|
||||
tracing = [
|
||||
"burn-core/tracing",
|
||||
"burn-cuda?/tracing",
|
||||
"burn-nn/tracing",
|
||||
"burn-tch?/tracing",
|
||||
"burn-tensor/tracing",
|
||||
"burn-wgpu?/tracing",
|
||||
]
|
||||
|
||||
|
||||
burnpack = ["serde", "ciborium"]
|
||||
cuda = ["burn-cuda"]
|
||||
metal = ["wgpu", "burn-wgpu/metal"]
|
||||
tch = ["burn-tch"]
|
||||
wgpu = ["burn-wgpu"]
|
||||
|
||||
safetensors = ["dep:safetensors"]
|
||||
|
||||
pytorch = ["burn-core/record-item-custom-serde", "zip", "serde", "tar"]
|
||||
|
||||
[dependencies]
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false }
|
||||
|
||||
# External dependencies
|
||||
byteorder = { workspace = true, default-features = false }
|
||||
bytes = { workspace = true }
|
||||
ciborium = { workspace = true, optional = true }
|
||||
half = { workspace = true }
|
||||
hashbrown = { workspace = true, features = ["serde"] }
|
||||
memmap2 = { workspace = true, optional = true }
|
||||
regex = { workspace = true, optional = true }
|
||||
serde = { workspace = true, optional = true }
|
||||
textdistance = { workspace = true }
|
||||
zip = { workspace = true, optional = true }
|
||||
tar = { workspace = true, optional = true }
|
||||
|
||||
# Workaround to force broken minor version to update
|
||||
lzma-rust2 = { workspace = true, optional = true }
|
||||
|
||||
safetensors = { workspace = true, optional = true }
|
||||
|
||||
# Optional backend dependencies for benchmarks
|
||||
burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
# burn-import = { path = "../burn-import", version = "=0.21.0-pre.2" } # disabled (circular dep in publish, only for bench)
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", default-features = false }
|
||||
divan = "0.1"
|
||||
tempfile = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "resnet18_loading"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "unified_loading"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "unified_saving"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "zero_copy_loading"
|
||||
|
||||
# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
bytes = { workspace = true, features = ["extra-platforms"] }
|
||||
325
crates/stable-diffusion-burn/burn-crates/burn-store/MIGRATION.md
Normal file
325
crates/stable-diffusion-burn/burn-crates/burn-store/MIGRATION.md
Normal file
@@ -0,0 +1,325 @@
|
||||
# Migration Guide: burn-import to burn-store
|
||||
|
||||
This guide helps you migrate from the deprecated `burn-import` recorders (`PyTorchFileRecorder`,
|
||||
`SafetensorsFileRecorder`) to the new `burn-store` API (`PytorchStore`, `SafetensorsStore`).
|
||||
|
||||
## Overview
|
||||
|
||||
The new `burn-store` API provides:
|
||||
|
||||
- **Simpler API**: Load directly into models instead of records
|
||||
- **Fluent builder pattern**: Chain configuration methods
|
||||
- **Better error handling**: Detailed load results with applied/missing/errors info
|
||||
- **Bidirectional support**: Both load and save operations
|
||||
- **More features**: Filtering, partial loading, metadata, zero-copy loading
|
||||
|
||||
## Quick Migration
|
||||
|
||||
### PyTorch Files (.pt/.pth)
|
||||
|
||||
**Before (burn-import):**
|
||||
|
||||
```rust
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
// Load into a record, then create model from record
|
||||
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("model.pt".into(), &device)
|
||||
.expect("Failed to load");
|
||||
|
||||
let model = Model::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
**After (burn-store):**
|
||||
|
||||
```rust
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
// Initialize model, then load weights directly
|
||||
let mut model = Model::init(&device);
|
||||
let mut store = PytorchStore::from_file("model.pt");
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
```
|
||||
|
||||
### SafeTensors Files (.safetensors)
|
||||
|
||||
**Before (burn-import):**
|
||||
|
||||
```rust
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};
|
||||
|
||||
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("model.safetensors".into(), &device)
|
||||
.expect("Failed to load");
|
||||
|
||||
let model = Model::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
**After (burn-store):**
|
||||
|
||||
```rust
|
||||
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
|
||||
|
||||
let mut model = Model::init(&device);
|
||||
|
||||
// For SafeTensors exported from PyTorch, use the adapter
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
|
||||
// For native Burn SafeTensors, no adapter needed
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors");
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
```
|
||||
|
||||
## API Mapping
|
||||
|
||||
### PyTorchFileRecorder Options
|
||||
|
||||
| burn-import | burn-store |
|
||||
| ---------------------------------------------- | ------------------------------------------- |
|
||||
| `LoadArgs::new(path)` | `PytorchStore::from_file(path)` |
|
||||
| `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` |
|
||||
| `.with_top_level_key(key)` | `.with_top_level_key(key)` |
|
||||
| `.with_debug_print()` | _(use tracing/logging instead)_ |
|
||||
| `PyTorchFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_ |
|
||||
|
||||
### SafetensorsFileRecorder Options
|
||||
|
||||
| burn-import | burn-store |
|
||||
| -------------------------------------------------- | ------------------------------------------- |
|
||||
| `LoadArgs::new(path)` | `SafetensorsStore::from_file(path)` |
|
||||
| `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` |
|
||||
| `.with_adapter_type(AdapterType::PyTorch)` | `.with_from_adapter(PyTorchToBurnAdapter)` |
|
||||
| `.with_adapter_type(AdapterType::NoAdapter)` | _(default, no adapter)_ |
|
||||
| `.with_debug_print()` | _(use tracing/logging instead)_ |
|
||||
| `SafetensorsFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_ |
|
||||
|
||||
## Detailed Examples
|
||||
|
||||
### Key Remapping
|
||||
|
||||
**Before:**
|
||||
|
||||
```rust
|
||||
let args = LoadArgs::new("model.pt".into())
|
||||
.with_key_remap("conv\\.(.*)", "$1")
|
||||
.with_key_remap("^old_prefix\\.", "new_prefix.");
|
||||
|
||||
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(args, &device)?;
|
||||
```
|
||||
|
||||
**After:**
|
||||
|
||||
```rust
|
||||
let mut store = PytorchStore::from_file("model.pt")
|
||||
.with_key_remapping("conv\\.(.*)", "$1")
|
||||
.with_key_remapping("^old_prefix\\.", "new_prefix.");
|
||||
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
### Top-Level Key Access
|
||||
|
||||
**Before:**
|
||||
|
||||
```rust
|
||||
let args = LoadArgs::new("checkpoint.pt".into())
|
||||
.with_top_level_key("state_dict");
|
||||
|
||||
let record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(args, &device)?;
|
||||
```
|
||||
|
||||
**After:**
|
||||
|
||||
```rust
|
||||
let mut store = PytorchStore::from_file("checkpoint.pt")
|
||||
.with_top_level_key("state_dict");
|
||||
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
### PyTorch Adapter for SafeTensors
|
||||
|
||||
**Before:**
|
||||
|
||||
```rust
|
||||
use burn_import::safetensors::{AdapterType, LoadArgs};
|
||||
|
||||
let args = LoadArgs::new("pytorch_model.safetensors".into())
|
||||
.with_adapter_type(AdapterType::PyTorch);
|
||||
|
||||
let record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(args, &device)?;
|
||||
```
|
||||
|
||||
**After:**
|
||||
|
||||
```rust
|
||||
use burn_store::{PyTorchToBurnAdapter, SafetensorsStore};
|
||||
|
||||
let mut store = SafetensorsStore::from_file("pytorch_model.safetensors")
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
## New Features in burn-store
|
||||
|
||||
### Partial Loading
|
||||
|
||||
Handle missing tensors gracefully:
|
||||
|
||||
```rust
|
||||
let mut store = PytorchStore::from_file("model.pt")
|
||||
.allow_partial(true);
|
||||
|
||||
let result = model.load_from(&mut store)?;
|
||||
println!("Loaded: {:?}", result.applied);
|
||||
println!("Missing: {:?}", result.missing);
|
||||
```
|
||||
|
||||
### Filtering
|
||||
|
||||
Load only specific tensors:
|
||||
|
||||
```rust
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_regex(r"^encoder\..*") // Only encoder layers
|
||||
.allow_partial(true);
|
||||
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
### Saving Models
|
||||
|
||||
Save models (not supported by old recorders):
|
||||
|
||||
```rust
|
||||
// Save to SafeTensors
|
||||
let mut store = SafetensorsStore::from_file("output.safetensors")
|
||||
.metadata("version", "1.0");
|
||||
model.save_into(&mut store)?;
|
||||
|
||||
// Save to Burnpack (native format)
|
||||
let mut store = BurnpackStore::from_file("output.bpk");
|
||||
model.save_into(&mut store)?;
|
||||
```
|
||||
|
||||
### Load Results
|
||||
|
||||
Get detailed information about loading:
|
||||
|
||||
```rust
|
||||
let result = model.load_from(&mut store)?;
|
||||
|
||||
// Print the full result for debugging - shows applied, skipped, missing, and errors
|
||||
println!("{}", result);
|
||||
|
||||
// Or access individual fields
|
||||
println!("Applied: {} tensors", result.applied.len());
|
||||
println!("Skipped: {} tensors", result.skipped.len());
|
||||
println!("Missing: {:?}", result.missing);
|
||||
println!("Errors: {:?}", result.errors);
|
||||
|
||||
// Check if fully successful
|
||||
if result.is_success() {
|
||||
println!("All tensors loaded successfully");
|
||||
}
|
||||
```
|
||||
|
||||
The `LoadResult` implements `Display`, so printing it shows a formatted summary with suggestions for
|
||||
common issues (e.g., using `allow_partial(true)` for missing tensors).
|
||||
|
||||
## Updating Cargo.toml
|
||||
|
||||
**Before:**
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
burn-import = { version = "0.x", features = ["pytorch", "safetensors"] }
|
||||
```
|
||||
|
||||
**After:**
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
burn-store = { version = "0.x", features = ["pytorch", "safetensors"] }
|
||||
```
|
||||
|
||||
## Common Migration Issues
|
||||
|
||||
### 1. Model vs Record
|
||||
|
||||
The new API loads directly into models, not records. Update your model initialization:
|
||||
|
||||
```rust
|
||||
// Before: Create record, then model from record
|
||||
let record = recorder.load(...)?;
|
||||
let model = Model::init(&device).load_record(record);
|
||||
|
||||
// After: Create model, then load into it
|
||||
let mut model = Model::init(&device);
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
### 2. Inference Functions
|
||||
|
||||
If you had functions that took `ModelRecord`, update them to take `Model`:
|
||||
|
||||
```rust
|
||||
// Before
|
||||
fn infer(record: ModelRecord<B>) {
|
||||
let model = Model::init(&device).load_record(record);
|
||||
// ...
|
||||
}
|
||||
|
||||
// After
|
||||
fn infer(model: Model<B>) {
|
||||
// Model already has weights loaded
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Precision Settings
|
||||
|
||||
The old API required explicit precision settings. The new API handles this automatically:
|
||||
|
||||
```rust
|
||||
// Before: Had to specify FullPrecisionSettings or HalfPrecisionSettings
|
||||
PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
|
||||
// After: Precision handled automatically based on tensor dtype
|
||||
PytorchStore::from_file("model.pt")
|
||||
```
|
||||
|
||||
### 4. Error Handling
|
||||
|
||||
The new API provides richer error information:
|
||||
|
||||
```rust
|
||||
// Before: Simple Result
|
||||
let record = recorder.load(args, &device)?;
|
||||
|
||||
// After: LoadResult with detailed info
|
||||
let result = model.load_from(&mut store)?;
|
||||
|
||||
// Print the result to see a helpful summary with suggestions
|
||||
println!("{}", result);
|
||||
|
||||
// Or handle specific issues programmatically
|
||||
if !result.errors.is_empty() {
|
||||
for (path, error) in &result.errors {
|
||||
eprintln!("Error loading {}: {}", path, error);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [burn-store README](README.md) - Full documentation
|
||||
- [import-model-weights example](../../examples/import-model-weights/) - Working example
|
||||
@@ -0,0 +1,77 @@
|
||||
# Burn Store
|
||||
|
||||
> Advanced model storage and serialization for the Burn deep learning framework
|
||||
|
||||
[](https://crates.io/crates/burn-store)
|
||||
[](https://docs.rs/burn-store)
|
||||
|
||||
A comprehensive storage library for Burn that enables efficient model serialization, cross-framework
|
||||
interoperability, and advanced tensor management.
|
||||
|
||||
> **Migrating from burn-import?** See the [Migration Guide](MIGRATION.md) for help moving from
|
||||
> `PyTorchFileRecorder`/`SafetensorsFileRecorder` to the new Store API.
|
||||
|
||||
## Features
|
||||
|
||||
- **Burnpack Format** - Native Burn format with CBOR metadata, memory-mapped loading, ParamId
|
||||
persistence for stateful training, and no-std support
|
||||
- **SafeTensors Format** - Industry-standard format for secure and efficient tensor serialization
|
||||
- **PyTorch Support** - Direct loading of PyTorch .pth/.pt files with automatic weight
|
||||
transformation
|
||||
- **Zero-Copy Loading** - Memory-mapped files and lazy tensor materialization for optimal
|
||||
performance
|
||||
- **Flexible Filtering** - Load/save specific model subsets with regex, exact paths, or custom
|
||||
predicates
|
||||
- **Tensor Remapping** - Rename tensors during load/save for framework compatibility
|
||||
- **No-std Support** - Burnpack and SafeTensors formats available in embedded and WASM environments
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use burn_store::{ModuleSnapshot, PytorchStore, SafetensorsStore, BurnpackStore};
|
||||
|
||||
// Load from PyTorch
|
||||
let mut store = PytorchStore::from_file("model.pt");
|
||||
model.load_from(&mut store)?;
|
||||
|
||||
// Load from SafeTensors (with PyTorch adapter)
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
model.load_from(&mut store)?;
|
||||
|
||||
// Save to Burnpack
|
||||
let mut store = BurnpackStore::from_file("model.bpk");
|
||||
model.save_into(&mut store)?;
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For comprehensive documentation including:
|
||||
|
||||
- Exporting weights from PyTorch
|
||||
- Loading weights into Burn models
|
||||
- Saving models to various formats
|
||||
- Advanced features (filtering, remapping, partial loading, zero-copy)
|
||||
- API reference and troubleshooting
|
||||
|
||||
See the **[Burn Book - Model Weights](https://burn.dev/book/import/model-weights.html)** chapter.
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
```bash
|
||||
# Generate model files (one-time setup)
|
||||
uv run benches/generate_unified_models.py
|
||||
|
||||
# Run loading benchmarks
|
||||
cargo bench --bench unified_loading
|
||||
|
||||
# Run saving benchmarks
|
||||
cargo bench --bench unified_saving
|
||||
|
||||
# With specific backend
|
||||
cargo bench --bench unified_loading --features metal
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is dual-licensed under MIT and Apache-2.0.
|
||||
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.8"
|
||||
# dependencies = [
|
||||
# "torch",
|
||||
# "torchvision",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Download ResNet18 PyTorch model for benchmarking.
|
||||
This script downloads a pre-trained ResNet18 model from PyTorch Hub
|
||||
and saves it in a format suitable for benchmarking.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
def download_resnet18():
|
||||
"""Download ResNet18 model and save to temp directory."""
|
||||
|
||||
# Create a temporary directory for the model
|
||||
temp_dir = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_path = temp_dir / "resnet18.pth"
|
||||
|
||||
# Check if already downloaded
|
||||
if output_path.exists():
|
||||
file_size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
print(f"✅ ResNet18 already exists at: {output_path}")
|
||||
print(f" Size: {file_size_mb:.1f} MB")
|
||||
return str(output_path)
|
||||
|
||||
print("📥 Downloading ResNet18 model...")
|
||||
|
||||
try:
|
||||
# Download pre-trained ResNet18 model
|
||||
model = models.resnet18(pretrained=True)
|
||||
|
||||
# Save the model state dict (this is what burn-store reads)
|
||||
# Using the legacy format for compatibility
|
||||
torch.save(model.state_dict(), output_path, _use_new_zipfile_serialization=False)
|
||||
|
||||
file_size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
print(f"✅ Successfully downloaded ResNet18 to: {output_path}")
|
||||
print(f" Size: {file_size_mb:.1f} MB")
|
||||
print(f" Format: PyTorch legacy format")
|
||||
|
||||
# Verify it's readable
|
||||
state_dict = torch.load(output_path, map_location='cpu')
|
||||
print(f" Tensors: {len(state_dict)} tensors")
|
||||
|
||||
# Print a few tensor names and shapes for verification
|
||||
print("\n Sample tensors:")
|
||||
for i, (name, tensor) in enumerate(state_dict.items()):
|
||||
if i < 3:
|
||||
print(f" - {name}: {list(tensor.shape)}")
|
||||
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to download ResNet18: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
path = download_resnet18()
|
||||
|
||||
# Write the path to a file that the benchmark can read
|
||||
bench_config = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark" / "path.txt"
|
||||
bench_config.write_text(path)
|
||||
|
||||
print(f"\n💡 Model ready for benchmarking")
|
||||
print(f" Run: cargo bench --bench resnet18_loading")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.8"
|
||||
# dependencies = [
|
||||
# "torch",
|
||||
# "safetensors",
|
||||
# "packaging",
|
||||
# "numpy",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Generate a large model (~312MB) in both PyTorch and SafeTensors formats for unified benchmarking.
|
||||
|
||||
Usage:
|
||||
uv run benches/generate_unified_models.py
|
||||
|
||||
The script will create model files in /tmp/simple_bench_models/ directory.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from safetensors.torch import save_file
|
||||
|
||||
def get_temp_dir():
|
||||
"""Get the appropriate temp directory."""
|
||||
temp_dir = Path(tempfile.gettempdir()) / "simple_bench_models"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
return temp_dir
|
||||
|
||||
class LargeModel(nn.Module):
|
||||
"""Large model with 20 layers to match Rust benchmark."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
# Create a model with 20 layers matching the Rust LargeModel
|
||||
for i in range(20):
|
||||
in_size = 1024 if i == 0 else 2048
|
||||
out_size = 2048
|
||||
self.layers.append(nn.Linear(in_size, out_size))
|
||||
|
||||
print(f"Created model with {len(self.layers)} layers")
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def calculate_model_size(model):
|
||||
"""Calculate the size of the model in MB."""
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
size_mb = (total_params * 4) / (1024 * 1024) # 4 bytes per float32
|
||||
return total_params, size_mb
|
||||
|
||||
def initialize_weights(model):
|
||||
"""Initialize model weights with random values."""
|
||||
for param in model.parameters():
|
||||
if param.dim() > 1:
|
||||
nn.init.xavier_uniform_(param)
|
||||
else:
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def save_pytorch_format(model, output_dir):
|
||||
"""Save model in PyTorch format."""
|
||||
pt_path = output_dir / "large_model.pt"
|
||||
|
||||
# Save as checkpoint with model_state_dict (common format)
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'metadata': {
|
||||
'model_type': 'large_benchmark_model',
|
||||
'num_layers': len(model.layers),
|
||||
}
|
||||
}
|
||||
torch.save(checkpoint, pt_path)
|
||||
|
||||
return pt_path
|
||||
|
||||
def save_safetensors_format(model, output_dir):
|
||||
"""Save model in SafeTensors format."""
|
||||
st_path = output_dir / "large_model.safetensors"
|
||||
|
||||
# Convert state dict to safetensors format
|
||||
state_dict = model.state_dict()
|
||||
# Ensure all tensors are contiguous and on CPU
|
||||
state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()}
|
||||
|
||||
# Save with metadata
|
||||
metadata = {
|
||||
'model_type': 'large_benchmark_model',
|
||||
'num_layers': str(len(model.layers)),
|
||||
}
|
||||
save_file(state_dict, st_path, metadata=metadata)
|
||||
|
||||
return st_path
|
||||
|
||||
def verify_files(pt_path, st_path):
|
||||
"""Verify the saved files can be loaded."""
|
||||
# Verify PyTorch file
|
||||
checkpoint = torch.load(pt_path, map_location='cpu')
|
||||
pt_keys = set(checkpoint['model_state_dict'].keys())
|
||||
print(f" PyTorch file: {len(pt_keys)} tensors")
|
||||
|
||||
# Verify SafeTensors file
|
||||
from safetensors import safe_open
|
||||
with safe_open(st_path, framework="pt", device="cpu") as f:
|
||||
st_keys = set(f.keys())
|
||||
print(f" SafeTensors file: {len(st_keys)} tensors")
|
||||
|
||||
# Check keys match
|
||||
if pt_keys != st_keys:
|
||||
print(" ⚠️ Warning: Keys don't match between formats!")
|
||||
else:
|
||||
print(" ✓ Keys match between formats")
|
||||
|
||||
def main():
|
||||
print("🔧 Generating unified benchmark model files...")
|
||||
print("")
|
||||
|
||||
output_dir = get_temp_dir()
|
||||
print(f"📁 Output directory: {output_dir}")
|
||||
print("")
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create the large model
|
||||
print("📝 Creating large model...")
|
||||
model = LargeModel()
|
||||
|
||||
# Calculate and display model size
|
||||
total_params, size_mb = calculate_model_size(model)
|
||||
print(f" Total parameters: {total_params:,}")
|
||||
print(f" Model size: {size_mb:.2f} MB")
|
||||
print("")
|
||||
|
||||
# Initialize weights
|
||||
print("🎲 Initializing weights...")
|
||||
initialize_weights(model)
|
||||
|
||||
# Save in PyTorch format
|
||||
print("💾 Saving PyTorch format...")
|
||||
pt_path = save_pytorch_format(model, output_dir)
|
||||
pt_size_mb = pt_path.stat().st_size / (1024 * 1024)
|
||||
print(f" Saved: {pt_path}")
|
||||
print(f" File size: {pt_size_mb:.2f} MB")
|
||||
print("")
|
||||
|
||||
# Save in SafeTensors format
|
||||
print("💾 Saving SafeTensors format...")
|
||||
st_path = save_safetensors_format(model, output_dir)
|
||||
st_size_mb = st_path.stat().st_size / (1024 * 1024)
|
||||
print(f" Saved: {st_path}")
|
||||
print(f" File size: {st_size_mb:.2f} MB")
|
||||
print("")
|
||||
|
||||
# Verify files
|
||||
print("🔍 Verifying saved files...")
|
||||
verify_files(pt_path, st_path)
|
||||
print("")
|
||||
|
||||
print(f"✅ Model files generated successfully!")
|
||||
print("")
|
||||
print("📊 Summary:")
|
||||
print(f" PyTorch file: {pt_path.name} ({pt_size_mb:.2f} MB)")
|
||||
print(f" SafeTensors file: {st_path.name} ({st_size_mb:.2f} MB)")
|
||||
print("")
|
||||
print("💡 To run the unified benchmark:")
|
||||
print(" cargo bench --bench unified_loading")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,213 @@
|
||||
//! Benchmark for ResNet18 loading to verify lazy loading memory usage.
|
||||
//!
|
||||
//! resnet18.pth is pytorch's legacy file format.
|
||||
//!
|
||||
//! This benchmark loads a ResNet18 model and materializes all tensors
|
||||
//! to ensure memory usage stays reasonable with lazy loading.
|
||||
//!
|
||||
//! Run the benchmark:
|
||||
//! ```bash
|
||||
//! cargo bench --bench resnet18_loading
|
||||
//! ```
|
||||
|
||||
use burn_store::pytorch::PytorchReader;
|
||||
use divan::{AllocProfiler, Bencher};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
fn main() {
|
||||
// Check if ResNet18 file exists
|
||||
let path = resnet18_path();
|
||||
if !path.exists() {
|
||||
eprintln!("❌ ResNet18 model not found!");
|
||||
eprintln!();
|
||||
eprintln!("Please download it first by running:");
|
||||
eprintln!(" python benches/download_resnet18.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you don't have Python/PyTorch installed:");
|
||||
eprintln!(" uv run benches/download_resnet18.py");
|
||||
eprintln!();
|
||||
eprintln!("Expected location: {}", path.display());
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Verify file size is reasonable
|
||||
let metadata = std::fs::metadata(&path).expect("Failed to read file metadata");
|
||||
let size_mb = metadata.len() as f64 / 1_048_576.0;
|
||||
|
||||
if size_mb < 40.0 || size_mb > 50.0 {
|
||||
eprintln!(
|
||||
"⚠️ Warning: ResNet18 file size ({:.1} MB) seems unusual",
|
||||
size_mb
|
||||
);
|
||||
eprintln!("Expected size is around 45 MB");
|
||||
}
|
||||
|
||||
println!("✅ Found ResNet18 model at: {}", path.display());
|
||||
println!("📦 File size: {:.1} MB", size_mb);
|
||||
println!("📊 Running ResNet18 loading benchmarks...\n");
|
||||
|
||||
// Run divan benchmarks
|
||||
divan::main();
|
||||
}
|
||||
|
||||
/// Get the path to ResNet18 model file
|
||||
fn resnet18_path() -> PathBuf {
|
||||
// First try to read from the path file created by download script
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let config_file = temp_dir.join("burn_resnet18_benchmark").join("path.txt");
|
||||
|
||||
if config_file.exists()
|
||||
&& let Ok(path_str) = std::fs::read_to_string(&config_file)
|
||||
{
|
||||
let path = PathBuf::from(path_str.trim());
|
||||
if path.exists() {
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to default location
|
||||
temp_dir
|
||||
.join("burn_resnet18_benchmark")
|
||||
.join("resnet18.pth")
|
||||
}
|
||||
|
||||
#[divan::bench(sample_count = 10)]
|
||||
fn load_resnet18_metadata(bencher: Bencher) {
|
||||
let path = resnet18_path();
|
||||
|
||||
bencher.bench_local(|| {
|
||||
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
|
||||
let metadata = reader.metadata();
|
||||
|
||||
// Just access metadata without materializing tensors
|
||||
assert_eq!(metadata.tensor_count, 122);
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(sample_count = 5)]
|
||||
fn load_resnet18_materialize_all(bencher: Bencher) {
|
||||
let path = resnet18_path();
|
||||
|
||||
bencher.bench_local(|| {
|
||||
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
|
||||
let keys = reader.keys();
|
||||
|
||||
let mut total_bytes = 0usize;
|
||||
|
||||
// Materialize all tensors one by one
|
||||
for key in &keys {
|
||||
let tensor = reader.get(key).expect("Failed to get tensor");
|
||||
// Materialize the tensor data
|
||||
let _data = tensor.to_data().expect("Failed to materialize tensor data");
|
||||
total_bytes += tensor.data_len();
|
||||
}
|
||||
|
||||
// Verify we processed all the data
|
||||
assert!(total_bytes > 40_000_000); // Should be ~45MB
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(sample_count = 5)]
|
||||
fn load_resnet18_materialize_sequential(bencher: Bencher) {
|
||||
let path = resnet18_path();
|
||||
|
||||
bencher.bench_local(|| {
|
||||
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
|
||||
let keys = reader.keys();
|
||||
|
||||
// Materialize tensors one at a time, letting previous ones be dropped
|
||||
// This simulates processing tensors sequentially without keeping all in memory
|
||||
for key in &keys {
|
||||
let tensor = reader.get(key).expect("Failed to get tensor");
|
||||
let data = tensor.to_data().expect("Failed to materialize tensor data");
|
||||
|
||||
// Do minimal work with the data to prevent optimization
|
||||
let sum = match data.dtype {
|
||||
burn_tensor::DType::F32 => data
|
||||
.as_slice::<f32>()
|
||||
.map(|s| s.iter().sum::<f32>())
|
||||
.unwrap_or(0.0) as f64,
|
||||
burn_tensor::DType::F64 => data
|
||||
.as_slice::<f64>()
|
||||
.map(|s| s.iter().sum::<f64>())
|
||||
.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
};
|
||||
|
||||
// Use the sum to prevent dead code elimination
|
||||
std::hint::black_box(sum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(sample_count = 10)]
|
||||
fn load_resnet18_largest_tensor(bencher: Bencher) {
|
||||
let path = resnet18_path();
|
||||
|
||||
bencher.bench_local(|| {
|
||||
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
|
||||
|
||||
// Find and materialize only the largest tensor
|
||||
// This tests peak memory for a single tensor operation
|
||||
let keys = reader.keys();
|
||||
let mut largest_key = String::new();
|
||||
let mut largest_size = 0usize;
|
||||
|
||||
for key in &keys {
|
||||
let tensor = reader.get(key).expect("Failed to get tensor");
|
||||
let size = tensor.data_len();
|
||||
if size > largest_size {
|
||||
largest_size = size;
|
||||
largest_key = key.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Materialize the largest tensor
|
||||
let tensor = reader
|
||||
.get(&largest_key)
|
||||
.expect("Failed to get largest tensor");
|
||||
let _data = tensor.to_data().expect("Failed to materialize tensor data");
|
||||
|
||||
assert!(largest_size > 9_000_000); // Should be ~9MB for layer4.0.conv2.weight
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench(sample_count = 10)]
|
||||
fn load_resnet18_memory_profile(bencher: Bencher) {
|
||||
let path = resnet18_path();
|
||||
|
||||
bencher
|
||||
.with_inputs(|| path.clone())
|
||||
.bench_local_values(|path| {
|
||||
let reader = PytorchReader::new(&path).expect("Failed to load ResNet18");
|
||||
let keys = reader.keys();
|
||||
|
||||
let mut peak_single_tensor = 0usize;
|
||||
let mut total_data = 0usize;
|
||||
|
||||
// Process each tensor and track memory
|
||||
for key in &keys {
|
||||
let tensor = reader.get(key).expect("Failed to get tensor");
|
||||
let tensor_size = tensor.data_len();
|
||||
|
||||
// Track largest single tensor
|
||||
if tensor_size > peak_single_tensor {
|
||||
peak_single_tensor = tensor_size;
|
||||
}
|
||||
|
||||
// Materialize the tensor
|
||||
let data = tensor.to_data().expect("Failed to materialize tensor data");
|
||||
total_data += tensor_size;
|
||||
|
||||
// Drop data immediately to test lazy loading memory efficiency
|
||||
drop(data);
|
||||
}
|
||||
|
||||
// Return stats for verification
|
||||
(peak_single_tensor, total_data)
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
//! Unified benchmark comparing all loading methods:
|
||||
//! - BurnpackStore (new native format)
|
||||
//! - NamedMpkFileRecorder (old native format)
|
||||
//! - SafetensorsStore (new)
|
||||
//! - SafetensorsFileRecorder (old)
|
||||
//! - PytorchStore (new)
|
||||
//! - PyTorchFileRecorder (old)
|
||||
//!
|
||||
//! Before running this benchmark, generate the model files:
|
||||
//! ```bash
|
||||
//! cd crates/burn-store
|
||||
//! uv run benches/generate_unified_models.py
|
||||
//! ```
|
||||
//!
|
||||
//! Then run the benchmark:
|
||||
//! ```bash
|
||||
//! cargo bench --bench unified_loading
|
||||
//! ```
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn_core::module::Module;
|
||||
use burn_core::prelude::*;
|
||||
use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
// use burn_import::safetensors::SafetensorsFileRecorder;
|
||||
use burn_nn as nn;
|
||||
use burn_store::{
|
||||
BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore,
|
||||
};
|
||||
use divan::{AllocProfiler, Bencher};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
// Backend type aliases
|
||||
type NdArrayBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
type WgpuBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
type CudaBackend = burn_cuda::Cuda<f32, i32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
type TchBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
type MetalBackend = burn_wgpu::Metal;
|
||||
|
||||
// Use the same LargeModel as other benchmarks for fair comparison
|
||||
#[derive(Module, Debug)]
|
||||
struct LargeModel<B: Backend> {
|
||||
layers: Vec<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> LargeModel<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
let mut layers = Vec::new();
|
||||
// Create a model with 20 layers - same as safetensor_loading benchmark
|
||||
for i in 0..20 {
|
||||
let in_size = if i == 0 { 1024 } else { 2048 };
|
||||
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
|
||||
}
|
||||
Self { layers }
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the path to the model files
|
||||
fn get_model_dir() -> PathBuf {
|
||||
std::env::temp_dir().join("simple_bench_models")
|
||||
}
|
||||
|
||||
/// Generate Burnpack and NamedMpk files from existing SafeTensors file
|
||||
fn generate_burn_formats(st_path: &Path, bp_path: &Path, mpk_path: &Path) {
|
||||
type TestBackend = NdArrayBackend;
|
||||
let device = Default::default();
|
||||
|
||||
// Load the model from SafeTensors
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = SafetensorsStore::from_file(st_path).with_from_adapter(PyTorchToBurnAdapter);
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Failed to load from SafeTensors");
|
||||
|
||||
// Save as Burnpack
|
||||
if !bp_path.exists() {
|
||||
println!(" Creating Burnpack file...");
|
||||
let mut burnpack_store = BurnpackStore::from_file(bp_path);
|
||||
model
|
||||
.save_into(&mut burnpack_store)
|
||||
.expect("Failed to save as Burnpack");
|
||||
}
|
||||
|
||||
// Save as NamedMpk
|
||||
if !mpk_path.exists() {
|
||||
println!(" Creating NamedMpk file...");
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
model
|
||||
.save_file(mpk_path, &recorder)
|
||||
.expect("Failed to save as NamedMpk");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get paths to the model files
|
||||
fn get_model_paths() -> (PathBuf, PathBuf, PathBuf, PathBuf) {
|
||||
let dir = get_model_dir();
|
||||
(
|
||||
dir.join("large_model.bpk"),
|
||||
dir.join("large_model.mpk"),
|
||||
dir.join("large_model.safetensors"),
|
||||
dir.join("large_model.pt"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if model files exist
|
||||
fn check_model_files() -> Result<(), String> {
|
||||
let (_, _, st_path, pt_path) = get_model_paths();
|
||||
|
||||
// For now, only check safetensors and pytorch files (will generate burnpack/mpk later)
|
||||
if !st_path.exists() || !pt_path.exists() {
|
||||
return Err(format!(
|
||||
"\n❌ Model files not found!\n\
|
||||
\n\
|
||||
Please generate the model files first by running:\n\
|
||||
\n\
|
||||
cd crates/burn-store\n\
|
||||
uv run benches/generate_unified_models.py\n\
|
||||
\n\
|
||||
Expected files:\n\
|
||||
- {}\n\
|
||||
- {}\n",
|
||||
st_path.display(),
|
||||
pt_path.display()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Check if model files exist before running benchmarks
|
||||
match check_model_files() {
|
||||
Ok(()) => {
|
||||
let (bp_path, mpk_path, st_path, pt_path) = get_model_paths();
|
||||
|
||||
// First, generate Burnpack and MPK files if they don't exist
|
||||
if !bp_path.exists() || !mpk_path.exists() {
|
||||
println!("⏳ Generating Burnpack and NamedMpk files from SafeTensors...");
|
||||
generate_burn_formats(&st_path, &bp_path, &mpk_path);
|
||||
}
|
||||
|
||||
let bp_size = fs::metadata(&bp_path)
|
||||
.ok()
|
||||
.map(|m| m.len() as f64 / 1_048_576.0);
|
||||
let mpk_size = fs::metadata(&mpk_path)
|
||||
.ok()
|
||||
.map(|m| m.len() as f64 / 1_048_576.0);
|
||||
let st_size = fs::metadata(&st_path).unwrap().len() as f64 / 1_048_576.0;
|
||||
let pt_size = fs::metadata(&pt_path).unwrap().len() as f64 / 1_048_576.0;
|
||||
|
||||
println!("✅ Found model files:");
|
||||
if let Some(size) = bp_size {
|
||||
println!(" Burnpack: {} ({:.1} MB)", bp_path.display(), size);
|
||||
}
|
||||
if let Some(size) = mpk_size {
|
||||
println!(" NamedMpk: {} ({:.1} MB)", mpk_path.display(), size);
|
||||
}
|
||||
println!(" SafeTensors: {} ({:.1} MB)", st_path.display(), st_size);
|
||||
println!(" PyTorch: {} ({:.1} MB)", pt_path.display(), pt_size);
|
||||
println!();
|
||||
println!("🚀 Running unified loading benchmarks...");
|
||||
println!();
|
||||
println!("Comparing 6 loading methods:");
|
||||
println!(" 1. BurnpackStore (new native format - lazy loading)");
|
||||
println!(" 2. NamedMpkFileRecorder (old native format - loads all to memory)");
|
||||
println!(" 3. SafetensorsStore (new)");
|
||||
println!(" 4. SafetensorsFileRecorder (old)");
|
||||
println!(" 5. PytorchStore (new)");
|
||||
println!(" 6. PyTorchFileRecorder (old)");
|
||||
println!();
|
||||
println!("Available backends:");
|
||||
println!(" - NdArray (CPU)");
|
||||
#[cfg(feature = "wgpu")]
|
||||
println!(" - WGPU (GPU)");
|
||||
#[cfg(feature = "cuda")]
|
||||
println!(" - CUDA (NVIDIA GPU)");
|
||||
#[cfg(feature = "tch")]
|
||||
println!(" - LibTorch");
|
||||
#[cfg(feature = "metal")]
|
||||
println!(" - Metal (Apple GPU)");
|
||||
println!();
|
||||
|
||||
divan::main();
|
||||
}
|
||||
Err(msg) => {
|
||||
eprintln!("{}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to generate benchmarks for each backend
|
||||
macro_rules! bench_backend {
|
||||
($backend:ty, $mod_name:ident, $backend_name:literal) => {
|
||||
#[divan::bench_group(name = $backend_name, sample_count = 10)]
|
||||
mod $mod_name {
|
||||
use super::*;
|
||||
|
||||
type TestBackend = $backend;
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
#[divan::bench]
|
||||
fn burnpack_store(bencher: Bencher) {
|
||||
let (bp_path, _, _, _) = get_model_paths();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = BurnpackStore::from_file(bp_path.clone());
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn namedmpk_recorder(bencher: Bencher) {
|
||||
let (_, mpk_path, _, _) = get_model_paths();
|
||||
let file_size = fs::metadata(&mpk_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
let record = recorder
|
||||
.load(mpk_path.clone().into(), &device)
|
||||
.expect("Failed to load");
|
||||
let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn safetensors_store(bencher: Bencher) {
|
||||
let (_, _, st_path, _) = get_model_paths();
|
||||
let file_size = fs::metadata(&st_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = SafetensorsStore::from_file(st_path.clone())
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
// #[divan::bench]
|
||||
// fn safetensors_recorder(bencher: Bencher) {
|
||||
// let (_, _, st_path, _) = get_model_paths();
|
||||
// let file_size = fs::metadata(&st_path).unwrap().len();
|
||||
|
||||
// bencher
|
||||
// .counter(divan::counter::BytesCount::new(file_size))
|
||||
// .bench(|| {
|
||||
// let device: TestDevice = Default::default();
|
||||
// let recorder = SafetensorsFileRecorder::<FullPrecisionSettings>::default();
|
||||
// let record = recorder
|
||||
// .load(st_path.clone().into(), &device)
|
||||
// .expect("Failed to load");
|
||||
// let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
|
||||
// });
|
||||
// }
|
||||
|
||||
#[divan::bench]
|
||||
fn pytorch_store(bencher: Bencher) {
|
||||
let (_, _, _, pt_path) = get_model_paths();
|
||||
let file_size = fs::metadata(&pt_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file(pt_path.clone())
|
||||
.with_top_level_key("model_state_dict")
|
||||
.allow_partial(true);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
// #[divan::bench]
|
||||
// fn pytorch_recorder(bencher: Bencher) {
|
||||
// let (_, _, _, pt_path) = get_model_paths();
|
||||
// let file_size = fs::metadata(&pt_path).unwrap().len();
|
||||
|
||||
// bencher
|
||||
// .counter(divan::counter::BytesCount::new(file_size))
|
||||
// .bench(|| {
|
||||
// let device: TestDevice = Default::default();
|
||||
// let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();
|
||||
// let load_args =
|
||||
// LoadArgs::new(pt_path.clone()).with_top_level_key("model_state_dict");
|
||||
// let record = recorder.load(load_args, &device).expect("Failed to load");
|
||||
// let _model = LargeModel::<TestBackend>::new(&device).load_record(record);
|
||||
// });
|
||||
// }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Generate benchmarks for each backend
|
||||
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");
|
||||
@@ -0,0 +1,183 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
//! Unified benchmark comparing all saving methods:
|
||||
//! - BurnpackStore (new native format)
|
||||
//! - NamedMpkFileRecorder (old native format)
|
||||
//! - SafetensorsStore (new)
|
||||
//!
|
||||
//! Before running this benchmark, ensure the directory exists:
|
||||
//! ```bash
|
||||
//! mkdir -p /tmp/simple_bench_models
|
||||
//! ```
|
||||
//!
|
||||
//! Then run the benchmark:
|
||||
//! ```bash
|
||||
//! cargo bench --bench unified_saving
|
||||
//! ```
|
||||
use burn_core as burn;
|
||||
|
||||
use burn_core::module::Module;
|
||||
use burn_core::prelude::*;
|
||||
use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder};
|
||||
use burn_nn as nn;
|
||||
use burn_store::{BurnpackStore, ModuleSnapshot, SafetensorsStore};
|
||||
use divan::{AllocProfiler, Bencher};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
// Backend type aliases
|
||||
type NdArrayBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
type WgpuBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
type CudaBackend = burn_cuda::Cuda<f32, i32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
type TchBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
type MetalBackend = burn_wgpu::Metal;
|
||||
|
||||
// Use the same LargeModel as other benchmarks for fair comparison
|
||||
#[derive(Module, Debug)]
|
||||
struct LargeModel<B: Backend> {
|
||||
layers: Vec<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> LargeModel<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
let mut layers = Vec::new();
|
||||
// Create a model with 20 layers - same as loading benchmarks
|
||||
for i in 0..20 {
|
||||
let in_size = if i == 0 { 1024 } else { 2048 };
|
||||
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
|
||||
}
|
||||
Self { layers }
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the path to the output directory
|
||||
fn get_output_dir() -> PathBuf {
|
||||
std::env::temp_dir().join("simple_bench_models_saving")
|
||||
}
|
||||
|
||||
/// Ensure output directory exists
|
||||
fn ensure_output_dir() -> Result<(), String> {
|
||||
let dir = get_output_dir();
|
||||
if !dir.exists() {
|
||||
fs::create_dir_all(&dir)
|
||||
.map_err(|e| format!("Failed to create output directory: {}", e))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
match ensure_output_dir() {
|
||||
Ok(()) => {
|
||||
println!("✅ Output directory ready: {}", get_output_dir().display());
|
||||
println!();
|
||||
println!("🚀 Running unified saving benchmarks...");
|
||||
println!();
|
||||
println!("Comparing 3 saving methods:");
|
||||
println!(" 1. BurnpackStore (new native format)");
|
||||
println!(" 2. NamedMpkFileRecorder (old native format)");
|
||||
println!(" 3. SafetensorsStore (new)");
|
||||
println!();
|
||||
println!("Available backends:");
|
||||
println!(" - NdArray (CPU)");
|
||||
#[cfg(feature = "wgpu")]
|
||||
println!(" - WGPU (GPU)");
|
||||
#[cfg(feature = "cuda")]
|
||||
println!(" - CUDA (NVIDIA GPU)");
|
||||
#[cfg(feature = "tch")]
|
||||
println!(" - LibTorch");
|
||||
#[cfg(feature = "metal")]
|
||||
println!(" - Metal (Apple GPU)");
|
||||
println!();
|
||||
|
||||
divan::main();
|
||||
}
|
||||
Err(msg) => {
|
||||
eprintln!("❌ {}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to generate benchmarks for each backend
|
||||
macro_rules! bench_backend {
|
||||
($backend:ty, $mod_name:ident, $backend_name:literal) => {
|
||||
#[divan::bench_group(name = $backend_name, sample_count = 10)]
|
||||
mod $mod_name {
|
||||
use super::*;
|
||||
|
||||
type TestBackend = $backend;
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
#[divan::bench]
|
||||
fn burnpack_store(bencher: Bencher) {
|
||||
bencher.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let model = LargeModel::<TestBackend>::new(&device);
|
||||
let output_path = get_output_dir().join("test_burnpack.bpk");
|
||||
let mut store = BurnpackStore::from_file(output_path.clone()).overwrite(true);
|
||||
model
|
||||
.save_into(&mut store)
|
||||
.expect("Failed to save with BurnpackStore");
|
||||
// Clean up
|
||||
let _ = fs::remove_file(output_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn namedmpk_recorder(bencher: Bencher) {
|
||||
bencher.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let model = LargeModel::<TestBackend>::new(&device);
|
||||
let output_path = get_output_dir().join("test_namedmpk.mpk");
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
model
|
||||
.save_file(output_path.clone(), &recorder)
|
||||
.expect("Failed to save with NamedMpkFileRecorder");
|
||||
// Clean up
|
||||
let _ = fs::remove_file(output_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[divan::bench]
|
||||
fn safetensors_store(bencher: Bencher) {
|
||||
bencher.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let model = LargeModel::<TestBackend>::new(&device);
|
||||
let output_path = get_output_dir().join("test_safetensors_store.safetensors");
|
||||
let mut store = SafetensorsStore::from_file(output_path.clone());
|
||||
model
|
||||
.save_into(&mut store)
|
||||
.expect("Failed to save with SafetensorsStore");
|
||||
// Clean up
|
||||
let _ = fs::remove_file(output_path);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Generate benchmarks for each backend
|
||||
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");
|
||||
@@ -0,0 +1,596 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
//! Benchmark comparing zero-copy vs copy loading modes for BurnpackStore.
|
||||
//!
|
||||
//! This benchmark measures the performance difference between:
|
||||
//! - `zero_copy(false)` - Default mode, copies tensor data into new allocations
|
||||
//! - `zero_copy(true)` - Zero-copy mode, slices tensor data without copying
|
||||
//!
|
||||
//! ## Understanding the Results
|
||||
//!
|
||||
//! **IMPORTANT**: For NdArray backend, you'll see similar allocation numbers because:
|
||||
//! - NdArray uses `ndarray::ArrayD` which MUST own data as `Vec<T>`
|
||||
//! - Even with zero-copy, the backend eventually copies data into its own format
|
||||
//!
|
||||
//! The zero-copy benefit is:
|
||||
//! - **Without zero-copy**: File → Copy to heap (Bytes) → Copy to Vec (backend)
|
||||
//! - **With zero-copy**: File → Zero-copy slice → Copy to Vec (backend)
|
||||
//!
|
||||
//! So zero-copy saves ONE memory copy at the store level. The `store_only_*` benchmarks
|
||||
//! show the raw store performance without backend allocation overhead.
|
||||
//!
|
||||
//! GPU backends that can consume `Bytes` directly will show larger benefits.
|
||||
//!
|
||||
//! ## Running the benchmark
|
||||
//!
|
||||
//! Before running this benchmark, generate the model files:
|
||||
//! ```bash
|
||||
//! cd crates/burn-store
|
||||
//! uv run benches/generate_unified_models.py
|
||||
//! ```
|
||||
//!
|
||||
//! Then run the benchmark:
|
||||
//! ```bash
|
||||
//! cargo bench --bench zero_copy_loading
|
||||
//! ```
|
||||
|
||||
use burn_core as burn;
|
||||
|
||||
use burn_core::module::Module;
|
||||
use burn_core::prelude::*;
|
||||
use burn_nn as nn;
|
||||
use burn_store::{
|
||||
BurnpackStore, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore,
|
||||
};
|
||||
use burn_tensor::{AllocationProperty, Bytes};
|
||||
use divan::{AllocProfiler, Bencher};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: AllocProfiler = AllocProfiler::system();
|
||||
|
||||
// Static storage for embedded model bytes (simulating include_bytes!)
|
||||
static STATIC_MODEL_BYTES: OnceLock<&'static [u8]> = OnceLock::new();
|
||||
|
||||
// Backend type aliases
|
||||
type NdArrayBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
type WgpuBackend = burn_wgpu::Wgpu;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
type CudaBackend = burn_cuda::Cuda<f32, i32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
type TchBackend = burn_tch::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
type MetalBackend = burn_wgpu::Metal;
|
||||
|
||||
// Use the same LargeModel as other benchmarks for fair comparison
|
||||
#[derive(Module, Debug)]
|
||||
struct LargeModel<B: Backend> {
|
||||
layers: Vec<nn::Linear<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> LargeModel<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
let mut layers = Vec::new();
|
||||
// Create a model with 20 layers - same as unified_loading benchmark
|
||||
for i in 0..20 {
|
||||
let in_size = if i == 0 { 1024 } else { 2048 };
|
||||
layers.push(nn::LinearConfig::new(in_size, 2048).init(device));
|
||||
}
|
||||
Self { layers }
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the path to the model files
|
||||
fn get_model_dir() -> PathBuf {
|
||||
std::env::temp_dir().join("simple_bench_models")
|
||||
}
|
||||
|
||||
/// Get path to Burnpack model file
|
||||
fn get_burnpack_path() -> PathBuf {
|
||||
get_model_dir().join("large_model.bpk")
|
||||
}
|
||||
|
||||
/// Generate Burnpack file from existing SafeTensors file if needed
|
||||
fn ensure_burnpack_file() {
|
||||
let bp_path = get_burnpack_path();
|
||||
let st_path = get_model_dir().join("large_model.safetensors");
|
||||
|
||||
if bp_path.exists() {
|
||||
return;
|
||||
}
|
||||
|
||||
if !st_path.exists() {
|
||||
panic!(
|
||||
"\n❌ SafeTensors model file not found!\n\
|
||||
\n\
|
||||
Please generate the model files first by running:\n\
|
||||
\n\
|
||||
cd crates/burn-store\n\
|
||||
uv run benches/generate_unified_models.py\n\
|
||||
\n\
|
||||
Expected file: {}\n",
|
||||
st_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
println!("⏳ Generating Burnpack file from SafeTensors...");
|
||||
|
||||
type TestBackend = NdArrayBackend;
|
||||
let device = Default::default();
|
||||
|
||||
// Load from SafeTensors
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = SafetensorsStore::from_file(&st_path).with_from_adapter(PyTorchToBurnAdapter);
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Failed to load from SafeTensors");
|
||||
|
||||
// Save as Burnpack
|
||||
let mut burnpack_store = BurnpackStore::from_file(&bp_path);
|
||||
model
|
||||
.save_into(&mut burnpack_store)
|
||||
.expect("Failed to save as Burnpack");
|
||||
|
||||
println!("✅ Created Burnpack file: {}", bp_path.display());
|
||||
}
|
||||
|
||||
/// Initialize static model bytes (simulating include_bytes! at runtime for benchmarks)
|
||||
fn get_static_model_bytes() -> &'static [u8] {
|
||||
STATIC_MODEL_BYTES.get_or_init(|| {
|
||||
let bp_path = get_burnpack_path();
|
||||
let bytes = fs::read(&bp_path).expect("Failed to read Burnpack file");
|
||||
// Leak the bytes to get a 'static lifetime (acceptable for benchmarks)
|
||||
Box::leak(bytes.into_boxed_slice())
|
||||
})
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Ensure Burnpack file exists
|
||||
ensure_burnpack_file();
|
||||
|
||||
let bp_path = get_burnpack_path();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len() as f64 / 1_048_576.0;
|
||||
|
||||
println!("✅ Found Burnpack model file:");
|
||||
println!(" Path: {}", bp_path.display());
|
||||
println!(" Size: {:.1} MB", file_size);
|
||||
println!();
|
||||
println!("🚀 Running zero-copy loading benchmarks...");
|
||||
println!();
|
||||
println!("Comparing loading modes:");
|
||||
println!(" 1. file_copy - from_file().zero_copy(false) - copies tensor data");
|
||||
println!(" 2. file_zero_copy - from_file().zero_copy(true) - zero-copy via mmap");
|
||||
println!(" 3. static_copy - from_bytes() with Vec copy - copies from static");
|
||||
println!(" 4. static_zero_copy - from_static() - zero-copy from static");
|
||||
println!();
|
||||
println!("Available backends:");
|
||||
println!(" - NdArray (CPU)");
|
||||
#[cfg(feature = "wgpu")]
|
||||
println!(" - WGPU (GPU)");
|
||||
#[cfg(feature = "cuda")]
|
||||
println!(" - CUDA (NVIDIA GPU)");
|
||||
#[cfg(feature = "tch")]
|
||||
println!(" - LibTorch");
|
||||
#[cfg(feature = "metal")]
|
||||
println!(" - Metal (Apple GPU)");
|
||||
println!();
|
||||
|
||||
// Pre-initialize static bytes before benchmarks
|
||||
let _ = get_static_model_bytes();
|
||||
|
||||
divan::main();
|
||||
}
|
||||
|
||||
// Macro to generate benchmarks for each backend
|
||||
macro_rules! bench_backend {
|
||||
($backend:ty, $mod_name:ident, $backend_name:literal) => {
|
||||
#[divan::bench_group(name = $backend_name, sample_count = 10)]
|
||||
mod $mod_name {
|
||||
use super::*;
|
||||
|
||||
type TestBackend = $backend;
|
||||
type TestDevice = <TestBackend as Backend>::Device;
|
||||
|
||||
/// File-based loading with copy mode (default)
|
||||
#[divan::bench]
|
||||
fn file_copy(bencher: Bencher) {
|
||||
let bp_path = get_burnpack_path();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
/// File-based loading with zero-copy mode (mmap + bytes::Bytes)
|
||||
#[divan::bench]
|
||||
fn file_zero_copy(bencher: Bencher) {
|
||||
let bp_path = get_burnpack_path();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
/// Static bytes with copy mode (simulating old behavior)
|
||||
#[divan::bench]
|
||||
fn static_copy(bencher: Bencher) {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let file_size = static_bytes.len() as u64;
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
|
||||
// Simulate old behavior: copy static bytes to Vec, then load
|
||||
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
|
||||
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
/// Static bytes with zero-copy mode (new from_static)
|
||||
#[divan::bench]
|
||||
fn static_zero_copy(bencher: Bencher) {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let file_size = static_bytes.len() as u64;
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
|
||||
// Zero-copy: use from_static which keeps data in .rodata
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
|
||||
/// In-memory shared bytes with zero-copy
|
||||
#[divan::bench]
|
||||
fn memory_shared_zero_copy(bencher: Bencher) {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let file_size = static_bytes.len() as u64;
|
||||
|
||||
// Pre-create shared bytes outside the benchmark loop
|
||||
let shared = bytes::Bytes::from_static(static_bytes);
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let device: TestDevice = Default::default();
|
||||
let mut model = LargeModel::<TestBackend>::new(&device);
|
||||
|
||||
// Create Bytes from shared (cheap clone of Arc)
|
||||
let bytes = Bytes::from_shared(shared.clone(), AllocationProperty::Other);
|
||||
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(true);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Zero-copy verification (proves operations use static region data)
|
||||
// =============================================================================
|
||||
|
||||
/// Verify that zero-copy loading actually uses data from the static region.
|
||||
/// This runs once at startup to prove correctness before benchmarking.
|
||||
#[divan::bench_group(name = "Zero-Copy Verification", sample_count = 1)]
|
||||
mod verification {
|
||||
use super::*;
|
||||
use burn_ndarray::NdArray;
|
||||
|
||||
type B = NdArray<f32>;
|
||||
|
||||
/// Verify zero-copy: tensor storage is borrowed (not owned)
|
||||
#[divan::bench]
|
||||
fn verify_storage_is_borrowed() {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
|
||||
// Load model with zero-copy from static bytes
|
||||
let device = Default::default();
|
||||
let mut model = LargeModel::<B>::new(&device);
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
|
||||
// Get the first layer's weight tensor and verify it uses borrowed storage
|
||||
let weight = model.layers[0].weight.val();
|
||||
// .into_primitive() returns TensorPrimitive<B>, .tensor() extracts B::FloatTensorPrimitive
|
||||
let ndarray_tensor = weight.into_primitive().tensor();
|
||||
|
||||
// Verify the storage is borrowed (zero-copy from static region)
|
||||
assert!(
|
||||
ndarray_tensor.is_borrowed(),
|
||||
"ZERO-COPY FAILURE: Tensor storage is NOT borrowed. \
|
||||
Data was copied instead of being zero-copy!"
|
||||
);
|
||||
|
||||
println!("✅ Verified: Tensor storage is borrowed (zero-copy from static region)");
|
||||
}
|
||||
|
||||
/// Verify ALL layers use borrowed (zero-copy) storage.
|
||||
/// This is the key proof that loaded weights point to static memory.
|
||||
#[divan::bench]
|
||||
fn verify_all_layers_borrowed() {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
|
||||
// Load model with zero-copy
|
||||
let device = Default::default();
|
||||
let mut model = LargeModel::<B>::new(&device);
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
|
||||
// Check ALL layers have borrowed storage
|
||||
let mut total_elements = 0usize;
|
||||
for (i, layer) in model.layers.iter().enumerate() {
|
||||
let weight = layer.weight.val();
|
||||
total_elements += weight.shape().num_elements();
|
||||
|
||||
assert!(
|
||||
weight.into_primitive().tensor().is_borrowed(),
|
||||
"Layer {} weight should be borrowed (zero-copy)",
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
let total_mb = (total_elements * 4) as f64 / 1_048_576.0;
|
||||
println!(
|
||||
"✅ Verified: All {} layers use borrowed storage",
|
||||
model.layers.len()
|
||||
);
|
||||
println!(
|
||||
" - Model size: {:.2} MB - all pointing to static region",
|
||||
total_mb
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify data is readable and correct using sum().into_scalar().
|
||||
/// Note: sum() triggers COW copy, so this shows ops work correctly on zero-copy data.
|
||||
#[divan::bench]
|
||||
fn verify_ops_produce_correct_results() {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
|
||||
let device = Default::default();
|
||||
let mut model = LargeModel::<B>::new(&device);
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
|
||||
// Compute sum of first layer weight - proves data is valid
|
||||
let weight = model.layers[0].weight.val();
|
||||
let sum: f32 = weight.sum().into_scalar();
|
||||
|
||||
assert!(sum.is_finite(), "Sum should be finite");
|
||||
println!("✅ Verified: Operations on zero-copy data produce valid results");
|
||||
println!(" - First layer sum: {:.4}", sum);
|
||||
}
|
||||
|
||||
/// Verify operations produce correct results on zero-copy data
|
||||
#[divan::bench]
|
||||
fn verify_operations_on_static_data() {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
|
||||
// Load model with zero-copy
|
||||
let device = Default::default();
|
||||
let mut model = LargeModel::<B>::new(&device);
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
model.load_from(&mut store).expect("Failed to load");
|
||||
|
||||
// Perform operations on the loaded weights
|
||||
let weight = model.layers[0].weight.val();
|
||||
let shape = weight.shape();
|
||||
|
||||
// Test 1: Sum should be finite (not NaN or Inf)
|
||||
let sum: f32 = weight.clone().sum().to_data().to_vec().unwrap()[0];
|
||||
assert!(
|
||||
sum.is_finite(),
|
||||
"Operation failed: sum is not finite ({})",
|
||||
sum
|
||||
);
|
||||
|
||||
// Test 2: Matrix multiply with itself transposed (W @ W.T)
|
||||
let transposed = weight.clone().transpose();
|
||||
let matmul_result = weight.clone().matmul(transposed);
|
||||
let matmul_sum: f32 = matmul_result.sum().to_data().to_vec().unwrap()[0];
|
||||
assert!(
|
||||
matmul_sum.is_finite(),
|
||||
"Matmul failed: result sum is not finite ({})",
|
||||
matmul_sum
|
||||
);
|
||||
|
||||
// Test 3: Element-wise operations
|
||||
let doubled = weight.clone() * 2.0;
|
||||
let doubled_sum: f32 = doubled.sum().to_data().to_vec().unwrap()[0];
|
||||
assert!(
|
||||
(doubled_sum - sum * 2.0).abs() < 1e-3,
|
||||
"Element-wise op failed: doubled_sum ({}) != sum*2 ({})",
|
||||
doubled_sum,
|
||||
sum * 2.0
|
||||
);
|
||||
|
||||
println!("✅ Verified: Operations on zero-copy data produce correct results");
|
||||
println!(" - Weight shape: {:?}", shape.as_slice());
|
||||
println!(" - Sum: {:.4}", sum);
|
||||
println!(" - Matmul result sum: {:.4}", matmul_sum);
|
||||
}
|
||||
|
||||
/// Compare zero-copy vs copy: verify both produce identical results
|
||||
#[divan::bench]
|
||||
fn verify_copy_vs_zero_copy_equality() {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let device: <B as Backend>::Device = Default::default();
|
||||
|
||||
// Load with zero-copy
|
||||
let mut model_zc = LargeModel::<B>::new(&device);
|
||||
let mut store_zc = BurnpackStore::from_static(static_bytes);
|
||||
model_zc
|
||||
.load_from(&mut store_zc)
|
||||
.expect("Failed to load zero-copy");
|
||||
|
||||
// Load with copy (simulate old behavior)
|
||||
let mut model_copy = LargeModel::<B>::new(&device);
|
||||
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
|
||||
let mut store_copy = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
|
||||
model_copy
|
||||
.load_from(&mut store_copy)
|
||||
.expect("Failed to load copy");
|
||||
|
||||
// Compare weights from both models
|
||||
for (i, (layer_zc, layer_copy)) in model_zc
|
||||
.layers
|
||||
.iter()
|
||||
.zip(model_copy.layers.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let weight_zc = layer_zc.weight.val();
|
||||
let weight_copy = layer_copy.weight.val();
|
||||
|
||||
// Check shapes match
|
||||
assert_eq!(
|
||||
weight_zc.shape(),
|
||||
weight_copy.shape(),
|
||||
"Layer {} weight shapes don't match",
|
||||
i
|
||||
);
|
||||
|
||||
// Check values match (using sum as a proxy)
|
||||
let sum_zc: f32 = weight_zc.clone().sum().to_data().to_vec().unwrap()[0];
|
||||
let sum_copy: f32 = weight_copy.clone().sum().to_data().to_vec().unwrap()[0];
|
||||
assert!(
|
||||
(sum_zc - sum_copy).abs() < 1e-6,
|
||||
"Layer {} weight sums don't match: zero-copy={}, copy={}",
|
||||
i,
|
||||
sum_zc,
|
||||
sum_copy
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"✅ Verified: Zero-copy and copy loading produce identical results for all {} layers",
|
||||
model_zc.layers.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Store-only benchmarks (no backend allocation overhead)
|
||||
// These show the TRUE zero-copy benefit at the store level
|
||||
// =============================================================================
|
||||
|
||||
#[divan::bench_group(name = "Store Only (no backend)", sample_count = 10)]
|
||||
mod store_only {
|
||||
use super::*;
|
||||
|
||||
/// File-based store with copy mode - measures store overhead only
|
||||
#[divan::bench]
|
||||
fn file_copy(bencher: Bencher) {
|
||||
let bp_path = get_burnpack_path();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);
|
||||
// Just iterate through all tensor snapshots, calling to_data() on each
|
||||
// This forces the store to read and materialize all tensor data
|
||||
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
|
||||
for snapshot in snapshots.values() {
|
||||
let _data = snapshot.to_data().expect("Failed to get tensor data");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// File-based store with zero-copy mode - measures store overhead only
|
||||
#[divan::bench]
|
||||
fn file_zero_copy(bencher: Bencher) {
|
||||
let bp_path = get_burnpack_path();
|
||||
let file_size = fs::metadata(&bp_path).unwrap().len();
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);
|
||||
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
|
||||
for snapshot in snapshots.values() {
|
||||
let _data = snapshot.to_data().expect("Failed to get tensor data");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Static bytes with copy mode - measures store overhead only
|
||||
#[divan::bench]
|
||||
fn static_copy(bencher: Bencher) {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let file_size = static_bytes.len() as u64;
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
// Simulate old behavior: copy static bytes to Vec
|
||||
let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());
|
||||
let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);
|
||||
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
|
||||
for snapshot in snapshots.values() {
|
||||
let _data = snapshot.to_data().expect("Failed to get tensor data");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Static bytes with zero-copy mode - measures store overhead only
|
||||
#[divan::bench]
|
||||
fn static_zero_copy(bencher: Bencher) {
|
||||
let static_bytes = get_static_model_bytes();
|
||||
let file_size = static_bytes.len() as u64;
|
||||
|
||||
bencher
|
||||
.counter(divan::counter::BytesCount::new(file_size))
|
||||
.bench(|| {
|
||||
let mut store = BurnpackStore::from_static(static_bytes);
|
||||
let snapshots = store.get_all_snapshots().expect("Failed to get snapshots");
|
||||
for snapshot in snapshots.values() {
|
||||
let _data = snapshot.to_data().expect("Failed to get tensor data");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Full model loading benchmarks (includes backend allocation)
|
||||
// =============================================================================
|
||||
|
||||
// Generate benchmarks for each backend
|
||||
bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)");
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)");
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)");
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
bench_backend!(TchBackend, tch_backend, "LibTorch Backend");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)");
|
||||
@@ -0,0 +1,148 @@
|
||||
//! Example: Generate a Burnpack file for inspection
|
||||
//!
|
||||
//! This example creates a simple Burnpack file that you can examine to understand the format.
|
||||
//!
|
||||
//! Usage:
|
||||
//! cargo run --example burnpack-inspect [output_path]
|
||||
//!
|
||||
//! Example:
|
||||
//! cargo run --example burnpack-inspect sample.bpk
|
||||
//! cargo run --example burnpack-inspect /tmp/test.bpk
|
||||
//!
|
||||
//! After generating the file, examine it with:
|
||||
//! hexdump -C sample.bpk | head -100
|
||||
//! xxd sample.bpk | head -100
|
||||
//! hexyl sample.bpk
|
||||
use burn_core as burn;
|
||||
|
||||
use burn_core::module::Module;
|
||||
use burn_ndarray::NdArray;
|
||||
use burn_nn::{Linear, LinearConfig};
|
||||
use burn_store::{BurnpackStore, ModuleSnapshot};
|
||||
use burn_tensor::backend::Backend;
|
||||
use std::env;
|
||||
|
||||
// Simple model with a few layers
|
||||
#[derive(Module, Debug)]
|
||||
struct SampleModel<B: Backend> {
|
||||
linear1: Linear<B>,
|
||||
linear2: Linear<B>,
|
||||
linear3: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SampleModel<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
linear1: LinearConfig::new(128, 64).init(device),
|
||||
linear2: LinearConfig::new(64, 32).init(device),
|
||||
linear3: LinearConfig::new(32, 10).init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = NdArray<f32>;
|
||||
|
||||
// Get output path from command line or use default
|
||||
let output_path = env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "sample.bpk".to_string());
|
||||
|
||||
println!("Creating sample Burnpack file: {}", output_path);
|
||||
println!();
|
||||
|
||||
// Create a simple model
|
||||
let device = Default::default();
|
||||
let model = SampleModel::<Backend>::new(&device);
|
||||
|
||||
// Save to Burnpack format with metadata
|
||||
let mut store = BurnpackStore::from_file(&output_path)
|
||||
.overwrite(true)
|
||||
.metadata("format", "burnpack")
|
||||
.metadata("description", "Sample file for examining Burnpack format")
|
||||
.metadata("version", env!("CARGO_PKG_VERSION"))
|
||||
.metadata("author", "Burn Example");
|
||||
|
||||
model.save_into(&mut store).expect("Failed to save model");
|
||||
|
||||
println!("✅ Successfully created: {}", output_path);
|
||||
println!();
|
||||
println!("📋 File Structure:");
|
||||
println!(" ┌─────────────────────────────────────┐");
|
||||
println!(" │ Header (10 bytes) │");
|
||||
println!(" ├─────────────────────────────────────┤");
|
||||
println!(" │ - Magic: 0x4E525542 (BURN in LE) │");
|
||||
println!(" │ - Version: 0x0001 (2 bytes) │");
|
||||
println!(" │ - Metadata size: (4 bytes, u32 LE) │");
|
||||
println!(" ├─────────────────────────────────────┤");
|
||||
println!(" │ Metadata (CBOR format) │");
|
||||
println!(" ├─────────────────────────────────────┤");
|
||||
println!(" │ - Tensor descriptors │");
|
||||
println!(" │ * name, dtype, shape, offsets │");
|
||||
println!(" │ - User metadata │");
|
||||
println!(" ├─────────────────────────────────────┤");
|
||||
println!(" │ Tensor Data (raw bytes, LE) │");
|
||||
println!(" ├─────────────────────────────────────┤");
|
||||
println!(" │ - linear1.weight [64, 128] │");
|
||||
println!(" │ - linear1.bias [64] │");
|
||||
println!(" │ - linear2.weight [32, 64] │");
|
||||
println!(" │ - linear2.bias [32] │");
|
||||
println!(" │ - linear3.weight [10, 32] │");
|
||||
println!(" │ - linear3.bias [10] │");
|
||||
println!(" └─────────────────────────────────────┘");
|
||||
println!();
|
||||
println!("📊 Model Contents:");
|
||||
println!(" - linear1.weight: [64, 128] = 8,192 params → 32,768 bytes");
|
||||
println!(" - linear1.bias: [64] = 64 params → 256 bytes");
|
||||
println!(" - linear2.weight: [32, 64] = 2,048 params → 8,192 bytes");
|
||||
println!(" - linear2.bias: [32] = 32 params → 128 bytes");
|
||||
println!(" - linear3.weight: [10, 32] = 320 params → 1,280 bytes");
|
||||
println!(" - linear3.bias: [10] = 10 params → 40 bytes");
|
||||
println!(" ───────────────────────────────────────────────────────");
|
||||
|
||||
let total_params = 8192 + 64 + 2048 + 32 + 320 + 10;
|
||||
let total_bytes = total_params * 4;
|
||||
println!(
|
||||
" Total: {} parameters = {} KB",
|
||||
total_params,
|
||||
total_bytes / 1024
|
||||
);
|
||||
println!();
|
||||
|
||||
// Get actual file size
|
||||
if let Ok(metadata) = std::fs::metadata(&output_path) {
|
||||
let file_size = metadata.len();
|
||||
println!(
|
||||
"📦 File size: {} bytes ({:.2} KB)",
|
||||
file_size,
|
||||
file_size as f64 / 1024.0
|
||||
);
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("🔍 Inspection Commands:");
|
||||
println!();
|
||||
println!(" # View first 100 bytes in hex:");
|
||||
println!(" hexdump -C {} | head -20", output_path);
|
||||
println!();
|
||||
println!(" # View header only (10 bytes):");
|
||||
println!(" head -c 10 {} | hexdump -C", output_path);
|
||||
println!();
|
||||
println!(" # View with prettier hex viewer (if installed):");
|
||||
println!(" hexyl {} | head -50", output_path);
|
||||
println!();
|
||||
println!(" # View in binary format:");
|
||||
println!(" xxd -b {} | head -20", output_path);
|
||||
println!();
|
||||
println!(" # Extract and examine header:");
|
||||
println!(" # Magic (bytes 0-3): Should be 42 55 52 4E (BURN)");
|
||||
println!(" # Version (bytes 4-5): Should be 01 00");
|
||||
println!(" # Metadata size (bytes 6-9): u32 little-endian");
|
||||
println!();
|
||||
println!(" # Load back the model:");
|
||||
println!(
|
||||
" # let mut store = BurnpackStore::from_file(\"{}\");",
|
||||
output_path
|
||||
);
|
||||
println!(" # model.load_from(&mut store)?;");
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "pytorch-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
burn-autodiff = { path = "../../burn-autodiff" }
|
||||
burn-store = { path = "../", features = ["std", "pytorch"] }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.BatchNorm2d(5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
# Condition batch norm (each forward will affect the running stats)
|
||||
x1 = torch.ones(1, 5, 2, 2) - 0.5
|
||||
_ = model(x1)
|
||||
model.eval() # Set to eval mode to freeze running stats
|
||||
# Save the model after the first forward
|
||||
torch.save(model.state_dict(), "batch_norm2d.pt")
|
||||
|
||||
x2 = torch.ones(1, 5, 2, 2) - 0.3
|
||||
print("Input shape: {}", x2.shape)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,62 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{BatchNorm, BatchNormConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: BatchNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
norm1: BatchNormConfig::new(5).init(device), // Python model uses BatchNorm2d(5)
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn batch_norm2d() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file("tests/batch_norm/batch_norm2d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 5, 2, 2], &device) - 0.3;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.tensor([True, False, True])
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "boolean.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,58 @@
|
||||
use burn::{
|
||||
module::{Module, Param, ParamId},
|
||||
tensor::{Bool, Tensor, TensorData, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 1, Bool>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model with placeholder values.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
buffer: Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::from_bool(TensorData::from([false, false, false]), device),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
|
||||
self.buffer.val()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use burn::tensor::TensorData;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn boolean() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/boolean/boolean.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 1, Bool>::from_bool(
|
||||
TensorData::from([true, false, true]),
|
||||
&device,
|
||||
);
|
||||
|
||||
assert_eq!(output.to_data(), expected.to_data());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.ones(3, 3)
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer + x
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "buffer.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,53 @@
|
||||
use burn::{
|
||||
module::{Module, Param},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model with placeholder values.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
buffer: Param::from_tensor(Tensor::zeros([3, 3], device)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
self.buffer.val() + x
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn buffer() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/buffer/buffer.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 2>::ones([3, 3], &device) * 2.0;
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
self.conv_blocks = nn.Sequential(
|
||||
ConvBlock(2, 4, (3, 2)),
|
||||
ConvBlock(4, 6, (3, 2)),
|
||||
)
|
||||
self.norm1 = nn.BatchNorm2d(6)
|
||||
|
||||
self.fc1 = nn.Linear(120, 12)
|
||||
self.fc2 = nn.Linear(12, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_blocks(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = F.log_softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(2)
|
||||
|
||||
|
||||
model = Net().to(torch.device("cpu"))
|
||||
|
||||
# Condition the model (batch norm requires a forward pass to compute the mean and variance)
|
||||
x1 = torch.ones(1, 2, 9, 6) - 0.1
|
||||
x2 = torch.ones(1, 2, 9, 6) - 0.3
|
||||
output = model(x1)
|
||||
output = model(x2)
|
||||
model.eval() # set to eval mode
|
||||
|
||||
torch.save(model.state_dict(), "complex_nested.pt")
|
||||
|
||||
# feed test data
|
||||
x = torch.ones(1, 2, 9, 6) - 0.5
|
||||
output = model(x)
|
||||
print("Input shape: {}", x.shape)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,240 @@
|
||||
use burn::tensor::Tolerance;
|
||||
use burn::tensor::ops::FloatElem;
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
BatchNorm, BatchNormConfig, Linear, LinearConfig,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
},
|
||||
tensor::{
|
||||
Tensor,
|
||||
activation::{log_softmax, relu},
|
||||
backend::Backend,
|
||||
},
|
||||
};
|
||||
use burn_autodiff::Autodiff;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
norm: BatchNorm<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv_blocks: Vec<ConvBlock<B>>,
|
||||
norm1: BatchNorm<B>,
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv_blocks = vec![
|
||||
ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
|
||||
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
|
||||
},
|
||||
ConvBlock {
|
||||
conv: Conv2dConfig::new([4, 6], [3, 2]).init(device),
|
||||
norm: BatchNormConfig::new(6).init(device), // matches conv output channels
|
||||
},
|
||||
];
|
||||
let norm1 = BatchNormConfig::new(6).init(device);
|
||||
let fc1 = LinearConfig::new(120, 12).init(device);
|
||||
let fc2 = LinearConfig::new(12, 10).init(device);
|
||||
|
||||
Self {
|
||||
conv_blocks,
|
||||
norm1,
|
||||
fc1,
|
||||
fc2,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
|
||||
let x = self.conv_blocks[0].forward(x);
|
||||
let x = self.conv_blocks[1].forward(x);
|
||||
let x = self.norm1.forward(x);
|
||||
let x = x.reshape([0, -1]);
|
||||
let x = self.fc1.forward(x);
|
||||
let x = relu(x);
|
||||
let x = self.fc2.forward(x);
|
||||
|
||||
log_softmax(x, 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv.forward(x);
|
||||
|
||||
self.norm.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Partial model to test loading of partial records.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PartialNet<B: Backend> {
|
||||
conv1: ConvBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> PartialNet<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
|
||||
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
|
||||
};
|
||||
Self { conv1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.conv1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Model with extra fields to test loading of records (e.g. from a different model).
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PartialWithExtraNet<B: Backend> {
|
||||
conv1: ConvBlock<B>,
|
||||
extra_field: bool, // This field is not present in the pytorch model
|
||||
}
|
||||
|
||||
impl<B: Backend> PartialWithExtraNet<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = ConvBlock {
|
||||
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
|
||||
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
|
||||
};
|
||||
|
||||
Self {
|
||||
conv1,
|
||||
extra_field: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.conv1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn model_test(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 2>::from_data(
|
||||
[[
|
||||
-2.306_613,
|
||||
-2.058_945_4,
|
||||
-2.298_372_7,
|
||||
-2.358_294,
|
||||
-2.296_395_5,
|
||||
-2.416_090_5,
|
||||
-2.107_669,
|
||||
-2.428_420_8,
|
||||
-2.526_469,
|
||||
-2.319_918_6,
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output.to_data().assert_approx_eq::<FloatElem<TestBackend>>(
|
||||
&expected.to_data(),
|
||||
Tolerance::absolute(precision),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_record() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
model_test(model, 1e-8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_record_autodiff() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<Autodiff<TestBackend>>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn half_record() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
model_test(model, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_model_loading() {
|
||||
let device = Default::default();
|
||||
let mut model = PartialNet::<TestBackend>::init(&device);
|
||||
|
||||
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
|
||||
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
|
||||
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
|
||||
.allow_partial(true);
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
// get the sum of all elements in the output tensor for quick check
|
||||
let sum = output.sum();
|
||||
|
||||
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extra_field_model_loading() {
|
||||
let device = Default::default();
|
||||
let mut model = PartialWithExtraNet::<TestBackend>::init(&device);
|
||||
|
||||
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
|
||||
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
|
||||
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
|
||||
.allow_partial(true);
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
// get the sum of all elements in the output tensor for quick check
|
||||
let sum = output.sum();
|
||||
|
||||
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
|
||||
|
||||
assert!(model.extra_field);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
self.fc2 = nn.Linear(3, 4, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
CONFIG = {
|
||||
"n_head": 2,
|
||||
"n_layer": 3,
|
||||
"d_model": 512,
|
||||
"some_float": 0.1,
|
||||
"some_int": 1,
|
||||
"some_bool": True,
|
||||
"some_str": "hello",
|
||||
"some_list_int": [1, 2, 3],
|
||||
"some_list_str": ["hello", "world"],
|
||||
"some_list_float": [0.1, 0.2, 0.3],
|
||||
"some_dict": {
|
||||
"some_key": "some_value"
|
||||
}
|
||||
}
|
||||
|
||||
class ModelWithBias(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithBias, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
weights_with_config = {
|
||||
"my_model": model.state_dict(),
|
||||
"my_config": CONFIG
|
||||
}
|
||||
|
||||
torch.save(weights_with_config, "weights_with_config.pt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,53 @@
|
||||
#![allow(clippy::too_many_arguments)] // To mute derive Config warning
|
||||
use std::collections::HashMap;
|
||||
|
||||
use burn::config::Config;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(Debug, PartialEq, Config)]
|
||||
struct NetConfig {
|
||||
n_head: usize,
|
||||
n_layer: usize,
|
||||
d_model: usize,
|
||||
some_float: f64,
|
||||
some_int: i32,
|
||||
some_bool: bool,
|
||||
some_str: String,
|
||||
some_list_int: Vec<i32>,
|
||||
some_list_str: Vec<String>,
|
||||
some_list_float: Vec<f64>,
|
||||
some_dict: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_store::pytorch::PytorchReader;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_net_config() {
|
||||
let config_expected = NetConfig {
|
||||
n_head: 2,
|
||||
n_layer: 3,
|
||||
d_model: 512,
|
||||
some_float: 0.1,
|
||||
some_int: 1,
|
||||
some_bool: true,
|
||||
some_str: "hello".to_string(),
|
||||
some_list_int: vec![1, 2, 3],
|
||||
some_list_str: vec!["hello".to_string(), "world".to_string()],
|
||||
some_list_float: vec![0.1, 0.2, 0.3],
|
||||
some_dict: {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("some_key".to_string(), "some_value".to_string());
|
||||
map
|
||||
},
|
||||
};
|
||||
let path = "tests/config/weights_with_config.pt";
|
||||
let top_level_key = Some("my_config");
|
||||
let config: NetConfig = PytorchReader::load_config(path, top_level_key).unwrap();
|
||||
|
||||
assert_eq!(config, config_expected);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv1d(2, 2, 2)
|
||||
self.conv2 = nn.Conv1d(2, 2, 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv1d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 6)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,97 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv1d, Conv1dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv1d<B>,
|
||||
conv2: Conv1d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = Conv1dConfig::new(2, 2, 2).init(device);
|
||||
let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device);
|
||||
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv1d(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
[[
|
||||
[
|
||||
0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,
|
||||
],
|
||||
[
|
||||
0.33977574,
|
||||
0.523_940_8,
|
||||
0.798_063_9,
|
||||
0.77176833,
|
||||
0.01122457,
|
||||
0.80996025,
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 3>::from_data(
|
||||
[[
|
||||
[0.02987457, 0.03134188, 0.04234261, -0.02437721],
|
||||
[-0.03788019, -0.02972012, -0.00806090, -0.01981254],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_full_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv1d(model, 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_half_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv1d(model, 1e-4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv2d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,134 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv2d(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[
|
||||
0.024_595_8,
|
||||
0.25883394,
|
||||
0.93905586,
|
||||
0.416_715_5,
|
||||
0.713_979_7,
|
||||
],
|
||||
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
|
||||
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
|
||||
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
|
||||
[
|
||||
0.99802476,
|
||||
0.900_794_2,
|
||||
0.476_588_2,
|
||||
0.16625845,
|
||||
0.804_481_1,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
0.943_447_5,
|
||||
],
|
||||
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
|
||||
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
|
||||
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
|
||||
[
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
0.414_796_4,
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.02502128, 0.00250649, 0.04841233],
|
||||
[0.04589614, -0.00296854, 0.01991477],
|
||||
[0.02920526, 0.059_497_3, 0.04326791],
|
||||
],
|
||||
[
|
||||
[-0.04825336, 0.080_190_9, -0.02375088],
|
||||
[0.02885434, 0.09638263, -0.07460806],
|
||||
[0.02004079, 0.06244051, 0.035_887_1],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_full_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv2d(model, 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_half_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv2d(model, 1e-4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.ConvTranspose1d(2, 2, 2)
|
||||
self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv_transpose1d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,87 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: ConvTranspose1d<B>,
|
||||
conv2: ConvTranspose1d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device);
|
||||
let conv2 = ConvTranspose1dConfig::new([2, 2], 2)
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv_transpose1d(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 3>::from_data(
|
||||
[[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 3>::from_data(
|
||||
[[
|
||||
[0.02935525, 0.01119324, -0.01356167, -0.00682688],
|
||||
[0.01644749, -0.01429807, 0.00083987, 0.00279229],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_full() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose1d(model, 1e-8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_half() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose1d(model, 1e-4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))
|
||||
self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "conv_transpose2d.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,99 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: ConvTranspose2d<B>,
|
||||
conv2: ConvTranspose2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device);
|
||||
let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn conv_transpose2d(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
[[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.04547675, 0.01879685, -0.01636661, 0.00310803],
|
||||
[0.02090115, 0.01192738, -0.048_240_2, 0.02252235],
|
||||
[0.03249975, -0.00460748, 0.05003899, 0.04029131],
|
||||
[0.02185687, -0.10226749, -0.06508022, -0.01267705],
|
||||
],
|
||||
[
|
||||
[0.00277598, -0.00513832, -0.059_048_3, 0.00567626],
|
||||
[-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],
|
||||
[-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],
|
||||
[-0.03174751, 0.02963913, -0.02703723, -0.01860938],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose2d_full() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose2d(model, 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose2d_half() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
conv_transpose2d(model, 1e-4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.embed = nn.Embedding(10, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "embedding.pt")
|
||||
|
||||
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,86 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Embedding, EmbeddingConfig},
|
||||
tensor::{Int, Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
embed: Embedding<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let embed = EmbeddingConfig::new(10, 3).init(device);
|
||||
Self { embed }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
self.embed.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn embedding(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 3>::from_data(
|
||||
[
|
||||
[
|
||||
[-1.609_484_9, -0.10016718, -0.609_188_9],
|
||||
[-0.97977227, -1.609_096_3, -0.712_144_6],
|
||||
[-0.22227049, 1.687_113_4, -0.32062083],
|
||||
[-0.29934573, 1.879_345_7, -0.07213178],
|
||||
],
|
||||
[
|
||||
[-0.22227049, 1.687_113_4, -0.32062083],
|
||||
[0.303_722, -0.777_314_3, -0.25145486],
|
||||
[-0.97977227, -1.609_096_3, -0.712_144_6],
|
||||
[-0.02878714, 2.357_111, -1.037_338_7],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_full_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
embedding(model, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_half_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
embedding(model, 1e-3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
class DwsConv(nn.Module):
|
||||
"""Depthwise separable convolution."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
|
||||
super().__init__()
|
||||
# Depthwise conv
|
||||
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
|
||||
# Pointwise conv
|
||||
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.dconv(x)
|
||||
return self.pconv(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, depthwise: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "enum_depthwise_false.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
|
||||
print("Depthwise is False")
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
print("Depthwise is True")
|
||||
model = Model(depthwise=True).to(torch.device("cpu"))
|
||||
torch.save(model.state_dict(), "enum_depthwise_true.pt")
|
||||
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,197 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum Conv<B: Backend> {
|
||||
DwsConv(DwsConv<B>),
|
||||
Conv(Conv2d<B>),
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DwsConv<B: Backend> {
|
||||
dconv: Conv2d<B>,
|
||||
pconv: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv: Conv<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model with DwsConv variant.
|
||||
pub fn init_dws_conv(device: &B::Device) -> Self {
|
||||
let dconv = Conv2dConfig::new([2, 2], [3, 3])
|
||||
.with_groups(2)
|
||||
.init(device);
|
||||
let pconv = Conv2dConfig::new([2, 2], [1, 1])
|
||||
.with_groups(1)
|
||||
.init(device);
|
||||
Net {
|
||||
conv: Conv::DwsConv(DwsConv { dconv, pconv }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new model with Conv variant.
|
||||
pub fn init_conv(device: &B::Device) -> Self {
|
||||
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);
|
||||
Net {
|
||||
conv: Conv::Conv(conv2d_config.init(device)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
match &self.conv {
|
||||
Conv::DwsConv(dws_conv) => {
|
||||
let x = dws_conv.dconv.forward(x);
|
||||
dws_conv.pconv.forward(x)
|
||||
}
|
||||
Conv::Conv(conv) => conv.forward(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn depthwise_false() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init_conv(&device);
|
||||
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_false.pt");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
|
||||
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
|
||||
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
|
||||
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
|
||||
[
|
||||
0.804_481_1,
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
],
|
||||
],
|
||||
[
|
||||
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
|
||||
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
|
||||
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
|
||||
[
|
||||
0.03694397,
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
],
|
||||
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.35449377, -0.02832414, 0.490_976_1],
|
||||
[0.29709217, 0.332_586_3, 0.30594018],
|
||||
[0.18101373, 0.30932188, 0.30558896],
|
||||
],
|
||||
[
|
||||
[-0.17683622, -0.13244139, -0.05608707],
|
||||
[0.23467252, -0.07038684, 0.255_044_1],
|
||||
[-0.241_931_3, -0.20476191, -0.14468731],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn depthwise_true() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init_dws_conv(&device);
|
||||
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_true.pt");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
|
||||
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
|
||||
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
|
||||
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
|
||||
[
|
||||
0.804_481_1,
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
],
|
||||
],
|
||||
[
|
||||
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
|
||||
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
|
||||
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
|
||||
[
|
||||
0.03694397,
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
],
|
||||
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.77874625, 0.859_017_6, 0.834_283_5],
|
||||
[0.773_056_4, 0.73817325, 0.78292674],
|
||||
[0.710_775_2, 0.747_187_2, 0.733_264_4],
|
||||
],
|
||||
[
|
||||
[-0.44891885, -0.49027523, -0.394_170_7],
|
||||
[-0.43836114, -0.33961445, -0.387_311_5],
|
||||
[-0.581_134_3, -0.34197026, -0.535_035_7],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.GroupNorm(2, 6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "group_norm.pt")
|
||||
|
||||
x2 = torch.rand(1, 6, 2, 2)
|
||||
print("Input shape: {}", x2.shape)
|
||||
print("Input: {}", x2)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,90 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{GroupNorm, GroupNormConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model from the given record.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let norm1 = GroupNormConfig::new(2, 6).init(device);
|
||||
Self { norm1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn group_norm(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
|
||||
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
|
||||
[[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],
|
||||
[[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],
|
||||
[[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],
|
||||
[[0.31379688, 0.19801933], [0.41619217, 0.28432965]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],
|
||||
[[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],
|
||||
[[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],
|
||||
[[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],
|
||||
[[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],
|
||||
[[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_full() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
group_norm(model, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_half() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
group_norm(model, 1e-3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
buffer = torch.tensor([1, 2, 3])
|
||||
self.register_buffer("buffer", buffer, persistent=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.buffer
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "integer.pt")
|
||||
|
||||
input = torch.ones(3, 3)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,72 @@
|
||||
use burn::{
|
||||
module::{Module, Param, ParamId},
|
||||
tensor::{Int, Tensor, TensorData, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
buffer: Param<Tensor<B, 1, Int>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model with placeholder values.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
buffer: Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::<B, 1, Int>::from_data(TensorData::from([0, 0, 0]), device),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {
|
||||
self.buffer.val()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
use burn::tensor::TensorData;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn integer(model: Net<TestBackend>) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected =
|
||||
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 2, 3]), &device);
|
||||
|
||||
assert_eq!(output.to_data(), expected.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integer_full_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
integer(model);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integer_half_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
integer(model);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(ConvModule, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv = ConvModule()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "key_remap.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,118 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn key_remap() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/key_remap/key_remap.pt")
|
||||
.with_key_remapping("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[
|
||||
0.024_595_8,
|
||||
0.25883394,
|
||||
0.93905586,
|
||||
0.416_715_5,
|
||||
0.713_979_7,
|
||||
],
|
||||
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
|
||||
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
|
||||
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
|
||||
[
|
||||
0.99802476,
|
||||
0.900_794_2,
|
||||
0.476_588_2,
|
||||
0.16625845,
|
||||
0.804_481_1,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.65517855,
|
||||
0.17679012,
|
||||
0.824_772_3,
|
||||
0.803_550_9,
|
||||
0.943_447_5,
|
||||
],
|
||||
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
|
||||
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
|
||||
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
|
||||
[
|
||||
0.751_675_7,
|
||||
0.148_438_4,
|
||||
0.12274551,
|
||||
0.530_407_2,
|
||||
0.414_796_4,
|
||||
],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.02502128, 0.00250649, 0.04841233],
|
||||
[0.04589614, -0.00296854, 0.01991477],
|
||||
[0.02920526, 0.059_497_3, 0.04326791],
|
||||
],
|
||||
[
|
||||
[-0.04825336, 0.080_190_9, -0.02375088],
|
||||
[0.02885434, 0.09638263, -0.07460806],
|
||||
[0.02004079, 0.06244051, 0.035_887_1],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 6, 3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(6)
|
||||
self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(42)
|
||||
|
||||
model = Model()
|
||||
|
||||
input = torch.rand(1, 3, 4, 4)
|
||||
model(input) # condition batch norm
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
print(f"Input shape: {input.shape}")
|
||||
print("Input type: {}", input.dtype)
|
||||
print(f"Input: {input}")
|
||||
output = model(input)
|
||||
|
||||
print(f"Output: {output}")
|
||||
print(f"Output Shape: {output.shape}")
|
||||
|
||||
torch.save(model.state_dict(), "key_remap.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,179 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
BatchNorm, BatchNormConfig,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
},
|
||||
tensor::{Device, Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
/// Some module that implements a specific method so it can be used in a sequential block.
|
||||
pub trait ForwardModule<B: Backend> {
|
||||
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
|
||||
}
|
||||
|
||||
/// Conv2d + BatchNorm block.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
bn: BatchNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ForwardModule<B> for ConvBlock<B> {
|
||||
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let out = self.conv.forward(input);
|
||||
self.bn.forward(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn new(in_channels: usize, out_channels: usize, device: &Device<B>) -> Self {
|
||||
let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
let bn = BatchNormConfig::new(out_channels).init(device);
|
||||
|
||||
Self { conv, bn }
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of sequential blocks.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ModuleBlock<B: Backend, M> {
|
||||
blocks: Vec<M>,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, M: ForwardModule<B>> ModuleBlock<B, M> {
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let mut out = input;
|
||||
for block in &self.blocks {
|
||||
out = block.forward(out);
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleBlock<B, ConvBlock<B>> {
|
||||
pub fn new(device: &Device<B>) -> Self {
|
||||
let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)];
|
||||
|
||||
Self {
|
||||
blocks,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend, M> {
|
||||
conv: Conv2d<B>,
|
||||
bn: BatchNorm<B>,
|
||||
layer: ModuleBlock<B, M>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B, ConvBlock<B>> {
|
||||
pub fn new(device: &Device<B>) -> Self {
|
||||
let conv = Conv2dConfig::new([3, 6], [3, 3])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
let bn = BatchNormConfig::new(6).init(device);
|
||||
|
||||
let layer = ModuleBlock::new(device);
|
||||
|
||||
Self { conv, bn, layer }
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let out = self.conv.forward(input);
|
||||
let out = self.bn.forward(out);
|
||||
self.layer.forward(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn key_remap_chained_missing_pattern() {
|
||||
// Loading record should fail due to missing pattern to map the layer.blocks
|
||||
let device = Default::default();
|
||||
let mut model: Model<TestBackend, _> = Model::new(&device);
|
||||
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
|
||||
// Map *.block.0.* -> *.conv.*
|
||||
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
|
||||
// Map *.block.1.* -> *.bn.*
|
||||
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_remap_chained() {
|
||||
let device = Default::default();
|
||||
let mut model: Model<TestBackend, _> = Model::new(&device);
|
||||
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
|
||||
// Map *.block.0.* -> *.conv.*
|
||||
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
|
||||
// Map *.block.1.* -> *.bn.*
|
||||
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2")
|
||||
// Map layer.[i].* -> layer.blocks.[i].*
|
||||
.with_key_remapping("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.76193494, 0.626_546_1, 0.49510366, 0.11974698],
|
||||
[0.07161391, 0.03232569, 0.704_681, 0.254_516],
|
||||
[0.399_373_7, 0.21224737, 0.40888822, 0.14808255],
|
||||
[0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6],
|
||||
],
|
||||
[
|
||||
[0.33959562, 0.13321638, 0.41178054, 0.257_626_3],
|
||||
[0.347_029_2, 0.02400219, 0.77974546, 0.15189773],
|
||||
[0.75130886, 0.726_892_1, 0.85721636, 0.11647397],
|
||||
[0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734],
|
||||
],
|
||||
[
|
||||
[0.42948407, 0.49613327, 0.38488472, 0.08250773],
|
||||
[0.73995143, 0.00364107, 0.81039995, 0.87411255],
|
||||
[0.972_853_2, 0.38206023, 0.08917904, 0.61241513],
|
||||
[0.77621365, 0.00234562, 0.38650817, 0.20027226],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]],
|
||||
[[0.17582723, 0.11344293], [0.05444185, 0.13307181]],
|
||||
[[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]],
|
||||
[[0.00230906, -0.02177845], [0.01129148, 0.00925517]],
|
||||
[[0.14751078, 0.14433631], [0.05498439, 0.29049855]],
|
||||
[[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.norm1 = nn.LayerNorm(2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "layer_norm.pt")
|
||||
|
||||
x2 = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", x2.shape)
|
||||
print("Input: {}", x2)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,82 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{LayerNorm, LayerNormConfig},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
norm1: LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let norm1 = LayerNormConfig::new(2).init(device);
|
||||
Self { norm1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.norm1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn layer_norm(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
|
||||
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],
|
||||
[[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_full() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
layer_norm(model, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layer_norm_half() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
layer_norm(model, 1e-3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
self.fc2 = nn.Linear(3, 4, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModelWithBias(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModelWithBias, self).__init__()
|
||||
self.fc1 = nn.Linear(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
model_with_bias = ModelWithBias().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "linear.pt")
|
||||
torch.save(model_with_bias.state_dict(), "linear_with_bias.pt")
|
||||
|
||||
input = torch.rand(1, 2, 2, 2)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
print("Model with bias")
|
||||
output = model_with_bias(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,154 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{Linear, LinearConfig, Relu},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
fc2: Linear<B>,
|
||||
relu: Relu,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init(device);
|
||||
let fc2 = LinearConfig::new(3, 4).with_bias(false).init(device);
|
||||
let relu = Relu;
|
||||
|
||||
Self { fc1, fc2, relu }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.fc1.forward(x);
|
||||
let x = self.relu.forward(x);
|
||||
|
||||
self.fc2.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct NetWithBias<B: Backend> {
|
||||
fc1: Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> NetWithBias<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let fc1 = LinearConfig::new(2, 3).init(device);
|
||||
|
||||
Self { fc1 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.fc1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn linear_test(model: Net<TestBackend>, precision: f32) {
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.09778349, -0.13756673, 0.04962806, 0.08856435],
|
||||
[0.03163241, -0.02848549, 0.01437942, 0.11905234],
|
||||
],
|
||||
[
|
||||
[0.07628226, -0.10757702, 0.03656857, 0.03824598],
|
||||
[0.05443089, -0.06904714, 0.02744314, 0.09997337],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_full_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
linear_test(model, 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_half_precision() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
linear_test(model, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_with_bias() {
|
||||
let device = Default::default();
|
||||
|
||||
let mut model = NetWithBias::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/linear/linear_with_bias.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
|
||||
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[-0.00432095, -1.107_101_2, 0.870_691_4],
|
||||
[0.024_595_5, -0.954_462_9, 0.48518157],
|
||||
],
|
||||
[
|
||||
[0.34315687, -0.757_384_2, 0.548_288],
|
||||
[-0.06608963, -1.072_072_7, 0.645_800_5],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
model = Model().to(torch.device("cpu"))
|
||||
torch.save(model.state_dict(), "missing_module_field.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,37 @@
|
||||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct Net<B: Backend> {
|
||||
do_not_exist_in_pt: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
do_not_exist_in_pt: Conv2dConfig::new([2, 2], [2, 2]).init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "do_not_exist_in_pt")]
|
||||
fn should_fail_if_struct_field_is_missing() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store =
|
||||
PytorchStore::from_file("tests/missing_module_field/missing_module_field.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
num_layers = 5 # Number of repeated convolutional layers
|
||||
|
||||
# Create a list to store the layers
|
||||
layers = []
|
||||
for _ in range(num_layers):
|
||||
layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
# Use nn.Sequential to create a single module from the layers
|
||||
self.fc = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
torch.save(model.state_dict(), "non_contiguous_indexes.pt")
|
||||
|
||||
input = torch.rand(1, 2, 5, 5)
|
||||
print("Input shape: {}", input.shape)
|
||||
print("Input: {}", input)
|
||||
output = model(input)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,110 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
PaddingConfig2d,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
},
|
||||
tensor::{Tensor, activation::relu, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
fc: Vec<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model with placeholder values.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same);
|
||||
// The PyTorch file has 5 Conv2d layers at non-contiguous indices (0, 2, 4, 6, 8)
|
||||
// in the Sequential (alternating with ReLU layers)
|
||||
let fc = vec![
|
||||
conv2d_config.init(device),
|
||||
conv2d_config.init(device),
|
||||
conv2d_config.init(device),
|
||||
conv2d_config.init(device),
|
||||
conv2d_config.init(device),
|
||||
];
|
||||
Net { fc }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::{Tolerance, ops::FloatElem};
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn non_contiguous_indexes() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store =
|
||||
PytorchStore::from_file("tests/non_contiguous_indexes/non_contiguous_indexes.pt");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[
|
||||
0.67890584,
|
||||
0.307_537_2,
|
||||
0.265_156_2,
|
||||
0.528_318_8,
|
||||
0.86194897,
|
||||
],
|
||||
[0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455],
|
||||
[0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637],
|
||||
[0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804],
|
||||
[0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932],
|
||||
],
|
||||
[
|
||||
[0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266],
|
||||
[0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439],
|
||||
[0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786],
|
||||
[0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244],
|
||||
[0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Tensor::<TestBackend, 4>::from_data(
|
||||
[[
|
||||
[
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3],
|
||||
],
|
||||
[
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
|
||||
],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(1e-7));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
mod backend;
|
||||
|
||||
mod batch_norm;
|
||||
mod boolean;
|
||||
mod buffer;
|
||||
mod complex_nested;
|
||||
mod config;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod embedding;
|
||||
mod enum_module;
|
||||
mod group_norm;
|
||||
mod integer;
|
||||
mod key_remap;
|
||||
mod key_remap_chained;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
mod missing_module_field;
|
||||
mod non_contiguous_indexes;
|
||||
mod top_level_key;
|
||||
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
model = Model().to(torch.device("cpu"))
|
||||
torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,48 @@
|
||||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
Self {
|
||||
conv1: Conv2dConfig::new([2, 2], [2, 2]).init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_fail_if_not_found() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt");
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::init(&device);
|
||||
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt")
|
||||
.with_top_level_key("my_state_dict");
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "safetensors-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
burn-autodiff = { path = "../../burn-autodiff" }
|
||||
burn-store = { path = "../", features = ["std", "safetensors"] }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
pub type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
@@ -0,0 +1,92 @@
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{
|
||||
BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
norm1: BatchNorm<B>,
|
||||
fc1: Linear<B>,
|
||||
relu: Relu,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
conv1: Conv2dConfig::new([3, 4], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
|
||||
.init(device),
|
||||
norm1: BatchNormConfig::new(4).init(device),
|
||||
fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),
|
||||
relu: Relu::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
|
||||
let x = self.conv1.forward(x);
|
||||
let x = self.norm1.forward(x);
|
||||
let x = self.relu.forward(x);
|
||||
// Flatten all dimensions except the batch dimension
|
||||
let x = x.flatten(1, 3);
|
||||
self.fc1.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::tensor::Tolerance;
|
||||
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn multi_layer_model() {
|
||||
let device = Default::default();
|
||||
let mut model = Net::<TestBackend>::new(&device);
|
||||
let mut store = SafetensorsStore::from_file("tests/multi_layer/multi_layer.safetensors")
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
|
||||
model
|
||||
.load_from(&mut store)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
// Note: Expected values should be updated based on the actual output from the PyTorch model
|
||||
let expected = Tensor::<TestBackend, 2>::from_data(
|
||||
[[
|
||||
0.04971555,
|
||||
-0.16849735,
|
||||
0.05182848,
|
||||
-0.18032673,
|
||||
0.23138367,
|
||||
0.05041867,
|
||||
0.13005908,
|
||||
-0.32202929,
|
||||
-0.07915690,
|
||||
-0.03232457,
|
||||
-0.19790289,
|
||||
-0.17476529,
|
||||
-0.19627589,
|
||||
-0.21757686,
|
||||
-0.31376451,
|
||||
0.08377837,
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 4, kernel_size=3, padding=1)
|
||||
self.norm1 = nn.BatchNorm2d(4)
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc1 = nn.Linear(4 * 8 * 8, 16) # Changed for smaller input size
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = F.relu(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = Model().to(torch.device("cpu"))
|
||||
|
||||
# Use a smaller input size
|
||||
# 1 batch, 3 channels (RGB), 8x8 image (small input)
|
||||
x1 = torch.ones(1, 3, 8, 8)
|
||||
_ = model(x1)
|
||||
model.eval() # Set to eval mode to freeze running stats
|
||||
# Save the model to safetensors after the first forward
|
||||
save_file(model.state_dict(), "multi_layer.safetensors")
|
||||
|
||||
x2 = torch.ones(1, 3, 8, 8)
|
||||
print("Input shape: {}", x2.shape)
|
||||
output = model(x2)
|
||||
print("Output: {}", output)
|
||||
print("Output Shape: {}", output.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,3 @@
|
||||
mod backend;
|
||||
|
||||
mod multi_layer;
|
||||
@@ -0,0 +1,663 @@
|
||||
//! Module adapters for transforming tensors between different formats
|
||||
//!
|
||||
//! This module provides adapters that handle differences between PyTorch and Burn:
|
||||
//! - Linear layer weight transposition
|
||||
//! - Normalization parameter naming (weight/bias vs gamma/beta)
|
||||
|
||||
use crate::TensorSnapshot;
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::rc::Rc;
|
||||
use alloc::string::String;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
// Module type names as they appear in the container_type field
|
||||
// These come from the Module derive macro which uses stringify! on the struct name
|
||||
// Format: "Struct:TypeName" for user-defined structs
|
||||
mod module_names {
|
||||
// The actual string constants that match what the Module derive macro produces
|
||||
pub const LINEAR: &str = "Struct:Linear";
|
||||
pub const BATCH_NORM: &str = "Struct:BatchNorm";
|
||||
pub const LAYER_NORM: &str = "Struct:LayerNorm";
|
||||
pub const GROUP_NORM: &str = "Struct:GroupNorm";
|
||||
}
|
||||
|
||||
/// Trait for adapting tensor snapshots between different module formats
|
||||
pub trait ModuleAdapter: Send + Sync {
|
||||
/// Adapt a tensor snapshot based on its container type and parameter name
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
|
||||
|
||||
/// Get alternative parameter name to try during matching
|
||||
///
|
||||
/// When looking for a parameter in a module, this method provides an alternative
|
||||
/// name to try if the direct name doesn't match. This enables matching parameters
|
||||
/// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma").
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `param_name` - The parameter name we're looking for
|
||||
/// * `container_type` - The type of container module (e.g., "BatchNorm")
|
||||
///
|
||||
/// # Returns
|
||||
/// Alternative parameter name to try, or None if no alternative exists
|
||||
fn get_alternative_param_name(
|
||||
&self,
|
||||
_param_name: &str,
|
||||
_container_type: &str,
|
||||
) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Clone the adapter into a boxed trait object
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter>;
|
||||
|
||||
/// Chain adapters together, applying `self` first and then `next`.
|
||||
///
|
||||
/// This is useful when multiple transformations are required when importing model weights
|
||||
/// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping).
|
||||
///
|
||||
/// The semantics follow a simple pipeline:
|
||||
/// - `adapt`: `next.adapt(&self.adapt(snapshot))`
|
||||
/// - `get_alternative_param_name`: try `self` first; if it returns an alternative name,
|
||||
/// try `next` with that name, otherwise return the first alternative name.
|
||||
fn chain<A>(self, next: A) -> ChainAdapter
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
A: ModuleAdapter + 'static,
|
||||
{
|
||||
ChainAdapter::new(self, next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn ModuleAdapter> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapter that applies two adapters in sequence.
|
||||
///
|
||||
/// This allows composing smaller adapters instead of creating one large monolithic adapter.
|
||||
#[derive(Clone)]
|
||||
pub struct ChainAdapter {
|
||||
first: Box<dyn ModuleAdapter>,
|
||||
second: Box<dyn ModuleAdapter>,
|
||||
}
|
||||
|
||||
impl ChainAdapter {
|
||||
/// Create a new adapter chain.
|
||||
pub fn new<A, B>(first: A, second: B) -> Self
|
||||
where
|
||||
A: ModuleAdapter + 'static,
|
||||
B: ModuleAdapter + 'static,
|
||||
{
|
||||
Self {
|
||||
first: Box::new(first),
|
||||
second: Box::new(second),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleAdapter for ChainAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
let snapshot = self.first.adapt(snapshot);
|
||||
self.second.adapt(&snapshot)
|
||||
}
|
||||
|
||||
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
|
||||
if let Some(name) = self
|
||||
.first
|
||||
.get_alternative_param_name(param_name, container_type)
|
||||
{
|
||||
self.second
|
||||
.get_alternative_param_name(&name, container_type)
|
||||
.or(Some(name))
|
||||
} else {
|
||||
self.second
|
||||
.get_alternative_param_name(param_name, container_type)
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Identity adapter that passes tensors through unchanged
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct IdentityAdapter;
|
||||
|
||||
impl ModuleAdapter for IdentityAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
snapshot.clone()
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapter for converting from PyTorch format to Burn format
|
||||
///
|
||||
/// Handles:
|
||||
/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])
|
||||
/// - Normalization parameter renaming (weight → gamma, bias → beta)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PyTorchToBurnAdapter;
|
||||
|
||||
impl ModuleAdapter for PyTorchToBurnAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
|
||||
}
|
||||
|
||||
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
|
||||
// For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias)
|
||||
if is_normalization_layer(container_type) {
|
||||
burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Adapter for converting from Burn format to PyTorch format
|
||||
///
|
||||
/// Handles:
|
||||
/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])
|
||||
/// - Normalization parameter renaming (gamma → weight, beta → bias)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BurnToPyTorchAdapter;
|
||||
|
||||
impl ModuleAdapter for BurnToPyTorchAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
|
||||
}
|
||||
|
||||
fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
|
||||
// For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta)
|
||||
if is_normalization_layer(container_type) {
|
||||
pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Direction of PyTorch conversion for parameter naming
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum PyTorchConversionDirection {
|
||||
PyTorchToBurn,
|
||||
BurnToPyTorch,
|
||||
}
|
||||
|
||||
/// Check if container type is a normalization layer
|
||||
fn is_normalization_layer(container_type: &str) -> bool {
|
||||
matches!(
|
||||
container_type,
|
||||
module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
|
||||
)
|
||||
}
|
||||
|
||||
/// Map PyTorch normalization parameter name to Burn
|
||||
fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
|
||||
match param_name {
|
||||
"weight" => Some("gamma"),
|
||||
"bias" => Some("beta"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map Burn normalization parameter name to PyTorch
|
||||
fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
|
||||
match param_name {
|
||||
"gamma" => Some("weight"),
|
||||
"beta" => Some("bias"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Core tensor adaptation logic for PyTorch format conversions
|
||||
fn adapt_pytorch_tensor(
|
||||
snapshot: &TensorSnapshot,
|
||||
direction: PyTorchConversionDirection,
|
||||
) -> TensorSnapshot {
|
||||
// Extract path and parameter name
|
||||
let (path_stack, param_name) = match get_path_and_param(snapshot) {
|
||||
Some(result) => result,
|
||||
None => return snapshot.clone(),
|
||||
};
|
||||
|
||||
// Get module type for matching (ignores Vec/Array wrappers)
|
||||
let module_type = match snapshot.module_type() {
|
||||
Some(mt) => mt,
|
||||
None => return snapshot.clone(), // No user-defined module found
|
||||
};
|
||||
|
||||
// Linear: transpose weight (bidirectional - same operation both ways)
|
||||
if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
|
||||
return transpose_2d_tensor(snapshot);
|
||||
}
|
||||
|
||||
// Normalization layers: rename parameters based on direction
|
||||
if is_normalization_layer(&module_type) {
|
||||
let new_name = match direction {
|
||||
PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
|
||||
PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
|
||||
};
|
||||
|
||||
if let Some(new_name) = new_name {
|
||||
return rename_parameter(snapshot, path_stack, new_name);
|
||||
}
|
||||
}
|
||||
|
||||
snapshot.clone()
|
||||
}
|
||||
|
||||
/// Extract path stack and parameter name from snapshot
|
||||
fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
|
||||
let path_stack = snapshot.path_stack.as_ref()?;
|
||||
let param_name = path_stack.last()?.as_str();
|
||||
Some((path_stack.as_slice(), param_name))
|
||||
}
|
||||
|
||||
/// Rename a parameter in the snapshot
|
||||
fn rename_parameter(
|
||||
snapshot: &TensorSnapshot,
|
||||
path_stack: &[String],
|
||||
new_name: &str,
|
||||
) -> TensorSnapshot {
|
||||
let mut new_path = path_stack.to_vec();
|
||||
*new_path.last_mut().unwrap() = new_name.to_string();
|
||||
|
||||
TensorSnapshot::from_closure(
|
||||
snapshot.clone_data_fn(),
|
||||
snapshot.dtype,
|
||||
snapshot.shape.clone(),
|
||||
new_path,
|
||||
snapshot.container_stack.clone().unwrap_or_default(),
|
||||
snapshot.tensor_id.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Transpose a 2D tensor
|
||||
fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
if snapshot.shape.len() != 2 {
|
||||
return snapshot.clone();
|
||||
}
|
||||
|
||||
let original_data_fn = snapshot.clone_data_fn();
|
||||
let dtype = snapshot.dtype;
|
||||
let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
|
||||
|
||||
// Create a lazy closure that transposes when called
|
||||
let transposed_data_fn = Rc::new(move || {
|
||||
let data = original_data_fn()?;
|
||||
Ok(transpose_tensor_data(data))
|
||||
});
|
||||
|
||||
TensorSnapshot::from_closure(
|
||||
transposed_data_fn,
|
||||
dtype,
|
||||
transposed_shape,
|
||||
snapshot.path_stack.clone().unwrap_or_default(),
|
||||
snapshot.container_stack.clone().unwrap_or_default(),
|
||||
snapshot.tensor_id.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Transpose tensor data (assumes 2D shape is already validated)
|
||||
fn transpose_tensor_data(data: TensorData) -> TensorData {
|
||||
let shape = &data.shape;
|
||||
let rows = shape[0];
|
||||
let cols = shape[1];
|
||||
let transposed_shape = vec![cols, rows];
|
||||
|
||||
// Get the raw bytes and element size
|
||||
let bytes = data.as_bytes();
|
||||
let element_size = data.dtype.size();
|
||||
|
||||
// Create a new buffer for transposed data
|
||||
let mut transposed_bytes = vec![0u8; bytes.len()];
|
||||
|
||||
// Transpose at the byte level - works for any data type
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
let src_idx = (i * cols + j) * element_size;
|
||||
let dst_idx = (j * rows + i) * element_size;
|
||||
|
||||
// Copy the bytes for this element
|
||||
transposed_bytes[dst_idx..dst_idx + element_size]
|
||||
.copy_from_slice(&bytes[src_idx..src_idx + element_size]);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new TensorData from transposed bytes
|
||||
TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::rc::Rc;
|
||||
use alloc::sync::Arc;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
use core::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[test]
|
||||
fn test_module_names_match_burn_nn() {
|
||||
// If these types are renamed or moved in `burn-nn`, this test will fail to compile.
|
||||
// This use statement replicates the previous check/alarm system.
|
||||
#[allow(unused_imports)]
|
||||
use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
|
||||
|
||||
// These assert statements work as extra checks that should remind maintainers more
|
||||
// clearly that the hardcoded strings needs get updated.
|
||||
assert_eq!(module_names::LINEAR, "Struct:Linear");
|
||||
assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm");
|
||||
assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm");
|
||||
assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm");
|
||||
}
|
||||
|
||||
fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
|
||||
let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
|
||||
let values = vec![1.0f32; shape.iter().product()];
|
||||
let data = TensorData::new(values, shape.clone());
|
||||
|
||||
TensorSnapshot::from_closure(
|
||||
Rc::new(move || Ok(data.clone())),
|
||||
DType::F32,
|
||||
shape,
|
||||
path_parts,
|
||||
vec![container_type.to_string()],
|
||||
burn_core::module::ParamId::new(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pytorch_to_burn_linear_weight() {
|
||||
let adapter = PyTorchToBurnAdapter;
|
||||
|
||||
// Linear layer weight should be transposed
|
||||
let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.shape, vec![5, 10]);
|
||||
|
||||
// Linear layer bias should not be transposed
|
||||
let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.shape, vec![10]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pytorch_to_burn_norm_params() {
|
||||
let adapter = PyTorchToBurnAdapter;
|
||||
|
||||
// BatchNorm weight -> gamma
|
||||
let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.full_path(), "norm.gamma");
|
||||
|
||||
// BatchNorm bias -> beta
|
||||
let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.full_path(), "norm.beta");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_burn_to_pytorch_linear_weight() {
|
||||
let adapter = BurnToPyTorchAdapter;
|
||||
|
||||
// Linear layer weight should be transposed
|
||||
let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.shape, vec![10, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_burn_to_pytorch_norm_params() {
|
||||
let adapter = BurnToPyTorchAdapter;
|
||||
|
||||
// BatchNorm gamma -> weight
|
||||
let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.full_path(), "norm.weight");
|
||||
|
||||
// BatchNorm beta -> bias
|
||||
let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.full_path(), "norm.bias");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose_different_dtypes() {
|
||||
// Test that transpose works for different data types
|
||||
|
||||
// Test with F32
|
||||
let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
|
||||
let transposed = transpose_tensor_data(f32_data);
|
||||
assert_eq!(transposed.shape, vec![3, 2]);
|
||||
let values = transposed.to_vec::<f32>().unwrap();
|
||||
assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
|
||||
|
||||
// Test with I32
|
||||
let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
|
||||
let transposed = transpose_tensor_data(i32_data);
|
||||
assert_eq!(transposed.shape, vec![3, 2]);
|
||||
let values = transposed.to_vec::<i32>().unwrap();
|
||||
assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
|
||||
|
||||
// Test with F64
|
||||
let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
|
||||
let transposed = transpose_tensor_data(f64_data);
|
||||
assert_eq!(transposed.shape, vec![2, 2]);
|
||||
let values = transposed.to_vec::<f64>().unwrap();
|
||||
assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_container_info() {
|
||||
let adapter = PyTorchToBurnAdapter;
|
||||
|
||||
// Without container info, adapter returns unchanged for non-norm parameters
|
||||
let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
|
||||
snapshot.container_stack = None;
|
||||
|
||||
// Without container info, no transformation occurs for linear layers
|
||||
let adapted = adapter.adapt(&snapshot);
|
||||
assert_eq!(adapted.shape, vec![10, 5]); // No transposition without container info
|
||||
|
||||
// Test a non-linear, non-norm parameter - should pass through unchanged
|
||||
let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Struct:Other");
|
||||
snapshot2.container_stack = None;
|
||||
let adapted2 = adapter.adapt(&snapshot2);
|
||||
assert_eq!(adapted2.shape, vec![10, 5]); // No transposition
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RenameParamAdapter {
|
||||
from: &'static str,
|
||||
to: &'static str,
|
||||
called: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl ModuleAdapter for RenameParamAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
self.called.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let path_stack = match snapshot.path_stack.as_ref() {
|
||||
Some(stack) => stack,
|
||||
None => return snapshot.clone(),
|
||||
};
|
||||
let param = match path_stack.last() {
|
||||
Some(p) => p.as_str(),
|
||||
None => return snapshot.clone(),
|
||||
};
|
||||
if param != self.from {
|
||||
return snapshot.clone();
|
||||
}
|
||||
|
||||
let mut new_path = path_stack.to_vec();
|
||||
*new_path.last_mut().unwrap() = self.to.to_string();
|
||||
|
||||
TensorSnapshot::from_closure(
|
||||
snapshot.clone_data_fn(),
|
||||
snapshot.dtype,
|
||||
snapshot.shape.clone(),
|
||||
new_path,
|
||||
snapshot.container_stack.clone().unwrap_or_default(),
|
||||
snapshot.tensor_id.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_alternative_param_name(
|
||||
&self,
|
||||
_param_name: &str,
|
||||
_container_type: &str,
|
||||
) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AltNameAdapter {
|
||||
from: &'static str,
|
||||
to: &'static str,
|
||||
called: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl ModuleAdapter for AltNameAdapter {
|
||||
fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
|
||||
TensorSnapshot::from_closure(
|
||||
snapshot.clone_data_fn(),
|
||||
snapshot.dtype,
|
||||
snapshot.shape.clone(),
|
||||
snapshot.path_stack.clone().unwrap_or_default(),
|
||||
snapshot.container_stack.clone().unwrap_or_default(),
|
||||
snapshot.tensor_id.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_alternative_param_name(
|
||||
&self,
|
||||
param_name: &str,
|
||||
_container_type: &str,
|
||||
) -> Option<String> {
|
||||
self.called.fetch_add(1, Ordering::Relaxed);
|
||||
if param_name == self.from {
|
||||
Some(self.to.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_adapter_pipes_adapt() {
|
||||
let called1 = Arc::new(AtomicUsize::new(0));
|
||||
let called2 = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let a = RenameParamAdapter {
|
||||
from: "weight",
|
||||
to: "a",
|
||||
called: called1.clone(),
|
||||
};
|
||||
let b = RenameParamAdapter {
|
||||
from: "a",
|
||||
to: "b",
|
||||
called: called2.clone(),
|
||||
};
|
||||
|
||||
let chain = a.chain(b);
|
||||
let snapshot = create_test_snapshot("fc.weight", vec![2, 2], module_names::LINEAR);
|
||||
let adapted = chain.adapt(&snapshot);
|
||||
|
||||
assert_eq!(adapted.full_path(), "fc.b");
|
||||
assert_eq!(called1.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(called2.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_adapter_alternative_name_pipes_and_fallbacks() {
|
||||
let called1 = Arc::new(AtomicUsize::new(0));
|
||||
let called2 = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let a = AltNameAdapter {
|
||||
from: "gamma",
|
||||
to: "weight",
|
||||
called: called1.clone(),
|
||||
};
|
||||
let b = AltNameAdapter {
|
||||
from: "weight",
|
||||
to: "scale",
|
||||
called: called2.clone(),
|
||||
};
|
||||
|
||||
let chain = a.chain(b);
|
||||
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
|
||||
assert_eq!(alt.as_deref(), Some("scale"));
|
||||
assert_eq!(called1.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(called2.load(Ordering::Relaxed), 1);
|
||||
|
||||
// If the second adapter doesn't have a mapping for the first alternative,
|
||||
// fall back to the first alternative name.
|
||||
let called1 = Arc::new(AtomicUsize::new(0));
|
||||
let called2 = Arc::new(AtomicUsize::new(0));
|
||||
let a = AltNameAdapter {
|
||||
from: "gamma",
|
||||
to: "weight",
|
||||
called: called1.clone(),
|
||||
};
|
||||
let b = AltNameAdapter {
|
||||
from: "something-else",
|
||||
to: "unused",
|
||||
called: called2.clone(),
|
||||
};
|
||||
let chain = a.chain(b);
|
||||
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
|
||||
assert_eq!(alt.as_deref(), Some("weight"));
|
||||
assert_eq!(called1.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(called2.load(Ordering::Relaxed), 1);
|
||||
|
||||
// If the first adapter doesn't provide an alternative, try the second with the original name.
|
||||
let called1 = Arc::new(AtomicUsize::new(0));
|
||||
let called2 = Arc::new(AtomicUsize::new(0));
|
||||
let a = AltNameAdapter {
|
||||
from: "something-else",
|
||||
to: "unused",
|
||||
called: called1.clone(),
|
||||
};
|
||||
let b = AltNameAdapter {
|
||||
from: "gamma",
|
||||
to: "weight",
|
||||
called: called2.clone(),
|
||||
};
|
||||
let chain = a.chain(b);
|
||||
let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
|
||||
assert_eq!(alt.as_deref(), Some("weight"));
|
||||
assert_eq!(called1.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(called2.load(Ordering::Relaxed), 1);
|
||||
|
||||
// clone_box must preserve behavior.
|
||||
let boxed = chain.clone_box();
|
||||
let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM);
|
||||
assert_eq!(alt.as_deref(), Some("weight"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
//! Applier that correctly applies tensor snapshots with adapter support
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::format;
|
||||
use alloc::string::{String, ToString};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
use burn_core::module::{ModuleMapper, Param};
|
||||
use burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend};
|
||||
|
||||
use crate::apply_result::{ApplyError, ApplyResult};
|
||||
use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
|
||||
|
||||
/// Applier that applies tensor snapshots to module parameters
|
||||
/// with proper adapter support using container type information
|
||||
pub struct Applier<B: Backend> {
|
||||
/// Map of tensor paths to their snapshots
|
||||
snapshots: HashMap<String, TensorSnapshot>,
|
||||
/// Current path in the module hierarchy
|
||||
path_stack: Vec<String>,
|
||||
/// Current container type stack in the module hierarchy
|
||||
container_stack: Vec<String>,
|
||||
/// Optional filter for selective application
|
||||
filter: Option<PathFilter>,
|
||||
/// Optional adapter to transform tensors based on container types
|
||||
adapter: Option<Box<dyn ModuleAdapter>>,
|
||||
/// Successfully applied tensor paths
|
||||
applied: Vec<String>,
|
||||
/// Skipped tensor paths
|
||||
skipped: HashSet<String>,
|
||||
/// Errors encountered during application
|
||||
errors: Vec<ApplyError>,
|
||||
/// Track visited paths with their container stacks (in dot notation) to find missing tensors
|
||||
visited_paths: HashMap<String, String>,
|
||||
/// Skip enum variant names when matching paths
|
||||
/// When true, "feature.BaseConv.weight" will also try to match "feature.weight"
|
||||
skip_enum_variants: bool,
|
||||
/// Phantom data for backend type
|
||||
_backend: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Applier<B> {
|
||||
/// Create a new applier with snapshots, optional filter, and optional adapter
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `views` - A vector of TensorSnapshot objects to apply
|
||||
/// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
|
||||
/// When `None`, all available tensors are applied.
|
||||
/// * `adapter` - Optional adapter to transform tensors based on container types
|
||||
/// * `skip_enum_variants` - Skip enum variant names when matching paths
|
||||
pub fn new(
|
||||
views: Vec<TensorSnapshot>,
|
||||
filter: Option<PathFilter>,
|
||||
adapter: Option<Box<dyn ModuleAdapter>>,
|
||||
skip_enum_variants: bool,
|
||||
) -> Self {
|
||||
let views_map: HashMap<String, TensorSnapshot> = views
|
||||
.into_iter()
|
||||
.map(|view| (view.full_path(), view))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
snapshots: views_map,
|
||||
path_stack: Vec::new(),
|
||||
container_stack: Vec::new(),
|
||||
filter,
|
||||
adapter,
|
||||
applied: Vec::new(),
|
||||
skipped: HashSet::new(),
|
||||
errors: Vec::new(),
|
||||
visited_paths: HashMap::new(),
|
||||
skip_enum_variants,
|
||||
_backend: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current path in the module hierarchy
|
||||
fn current_path(&self) -> String {
|
||||
self.path_stack.join(".")
|
||||
}
|
||||
|
||||
/// Get the current module type (last Struct/Enum in container stack)
|
||||
fn current_module_type(&self) -> Option<&str> {
|
||||
self.container_stack
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Check if a tensor should be applied based on filter
|
||||
fn should_apply(&self) -> bool {
|
||||
match &self.filter {
|
||||
None => true,
|
||||
Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the applier into a result
|
||||
pub fn into_result(self) -> ApplyResult {
|
||||
let mut unused: Vec<String> = self
|
||||
.snapshots
|
||||
.keys()
|
||||
.filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path))
|
||||
.cloned()
|
||||
.collect();
|
||||
// Sort for stable output order
|
||||
unused.sort();
|
||||
|
||||
// Create a set of successfully applied paths for efficient lookup
|
||||
let applied_set: HashSet<String> = self.applied.iter().cloned().collect();
|
||||
|
||||
// Extract paths that have errors - these are not "missing", they were found but had issues
|
||||
let errored_paths: HashSet<String> = self
|
||||
.errors
|
||||
.iter()
|
||||
.map(|e| match e {
|
||||
ApplyError::ShapeMismatch { path, .. } => path.clone(),
|
||||
ApplyError::DTypeMismatch { path, .. } => path.clone(),
|
||||
ApplyError::AdapterError { path, .. } => path.clone(),
|
||||
ApplyError::LoadError { path, .. } => path.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// A path is missing if it was visited but not successfully applied, not skipped, and didn't have an error
|
||||
// Store both the path and its container stack (in dot notation)
|
||||
let mut missing: Vec<(String, String)> = self
|
||||
.visited_paths
|
||||
.into_iter()
|
||||
.filter(|(p, _)| {
|
||||
!applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p)
|
||||
})
|
||||
.collect();
|
||||
// Sort for stable output order (by path)
|
||||
missing.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
// Convert skipped HashSet to sorted Vec for stable output
|
||||
let mut skipped: Vec<String> = self.skipped.into_iter().collect();
|
||||
skipped.sort();
|
||||
|
||||
ApplyResult {
|
||||
applied: self.applied,
|
||||
skipped,
|
||||
missing,
|
||||
unused,
|
||||
errors: self.errors,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply a tensor snapshot with shape validation and optional adapter transformation
|
||||
/// Returns None if snapshot not found, filtered, or validation fails
|
||||
fn apply_tensor<const D: usize, K>(
|
||||
&mut self,
|
||||
target_device: &B::Device,
|
||||
target_shape: Shape,
|
||||
) -> Option<Tensor<B, D, K>>
|
||||
where
|
||||
K: burn_tensor::TensorKind<B>,
|
||||
K: burn_tensor::BasicOps<B>,
|
||||
{
|
||||
let path = self.current_path();
|
||||
let container_stack_str = self.container_stack.join(".");
|
||||
self.visited_paths.insert(path.clone(), container_stack_str);
|
||||
|
||||
// Try to get snapshot with original path first
|
||||
let mut snapshot = self.snapshots.get(&path).cloned();
|
||||
|
||||
// If not found and we have an adapter, try alternative parameter names
|
||||
if snapshot.is_none()
|
||||
&& let Some(ref adapter) = self.adapter
|
||||
&& let Some(module_type) = self.current_module_type()
|
||||
{
|
||||
// Get alternative name based on current module type (user-defined module only)
|
||||
let param_name = self.path_stack.last()?;
|
||||
|
||||
if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) {
|
||||
// Build alternative path with parameter name substitution
|
||||
let mut alt_path_stack = self.path_stack.clone();
|
||||
*alt_path_stack.last_mut().unwrap() = alt_name.clone();
|
||||
let alt_path = alt_path_stack.join(".");
|
||||
|
||||
// Try to get snapshot with alternative name
|
||||
snapshot = self.snapshots.get(&alt_path).cloned();
|
||||
|
||||
// Don't mark the alternative path as visited - only the original Burn path
|
||||
// should be tracked. The alternative path is just for lookup.
|
||||
}
|
||||
}
|
||||
|
||||
let mut snapshot = snapshot?;
|
||||
|
||||
// Apply adapter transformation using current container_stack context (for data transformation like transpose)
|
||||
if let Some(ref adapter) = self.adapter {
|
||||
// Create a temporary snapshot with current context for adaptation
|
||||
let snapshot_with_context = TensorSnapshot::from_closure(
|
||||
snapshot.clone_data_fn(),
|
||||
snapshot.dtype,
|
||||
snapshot.shape.clone(),
|
||||
self.path_stack.clone(),
|
||||
self.container_stack.clone(),
|
||||
snapshot.tensor_id.unwrap_or_default(),
|
||||
);
|
||||
|
||||
// Transform using adapter (handles transpose)
|
||||
snapshot = adapter.adapt(&snapshot_with_context);
|
||||
}
|
||||
|
||||
// Check if we should apply based on filter
|
||||
if !self.should_apply() {
|
||||
self.skipped.insert(path.clone());
|
||||
return None;
|
||||
}
|
||||
|
||||
// Load tensor data
|
||||
let data = match snapshot.to_data() {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
self.errors.push(ApplyError::LoadError {
|
||||
path: path.clone(),
|
||||
message: format!("Failed to load tensor data: {:?}", e),
|
||||
});
|
||||
return None; // Signal caller to fall back to initialization
|
||||
}
|
||||
};
|
||||
|
||||
// Validate shape
|
||||
if data.shape != *target_shape {
|
||||
self.errors.push(ApplyError::ShapeMismatch {
|
||||
path: path.clone(),
|
||||
expected: target_shape.to_vec(),
|
||||
found: data.shape.clone(),
|
||||
});
|
||||
return None; // Signal caller to fall back to initialization
|
||||
}
|
||||
|
||||
self.applied.push(path);
|
||||
Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleMapper<B> for Applier<B> {
|
||||
fn enter_module(&mut self, name: &str, container_type: &str) {
|
||||
// Always track the container type for proper module type detection
|
||||
self.container_stack.push(container_type.to_string());
|
||||
|
||||
// Only add to path if it's not an enum variant (when skip_enum_variants is enabled)
|
||||
// This ensures paths are built without enum variant names from the start
|
||||
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
|
||||
self.path_stack.push(name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn exit_module(&mut self, _name: &str, container_type: &str) {
|
||||
self.container_stack.pop();
|
||||
|
||||
// Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)
|
||||
if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
|
||||
self.path_stack.pop();
|
||||
}
|
||||
}
|
||||
|
||||
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
|
||||
let param_id = param.id;
|
||||
let target_device = param.lazy_device();
|
||||
let target_shape = param.lazy_shape();
|
||||
|
||||
// Try to apply snapshot with shape validation
|
||||
match self.apply_tensor(&target_device, target_shape) {
|
||||
Some(tensor) => {
|
||||
// We have a tensor to apply - load it
|
||||
param.transform_for_load(tensor, param_id)
|
||||
}
|
||||
None => {
|
||||
// No snapshot, filtered, or validation failed - return param unchanged
|
||||
param
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map_int<const D: usize>(
|
||||
&mut self,
|
||||
param: Param<Tensor<B, D, Int>>,
|
||||
) -> Param<Tensor<B, D, Int>> {
|
||||
let param_id = param.id;
|
||||
let target_device = param.lazy_device();
|
||||
let target_shape = param.lazy_shape();
|
||||
|
||||
// Try to apply snapshot with shape validation
|
||||
match self.apply_tensor(&target_device, target_shape) {
|
||||
Some(tensor) => {
|
||||
// We have a tensor to apply - load it
|
||||
param.transform_for_load(tensor, param_id)
|
||||
}
|
||||
None => {
|
||||
// No snapshot, filtered, or validation failed - return param unchanged
|
||||
param
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map_bool<const D: usize>(
|
||||
&mut self,
|
||||
param: Param<Tensor<B, D, Bool>>,
|
||||
) -> Param<Tensor<B, D, Bool>> {
|
||||
let param_id = param.id;
|
||||
let target_device = param.lazy_device();
|
||||
let target_shape = param.lazy_shape();
|
||||
|
||||
// Try to apply snapshot with shape validation
|
||||
match self.apply_tensor(&target_device, target_shape) {
|
||||
Some(tensor) => {
|
||||
// We have a tensor to apply - load it
|
||||
param.transform_for_load(tensor, param_id)
|
||||
}
|
||||
None => {
|
||||
// No snapshot, filtered, or validation failed - return param unchanged
|
||||
param
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_core::module::{ModuleMapper, Param, ParamId};
|
||||
use burn_tensor::{DType, Tensor, TensorData};
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray;
|
||||
|
||||
#[test]
|
||||
fn root_level_parameters() {
|
||||
let device = Default::default();
|
||||
|
||||
// Create root-level parameters (not inside any module)
|
||||
let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
|
||||
let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
|
||||
|
||||
// Create snapshots with root-level paths (single-element path, no nested modules)
|
||||
let weight_snapshot = crate::TensorSnapshot::from_data(
|
||||
weight.val().to_data(),
|
||||
vec!["weight".to_string()], // root-level parameter name
|
||||
vec![], // no container
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let bias_snapshot = crate::TensorSnapshot::from_data(
|
||||
bias.val().to_data(),
|
||||
vec!["bias".to_string()], // root-level parameter name
|
||||
vec![], // no container
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
// Create applier with root-level snapshots
|
||||
let mut applier =
|
||||
Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None, false);
|
||||
|
||||
// Create new params to load into
|
||||
let weight_target = Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::<TestBackend, 2>::zeros([2, 2], &device),
|
||||
);
|
||||
let bias_target = Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::<TestBackend, 1>::zeros([2], &device),
|
||||
);
|
||||
|
||||
// Apply using the ModuleMapper interface - simulate module traversal
|
||||
// Enter "weight" path (as if we're visiting a field named "weight")
|
||||
applier.enter_module("weight", "");
|
||||
let weight_loaded = applier.map_float(weight_target);
|
||||
applier.exit_module("weight", "");
|
||||
|
||||
// Enter "bias" path (as if we're visiting a field named "bias")
|
||||
applier.enter_module("bias", "");
|
||||
let bias_loaded = applier.map_float(bias_target);
|
||||
applier.exit_module("bias", "");
|
||||
|
||||
// Verify values were loaded
|
||||
let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
|
||||
let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
|
||||
|
||||
assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
assert_eq!(bias_data, vec![5.0, 6.0]);
|
||||
|
||||
// Verify applier result
|
||||
let result = applier.into_result();
|
||||
assert_eq!(result.applied.len(), 2);
|
||||
assert_eq!(result.errors.len(), 0);
|
||||
}
|
||||
|
||||
/// Test that the applier preserves dtype when loading tensor data.
|
||||
/// This is a regression test for the bug where F16 tensors were being
|
||||
/// loaded as F32 because `Tensor::from_data` was used instead of
|
||||
/// `Tensor::from_data_dtype`.
|
||||
#[test]
|
||||
fn dtype_preservation_f64() {
|
||||
// Use NdArray<f64> backend to properly test F64 dtype preservation
|
||||
type TestBackendF64 = burn_ndarray::NdArray<f64>;
|
||||
let device = Default::default();
|
||||
|
||||
// Create TensorData with F64 dtype explicitly
|
||||
let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
|
||||
assert_eq!(f64_data.dtype, DType::F64, "Test setup: data should be F64");
|
||||
|
||||
// Create a snapshot with F64 data
|
||||
let snapshot = crate::TensorSnapshot::from_data(
|
||||
f64_data.clone(),
|
||||
vec!["weight".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
assert_eq!(
|
||||
snapshot.dtype,
|
||||
DType::F64,
|
||||
"Snapshot should preserve F64 dtype"
|
||||
);
|
||||
|
||||
// Create applier with the F64 snapshot
|
||||
let mut applier = Applier::<TestBackendF64>::new(vec![snapshot], None, None, false);
|
||||
|
||||
// Create target parameter
|
||||
let target = Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::<TestBackendF64, 2>::zeros([2, 2], &device),
|
||||
);
|
||||
|
||||
// Apply the snapshot
|
||||
applier.enter_module("weight", "");
|
||||
let loaded = applier.map_float(target);
|
||||
applier.exit_module("weight", "");
|
||||
|
||||
// Verify dtype is preserved - this would fail before the fix
|
||||
// because the data would be converted to the backend's default FloatElem
|
||||
assert_eq!(
|
||||
loaded.val().dtype(),
|
||||
DType::F64,
|
||||
"Loaded tensor should have F64 dtype"
|
||||
);
|
||||
|
||||
// Verify data values are correct
|
||||
let loaded_data = loaded.val().to_data().to_vec::<f64>().unwrap();
|
||||
assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
// Verify applier result
|
||||
let result = applier.into_result();
|
||||
assert_eq!(result.applied.len(), 1);
|
||||
assert_eq!(result.errors.len(), 0);
|
||||
}
|
||||
|
||||
/// Test that F32 dtype is preserved when loading (verifies we didn't break F32 handling)
|
||||
#[test]
|
||||
fn dtype_preservation_f32() {
|
||||
let device = Default::default();
|
||||
|
||||
// Create TensorData with F32 dtype
|
||||
let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]);
|
||||
assert_eq!(f32_data.dtype, DType::F32);
|
||||
|
||||
// Create a snapshot with F32 data
|
||||
let snapshot = crate::TensorSnapshot::from_data(
|
||||
f32_data.clone(),
|
||||
vec!["weight".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
assert_eq!(snapshot.dtype, DType::F32);
|
||||
|
||||
// Create applier with the F32 snapshot
|
||||
let mut applier = Applier::<TestBackend>::new(vec![snapshot], None, None, false);
|
||||
|
||||
// Create target parameter
|
||||
let target = Param::initialized(
|
||||
ParamId::new(),
|
||||
Tensor::<TestBackend, 2>::zeros([2, 2], &device),
|
||||
);
|
||||
|
||||
// Apply the snapshot
|
||||
applier.enter_module("weight", "");
|
||||
let loaded = applier.map_float(target);
|
||||
applier.exit_module("weight", "");
|
||||
|
||||
// Verify dtype is F32
|
||||
assert_eq!(loaded.val().dtype(), DType::F32);
|
||||
|
||||
// Verify data values
|
||||
let loaded_data = loaded.val().to_data().to_vec::<f32>().unwrap();
|
||||
assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
/// Test that F16 dtype is correctly preserved in TensorSnapshot.
|
||||
///
|
||||
/// Note: Full F16 tensor loading requires a backend that supports F16
|
||||
/// (e.g., CUDA, WebGPU). The NdArray backend does not support F16.
|
||||
/// This test verifies that the snapshot correctly preserves F16 dtype,
|
||||
/// which is the key part of the dtype preservation fix.
|
||||
#[test]
|
||||
fn dtype_preservation_f16_snapshot() {
|
||||
use half::f16;
|
||||
|
||||
// Create TensorData with F16 dtype using the half crate
|
||||
let f16_values: Vec<f16> = vec![
|
||||
f16::from_f32(1.0),
|
||||
f16::from_f32(2.0),
|
||||
f16::from_f32(3.0),
|
||||
f16::from_f32(4.0),
|
||||
];
|
||||
let f16_data = TensorData::new(f16_values.clone(), [2, 2]);
|
||||
assert_eq!(
|
||||
f16_data.dtype,
|
||||
DType::F16,
|
||||
"TensorData should have F16 dtype"
|
||||
);
|
||||
|
||||
// Create a snapshot with F16 data
|
||||
let snapshot = crate::TensorSnapshot::from_data(
|
||||
f16_data.clone(),
|
||||
vec!["weight".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
// Verify snapshot preserves F16 dtype
|
||||
assert_eq!(
|
||||
snapshot.dtype,
|
||||
DType::F16,
|
||||
"TensorSnapshot should preserve F16 dtype"
|
||||
);
|
||||
|
||||
// Verify the data can be retrieved with correct dtype
|
||||
let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
|
||||
assert_eq!(
|
||||
retrieved_data.dtype,
|
||||
DType::F16,
|
||||
"Retrieved data should have F16 dtype"
|
||||
);
|
||||
|
||||
// Verify the actual values are preserved
|
||||
let retrieved_values: Vec<f16> = retrieved_data
|
||||
.to_vec()
|
||||
.expect("Should be able to convert to f16 vec");
|
||||
assert_eq!(
|
||||
retrieved_values, f16_values,
|
||||
"F16 values should be preserved"
|
||||
);
|
||||
|
||||
// Note: To fully test F16 tensor creation, you would need a backend
|
||||
// that supports F16 (like CUDA or WebGPU). The applier fix ensures
|
||||
// that `Tensor::from_data_dtype(data, device, snapshot.dtype)` is
|
||||
// called with DType::F16, which will correctly create an F16 tensor
|
||||
// on backends that support it.
|
||||
}
|
||||
|
||||
/// Test that BF16 dtype is correctly preserved in TensorSnapshot.
|
||||
#[test]
|
||||
fn dtype_preservation_bf16_snapshot() {
|
||||
use half::bf16;
|
||||
|
||||
// Create TensorData with BF16 dtype
|
||||
let bf16_values: Vec<bf16> = vec![
|
||||
bf16::from_f32(1.0),
|
||||
bf16::from_f32(2.0),
|
||||
bf16::from_f32(3.0),
|
||||
bf16::from_f32(4.0),
|
||||
];
|
||||
let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]);
|
||||
assert_eq!(
|
||||
bf16_data.dtype,
|
||||
DType::BF16,
|
||||
"TensorData should have BF16 dtype"
|
||||
);
|
||||
|
||||
// Create a snapshot with BF16 data
|
||||
let snapshot = crate::TensorSnapshot::from_data(
|
||||
bf16_data.clone(),
|
||||
vec!["weight".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
// Verify snapshot preserves BF16 dtype
|
||||
assert_eq!(
|
||||
snapshot.dtype,
|
||||
DType::BF16,
|
||||
"TensorSnapshot should preserve BF16 dtype"
|
||||
);
|
||||
|
||||
// Verify the data can be retrieved with correct dtype
|
||||
let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
|
||||
assert_eq!(
|
||||
retrieved_data.dtype,
|
||||
DType::BF16,
|
||||
"Retrieved data should have BF16 dtype"
|
||||
);
|
||||
|
||||
// Verify the actual values are preserved
|
||||
let retrieved_values: Vec<bf16> = retrieved_data
|
||||
.to_vec()
|
||||
.expect("Should be able to convert to bf16 vec");
|
||||
assert_eq!(
|
||||
retrieved_values, bf16_values,
|
||||
"BF16 values should be preserved"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
//! Result types and diagnostics for tensor application operations
|
||||
|
||||
use alloc::string::String;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use burn_tensor::DType;
|
||||
|
||||
/// Error types that can occur during tensor application
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ApplyError {
|
||||
/// Shape mismatch between expected and actual tensor
|
||||
ShapeMismatch {
|
||||
/// Path of the tensor
|
||||
path: String,
|
||||
/// Expected shape
|
||||
expected: Vec<usize>,
|
||||
/// Found shape
|
||||
found: Vec<usize>,
|
||||
},
|
||||
/// Data type mismatch between expected and actual tensor
|
||||
DTypeMismatch {
|
||||
/// Path of the tensor
|
||||
path: String,
|
||||
/// Expected data type
|
||||
expected: DType,
|
||||
/// Found data type
|
||||
found: DType,
|
||||
},
|
||||
/// Error from adapter transformation
|
||||
AdapterError {
|
||||
/// Path of the tensor
|
||||
path: String,
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
/// Error loading tensor data
|
||||
LoadError {
|
||||
/// Path of the tensor
|
||||
path: String,
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl core::fmt::Display for ApplyError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
Self::ShapeMismatch {
|
||||
path,
|
||||
expected,
|
||||
found,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Shape mismatch for '{}': expected {:?}, found {:?}",
|
||||
path, expected, found
|
||||
)
|
||||
}
|
||||
Self::DTypeMismatch {
|
||||
path,
|
||||
expected,
|
||||
found,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"DType mismatch for '{}': expected {:?}, found {:?}",
|
||||
path, expected, found
|
||||
)
|
||||
}
|
||||
Self::AdapterError { path, message } => {
|
||||
write!(f, "Adapter error for '{}': {}", path, message)
|
||||
}
|
||||
Self::LoadError { path, message } => {
|
||||
write!(f, "Load error for '{}': {}", path, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::error::Error for ApplyError {}
|
||||
|
||||
/// Result of applying tensor snapshots to a module
|
||||
#[derive(Clone)]
|
||||
pub struct ApplyResult {
|
||||
/// Successfully applied tensor paths
|
||||
pub applied: Vec<String>,
|
||||
/// Skipped tensor paths (due to filter)
|
||||
pub skipped: Vec<String>,
|
||||
/// Missing tensor paths with their container stacks in dot notation (path, container_stack)
|
||||
/// Container stack shows the hierarchy: "Struct:Model.Struct:Linear" or "Struct:Model.Enum:ConvType.Struct:Linear"
|
||||
pub missing: Vec<(String, String)>,
|
||||
/// Unused tensor paths (in snapshots but not in module)
|
||||
pub unused: Vec<String>,
|
||||
/// Errors encountered during application
|
||||
pub errors: Vec<ApplyError>,
|
||||
}
|
||||
|
||||
impl ApplyResult {
|
||||
/// Try to strip enum variant from a path
|
||||
/// e.g., "field.BaseConv.weight" -> "field.weight"
|
||||
fn strip_enum_variant(path: &str) -> Option<String> {
|
||||
let segments: Vec<&str> = path.split('.').collect();
|
||||
|
||||
// Find segments that look like enum variants (CamelCase in middle of path)
|
||||
let variant_indices: Vec<usize> = segments
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, segment)| {
|
||||
*i > 0 && *i < segments.len() - 1 // Not first or last
|
||||
&& !segment.is_empty()
|
||||
&& segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
|
||||
&& segment.len() > 1
|
||||
&& segment.chars().skip(1).any(|c| c.is_lowercase())
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if variant_indices.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Remove the first found variant and return the modified path
|
||||
let mut result_segments = segments.clone();
|
||||
result_segments.remove(variant_indices[0]);
|
||||
Some(result_segments.join("."))
|
||||
}
|
||||
|
||||
/// Find similar paths for a given missing path (for "Did you mean?" suggestions)
|
||||
fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec<String> {
|
||||
// First, try exact match with enum variant stripped
|
||||
if let Some(stripped) = Self::strip_enum_variant(missing_path)
|
||||
&& self.unused.contains(&stripped)
|
||||
{
|
||||
return vec![stripped];
|
||||
}
|
||||
|
||||
// Fall back to Jaro similarity (used by Elixir for "did you mean?" suggestions)
|
||||
// Jaro gives higher weight to matching prefixes, ideal for hierarchical tensor paths
|
||||
let mut similarities: Vec<(String, f64)> = self
|
||||
.unused
|
||||
.iter()
|
||||
.map(|available| {
|
||||
let similarity = textdistance::nstr::jaro(missing_path, available);
|
||||
(available.clone(), similarity)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by similarity (higher = more similar)
|
||||
similarities
|
||||
.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal));
|
||||
|
||||
// Only suggest paths with >= 70% similarity
|
||||
const SIMILARITY_THRESHOLD: f64 = 0.7;
|
||||
similarities
|
||||
.into_iter()
|
||||
.filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD)
|
||||
.take(max_suggestions)
|
||||
.map(|(path, _)| path)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl ApplyResult {
|
||||
/// Check if the apply operation was successful (no errors)
|
||||
/// Note: Missing tensors are not considered errors when allow_partial is true
|
||||
pub fn is_success(&self) -> bool {
|
||||
self.errors.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for ApplyResult {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
// Delegate to Display for comprehensive output
|
||||
core::fmt::Display::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for ApplyResult {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
writeln!(f, "┌─ Tensor Loading Summary ─────────────────────────")?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ ✓ Successfully applied: {} tensors",
|
||||
self.applied.len()
|
||||
)?;
|
||||
writeln!(f, "│ ⊘ Skipped (filtered): {} tensors", self.skipped.len())?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ ✗ Missing in source: {} tensors",
|
||||
self.missing.len()
|
||||
)?;
|
||||
writeln!(f, "│ ? Unused in target: {} tensors", self.unused.len())?;
|
||||
writeln!(f, "│ ! Errors: {} errors", self.errors.len())?;
|
||||
|
||||
if !self.missing.is_empty() {
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"├─ Missing Tensors (requested by model but not found in source)"
|
||||
)?;
|
||||
writeln!(f, "│")?;
|
||||
|
||||
// Use actual container stack data to detect enum variants
|
||||
// Count how many missing paths have "Enum:" in their container stack
|
||||
let enum_variant_missing: Vec<_> = self
|
||||
.missing
|
||||
.iter()
|
||||
.filter(|(_, stack)| stack.contains("Enum:"))
|
||||
.collect();
|
||||
|
||||
if !enum_variant_missing.is_empty() {
|
||||
writeln!(
|
||||
f,
|
||||
"│ ⚠️ {} paths contain enum variants (detected from container stack)",
|
||||
enum_variant_missing.len()
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ Burn includes enum variant names in paths, but PyTorch doesn't."
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'"
|
||||
)?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ 💡 Solution 1: Enable skip_enum_variants flag (simplest):"
|
||||
)?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ let mut store = PytorchStore::from_file(\"model.pth\")"
|
||||
)?;
|
||||
writeln!(f, "│ .skip_enum_variants(true); // ← Add this")?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ 💡 Solution 2: Remap enum keys in source (most precise):"
|
||||
)?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ let mut store = SafetensorsStore::from_file(\"model.safetensors\")"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"│ .with_key_remapping(r\"field\\.(\\w+)\", \"field.BaseConv.$1\");"
|
||||
)?;
|
||||
writeln!(f, "│")?;
|
||||
}
|
||||
|
||||
writeln!(f, "│ First 10 missing tensors:")?;
|
||||
for (path, _) in self.missing.iter().take(10) {
|
||||
writeln!(f, "│ • {}", path)?;
|
||||
|
||||
// Show "Did you mean?" suggestions for this path
|
||||
let suggestions = self.find_similar_paths(path, 1);
|
||||
if !suggestions.is_empty() {
|
||||
writeln!(f, "│ Did you mean: '{}'?", suggestions[0])?;
|
||||
}
|
||||
}
|
||||
if self.missing.len() > 10 {
|
||||
writeln!(f, "│ ... and {} more", self.missing.len() - 10)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !self.unused.is_empty() {
|
||||
writeln!(f, "│")?;
|
||||
writeln!(f, "├─ Unused Tensors (in source but not used by model)")?;
|
||||
writeln!(f, "│")?;
|
||||
writeln!(f, "│ First 10 unused tensors:")?;
|
||||
for path in self.unused.iter().take(10) {
|
||||
writeln!(f, "│ • {}", path)?;
|
||||
}
|
||||
if self.unused.len() > 10 {
|
||||
writeln!(f, "│ ... and {} more", self.unused.len() - 10)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !self.errors.is_empty() {
|
||||
writeln!(f, "│")?;
|
||||
writeln!(f, "├─ Errors")?;
|
||||
writeln!(f, "│")?;
|
||||
for error in self.errors.iter().take(10) {
|
||||
writeln!(f, "│ ⚠️ {}", error)?;
|
||||
}
|
||||
if self.errors.len() > 10 {
|
||||
writeln!(f, "│ ... and {} more", self.errors.len() - 10)?;
|
||||
}
|
||||
}
|
||||
|
||||
writeln!(f, "│")?;
|
||||
write!(f, "└───────────────────────────────────────────────────")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
//! Core types and constants for the Burnpack file format.
|
||||
//!
|
||||
//! See the [parent module](crate::burnpack) for the complete file format specification.
|
||||
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::DType;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Magic number identifying a Burnpack file: "BURN" in ASCII (0x4255524E)
|
||||
/// When written to file in little-endian format, appears as "NRUB" bytes
|
||||
pub const MAGIC_NUMBER: u32 = 0x4255524E;
|
||||
|
||||
/// Current format version
|
||||
pub const FORMAT_VERSION: u16 = 0x0001;
|
||||
|
||||
/// Size of the magic number in bytes
|
||||
pub const MAGIC_SIZE: usize = 4;
|
||||
|
||||
/// Size of the format version in bytes
|
||||
pub const VERSION_SIZE: usize = 2;
|
||||
|
||||
/// Size of the metadata size field in bytes
|
||||
pub const METADATA_SIZE_FIELD_SIZE: usize = 4;
|
||||
|
||||
/// Total header size (computed from components)
|
||||
pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;
|
||||
|
||||
/// Alignment for tensor data in bytes.
|
||||
///
|
||||
/// All tensor data is aligned to 256-byte boundaries to enable efficient
|
||||
/// memory-mapped (mmap) zero-copy loading. This alignment ensures:
|
||||
/// - Proper pointer alignment for all tensor element types (f64 requires 8-byte alignment)
|
||||
/// - Cache-line friendly access (most CPUs use 64-byte cache lines)
|
||||
/// - GPU memory alignment (CUDA prefers 256-byte for coalesced access)
|
||||
/// - Future-proofing for wider SIMD (AVX-512 = 64 bytes, future AVX-1024 = 128 bytes)
|
||||
///
|
||||
/// Industry alignment choices:
|
||||
/// - 256-byte: GGUF, MLX, ncnn, MNN, TNN, vLLM-AWQ, Marlin (15+ formats)
|
||||
/// - 64-byte: SafeTensors (minimum for AVX-512)
|
||||
/// - 4096-byte: Core ML
|
||||
///
|
||||
/// 256-byte alignment has negligible overhead for typical tensor sizes while
|
||||
/// providing maximum compatibility with current and future hardware.
|
||||
pub const TENSOR_ALIGNMENT: u64 = 256;
|
||||
|
||||
/// Calculate the byte offset where the tensor data section starts.
|
||||
///
|
||||
/// The data section is padded to start at a 256-byte aligned position
|
||||
/// so that all tensor offsets (which are relative to data section) result
|
||||
/// in properly aligned absolute file positions for mmap zero-copy access.
|
||||
///
|
||||
/// This function must be used consistently by both writer and reader.
|
||||
#[inline]
|
||||
pub fn aligned_data_section_start(metadata_size: usize) -> usize {
|
||||
let unaligned_start = (HEADER_SIZE + metadata_size) as u64;
|
||||
// Keep multiplication in u64 space to avoid overflow on 32-bit systems
|
||||
(unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize
|
||||
}
|
||||
|
||||
// Security limits to prevent DoS attacks via resource exhaustion
|
||||
// These can be adjusted based on your use case
|
||||
|
||||
/// Maximum allowed metadata size (100 MB)
|
||||
/// Prevents memory exhaustion attacks via oversized metadata claims
|
||||
pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;
|
||||
|
||||
/// Maximum allowed tensor size per tensor
|
||||
/// Prevents memory exhaustion attacks via oversized tensor claims
|
||||
/// 32-bit platforms: 2 GB limit (to fit within usize range)
|
||||
/// 64-bit platforms: 10 GB limit
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;
|
||||
#[cfg(not(target_pointer_width = "32"))]
|
||||
pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;
|
||||
|
||||
/// Maximum allowed number of tensors (100,000)
|
||||
/// Prevents resource exhaustion via excessive tensor counts
|
||||
pub const MAX_TENSOR_COUNT: usize = 100_000;
|
||||
|
||||
/// Maximum CBOR deserialization recursion depth (128 levels)
|
||||
/// Prevents stack overflow attacks via deeply nested CBOR structures
|
||||
pub const MAX_CBOR_RECURSION_DEPTH: usize = 128;
|
||||
|
||||
/// Maximum allowed file size (100 GB)
|
||||
/// Prevents resource exhaustion from extremely large files
|
||||
/// This limit applies to file-based loading (mmap and buffered)
|
||||
#[cfg(feature = "std")]
|
||||
pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;
|
||||
|
||||
/// Byte range for magic number in header
|
||||
pub const fn magic_range() -> core::ops::Range<usize> {
|
||||
let start = 0;
|
||||
let end = start + MAGIC_SIZE;
|
||||
start..end
|
||||
}
|
||||
|
||||
/// Byte range for format version in header
|
||||
pub const fn version_range() -> core::ops::Range<usize> {
|
||||
let start = MAGIC_SIZE;
|
||||
let end = start + VERSION_SIZE;
|
||||
start..end
|
||||
}
|
||||
|
||||
/// Byte range for metadata size field in header
|
||||
pub const fn metadata_size_range() -> core::ops::Range<usize> {
|
||||
let start = MAGIC_SIZE + VERSION_SIZE;
|
||||
let end = start + METADATA_SIZE_FIELD_SIZE;
|
||||
start..end
|
||||
}
|
||||
|
||||
// Compile-time validation that ranges are correct
|
||||
const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);
|
||||
|
||||
/// Header structure for Burnpack files
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BurnpackHeader {
|
||||
/// Magic number (4 bytes): 0x4255524E ("BURN")
|
||||
pub magic: u32,
|
||||
/// Format version (2 bytes)
|
||||
pub version: u16,
|
||||
/// Size of CBOR metadata in bytes (4 bytes)
|
||||
pub metadata_size: u32,
|
||||
}
|
||||
|
||||
impl BurnpackHeader {
|
||||
/// Create a new header with the given metadata size
|
||||
#[allow(dead_code)]
|
||||
pub fn new(metadata_size: u32) -> Self {
|
||||
Self {
|
||||
magic: MAGIC_NUMBER,
|
||||
version: FORMAT_VERSION,
|
||||
metadata_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize header into bytes
|
||||
pub fn into_bytes(self) -> [u8; HEADER_SIZE] {
|
||||
let mut bytes = [0u8; HEADER_SIZE];
|
||||
LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);
|
||||
LittleEndian::write_u16(&mut bytes[version_range()], self.version);
|
||||
LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Deserialize header from bytes
|
||||
pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {
|
||||
if bytes.len() < HEADER_SIZE {
|
||||
return Err(BurnpackError::InvalidHeader);
|
||||
}
|
||||
|
||||
let magic = LittleEndian::read_u32(&bytes[magic_range()]);
|
||||
if magic != MAGIC_NUMBER {
|
||||
return Err(BurnpackError::InvalidMagicNumber);
|
||||
}
|
||||
|
||||
let version = LittleEndian::read_u16(&bytes[version_range()]);
|
||||
let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);
|
||||
|
||||
Ok(Self {
|
||||
magic,
|
||||
version,
|
||||
metadata_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata structure serialized with CBOR
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BurnpackMetadata {
|
||||
/// Tensor descriptors mapped by name for efficient lookup
|
||||
pub tensors: BTreeMap<String, TensorDescriptor>,
|
||||
/// Optional additional metadata
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub metadata: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
/// Individual tensor descriptor
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TensorDescriptor {
|
||||
/// Data type of the tensor
|
||||
pub dtype: DType,
|
||||
/// Tensor shape dimensions
|
||||
pub shape: Vec<u64>,
|
||||
/// Byte offsets in data section (start, end)
|
||||
pub data_offsets: (u64, u64),
|
||||
/// Parameter ID for training state persistence matching.
|
||||
/// Generated automatically if not present during loading.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub param_id: Option<u64>,
|
||||
}
|
||||
|
||||
/// Error types for Burnpack operations
|
||||
#[derive(Debug)]
|
||||
pub enum BurnpackError {
|
||||
InvalidHeader,
|
||||
InvalidMagicNumber,
|
||||
InvalidVersion,
|
||||
MetadataSerializationError(String),
|
||||
MetadataDeserializationError(String),
|
||||
IoError(String),
|
||||
TensorNotFound(String),
|
||||
TensorBytesSizeMismatch(String),
|
||||
ValidationError(String),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for BurnpackError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"),
|
||||
BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"),
|
||||
BurnpackError::InvalidVersion => write!(f, "Unsupported version"),
|
||||
BurnpackError::MetadataSerializationError(e) => {
|
||||
write!(f, "Metadata serialization error: {}", e)
|
||||
}
|
||||
BurnpackError::MetadataDeserializationError(e) => {
|
||||
write!(f, "Metadata deserialization error: {}", e)
|
||||
}
|
||||
BurnpackError::IoError(e) => write!(f, "I/O error: {}", e),
|
||||
BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
|
||||
BurnpackError::TensorBytesSizeMismatch(e) => {
|
||||
write!(f, "Tensor bytes size mismatch: {}", e)
|
||||
}
|
||||
BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::error::Error for BurnpackError {}
|
||||
@@ -0,0 +1,62 @@
|
||||
//! # Burnpack - Native Burn Model Storage Format
|
||||
//!
|
||||
//! Burnpack is the native binary storage format for Burn models, designed for efficient
|
||||
//! serialization, fast loading, and cross-platform compatibility.
|
||||
//!
|
||||
//! ## Key Features
|
||||
//!
|
||||
//! - **CBOR Metadata**: Structured metadata with efficient binary encoding
|
||||
//! - **Memory-Mapped Loading**: Zero-copy loading for optimal performance
|
||||
//! - **256-byte Tensor Alignment**: Enables efficient mmap zero-copy access
|
||||
//! - **No-std Support**: Works in embedded and WASM environments
|
||||
//! - **ParamId Persistence**: Preserves parameter identities for stateful training
|
||||
//! - **Lazy Tensor Loading**: Deferred data materialization for efficient memory usage
|
||||
//!
|
||||
//! ## File Format Structure
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌──────────────────────────────────┐
|
||||
//! │ Header (10 bytes) │
|
||||
//! ├──────────────────────────────────┤
|
||||
//! │ - Magic number (4 bytes) │ 0x4E525542 ("NRUB" in LE)
|
||||
//! │ - Version (2 bytes) │ Format version (0x0001)
|
||||
//! │ - Metadata size (4 bytes) │ Size of CBOR metadata (u32)
|
||||
//! ├──────────────────────────────────┤
|
||||
//! │ Metadata (CBOR) │
|
||||
//! ├──────────────────────────────────┤
|
||||
//! │ - Tensor descriptors (BTreeMap) │ Order-preserving map of tensor metadata
|
||||
//! │ Key: tensor name (string) │ e.g., "model.layer1.weight"
|
||||
//! │ Value: TensorDescriptor │
|
||||
//! │ - dtype: DType │ Data type (F32, F64, I32, etc.)
|
||||
//! │ - shape: Vec<u64> │ Tensor dimensions
|
||||
//! │ - data_offsets: (u64, u64) │ (start, end) byte offsets (256-byte aligned)
|
||||
//! │ - param_id: Option<u64> │ Parameter ID (for training state)
|
||||
//! │ - Additional metadata(BTreeMap) │ User-defined key-value pairs
|
||||
//! ├──────────────────────────────────┤
|
||||
//! │ Tensor Data Section │
|
||||
//! ├──────────────────────────────────┤
|
||||
//! │ [padding][tensor1][padding]... │ Each tensor aligned to 256-byte boundary
|
||||
//! │ Raw tensor bytes (little-endian)│ Enables mmap zero-copy loading
|
||||
//! └──────────────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Tensor Alignment
|
||||
//!
|
||||
//! All tensor data is aligned to 256-byte boundaries to enable efficient memory-mapped
|
||||
//! (mmap) zero-copy loading. This alignment ensures:
|
||||
//!
|
||||
//! - Proper pointer alignment for all tensor element types (f64 requires 8 bytes)
|
||||
//! - Cache-line friendly access (most CPUs use 64-byte cache lines)
|
||||
//! - GPU memory alignment (CUDA prefers 256-byte for coalesced access)
|
||||
//! - Future-proofing for wider SIMD instructions (AVX-512, future AVX-1024)
|
||||
//!
|
||||
//! The 256-byte alignment matches industry standards used by GGUF, MLX, ncnn, MNN,
|
||||
//! and other major model formats.
|
||||
|
||||
pub mod base;
|
||||
pub mod reader;
|
||||
pub mod store;
|
||||
pub mod writer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -0,0 +1,761 @@
|
||||
#[cfg(feature = "std")]
|
||||
use super::base::MAX_FILE_SIZE;
|
||||
use super::base::{
|
||||
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
|
||||
MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE,
|
||||
aligned_data_section_start,
|
||||
};
|
||||
use crate::TensorSnapshot;
|
||||
use alloc::format;
|
||||
use alloc::rc::Rc;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::{Bytes, TensorData};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::cell::RefCell;
|
||||
#[cfg(feature = "std")]
|
||||
use std::fs::File;
|
||||
#[cfg(feature = "std")]
|
||||
use std::io::{Read, Seek};
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
/// Storage backend for BurnpackReader
|
||||
pub(crate) enum StorageBackend {
|
||||
/// Memory-based storage (also used for memory-mapped files converted to bytes::Bytes)
|
||||
Memory(Rc<Bytes>),
|
||||
/// File-based storage with buffered reading
|
||||
#[cfg(feature = "std")]
|
||||
#[allow(dead_code)]
|
||||
FileBuffered { file: Rc<RefCell<File>> },
|
||||
}
|
||||
|
||||
impl StorageBackend {
|
||||
/// Read data from storage into the provided buffer at the given offset.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bytes` - The buffer to read into (caller-allocated)
|
||||
/// * `offset` - Absolute file/data position to start reading from
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - The requested data range is out of bounds
|
||||
/// - Less data is available than requested (indicates corruption or incorrect offset)
|
||||
/// - File I/O fails
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The caller allocates the buffer, which allows for buffer reuse and future optimizations
|
||||
/// like memory pools and pinned memory.
|
||||
///
|
||||
/// This method ensures all backends have consistent behavior: if the exact number of
|
||||
/// requested bytes cannot be read, an error is returned to prevent data corruption.
|
||||
pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> {
|
||||
match self {
|
||||
StorageBackend::Memory(data) => {
|
||||
let data_bytes = data.as_ref();
|
||||
let end = offset.checked_add(bytes.len()).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Offset overflow: offset {} + length {} exceeds maximum",
|
||||
offset,
|
||||
bytes.len()
|
||||
))
|
||||
})?;
|
||||
|
||||
if end > data_bytes.len() {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Read out of bounds: requested {}..{} but data length is {}",
|
||||
offset,
|
||||
end,
|
||||
data_bytes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
bytes.copy_from_slice(&data_bytes[offset..end]);
|
||||
Ok(())
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
StorageBackend::FileBuffered { file } => {
|
||||
use std::io::SeekFrom;
|
||||
|
||||
let mut file = file.borrow_mut();
|
||||
file.seek(SeekFrom::Start(offset as u64)).map_err(|e| {
|
||||
BurnpackError::IoError(format!("Failed to seek in file: {}", e))
|
||||
})?;
|
||||
|
||||
file.read_exact(bytes).map_err(|e| {
|
||||
BurnpackError::IoError(format!("Failed to read from file: {}", e))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get full data reference for raw access
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> {
|
||||
match self {
|
||||
StorageBackend::Memory(data) => Ok(data.as_ref()),
|
||||
#[cfg(feature = "std")]
|
||||
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
|
||||
"Cannot get full bytes reference for FileBuffered backend".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to slice bytes without copying (zero-copy).
|
||||
///
|
||||
/// This uses `Bytes::clone()` + `split()` which is zero-copy when the underlying
|
||||
/// `Bytes` was created via `Bytes::from_shared()` (backed by `bytes::Bytes`).
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(bytes)` - Successfully created a zero-copy slice
|
||||
/// - `Err(_)` - Backend doesn't support zero-copy or split failed
|
||||
pub(crate) fn slice_bytes(&self, start: usize, end: usize) -> Result<Bytes, BurnpackError> {
|
||||
if end < start {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Invalid slice range: end ({}) < start ({})",
|
||||
end, start
|
||||
)));
|
||||
}
|
||||
|
||||
match self {
|
||||
StorageBackend::Memory(data) => {
|
||||
// Clone the Bytes - cheap if backed by SharedBytesAllocationController
|
||||
let cloned = (**data).clone();
|
||||
|
||||
// Split at start offset to get (_, right)
|
||||
let (_, right) = cloned.split(start).map_err(|(_, e)| {
|
||||
BurnpackError::IoError(format!("Failed to split at start {}: {:?}", start, e))
|
||||
})?;
|
||||
|
||||
// Split right at (end - start) to get (middle, _)
|
||||
let slice_len = end - start;
|
||||
let (middle, _) = right.split(slice_len).map_err(|(_, e)| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Failed to split at length {}: {:?}",
|
||||
slice_len, e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(middle)
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
|
||||
"Zero-copy not supported for buffered file reading. Use from_file() with memmap feature for zero-copy loading.".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reader for loading Burnpack files
|
||||
pub struct BurnpackReader {
|
||||
/// Parsed metadata
|
||||
pub(crate) metadata: BurnpackMetadata,
|
||||
/// Storage backend
|
||||
pub(crate) storage: StorageBackend,
|
||||
/// Offset to the start of tensor data
|
||||
pub(crate) data_offset: usize,
|
||||
}
|
||||
|
||||
impl BurnpackReader {
|
||||
/// Load from bytes
|
||||
pub fn from_bytes(bytes: Bytes) -> Result<Self, BurnpackError> {
|
||||
// Validate minimum size
|
||||
if bytes.len() < HEADER_SIZE {
|
||||
return Err(BurnpackError::InvalidHeader);
|
||||
}
|
||||
|
||||
// Parse header
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?;
|
||||
|
||||
// Verify magic number
|
||||
if header.magic != MAGIC_NUMBER {
|
||||
return Err(BurnpackError::InvalidMagicNumber);
|
||||
}
|
||||
|
||||
// Verify version compatibility
|
||||
if header.version > FORMAT_VERSION {
|
||||
return Err(BurnpackError::InvalidVersion);
|
||||
}
|
||||
|
||||
// Validate metadata size against security limit
|
||||
if header.metadata_size > MAX_METADATA_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
|
||||
header.metadata_size, MAX_METADATA_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse metadata
|
||||
let metadata_start = HEADER_SIZE;
|
||||
let metadata_end = metadata_start
|
||||
.checked_add(header.metadata_size as usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Metadata size overflow: {} + {}",
|
||||
metadata_start, header.metadata_size
|
||||
))
|
||||
})?;
|
||||
|
||||
if bytes.len() < metadata_end {
|
||||
return Err(BurnpackError::InvalidHeader);
|
||||
}
|
||||
|
||||
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
|
||||
&bytes[metadata_start..metadata_end],
|
||||
MAX_CBOR_RECURSION_DEPTH,
|
||||
)
|
||||
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
|
||||
|
||||
// Validate tensor count against security limit
|
||||
if metadata.tensors.len() > MAX_TENSOR_COUNT {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
|
||||
metadata.tensors.len(),
|
||||
MAX_TENSOR_COUNT
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate total file size - ensure file is large enough for all claimed tensor data
|
||||
if !metadata.tensors.is_empty() {
|
||||
let max_data_offset = metadata
|
||||
.tensors
|
||||
.values()
|
||||
.map(|t| t.data_offsets.1)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Data offset {} exceeds platform maximum",
|
||||
max_data_offset
|
||||
))
|
||||
})?;
|
||||
|
||||
let min_file_size =
|
||||
metadata_end
|
||||
.checked_add(max_data_offset_usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::ValidationError("File size calculation overflow".into())
|
||||
})?;
|
||||
|
||||
if bytes.len() < min_file_size {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File truncated: expected at least {} bytes, got {} bytes",
|
||||
min_file_size,
|
||||
bytes.len()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
metadata,
|
||||
storage: StorageBackend::Memory(Rc::new(bytes)),
|
||||
data_offset: aligned_data_section_start(header.metadata_size as usize),
|
||||
})
|
||||
}
|
||||
|
||||
/// Load from file with memory mapping (most efficient for large files)
|
||||
#[cfg(all(feature = "std", feature = "memmap"))]
|
||||
pub(crate) fn from_file_mmap<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
|
||||
let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
// Validate maximum file size to prevent resource exhaustion
|
||||
let file_size = file
|
||||
.metadata()
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?
|
||||
.len();
|
||||
|
||||
if file_size > MAX_FILE_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File size {} bytes exceeds maximum allowed size of {} bytes",
|
||||
file_size, MAX_FILE_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Memory map the file
|
||||
let mmap = unsafe {
|
||||
memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?
|
||||
};
|
||||
|
||||
// Parse header
|
||||
if mmap.len() < HEADER_SIZE {
|
||||
return Err(BurnpackError::InvalidHeader);
|
||||
}
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?;
|
||||
|
||||
// Verify magic number and version
|
||||
if header.magic != MAGIC_NUMBER {
|
||||
return Err(BurnpackError::InvalidMagicNumber);
|
||||
}
|
||||
|
||||
if header.version > FORMAT_VERSION {
|
||||
return Err(BurnpackError::InvalidVersion);
|
||||
}
|
||||
|
||||
// Validate metadata size against security limit
|
||||
if header.metadata_size > MAX_METADATA_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
|
||||
header.metadata_size, MAX_METADATA_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse metadata
|
||||
let metadata_start = HEADER_SIZE;
|
||||
let metadata_end = metadata_start
|
||||
.checked_add(header.metadata_size as usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Metadata size overflow: {} + {}",
|
||||
metadata_start, header.metadata_size
|
||||
))
|
||||
})?;
|
||||
|
||||
if mmap.len() < metadata_end {
|
||||
return Err(BurnpackError::InvalidHeader);
|
||||
}
|
||||
|
||||
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
|
||||
&mmap[metadata_start..metadata_end],
|
||||
MAX_CBOR_RECURSION_DEPTH,
|
||||
)
|
||||
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
|
||||
|
||||
// Validate tensor count against security limit
|
||||
if metadata.tensors.len() > MAX_TENSOR_COUNT {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
|
||||
metadata.tensors.len(),
|
||||
MAX_TENSOR_COUNT
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate total file size - ensure file is large enough for all claimed tensor data
|
||||
if !metadata.tensors.is_empty() {
|
||||
let max_data_offset = metadata
|
||||
.tensors
|
||||
.values()
|
||||
.map(|t| t.data_offsets.1)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Data offset {} exceeds platform maximum",
|
||||
max_data_offset
|
||||
))
|
||||
})?;
|
||||
|
||||
let min_file_size =
|
||||
metadata_end
|
||||
.checked_add(max_data_offset_usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::ValidationError("File size calculation overflow".into())
|
||||
})?;
|
||||
|
||||
if mmap.len() < min_file_size {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File truncated: expected at least {} bytes, got {} bytes",
|
||||
min_file_size,
|
||||
mmap.len()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert mmap to bytes::Bytes for zero-copy slicing support
|
||||
// bytes::Bytes::from_owner takes ownership and enables efficient slicing
|
||||
let shared_bytes = bytes::Bytes::from_owner(mmap);
|
||||
let bytes = Bytes::from_shared(shared_bytes, burn_tensor::AllocationProperty::File);
|
||||
|
||||
Ok(Self {
|
||||
metadata,
|
||||
storage: StorageBackend::Memory(Rc::new(bytes)),
|
||||
data_offset: aligned_data_section_start(header.metadata_size as usize),
|
||||
})
|
||||
}
|
||||
|
||||
/// Load from file - automatically uses memory mapping if available, otherwise uses buffered reading
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
|
||||
#[cfg(feature = "memmap")]
|
||||
{
|
||||
// Use memory mapping for efficient access
|
||||
Self::from_file_mmap(path)
|
||||
}
|
||||
#[cfg(not(feature = "memmap"))]
|
||||
{
|
||||
// Fall back to buffered reading for memory efficiency
|
||||
Self::from_file_buffered(path)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load from file with buffered reading (memory efficient but slower)
|
||||
/// This is less efficient than memory mapping but works everywhere
|
||||
#[cfg(feature = "std")]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn from_file_buffered<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
|
||||
let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
// Validate maximum file size to prevent resource exhaustion
|
||||
let file_size = file
|
||||
.metadata()
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?
|
||||
.len();
|
||||
|
||||
if file_size > MAX_FILE_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File size {} bytes exceeds maximum allowed size of {} bytes",
|
||||
file_size, MAX_FILE_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Read header
|
||||
let mut header_bytes = [0u8; HEADER_SIZE];
|
||||
file.read_exact(&mut header_bytes)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&header_bytes)?;
|
||||
|
||||
// Verify version
|
||||
if header.version > FORMAT_VERSION {
|
||||
return Err(BurnpackError::InvalidVersion);
|
||||
}
|
||||
|
||||
// Validate metadata size against security limit
|
||||
if header.metadata_size > MAX_METADATA_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
|
||||
header.metadata_size, MAX_METADATA_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Read metadata
|
||||
let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
|
||||
file.read_exact(&mut metadata_bytes)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
|
||||
metadata_bytes.as_slice(),
|
||||
MAX_CBOR_RECURSION_DEPTH,
|
||||
)
|
||||
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
|
||||
|
||||
// Validate tensor count against security limit
|
||||
if metadata.tensors.len() > MAX_TENSOR_COUNT {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
|
||||
metadata.tensors.len(),
|
||||
MAX_TENSOR_COUNT
|
||||
)));
|
||||
}
|
||||
|
||||
// Calculate metadata end offset
|
||||
let metadata_end = HEADER_SIZE
|
||||
.checked_add(header.metadata_size as usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Metadata size overflow: {} + {}",
|
||||
HEADER_SIZE, header.metadata_size
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate total file size - ensure file is large enough for all claimed tensor data
|
||||
if !metadata.tensors.is_empty() {
|
||||
let max_data_offset = metadata
|
||||
.tensors
|
||||
.values()
|
||||
.map(|t| t.data_offsets.1)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Data offset {} exceeds platform maximum",
|
||||
max_data_offset
|
||||
))
|
||||
})?;
|
||||
|
||||
let min_file_size =
|
||||
metadata_end
|
||||
.checked_add(max_data_offset_usize)
|
||||
.ok_or_else(|| {
|
||||
BurnpackError::ValidationError("File size calculation overflow".into())
|
||||
})?;
|
||||
|
||||
// Get actual file size
|
||||
let file_size = file
|
||||
.metadata()
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?
|
||||
.len() as usize;
|
||||
|
||||
if file_size < min_file_size {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"File truncated: expected at least {} bytes, got {} bytes",
|
||||
min_file_size, file_size
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
metadata,
|
||||
storage: StorageBackend::FileBuffered {
|
||||
file: Rc::new(RefCell::new(file)),
|
||||
},
|
||||
data_offset: aligned_data_section_start(header.metadata_size as usize),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get all tensor snapshots at once for efficient loading (always copies data)
|
||||
pub fn get_snapshots(&self) -> Result<Vec<TensorSnapshot>, BurnpackError> {
|
||||
self.get_snapshots_internal(false)
|
||||
}
|
||||
|
||||
/// Get all tensor snapshots with optional zero-copy loading.
|
||||
///
|
||||
/// When `zero_copy` is true and the backend supports it (Memory backend with
|
||||
/// `Bytes::from_shared()`), tensor data is sliced without copying. This keeps
|
||||
/// the original data alive as long as any tensor holds a reference.
|
||||
///
|
||||
/// When `zero_copy` is false or the backend doesn't support it, data is copied
|
||||
/// into newly allocated buffers (default behavior).
|
||||
pub fn get_snapshots_zero_copy(
|
||||
&self,
|
||||
zero_copy: bool,
|
||||
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
|
||||
self.get_snapshots_internal(zero_copy)
|
||||
}
|
||||
|
||||
/// Internal implementation with optional zero-copy support
|
||||
fn get_snapshots_internal(
|
||||
&self,
|
||||
zero_copy: bool,
|
||||
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
|
||||
let mut snapshots = Vec::new();
|
||||
|
||||
for (name, descriptor) in &self.metadata.tensors {
|
||||
// Clone metadata for use in closure
|
||||
// Convert shape dimensions with overflow checking
|
||||
let shape: Vec<usize> = descriptor
|
||||
.shape
|
||||
.iter()
|
||||
.map(|&s| {
|
||||
s.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum",
|
||||
name, s
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<usize>, BurnpackError>>()?;
|
||||
|
||||
let dtype = descriptor.dtype;
|
||||
|
||||
// Clone storage reference for the closure
|
||||
let storage = match &self.storage {
|
||||
StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()),
|
||||
#[cfg(feature = "std")]
|
||||
StorageBackend::FileBuffered { file } => {
|
||||
StorageBackend::FileBuffered { file: file.clone() }
|
||||
}
|
||||
};
|
||||
|
||||
// Always use absolute positions for all backends
|
||||
// Convert offsets with overflow checking
|
||||
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
|
||||
name, descriptor.data_offsets.0
|
||||
))
|
||||
})?;
|
||||
|
||||
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
|
||||
name, descriptor.data_offsets.1
|
||||
))
|
||||
})?;
|
||||
|
||||
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
|
||||
name, self.data_offset, offset_start
|
||||
))
|
||||
})?;
|
||||
|
||||
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
|
||||
BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
|
||||
name, self.data_offset, offset_end
|
||||
))
|
||||
})?;
|
||||
|
||||
// Clone shape for the closure (TensorSnapshot::from_closure will also need it)
|
||||
let shape_for_closure = shape.clone();
|
||||
|
||||
// Validate offset range
|
||||
if end < start {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
|
||||
name, end, start
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate tensor size against security limit
|
||||
let tensor_size = end - start;
|
||||
if tensor_size > MAX_TENSOR_SIZE {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
|
||||
name, tensor_size, MAX_TENSOR_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Restore param_id if it was saved, otherwise generate
|
||||
let tensor_id = descriptor
|
||||
.param_id
|
||||
.map(ParamId::from)
|
||||
.unwrap_or_else(ParamId::new);
|
||||
|
||||
// Create the data-loading closure based on zero_copy flag
|
||||
let data_fn: Rc<dyn Fn() -> Result<TensorData, crate::TensorSnapshotError>> =
|
||||
if zero_copy {
|
||||
// Zero-copy closure: slice without copying, error if not supported
|
||||
Rc::new(move || {
|
||||
let bytes = storage.slice_bytes(start, end).map_err(|e| {
|
||||
crate::TensorSnapshotError::IoError(format!(
|
||||
"Zero-copy slice failed: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(TensorData::from_bytes(
|
||||
bytes,
|
||||
shape_for_closure.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
} else {
|
||||
// Copying closure: always allocate and copy
|
||||
Rc::new(move || {
|
||||
let len = end - start;
|
||||
// TODO Should be allocated by the backend in the future
|
||||
// See https://github.com/tracel-ai/burn/pull/3792#discussion_r2416812091
|
||||
let mut data_bytes = vec![0u8; len];
|
||||
storage.read_into(&mut data_bytes, start).map_err(|e| {
|
||||
crate::TensorSnapshotError::IoError(format!(
|
||||
"Failed to read tensor data: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(TensorData::from_bytes_vec(
|
||||
data_bytes,
|
||||
shape_for_closure.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
};
|
||||
|
||||
// Create lazy TensorSnapshot
|
||||
let snapshot = TensorSnapshot::from_closure(
|
||||
data_fn,
|
||||
dtype,
|
||||
shape,
|
||||
name.split('.').map(|s| s.to_string()).collect(),
|
||||
vec![], // empty container_stack
|
||||
tensor_id, // restored or newly generated param id
|
||||
);
|
||||
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
Ok(snapshots)
|
||||
}
|
||||
|
||||
// Legacy methods for test compatibility - will be removed
|
||||
|
||||
/// Get tensor as TensorSnapshot with lazy loading
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result<TensorSnapshot, BurnpackError> {
|
||||
let snapshots = self.get_snapshots()?;
|
||||
snapshots
|
||||
.into_iter()
|
||||
.find(|s| s.full_path() == name)
|
||||
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))
|
||||
}
|
||||
|
||||
/// Get list of tensor names
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn tensor_names(&self) -> Vec<&str> {
|
||||
self.metadata
|
||||
.tensors
|
||||
.keys()
|
||||
.map(|name| name.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get metadata
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn metadata(&self) -> &BurnpackMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
/// Get tensor data as raw bytes
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn get_tensor_data(&self, name: &str) -> Result<Vec<u8>, BurnpackError> {
|
||||
let descriptor = self
|
||||
.metadata
|
||||
.tensors
|
||||
.get(name)
|
||||
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?;
|
||||
|
||||
// Always use absolute positions for all backends
|
||||
// Convert offsets with overflow checking
|
||||
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
|
||||
name, descriptor.data_offsets.0
|
||||
))
|
||||
})?;
|
||||
|
||||
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
|
||||
name, descriptor.data_offsets.1
|
||||
))
|
||||
})?;
|
||||
|
||||
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
|
||||
name, self.data_offset, offset_start
|
||||
))
|
||||
})?;
|
||||
|
||||
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
|
||||
name, self.data_offset, offset_end
|
||||
))
|
||||
})?;
|
||||
|
||||
// Validate offset range
|
||||
if end < start {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
|
||||
name, end, start
|
||||
)));
|
||||
}
|
||||
|
||||
let len = end - start;
|
||||
let mut buffer = vec![0u8; len];
|
||||
self.storage.read_into(&mut buffer, start)?;
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,507 @@
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::reader::BurnpackReader;
|
||||
use super::writer::BurnpackWriter;
|
||||
#[cfg(feature = "std")]
|
||||
use crate::KeyRemapper;
|
||||
use crate::burnpack::base::BurnpackError;
|
||||
use crate::{ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot};
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use burn_core::prelude::Backend;
|
||||
use burn_tensor::Bytes;
|
||||
|
||||
/// Store mode for BurnpackStore
|
||||
enum StoreMode {
|
||||
#[cfg(feature = "std")]
|
||||
File(PathBuf),
|
||||
Bytes(Option<Bytes>),
|
||||
}
|
||||
|
||||
/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
|
||||
pub struct BurnpackStore {
|
||||
/// Store mode - either file path or bytes
|
||||
mode: StoreMode,
|
||||
/// Optional filter for selective loading/saving
|
||||
filter: Option<PathFilter>,
|
||||
/// Additional metadata
|
||||
metadata: BTreeMap<String, String>,
|
||||
/// Allow partial loading (ignore missing tensors)
|
||||
allow_partial: bool,
|
||||
/// Validate tensors during loading (check shapes and dtypes)
|
||||
validate: bool,
|
||||
/// Allow overwriting existing files (default: false)
|
||||
overwrite: bool,
|
||||
/// Enable zero-copy tensor loading (default: false)
|
||||
///
|
||||
/// When enabled and the backend supports it, tensor data is sliced from
|
||||
/// the source without copying. This requires keeping the source data alive.
|
||||
zero_copy: bool,
|
||||
/// Automatically append .bpk extension if not present (default: true)
|
||||
#[cfg(feature = "std")]
|
||||
auto_extension: bool,
|
||||
/// Key remapper for tensor name transformations
|
||||
#[cfg(feature = "std")]
|
||||
remapper: KeyRemapper,
|
||||
/// Writer for saving
|
||||
writer: Option<BurnpackWriter>,
|
||||
/// Reader for loading
|
||||
reader: Option<BurnpackReader>,
|
||||
/// Cached tensor snapshots (parsed once, reused)
|
||||
snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
|
||||
}
|
||||
|
||||
impl BurnpackStore {
|
||||
/// Get the default metadata that includes Burn framework information.
|
||||
///
|
||||
/// This includes:
|
||||
/// - `format`: "burnpack"
|
||||
/// - `producer`: "burn"
|
||||
/// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
|
||||
///
|
||||
/// These metadata fields are automatically added to all saved models.
|
||||
pub fn default_metadata() -> BTreeMap<String, String> {
|
||||
let mut metadata = BTreeMap::new();
|
||||
metadata.insert("format".into(), "burnpack".into());
|
||||
metadata.insert("producer".into(), "burn".into());
|
||||
metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
|
||||
metadata
|
||||
}
|
||||
/// Create a new store from a file path
|
||||
///
|
||||
/// By default, automatically appends `.bpk` extension if the path doesn't have one.
|
||||
/// Use `.auto_extension(false)` to disable this behavior.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use burn_store::BurnpackStore;
|
||||
/// // Automatically appends .bpk
|
||||
/// let store = BurnpackStore::from_file("model"); // creates "model.bpk"
|
||||
///
|
||||
/// // Already has extension, no append
|
||||
/// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk"
|
||||
/// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext"
|
||||
///
|
||||
/// // Disable auto-extension
|
||||
/// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model"
|
||||
/// ```
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
|
||||
Self {
|
||||
mode: StoreMode::File(path.as_ref().to_path_buf()),
|
||||
filter: None,
|
||||
metadata: Self::default_metadata(),
|
||||
allow_partial: false,
|
||||
validate: true,
|
||||
overwrite: false,
|
||||
zero_copy: false,
|
||||
#[cfg(feature = "std")]
|
||||
auto_extension: true,
|
||||
#[cfg(feature = "std")]
|
||||
remapper: KeyRemapper::new(),
|
||||
writer: None,
|
||||
reader: None,
|
||||
snapshots_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new store from bytes (for reading) or empty (for writing)
|
||||
pub fn from_bytes(bytes: Option<Bytes>) -> Self {
|
||||
Self {
|
||||
mode: StoreMode::Bytes(bytes),
|
||||
filter: None,
|
||||
metadata: Self::default_metadata(),
|
||||
allow_partial: false,
|
||||
validate: true,
|
||||
overwrite: false,
|
||||
zero_copy: false,
|
||||
#[cfg(feature = "std")]
|
||||
auto_extension: false, // Not used for bytes mode
|
||||
#[cfg(feature = "std")]
|
||||
remapper: KeyRemapper::new(),
|
||||
writer: None,
|
||||
reader: None,
|
||||
snapshots_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new store from static bytes with zero-copy loading enabled.
|
||||
///
|
||||
/// This is optimized for embedded model weights where the data lives in the
|
||||
/// binary's `.rodata` section. Tensor data is sliced without copying, keeping
|
||||
/// the static reference alive.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
|
||||
/// let store = BurnpackStore::from_static(MODEL_DATA);
|
||||
/// ```
|
||||
pub fn from_static(data: &'static [u8]) -> Self {
|
||||
use burn_tensor::AllocationProperty;
|
||||
|
||||
// Create bytes::Bytes from static data (zero-copy, stays in .rodata)
|
||||
let shared = bytes::Bytes::from_static(data);
|
||||
|
||||
// Wrap in cubecl Bytes with shared-bytes allocation controller
|
||||
let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
|
||||
|
||||
Self {
|
||||
mode: StoreMode::Bytes(Some(bytes)),
|
||||
filter: None,
|
||||
metadata: Self::default_metadata(),
|
||||
allow_partial: false,
|
||||
validate: true,
|
||||
overwrite: false,
|
||||
zero_copy: true, // Enable zero-copy by default for static data
|
||||
#[cfg(feature = "std")]
|
||||
auto_extension: false,
|
||||
#[cfg(feature = "std")]
|
||||
remapper: KeyRemapper::new(),
|
||||
writer: None,
|
||||
reader: None,
|
||||
snapshots_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add metadata key-value pair
|
||||
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.metadata.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Clear all metadata (including defaults)
|
||||
///
|
||||
/// This removes all metadata including the default format, producer, and version fields.
|
||||
/// Use with caution as some tools may expect these fields to be present.
|
||||
pub fn clear_metadata(mut self) -> Self {
|
||||
self.metadata.clear();
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow partial loading (ignore missing tensors)
|
||||
///
|
||||
/// When set to `true`, the store will not fail if some tensors are missing
|
||||
/// during loading. This is useful when loading a subset of a model's parameters.
|
||||
///
|
||||
/// Default: `false`
|
||||
pub fn allow_partial(mut self, allow: bool) -> Self {
|
||||
self.allow_partial = allow;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable validation during loading
|
||||
///
|
||||
/// When validation is enabled, the store will check that loaded tensors
|
||||
/// match the expected shapes and data types. Disabling validation can
|
||||
/// improve performance but may lead to runtime errors if data is corrupted.
|
||||
///
|
||||
/// Default: `true`
|
||||
pub fn validate(mut self, validate: bool) -> Self {
|
||||
self.validate = validate;
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow overwriting existing files when saving
|
||||
///
|
||||
/// When set to `false`, attempting to save to an existing file will result in an error.
|
||||
/// When set to `true`, existing files will be overwritten without warning.
|
||||
///
|
||||
/// Default: `false`
|
||||
pub fn overwrite(mut self, overwrite: bool) -> Self {
|
||||
self.overwrite = overwrite;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable zero-copy tensor loading.
|
||||
///
|
||||
/// When enabled and the backend supports it (memory-backed with shared bytes),
|
||||
/// tensor data is sliced from the source without copying. This keeps the source
|
||||
/// data alive as long as any tensor holds a reference.
|
||||
///
|
||||
/// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
|
||||
/// Use this method to enable it for other memory-backed stores created with
|
||||
/// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
|
||||
///
|
||||
/// Default: `false` (except for `from_static` which defaults to `true`)
|
||||
pub fn zero_copy(mut self, enable: bool) -> Self {
|
||||
self.zero_copy = enable;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable automatic .bpk extension appending
|
||||
///
|
||||
/// When enabled (default), automatically appends `.bpk` to the file path
|
||||
/// if no extension is detected. If an extension is already present, it is preserved.
|
||||
///
|
||||
/// When disabled, uses the exact path provided without modification.
|
||||
///
|
||||
/// Default: `true`
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use burn_store::BurnpackStore;
|
||||
/// // With auto_extension enabled (default)
|
||||
/// let store = BurnpackStore::from_file("model"); // -> "model.bpk"
|
||||
///
|
||||
/// // With auto_extension disabled
|
||||
/// let store = BurnpackStore::from_file("model")
|
||||
/// .auto_extension(false); // -> "model"
|
||||
/// ```
|
||||
#[cfg(feature = "std")]
|
||||
pub fn auto_extension(mut self, enable: bool) -> Self {
|
||||
self.auto_extension = enable;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set path filter for selective loading/saving
|
||||
pub fn with_filter(mut self, filter: PathFilter) -> Self {
|
||||
self.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add regex pattern to filter
|
||||
#[cfg(feature = "std")]
|
||||
pub fn with_regex(mut self, pattern: &str) -> Self {
|
||||
let filter = self.filter.unwrap_or_default();
|
||||
self.filter = Some(filter.with_regex(pattern));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add exact path to filter
|
||||
pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
|
||||
let filter = self.filter.unwrap_or_default();
|
||||
self.filter = Some(filter.with_full_path(path));
|
||||
self
|
||||
}
|
||||
|
||||
/// Match all tensors (no filtering)
|
||||
pub fn match_all(mut self) -> Self {
|
||||
self.filter = Some(PathFilter::new().match_all());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set key remapper for tensor name transformations during loading
|
||||
#[cfg(feature = "std")]
|
||||
pub fn remap(mut self, remapper: KeyRemapper) -> Self {
|
||||
self.remapper = remapper;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a single regex pattern for key remapping
|
||||
#[cfg(feature = "std")]
|
||||
pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
|
||||
where
|
||||
S1: AsRef<str>,
|
||||
S2: Into<String>,
|
||||
{
|
||||
self.remapper = self
|
||||
.remapper
|
||||
.add_pattern(from.as_ref(), to.into())
|
||||
.expect("Invalid regex pattern");
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the path filter
|
||||
pub fn filter(mut self, filter: PathFilter) -> Self {
|
||||
self.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the bytes after writing (only valid for bytes mode after collecting)
|
||||
pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
|
||||
if let Some(writer) = &self.writer {
|
||||
return writer.to_bytes();
|
||||
}
|
||||
|
||||
match &self.mode {
|
||||
StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
|
||||
_ => Err(BurnpackError::IoError("No bytes available".into())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the file path with auto-extension logic
|
||||
#[cfg(feature = "std")]
|
||||
fn process_path(&self, path: &std::path::Path) -> PathBuf {
|
||||
if !self.auto_extension {
|
||||
return path.to_path_buf();
|
||||
}
|
||||
|
||||
// Check if path already has an extension
|
||||
if path.extension().is_some() {
|
||||
// Has extension, use as-is
|
||||
return path.to_path_buf();
|
||||
}
|
||||
|
||||
// No extension, append .bpk
|
||||
let mut new_path = path.to_path_buf();
|
||||
new_path.set_extension("bpk");
|
||||
new_path
|
||||
}
|
||||
|
||||
/// Ensure the reader is initialized, loading from storage if needed
|
||||
fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
|
||||
if self.reader.is_none() {
|
||||
let reader = match &self.mode {
|
||||
#[cfg(feature = "std")]
|
||||
StoreMode::File(path) => {
|
||||
let final_path = self.process_path(path);
|
||||
BurnpackReader::from_file(&final_path)?
|
||||
}
|
||||
StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
|
||||
StoreMode::Bytes(None) => {
|
||||
return Err(BurnpackError::IoError("No bytes to read from".into()));
|
||||
}
|
||||
};
|
||||
self.reader = Some(reader);
|
||||
}
|
||||
|
||||
self.reader
|
||||
.as_ref()
|
||||
.ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleStore for BurnpackStore {
|
||||
type Error = BurnpackError;
|
||||
|
||||
fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
) -> Result<(), Self::Error> {
|
||||
// Invalidate cache since we're writing new data
|
||||
self.snapshots_cache = None;
|
||||
self.reader = None;
|
||||
|
||||
// Collect snapshots from module
|
||||
let snapshots = module.collect(self.filter.clone(), None, false);
|
||||
|
||||
// Initialize writer with snapshots
|
||||
let mut writer = BurnpackWriter::new(snapshots);
|
||||
|
||||
// Add metadata using builder pattern
|
||||
for (key, value) in &self.metadata {
|
||||
writer = writer.with_metadata(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
// Store the writer for finalization
|
||||
self.writer = Some(writer);
|
||||
|
||||
// Write to storage based on mode
|
||||
if let Some(writer) = &self.writer {
|
||||
match &self.mode {
|
||||
#[cfg(feature = "std")]
|
||||
StoreMode::File(path) => {
|
||||
// Process path with auto-extension logic
|
||||
let final_path = self.process_path(path);
|
||||
|
||||
// Check if file exists and overwrite is disabled
|
||||
if final_path.exists() && !self.overwrite {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"File already exists: {}. Use .overwrite(true) to overwrite.",
|
||||
final_path.display()
|
||||
)));
|
||||
}
|
||||
writer.write_to_file(&final_path)?;
|
||||
}
|
||||
StoreMode::Bytes(_) => {
|
||||
// Generate and store the bytes
|
||||
let bytes_data = writer.to_bytes()?;
|
||||
// Update mode with bytes - this pattern is irrefutable in no-std mode
|
||||
#[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
|
||||
let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
|
||||
unreachable!("We just matched Bytes variant");
|
||||
};
|
||||
*bytes_ref = Some(bytes_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
|
||||
&mut self,
|
||||
module: &mut M,
|
||||
) -> Result<crate::ApplyResult, Self::Error> {
|
||||
// Get all snapshots using the cached method
|
||||
let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
|
||||
|
||||
// Apply all snapshots at once to the module
|
||||
// Burnpack is Burn's native format, so no enum variant skipping needed
|
||||
// Filter is applied here during apply, not during cache population
|
||||
let result = module.apply(snapshots, self.filter.clone(), None, false);
|
||||
|
||||
// Validate if needed
|
||||
if self.validate && !result.errors.is_empty() {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Import errors: {:?}",
|
||||
result.errors
|
||||
)));
|
||||
}
|
||||
|
||||
// Check for missing tensors if partial loading is not allowed
|
||||
if !self.allow_partial && !result.missing.is_empty() {
|
||||
return Err(BurnpackError::ValidationError(format!(
|
||||
"Missing tensors: {:?}",
|
||||
result.missing
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
|
||||
// Ensure cache is populated
|
||||
self.ensure_snapshots_cache()?;
|
||||
Ok(self.snapshots_cache.as_ref().unwrap().get(name))
|
||||
}
|
||||
|
||||
fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
|
||||
// Ensure cache is populated
|
||||
self.ensure_snapshots_cache()?;
|
||||
Ok(self.snapshots_cache.as_ref().unwrap())
|
||||
}
|
||||
|
||||
fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
|
||||
// Always use the cache to ensure remapping is applied consistently
|
||||
Ok(self.get_all_snapshots()?.keys().cloned().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl BurnpackStore {
|
||||
/// Ensure the snapshots cache is populated
|
||||
fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
|
||||
if self.snapshots_cache.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure reader is loaded
|
||||
self.ensure_reader()?;
|
||||
|
||||
// Get snapshots from reader with zero-copy if enabled
|
||||
let reader = self.reader.as_ref().unwrap();
|
||||
let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
|
||||
|
||||
// Apply remapping if configured (but NOT filtering - that's done at apply time)
|
||||
#[cfg(feature = "std")]
|
||||
let snapshots = if !self.remapper.patterns.is_empty() {
|
||||
let (remapped, _remapped_names) = self.remapper.remap(snapshots);
|
||||
remapped
|
||||
} else {
|
||||
snapshots
|
||||
};
|
||||
|
||||
// Build the cache as BTreeMap
|
||||
let cache: BTreeMap<String, TensorSnapshot> =
|
||||
snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
|
||||
|
||||
self.snapshots_cache = Some(cache);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
//! Tests for tensor data alignment in burnpack format.
|
||||
//!
|
||||
//! These tests verify that tensor data is properly aligned for mmap zero-copy access.
|
||||
|
||||
use crate::TensorSnapshot;
|
||||
use crate::burnpack::{
|
||||
base::{
|
||||
BurnpackHeader, BurnpackMetadata, HEADER_SIZE, TENSOR_ALIGNMENT, aligned_data_section_start,
|
||||
},
|
||||
reader::BurnpackReader,
|
||||
writer::BurnpackWriter,
|
||||
};
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
/// Verify that aligned_data_section_start always returns 256-byte aligned values
|
||||
#[test]
|
||||
fn test_aligned_data_section_start_is_always_aligned() {
|
||||
// Test various metadata sizes
|
||||
for metadata_size in 0..1024 {
|
||||
let result = aligned_data_section_start(metadata_size);
|
||||
assert_eq!(
|
||||
result % TENSOR_ALIGNMENT as usize,
|
||||
0,
|
||||
"aligned_data_section_start({}) = {} is not aligned to {}",
|
||||
metadata_size,
|
||||
result,
|
||||
TENSOR_ALIGNMENT
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify data section starts at correct aligned position
|
||||
#[test]
|
||||
fn test_data_section_alignment() {
|
||||
// Create a tensor
|
||||
let data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![4], DType::F32),
|
||||
vec!["tensor".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Parse header to get metadata size
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
// Verify data section starts at 256-byte aligned position
|
||||
assert_eq!(
|
||||
data_section_start % TENSOR_ALIGNMENT as usize,
|
||||
0,
|
||||
"Data section start {} is not 256-byte aligned",
|
||||
data_section_start
|
||||
);
|
||||
|
||||
// Verify the file is large enough
|
||||
assert!(
|
||||
file_bytes.len() >= data_section_start,
|
||||
"File too small: {} < {}",
|
||||
file_bytes.len(),
|
||||
data_section_start
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that first tensor's absolute file position is 256-byte aligned
|
||||
#[test]
|
||||
fn test_first_tensor_absolute_position_aligned() {
|
||||
let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![8], DType::U8),
|
||||
vec!["first".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
let tensor_desc = metadata.tensors.get("first").unwrap();
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
// Absolute file position of first tensor
|
||||
let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;
|
||||
|
||||
assert_eq!(
|
||||
absolute_pos % TENSOR_ALIGNMENT as usize,
|
||||
0,
|
||||
"First tensor absolute position {} is not 256-byte aligned",
|
||||
absolute_pos
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that all tensors in a multi-tensor file have 256-byte aligned absolute positions
|
||||
#[test]
|
||||
fn test_all_tensors_absolute_positions_aligned() {
|
||||
// Create multiple tensors of different sizes (all U8 to simplify shape calculation)
|
||||
let tensors = vec![
|
||||
("tensor_a", vec![1u8, 2, 3]), // 3 bytes
|
||||
("tensor_b", vec![0u8; 16]), // 16 bytes
|
||||
("tensor_c", vec![0u8; 64]), // 64 bytes
|
||||
("tensor_d", vec![42u8]), // 1 byte
|
||||
("tensor_e", vec![0u8; 400]), // 400 bytes
|
||||
];
|
||||
|
||||
let snapshots: Vec<TensorSnapshot> = tensors
|
||||
.into_iter()
|
||||
.map(|(name, data)| {
|
||||
let len = data.len();
|
||||
TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![len], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
// Check every tensor has aligned absolute position
|
||||
for (name, desc) in &metadata.tensors {
|
||||
let absolute_pos = data_section_start + desc.data_offsets.0 as usize;
|
||||
assert_eq!(
|
||||
absolute_pos % TENSOR_ALIGNMENT as usize,
|
||||
0,
|
||||
"Tensor '{}' at absolute position {} is not 256-byte aligned (offset in data section: {})",
|
||||
name,
|
||||
absolute_pos,
|
||||
desc.data_offsets.0
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test edge case: metadata size that results in no padding needed
|
||||
#[test]
|
||||
fn test_alignment_with_minimal_padding() {
|
||||
// We can't control metadata size directly, but we can verify the math works
|
||||
// When HEADER_SIZE + metadata_size is already a multiple of 256, no padding needed
|
||||
let aligned_metadata_size = TENSOR_ALIGNMENT as usize - HEADER_SIZE; // 256 - 10 = 246
|
||||
|
||||
let result = aligned_data_section_start(aligned_metadata_size);
|
||||
assert_eq!(result, TENSOR_ALIGNMENT as usize); // Should be exactly 256
|
||||
|
||||
// One byte more should still round up to 256
|
||||
let result_plus_one = aligned_data_section_start(aligned_metadata_size + 1);
|
||||
assert_eq!(result_plus_one, 2 * TENSOR_ALIGNMENT as usize); // Should be 512
|
||||
}
|
||||
|
||||
/// Verify padding bytes in the file are zeros
|
||||
#[test]
|
||||
fn test_padding_bytes_are_zeros() {
|
||||
let data: Vec<u8> = vec![0xAA; 16]; // Distinctive pattern
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![16], DType::U8),
|
||||
vec!["tensor".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
// Check padding between metadata and data section
|
||||
if data_section_start > metadata_end {
|
||||
let padding = &file_bytes[metadata_end..data_section_start];
|
||||
assert!(
|
||||
padding.iter().all(|&b| b == 0),
|
||||
"Padding bytes between metadata and data section contain non-zero values"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify alignment is sufficient for all primitive types
|
||||
/// 256-byte alignment is a multiple of all primitive type alignments:
|
||||
/// - f64/i64/u64: 8 bytes
|
||||
/// - f32/i32/u32: 4 bytes
|
||||
/// - f16/bf16/i16/u16: 2 bytes
|
||||
/// - i8/u8/bool: 1 byte
|
||||
#[test]
|
||||
#[allow(clippy::modulo_one)]
|
||||
fn test_alignment_covers_all_primitive_types() {
|
||||
// 256 must be divisible by all common alignments
|
||||
assert_eq!(
|
||||
TENSOR_ALIGNMENT % 8,
|
||||
0,
|
||||
"256 not divisible by 8 (f64 alignment)"
|
||||
);
|
||||
assert_eq!(
|
||||
TENSOR_ALIGNMENT % 4,
|
||||
0,
|
||||
"256 not divisible by 4 (f32 alignment)"
|
||||
);
|
||||
assert_eq!(
|
||||
TENSOR_ALIGNMENT % 2,
|
||||
0,
|
||||
"256 not divisible by 2 (f16 alignment)"
|
||||
);
|
||||
assert_eq!(
|
||||
TENSOR_ALIGNMENT % 1,
|
||||
0,
|
||||
"256 not divisible by 1 (u8 alignment)"
|
||||
);
|
||||
}
|
||||
|
||||
/// Verify that tensor data can be read correctly after alignment
|
||||
#[test]
|
||||
fn test_aligned_tensor_data_readable() {
|
||||
// Create f32 tensor
|
||||
let f32_data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes.clone(), vec![4], DType::F32),
|
||||
vec!["floats".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
let tensor_desc = metadata.tensors.get("floats").unwrap();
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
|
||||
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
|
||||
let tensor_bytes = &file_bytes[start..end];
|
||||
|
||||
// Verify the bytes match what we wrote
|
||||
assert_eq!(tensor_bytes, f32_bytes.as_slice());
|
||||
|
||||
// Verify we can interpret them as floats
|
||||
let mut floats = Vec::new();
|
||||
for chunk in tensor_bytes.chunks_exact(4) {
|
||||
floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
|
||||
}
|
||||
assert_eq!(floats, f32_data);
|
||||
}
|
||||
|
||||
/// Verify alignment works with f64 data
|
||||
#[test]
|
||||
fn test_aligned_f64_tensor_data_readable() {
|
||||
let f64_data = vec![1.0f64, 2.0, 3.0, 4.0];
|
||||
let f64_bytes: Vec<u8> = f64_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f64_bytes.clone(), vec![4], DType::F64),
|
||||
vec!["doubles".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
let tensor_desc = metadata.tensors.get("doubles").unwrap();
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
|
||||
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
|
||||
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
|
||||
let tensor_bytes = &file_bytes[start..end];
|
||||
|
||||
// Verify the bytes match
|
||||
assert_eq!(tensor_bytes, f64_bytes.as_slice());
|
||||
|
||||
// Verify we can interpret them as doubles
|
||||
let mut doubles = Vec::new();
|
||||
for chunk in tensor_bytes.chunks_exact(8) {
|
||||
doubles.push(f64::from_le_bytes([
|
||||
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
|
||||
]));
|
||||
}
|
||||
assert_eq!(doubles, f64_data);
|
||||
}
|
||||
|
||||
/// Test round-trip preserves alignment (write then read)
|
||||
#[test]
|
||||
fn test_round_trip_maintains_alignment() {
|
||||
let f32_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes, vec![2, 4], DType::F32),
|
||||
vec!["matrix".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
// Write
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read back
|
||||
let reader = BurnpackReader::from_bytes(file_bytes.clone()).unwrap();
|
||||
let snapshots = reader.get_snapshots().unwrap();
|
||||
|
||||
assert_eq!(snapshots.len(), 1);
|
||||
let loaded = &snapshots[0];
|
||||
assert_eq!(loaded.full_path(), "matrix");
|
||||
|
||||
// Verify the loaded data is correct
|
||||
let tensor_data = loaded.to_data().unwrap();
|
||||
let mut loaded_floats = Vec::new();
|
||||
for chunk in tensor_data.bytes.chunks_exact(4) {
|
||||
loaded_floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
|
||||
}
|
||||
assert_eq!(loaded_floats, f32_data);
|
||||
}
|
||||
|
||||
/// Test that tensor offsets within data section are also aligned
|
||||
#[test]
|
||||
fn test_tensor_relative_offsets_are_aligned() {
|
||||
// Create several small tensors to force multiple alignment padding
|
||||
let tensors: Vec<_> = (0..5)
|
||||
.map(|i| {
|
||||
let data = vec![i as u8; 7]; // 7 bytes each - not aligned
|
||||
TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![7], DType::U8),
|
||||
vec![format!("tensor_{}", i)],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let writer = BurnpackWriter::new(tensors);
|
||||
let file_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
// All tensor start offsets within data section should be multiples of 256
|
||||
for (name, desc) in &metadata.tensors {
|
||||
assert_eq!(
|
||||
desc.data_offsets.0 % TENSOR_ALIGNMENT,
|
||||
0,
|
||||
"Tensor '{}' relative offset {} is not 256-byte aligned",
|
||||
name,
|
||||
desc.data_offsets.0
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod file_tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Test alignment is preserved when writing to and reading from file
|
||||
#[test]
|
||||
fn test_file_io_preserves_alignment() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("aligned.bpk");
|
||||
|
||||
let f32_data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes, vec![4], DType::F32),
|
||||
vec!["floats".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
// Write to file
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read file bytes directly
|
||||
let file_bytes = fs::read(&file_path).unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
let tensor_desc = metadata.tensors.get("floats").unwrap();
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;
|
||||
|
||||
assert_eq!(
|
||||
absolute_pos % TENSOR_ALIGNMENT as usize,
|
||||
0,
|
||||
"Tensor absolute position in file {} is not 256-byte aligned",
|
||||
absolute_pos
|
||||
);
|
||||
|
||||
// Verify data is correct
|
||||
let start = data_section_start + tensor_desc.data_offsets.0 as usize;
|
||||
let end = data_section_start + tensor_desc.data_offsets.1 as usize;
|
||||
let tensor_bytes = &file_bytes[start..end];
|
||||
|
||||
let mut floats = Vec::new();
|
||||
for chunk in tensor_bytes.chunks_exact(4) {
|
||||
floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
|
||||
}
|
||||
assert_eq!(floats, vec![1.0f32, 2.0, 3.0, 4.0]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
use crate::TensorSnapshot;
|
||||
use crate::burnpack::{
|
||||
base::{BurnpackHeader, HEADER_SIZE},
|
||||
reader::BurnpackReader,
|
||||
writer::BurnpackWriter,
|
||||
};
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_maximum_metadata_size() {
|
||||
// Create metadata that approaches u32::MAX (4GB limit)
|
||||
// In practice, we'll test with a reasonably large metadata
|
||||
let large_key = "x".repeat(1000);
|
||||
let large_value = "y".repeat(10000);
|
||||
|
||||
let mut writer = BurnpackWriter::new(vec![]);
|
||||
|
||||
for i in 0..100 {
|
||||
writer = writer.with_metadata(&format!("{}_{}", large_key, i), &large_value);
|
||||
}
|
||||
|
||||
let result = writer.to_bytes();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let bytes = result.unwrap();
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
|
||||
// Metadata size should be large but within u32 bounds
|
||||
assert!(header.metadata_size > 1000000); // At least 1MB of metadata
|
||||
assert!(header.metadata_size < u32::MAX);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_size_tensor_shapes() {
|
||||
// Test various zero-dimensional shapes
|
||||
let test_cases = [
|
||||
(vec![0], vec![]), // Empty 1D
|
||||
(vec![0, 10], vec![]), // Zero rows
|
||||
(vec![10, 0], vec![]), // Zero columns
|
||||
(vec![0, 0], vec![]), // Zero both dimensions
|
||||
(vec![5, 0, 10], vec![]), // Zero in middle dimension
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for (i, (shape, data)) in test_cases.iter().enumerate() {
|
||||
let name = format!("zero_tensor_{}", i);
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32),
|
||||
vec![name.clone()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read back and verify
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let names = reader.tensor_names();
|
||||
assert_eq!(names.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extremely_long_tensor_names() {
|
||||
// Create a tensor with an extremely long name
|
||||
let long_name = "a".repeat(10000);
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
|
||||
vec![long_name.clone()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let names = reader.tensor_names();
|
||||
assert_eq!(names[0].len(), 10000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unicode_in_names_and_metadata() {
|
||||
// Test various Unicode characters in tensor names and metadata
|
||||
let unicode_names = vec![
|
||||
"测试_tensor", // Chinese
|
||||
"тест_tensor", // Cyrillic
|
||||
"テスト_tensor", // Japanese
|
||||
"🔥_burn_tensor", // Emoji
|
||||
"αβγδ_tensor", // Greek
|
||||
"한글_tensor", // Korean
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for name in &unicode_names {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots)
|
||||
.with_metadata("模型名称", "测试模型")
|
||||
.with_metadata("מודל", "בדיקה")
|
||||
.with_metadata("🔥", "fire");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Verify all Unicode names are preserved
|
||||
let names = reader.tensor_names();
|
||||
assert_eq!(names.len(), unicode_names.len());
|
||||
|
||||
// Verify metadata
|
||||
assert_eq!(
|
||||
reader.metadata().metadata.get("模型名称"),
|
||||
Some(&"测试模型".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
reader.metadata().metadata.get("🔥"),
|
||||
Some(&"fire".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_supported_dtypes() {
|
||||
// Test all DTypes with their boundary values
|
||||
let dtypes_with_data = [
|
||||
(
|
||||
DType::F32,
|
||||
[
|
||||
f32::MIN.to_le_bytes().to_vec(),
|
||||
f32::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(
|
||||
DType::F64,
|
||||
[
|
||||
f64::MIN.to_le_bytes().to_vec(),
|
||||
f64::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(
|
||||
DType::I32,
|
||||
[
|
||||
i32::MIN.to_le_bytes().to_vec(),
|
||||
i32::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(
|
||||
DType::I64,
|
||||
[
|
||||
i64::MIN.to_le_bytes().to_vec(),
|
||||
i64::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(
|
||||
DType::U32,
|
||||
[
|
||||
u32::MIN.to_le_bytes().to_vec(),
|
||||
u32::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(
|
||||
DType::U64,
|
||||
[
|
||||
u64::MIN.to_le_bytes().to_vec(),
|
||||
u64::MAX.to_le_bytes().to_vec(),
|
||||
]
|
||||
.concat(),
|
||||
),
|
||||
(DType::U8, vec![u8::MIN, u8::MAX]),
|
||||
(DType::Bool, vec![0, 1]),
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() {
|
||||
let name = format!("dtype_test_{}", i);
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![2], *dtype),
|
||||
vec![name],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
assert_eq!(reader.tensor_names().len(), dtypes_with_data.len());
|
||||
|
||||
// Verify dtypes are preserved
|
||||
for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() {
|
||||
let name = format!("dtype_test_{}", i);
|
||||
let snapshot = reader.get_tensor_snapshot(&name).unwrap();
|
||||
assert_eq!(snapshot.dtype, *expected_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_float_values() {
|
||||
// Test special floating-point values (NaN, Inf, -Inf)
|
||||
let special_values = [
|
||||
f32::NAN,
|
||||
f32::INFINITY,
|
||||
f32::NEG_INFINITY,
|
||||
0.0_f32,
|
||||
-0.0_f32,
|
||||
];
|
||||
|
||||
let data: Vec<u8> = special_values
|
||||
.iter()
|
||||
.flat_map(|f| f.to_le_bytes())
|
||||
.collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32),
|
||||
vec!["special_floats".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let tensor_data = reader.get_tensor_data("special_floats").unwrap();
|
||||
|
||||
// Check data is preserved exactly (bit-for-bit)
|
||||
assert_eq!(tensor_data, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_with_empty_values() {
|
||||
let writer = BurnpackWriter::new(vec![])
|
||||
.with_metadata("empty_value", "")
|
||||
.with_metadata("", "empty_key")
|
||||
.with_metadata("normal", "value");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
let metadata = &reader.metadata().metadata;
|
||||
assert_eq!(metadata.get("empty_value"), Some(&"".to_string()));
|
||||
assert_eq!(metadata.get(""), Some(&"empty_key".to_string()));
|
||||
assert_eq!(metadata.get("normal"), Some(&"value".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_byte_tensor() {
|
||||
// Test the smallest possible tensor (1 byte)
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![42], vec![1], DType::U8),
|
||||
vec!["single_byte".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let data = reader.get_tensor_data("single_byte").unwrap();
|
||||
assert_eq!(data, vec![42]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_dimensional_tensor() {
|
||||
// Test a tensor with many dimensions (10D)
|
||||
let shape = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; // 10 dimensions, 1024 elements total
|
||||
let data = vec![1u8; 1024];
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8),
|
||||
vec!["high_dim".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let loaded_snapshot = reader.get_tensor_snapshot("high_dim").unwrap();
|
||||
assert_eq!(loaded_snapshot.shape, shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metadata_key_collision() {
|
||||
// Test that later values override earlier ones for the same key
|
||||
let writer = BurnpackWriter::new(vec![])
|
||||
.with_metadata("key", "value1")
|
||||
.with_metadata("key", "value2")
|
||||
.with_metadata("key", "value3");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
reader.metadata().metadata.get("key"),
|
||||
Some(&"value3".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_name_with_path_separators() {
|
||||
// Test tensor names that look like file paths
|
||||
let path_like_names = vec![
|
||||
"model/encoder/layer1/weights",
|
||||
"model\\decoder\\layer1\\bias",
|
||||
"model::module::param",
|
||||
"model.submodule.weight",
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for name in &path_like_names {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let names = reader.tensor_names();
|
||||
|
||||
// All names should be preserved exactly
|
||||
for expected_name in &path_like_names {
|
||||
assert!(names.contains(expected_name));
|
||||
}
|
||||
}
|
||||
|
||||
// The following tests are commented out as they test error conditions
|
||||
// that might be handled differently in the new API
|
||||
|
||||
// #[test]
|
||||
// fn test_data_overflow_protection() {
|
||||
// // Test that we handle potential integer overflows in offset calculations
|
||||
// ...
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_reading_corrupted_header() {
|
||||
// // Test reading files with corrupted headers
|
||||
// ...
|
||||
// }
|
||||
@@ -0,0 +1,61 @@
|
||||
use crate::burnpack::base::*;
|
||||
|
||||
#[test]
|
||||
fn test_header_serialization() {
|
||||
let header = BurnpackHeader::new(12345);
|
||||
|
||||
// Check fields
|
||||
assert_eq!(header.magic, MAGIC_NUMBER);
|
||||
assert_eq!(header.version, FORMAT_VERSION);
|
||||
assert_eq!(header.metadata_size, 12345);
|
||||
|
||||
// Serialize to bytes
|
||||
let bytes = header.into_bytes();
|
||||
assert_eq!(bytes.len(), HEADER_SIZE);
|
||||
|
||||
// Deserialize back
|
||||
let header2 = BurnpackHeader::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(header2.magic, header.magic);
|
||||
assert_eq!(header2.version, header.version);
|
||||
assert_eq!(header2.metadata_size, header.metadata_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_header_invalid_magic() {
|
||||
let mut bytes = [0u8; HEADER_SIZE];
|
||||
// Write wrong magic number
|
||||
bytes[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
|
||||
|
||||
let result = BurnpackHeader::from_bytes(&bytes);
|
||||
match result {
|
||||
Err(BurnpackError::InvalidMagicNumber) => {}
|
||||
_ => panic!("Expected InvalidMagicNumber error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_header_insufficient_bytes() {
|
||||
let bytes = [0u8; 5]; // Too short
|
||||
|
||||
let result = BurnpackHeader::from_bytes(&bytes);
|
||||
match result {
|
||||
Err(BurnpackError::InvalidHeader) => {}
|
||||
_ => panic!("Expected InvalidHeader error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_compatibility() {
|
||||
// Create a header with current version
|
||||
let header = BurnpackHeader::new(100);
|
||||
let bytes = header.into_bytes();
|
||||
|
||||
// Should succeed with current version
|
||||
let result = BurnpackHeader::from_bytes(&bytes);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test with future version (should fail in real implementation)
|
||||
// For now, we just verify the version field is correctly set
|
||||
let header = result.unwrap();
|
||||
assert_eq!(header.version, FORMAT_VERSION);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use crate::TensorSnapshot;
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
/// Helper to create a test TensorSnapshot
|
||||
#[allow(dead_code)]
|
||||
pub fn create_test_snapshot(
|
||||
name: String,
|
||||
data: Vec<u8>,
|
||||
shape: Vec<usize>,
|
||||
dtype: DType,
|
||||
) -> TensorSnapshot {
|
||||
TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, shape, dtype),
|
||||
vec![name],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
use crate::TensorSnapshot;
|
||||
|
||||
mod alignment;
|
||||
mod edge_cases;
|
||||
mod header;
|
||||
mod helpers;
|
||||
mod reader;
|
||||
mod round_trip;
|
||||
mod store;
|
||||
mod writer;
|
||||
mod zero_copy;
|
||||
@@ -0,0 +1,775 @@
|
||||
use crate::burnpack::{
|
||||
base::{
|
||||
BurnpackError, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, metadata_size_range,
|
||||
version_range,
|
||||
},
|
||||
reader::BurnpackReader,
|
||||
writer::BurnpackWriter,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::{Bytes, DType, TensorData};
|
||||
|
||||
#[test]
|
||||
fn test_reader_from_bytes_empty() {
|
||||
// Create empty burnpack data
|
||||
let writer = BurnpackWriter::new(Vec::new());
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read it back
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
assert_eq!(reader.metadata().tensors.len(), 0);
|
||||
assert!(reader.metadata().metadata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_from_bytes_with_data() {
|
||||
// Create test data
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read it back
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
assert_eq!(reader.metadata().tensors.len(), 1);
|
||||
assert_eq!(
|
||||
reader.metadata().metadata.get("test"),
|
||||
Some(&"value".to_string())
|
||||
);
|
||||
|
||||
// Get tensor data
|
||||
let tensor_data = reader.get_tensor_data("test_tensor").unwrap();
|
||||
assert_eq!(tensor_data, &[1, 2, 3, 4]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_invalid_magic_number() {
|
||||
let mut bytes = vec![0u8; 100];
|
||||
// Write invalid magic number
|
||||
bytes[magic_range()].copy_from_slice(b"NOPE");
|
||||
|
||||
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
|
||||
assert!(matches!(result, Err(BurnpackError::InvalidMagicNumber)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_invalid_version() {
|
||||
let mut bytes = vec![0u8; 100];
|
||||
// Write correct magic but invalid version
|
||||
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
|
||||
bytes[version_range()].copy_from_slice(&999u16.to_le_bytes()); // Invalid version
|
||||
bytes[metadata_size_range()].copy_from_slice(&10u32.to_le_bytes()); // Metadata size
|
||||
|
||||
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
|
||||
assert!(matches!(result, Err(BurnpackError::InvalidVersion)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_header_too_short() {
|
||||
let bytes = vec![0u8; 5]; // Less than HEADER_SIZE
|
||||
|
||||
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
|
||||
assert!(matches!(result, Err(BurnpackError::InvalidHeader)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_metadata_truncated() {
|
||||
let mut bytes = vec![0u8; HEADER_SIZE + 10];
|
||||
// Write valid header
|
||||
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
|
||||
bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());
|
||||
bytes[metadata_size_range()].copy_from_slice(&100u32.to_le_bytes()); // Claims 100 bytes of metadata
|
||||
|
||||
// But only provide 10 bytes after header
|
||||
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
|
||||
assert!(matches!(result, Err(BurnpackError::InvalidHeader)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_get_tensor_not_found() {
|
||||
let writer = BurnpackWriter::new(Vec::new());
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
let result = reader.get_tensor_data("non_existent");
|
||||
assert!(matches!(result, Err(BurnpackError::TensorNotFound(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_get_tensor_snapshot() {
|
||||
let data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
|
||||
vec!["weights".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let writer_bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(writer_bytes).unwrap();
|
||||
|
||||
// Get tensor as snapshot
|
||||
let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap();
|
||||
|
||||
// Verify snapshot metadata
|
||||
assert_eq!(loaded_snapshot.full_path(), "weights");
|
||||
assert_eq!(loaded_snapshot.dtype, DType::F32);
|
||||
assert_eq!(loaded_snapshot.shape, vec![2, 2]);
|
||||
|
||||
// Verify data through closure
|
||||
let tensor_data = loaded_snapshot.to_data().unwrap();
|
||||
assert_eq!(tensor_data.shape, vec![2, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_multiple_tensors() {
|
||||
// Add multiple tensors
|
||||
let mut snapshots = Vec::new();
|
||||
for i in 0..10 {
|
||||
let name = format!("tensor_{}", i);
|
||||
let data = vec![i as u8; 100];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![100], DType::U8),
|
||||
vec![name.clone()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Verify all tensors can be read
|
||||
for i in 0..10 {
|
||||
let name = format!("tensor_{}", i);
|
||||
let data = reader.get_tensor_data(&name).unwrap();
|
||||
assert_eq!(data.len(), 100);
|
||||
assert!(data.iter().all(|&b| b == i as u8));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_lazy_loading() {
|
||||
// Create large tensor
|
||||
let size = 1024 * 1024; // 1MB
|
||||
let data = vec![42u8; size];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![size], DType::U8),
|
||||
vec!["large".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Get snapshot (should be lazy)
|
||||
let snapshot = reader.get_tensor_snapshot("large").unwrap();
|
||||
|
||||
// Data should only be accessed when to_data is called
|
||||
let tensor_data = snapshot.to_data().unwrap();
|
||||
assert_eq!(tensor_data.bytes.len(), size);
|
||||
assert!(tensor_data.bytes.iter().all(|&b| b == 42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_all_dtypes() {
|
||||
// Test all data types
|
||||
let test_data = [
|
||||
(DType::F32, [1.0f32.to_le_bytes().to_vec()].concat()),
|
||||
(DType::F64, [2.0f64.to_le_bytes().to_vec()].concat()),
|
||||
(DType::I32, [3i32.to_le_bytes().to_vec()].concat()),
|
||||
(DType::I64, [4i64.to_le_bytes().to_vec()].concat()),
|
||||
(DType::U32, [5u32.to_le_bytes().to_vec()].concat()),
|
||||
(DType::U64, [6u64.to_le_bytes().to_vec()].concat()),
|
||||
(DType::U8, vec![7u8]),
|
||||
(DType::Bool, vec![1u8]),
|
||||
];
|
||||
|
||||
let mut snapshots = Vec::new();
|
||||
for (i, (dtype, data)) in test_data.iter().enumerate() {
|
||||
let name = format!("tensor_{}", i);
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![1], *dtype),
|
||||
vec![name.clone()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Verify all dtypes are preserved
|
||||
for (i, (expected_dtype, expected_data)) in test_data.iter().enumerate() {
|
||||
let name = format!("tensor_{}", i);
|
||||
let snapshot = reader.get_tensor_snapshot(&name).unwrap();
|
||||
assert_eq!(snapshot.dtype, *expected_dtype);
|
||||
|
||||
let data = reader.get_tensor_data(&name).unwrap();
|
||||
assert_eq!(data, expected_data.as_slice());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_empty_tensor() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
|
||||
vec!["empty".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
let data = reader.get_tensor_data("empty").unwrap();
|
||||
assert_eq!(data.len(), 0);
|
||||
|
||||
let snapshot = reader.get_tensor_snapshot("empty").unwrap();
|
||||
assert_eq!(snapshot.shape, vec![0]);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_reader_from_file() {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.bpk");
|
||||
|
||||
// Create test file
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![10, 20, 30], vec![3], DType::U8),
|
||||
vec!["file_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("from_file_test", "true");
|
||||
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read from file
|
||||
let reader = BurnpackReader::from_file(&file_path).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
reader.metadata().metadata.get("from_file_test"),
|
||||
Some(&"true".to_string())
|
||||
);
|
||||
|
||||
let data = reader.get_tensor_data("file_tensor").unwrap();
|
||||
assert_eq!(data, &[10, 20, 30]);
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "std", feature = "memmap"))]
|
||||
#[test]
|
||||
fn test_reader_from_file_mmap() {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test_mmap.bpk");
|
||||
|
||||
// Create large test file
|
||||
let size = 1024 * 1024; // 1MB
|
||||
let data = vec![99u8; size];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![size], DType::U8),
|
||||
vec!["large_mmap".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read using mmap
|
||||
let reader = BurnpackReader::from_file_mmap(&file_path).unwrap();
|
||||
|
||||
let data = reader.get_tensor_data("large_mmap").unwrap();
|
||||
assert_eq!(data.len(), size);
|
||||
assert!(data.iter().all(|&b| b == 99));
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_reader_from_file_buffered() {
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test_buffered.bpk");
|
||||
|
||||
// Create test file
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![5, 10, 15], vec![3], DType::U8),
|
||||
vec!["buffered_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read using buffered reader
|
||||
let reader = BurnpackReader::from_file_buffered(&file_path).unwrap();
|
||||
|
||||
let data = reader.get_tensor_data("buffered_tensor").unwrap();
|
||||
assert_eq!(data, &[5, 10, 15]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_metadata_access() {
|
||||
// Add various metadata using builder pattern
|
||||
let writer = BurnpackWriter::new(Vec::new())
|
||||
.with_metadata("model_name", "test_model")
|
||||
.with_metadata("version", "1.2.3")
|
||||
.with_metadata("author", "test_author")
|
||||
.with_metadata("description", "A test model");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
let metadata = reader.metadata();
|
||||
assert_eq!(metadata.metadata.len(), 4);
|
||||
assert_eq!(
|
||||
metadata.metadata.get("model_name"),
|
||||
Some(&"test_model".to_string())
|
||||
);
|
||||
assert_eq!(metadata.metadata.get("version"), Some(&"1.2.3".to_string()));
|
||||
assert_eq!(
|
||||
metadata.metadata.get("author"),
|
||||
Some(&"test_author".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
metadata.metadata.get("description"),
|
||||
Some(&"A test model".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_tensor_iteration() {
|
||||
// Add tensors
|
||||
let tensor_names = vec!["weights", "bias", "running_mean", "running_var"];
|
||||
let mut snapshots = Vec::new();
|
||||
for name in &tensor_names {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Iterate through all tensors
|
||||
let metadata = reader.metadata();
|
||||
assert_eq!(metadata.tensors.len(), 4);
|
||||
|
||||
// Check that all expected tensor names are present
|
||||
for name in &tensor_names {
|
||||
let tensor_desc = metadata.tensors.get(*name).unwrap();
|
||||
assert_eq!(tensor_desc.shape, vec![4u64]);
|
||||
assert_eq!(tensor_desc.dtype, DType::U8);
|
||||
}
|
||||
|
||||
// Verify the keys match the expected names
|
||||
let mut actual_names: Vec<_> = metadata.tensors.keys().cloned().collect();
|
||||
actual_names.sort();
|
||||
let mut expected_names = tensor_names
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
expected_names.sort();
|
||||
assert_eq!(actual_names, expected_names);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_corrupt_metadata() {
|
||||
let mut bytes = vec![0u8; 100];
|
||||
|
||||
// Write valid header
|
||||
bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());
|
||||
bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());
|
||||
bytes[metadata_size_range()].copy_from_slice(&50u32.to_le_bytes()); // 50 bytes of metadata
|
||||
|
||||
// Write garbage as metadata
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in HEADER_SIZE..HEADER_SIZE + 50 {
|
||||
bytes[i] = 0xFF;
|
||||
}
|
||||
|
||||
let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_data_offsets_validation() {
|
||||
// Add two tensors
|
||||
let snapshot1 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
|
||||
vec!["tensor1".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
let snapshot2 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),
|
||||
vec!["tensor2".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Verify offsets don't overlap and are properly aligned
|
||||
let metadata = reader.metadata();
|
||||
let tensor1_desc = metadata.tensors.get("tensor1").unwrap();
|
||||
let tensor2_desc = metadata.tensors.get("tensor2").unwrap();
|
||||
|
||||
// First tensor starts at offset 0 (already aligned to 256 bytes)
|
||||
assert_eq!(tensor1_desc.data_offsets, (0, 4));
|
||||
// Second tensor starts at next 256-byte aligned offset
|
||||
assert_eq!(tensor2_desc.data_offsets, (256, 260));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_out_of_bounds_error() {
|
||||
use crate::burnpack::reader::StorageBackend;
|
||||
use alloc::rc::Rc;
|
||||
|
||||
// Create a small data buffer
|
||||
let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);
|
||||
let backend = StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
// Try to read beyond the available data
|
||||
let mut buffer = vec![0u8; 10];
|
||||
let result = backend.read_into(&mut buffer, 0);
|
||||
|
||||
// Should return an error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("out of bounds"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_offset_overflow_error() {
|
||||
use crate::burnpack::reader::StorageBackend;
|
||||
use alloc::rc::Rc;
|
||||
|
||||
let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);
|
||||
let backend = StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
// Try to read with an offset that would overflow
|
||||
let mut buffer = vec![0u8; 10];
|
||||
let result = backend.read_into(&mut buffer, usize::MAX - 5);
|
||||
|
||||
// Should return an error about overflow
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("overflow"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_corrupted_shape_returns_error() {
|
||||
// Only test this on platforms where usize is smaller than u64
|
||||
// On 64-bit platforms, u64 values can fit in usize
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
{
|
||||
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::rc::Rc;
|
||||
use burn_tensor::DType;
|
||||
|
||||
// Create metadata with a shape dimension that exceeds usize::MAX on 32-bit platforms
|
||||
let mut tensors = BTreeMap::new();
|
||||
tensors.insert(
|
||||
"corrupted_tensor".to_string(),
|
||||
TensorDescriptor {
|
||||
dtype: DType::F32,
|
||||
shape: vec![u64::MAX, 2, 3], // First dimension exceeds usize::MAX on 32-bit
|
||||
data_offsets: (0, 100),
|
||||
param_id: None,
|
||||
},
|
||||
);
|
||||
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: BTreeMap::new(),
|
||||
};
|
||||
|
||||
// Create a small data buffer
|
||||
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
|
||||
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
let reader = BurnpackReader {
|
||||
metadata,
|
||||
storage: backend,
|
||||
data_offset: 0,
|
||||
};
|
||||
|
||||
// This should return an error, not panic
|
||||
let result = reader.get_snapshots();
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(
|
||||
err.to_string().contains("corrupted shape data")
|
||||
|| err.to_string().contains("exceeds platform maximum")
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(target_pointer_width = "32"))]
|
||||
{
|
||||
// On 64-bit platforms, just pass the test
|
||||
// The conversion logic is still correct, but u64 fits in usize
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_corrupted_offsets_returns_error() {
|
||||
// Only test this on platforms where usize is smaller than u64
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
{
|
||||
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::rc::Rc;
|
||||
use burn_tensor::DType;
|
||||
|
||||
// Create metadata with offsets that would overflow
|
||||
let mut tensors = BTreeMap::new();
|
||||
tensors.insert(
|
||||
"tensor_bad_offset".to_string(),
|
||||
TensorDescriptor {
|
||||
dtype: DType::F32,
|
||||
shape: vec![2, 2],
|
||||
data_offsets: (u64::MAX - 10, u64::MAX), // Offsets that exceed usize::MAX on 32-bit
|
||||
param_id: None,
|
||||
},
|
||||
);
|
||||
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: BTreeMap::new(),
|
||||
};
|
||||
|
||||
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
|
||||
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
let reader = BurnpackReader {
|
||||
metadata,
|
||||
storage: backend,
|
||||
data_offset: 0,
|
||||
};
|
||||
|
||||
// This should return an error, not panic
|
||||
let result = reader.get_snapshots();
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(
|
||||
err.to_string().contains("corrupted offset data")
|
||||
|| err.to_string().contains("exceeds platform maximum")
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(target_pointer_width = "32"))]
|
||||
{
|
||||
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::rc::Rc;
|
||||
use burn_tensor::DType;
|
||||
|
||||
// On 64-bit platforms, test offset overflow during addition
|
||||
let mut tensors = BTreeMap::new();
|
||||
tensors.insert(
|
||||
"tensor_overflow".to_string(),
|
||||
TensorDescriptor {
|
||||
dtype: DType::F32,
|
||||
shape: vec![2, 2],
|
||||
data_offsets: (0, 100),
|
||||
param_id: None,
|
||||
},
|
||||
);
|
||||
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: BTreeMap::new(),
|
||||
};
|
||||
|
||||
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
|
||||
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
// Use a data_offset that will overflow when added to the tensor offset
|
||||
let reader = BurnpackReader {
|
||||
metadata,
|
||||
storage: backend,
|
||||
data_offset: usize::MAX - 50, // Will overflow when added to 100
|
||||
};
|
||||
|
||||
// This should return an error, not panic
|
||||
let result = reader.get_snapshots();
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(err.to_string().contains("overflow"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_inverted_offsets_returns_error() {
|
||||
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::rc::Rc;
|
||||
use burn_tensor::DType;
|
||||
|
||||
// Create metadata with end offset < start offset (corrupted)
|
||||
let mut tensors = BTreeMap::new();
|
||||
tensors.insert(
|
||||
"inverted_tensor".to_string(),
|
||||
TensorDescriptor {
|
||||
dtype: DType::F32,
|
||||
shape: vec![2, 2],
|
||||
data_offsets: (100, 50), // End offset < start offset
|
||||
param_id: None,
|
||||
},
|
||||
);
|
||||
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: BTreeMap::new(),
|
||||
};
|
||||
|
||||
let data = Bytes::from_bytes_vec(vec![0u8; 1000]);
|
||||
let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));
|
||||
|
||||
let reader = BurnpackReader {
|
||||
metadata,
|
||||
storage: backend,
|
||||
data_offset: 0,
|
||||
};
|
||||
|
||||
// This should return an error, not panic
|
||||
let result = reader.get_snapshots();
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(err.to_string().contains("end offset") && err.to_string().contains("start offset"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_truncated_file_from_bytes() {
|
||||
// Create a valid burnpack with tensor data
|
||||
let tensor_size = 1024; // 1KB of data
|
||||
let data = vec![42u8; tensor_size];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
|
||||
vec!["large_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let full_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Truncate the bytes - remove the last 512 bytes of tensor data
|
||||
let truncated_len = full_bytes.len() - 512;
|
||||
let truncated_bytes = Bytes::from_bytes_vec(full_bytes.to_vec()[..truncated_len].to_vec());
|
||||
|
||||
// This should fail with a validation error indicating file truncation
|
||||
let result = BurnpackReader::from_bytes(truncated_bytes);
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(err.to_string().contains("File truncated"));
|
||||
assert!(err.to_string().contains("expected at least"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_reader_truncated_file_from_file() {
|
||||
use std::fs::OpenOptions;
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("truncated.bpk");
|
||||
|
||||
// Create a valid burnpack file with tensor data
|
||||
let tensor_size = 2048; // 2KB of data
|
||||
let data = vec![99u8; tensor_size];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
|
||||
vec!["data_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read the full file to get its size
|
||||
let full_size = std::fs::metadata(&file_path).unwrap().len();
|
||||
|
||||
// Truncate the file - remove the last 1KB
|
||||
let truncated_size = full_size - 1024;
|
||||
let truncated_file = OpenOptions::new().write(true).open(&file_path).unwrap();
|
||||
truncated_file.set_len(truncated_size).unwrap();
|
||||
drop(truncated_file);
|
||||
|
||||
// Try to read the truncated file - should fail with validation error
|
||||
let result = BurnpackReader::from_file(&file_path);
|
||||
assert!(result.is_err());
|
||||
if let Err(err) = result {
|
||||
assert!(matches!(err, BurnpackError::ValidationError(_)));
|
||||
assert!(err.to_string().contains("File truncated"));
|
||||
assert!(err.to_string().contains("expected at least"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reader_file_size_exactly_correct() {
|
||||
// Test that a file with exactly the right size passes validation
|
||||
let tensor_size = 100;
|
||||
let data = vec![77u8; tensor_size];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),
|
||||
vec!["exact_size".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// This should succeed - file is exactly the right size
|
||||
let reader = BurnpackReader::from_bytes(bytes);
|
||||
assert!(reader.is_ok());
|
||||
|
||||
// Verify we can read the data
|
||||
let reader = reader.unwrap();
|
||||
let tensor_data = reader.get_tensor_data("exact_size").unwrap();
|
||||
assert_eq!(tensor_data.len(), tensor_size);
|
||||
assert!(tensor_data.iter().all(|&b| b == 77));
|
||||
}
|
||||
@@ -0,0 +1,606 @@
|
||||
use crate::burnpack::{reader::BurnpackReader, writer::BurnpackWriter};
|
||||
|
||||
use super::*;
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::string::String;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
|
||||
/// Helper function to perform round-trip test
|
||||
fn round_trip_test<F>(setup: F)
|
||||
where
|
||||
F: FnOnce(&mut Vec<TensorSnapshot>, &mut BTreeMap<String, String>),
|
||||
{
|
||||
// Collect snapshots and metadata
|
||||
let mut snapshots = Vec::new();
|
||||
let mut metadata = BTreeMap::new();
|
||||
setup(&mut snapshots, &mut metadata);
|
||||
|
||||
// Sort snapshots by name to ensure consistent ordering
|
||||
// This is necessary because BTreeMap will store them sorted
|
||||
snapshots.sort_by_key(|a| a.full_path());
|
||||
|
||||
// Create writer with snapshots and metadata
|
||||
let mut writer = BurnpackWriter::new(snapshots);
|
||||
for (key, value) in &metadata {
|
||||
writer = writer.with_metadata(key, value);
|
||||
}
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
let reader = BurnpackReader::from_bytes(bytes.clone()).unwrap();
|
||||
|
||||
// Write to bytes again from reader data
|
||||
let mut snapshots2 = Vec::new();
|
||||
|
||||
// Copy tensors (metadata.tensors is now BTreeMap<String, TensorDescriptor>)
|
||||
// They will come out in sorted order from tensor_names()
|
||||
for tensor_name in reader.tensor_names() {
|
||||
let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();
|
||||
snapshots2.push(snapshot);
|
||||
}
|
||||
|
||||
// Create writer2 with collected snapshots and metadata
|
||||
let mut writer2 = BurnpackWriter::new(snapshots2);
|
||||
for (key, value) in &reader.metadata().metadata {
|
||||
writer2 = writer2.with_metadata(key, value);
|
||||
}
|
||||
|
||||
let bytes2 = writer2.to_bytes().unwrap();
|
||||
|
||||
// Both byte representations should be identical
|
||||
assert_eq!(bytes, bytes2, "Round-trip produced different bytes");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_empty() {
|
||||
round_trip_test(|_snapshots, _metadata| {
|
||||
// Empty writer
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_metadata_only() {
|
||||
round_trip_test(|_snapshots, metadata| {
|
||||
metadata.insert("key1".to_string(), "value1".to_string());
|
||||
metadata.insert("key2".to_string(), "value2".to_string());
|
||||
metadata.insert("key3".to_string(), "value3".to_string());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_f32() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![5], DType::F32),
|
||||
vec!["f32_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_f64() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [1.0f64, 2.0, 3.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![3], DType::F64),
|
||||
vec!["f64_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_i32() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [-10i32, 0, 10, 20];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![4], DType::I32),
|
||||
vec!["i32_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_i64() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [i64::MIN, 0, i64::MAX];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![3], DType::I64),
|
||||
vec!["i64_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_u32() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [0u32, 100, 1000, u32::MAX];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![4], DType::U32),
|
||||
vec!["u32_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_u64() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = [0u64, u64::MAX / 2, u64::MAX];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![3], DType::U64),
|
||||
vec!["u64_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_u8() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = vec![0u8, 127, 255];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![3], DType::U8),
|
||||
vec!["u8_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_bool() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
let data = vec![0u8, 1, 0, 1, 1];
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data, vec![5], DType::Bool),
|
||||
vec!["bool_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_mixed_dtypes() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
// F32
|
||||
let f32_data = [1.0f32, 2.0];
|
||||
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let f32_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes, vec![2], DType::F32),
|
||||
vec!["f32".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(f32_snapshot);
|
||||
|
||||
// I64
|
||||
let i64_data = [100i64, 200];
|
||||
let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();
|
||||
let i64_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(i64_bytes, vec![2], DType::I64),
|
||||
vec!["i64".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(i64_snapshot);
|
||||
|
||||
// Bool
|
||||
let bool_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 0, 1], vec![3], DType::Bool),
|
||||
vec!["bool".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(bool_snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_multidimensional() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
// 2D tensor
|
||||
let data_2d = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let bytes_2d: Vec<u8> = data_2d.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot_2d = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes_2d, vec![2, 3], DType::F32),
|
||||
vec!["tensor_2d".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot_2d);
|
||||
|
||||
// 3D tensor
|
||||
let data_3d = [1.0f32; 24];
|
||||
let bytes_3d: Vec<u8> = data_3d.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot_3d = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes_3d, vec![2, 3, 4], DType::F32),
|
||||
vec!["tensor_3d".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot_3d);
|
||||
|
||||
// 4D tensor (common for CNNs)
|
||||
let data_4d = vec![1.0f32; 120];
|
||||
let bytes_4d: Vec<u8> = data_4d.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot_4d = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes_4d, vec![2, 3, 4, 5], DType::F32),
|
||||
vec!["tensor_4d".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot_4d);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_with_metadata_and_tensors() {
|
||||
round_trip_test(|snapshots, metadata| {
|
||||
// Add metadata
|
||||
metadata.insert("model_name".to_string(), "test_model".to_string());
|
||||
metadata.insert("version".to_string(), "1.0.0".to_string());
|
||||
metadata.insert(
|
||||
"description".to_string(),
|
||||
"A test model for round-trip testing".to_string(),
|
||||
);
|
||||
|
||||
// Add tensors
|
||||
let weights = [0.1f32, 0.2, 0.3, 0.4];
|
||||
let weights_bytes: Vec<u8> = weights.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let weights_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(weights_bytes, vec![2, 2], DType::F32),
|
||||
vec!["layer1".to_string(), "weights".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(weights_snapshot);
|
||||
|
||||
let bias = [0.5f32, 0.6];
|
||||
let bias_bytes: Vec<u8> = bias.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let bias_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bias_bytes, vec![2], DType::F32),
|
||||
vec!["layer1".to_string(), "bias".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(bias_snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_special_values() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
// Test special float values
|
||||
let special_f32 = [
|
||||
0.0f32,
|
||||
-0.0,
|
||||
f32::INFINITY,
|
||||
f32::NEG_INFINITY,
|
||||
f32::NAN,
|
||||
f32::MIN,
|
||||
f32::MAX,
|
||||
f32::EPSILON,
|
||||
];
|
||||
let f32_bytes: Vec<u8> = special_f32.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let f32_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes, vec![8], DType::F32),
|
||||
vec!["special_f32".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(f32_snapshot);
|
||||
|
||||
// Test special f64 values
|
||||
let special_f64 = [
|
||||
0.0f64,
|
||||
-0.0,
|
||||
f64::INFINITY,
|
||||
f64::NEG_INFINITY,
|
||||
f64::NAN,
|
||||
f64::MIN,
|
||||
f64::MAX,
|
||||
f64::EPSILON,
|
||||
];
|
||||
let f64_bytes: Vec<u8> = special_f64.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let f64_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f64_bytes, vec![8], DType::F64),
|
||||
vec!["special_f64".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(f64_snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_large_tensors() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
// Large tensor (100KB)
|
||||
let size = 25600; // 100KB / 4 bytes per f32
|
||||
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![size], DType::F32),
|
||||
vec!["large_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_round_trip_file_io() {
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::burnpack::writer::BurnpackWriter;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("round_trip.bpk");
|
||||
|
||||
// Create original data
|
||||
let data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
|
||||
vec!["weights".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "round_trip");
|
||||
|
||||
// Write to file
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Read from file
|
||||
let reader = BurnpackReader::from_file(&file_path).unwrap();
|
||||
|
||||
// Write to another file
|
||||
let file_path2 = dir.path().join("round_trip2.bpk");
|
||||
|
||||
// Collect snapshots from reader
|
||||
let mut snapshots2 = Vec::new();
|
||||
for tensor_name in reader.tensor_names() {
|
||||
let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();
|
||||
snapshots2.push(snapshot);
|
||||
}
|
||||
|
||||
// Create writer2 with snapshots and metadata
|
||||
let mut writer2 = BurnpackWriter::new(snapshots2);
|
||||
for (key, value) in &reader.metadata().metadata {
|
||||
writer2 = writer2.with_metadata(key, value);
|
||||
}
|
||||
|
||||
writer2.write_to_file(&file_path2).unwrap();
|
||||
|
||||
// Compare files
|
||||
let bytes1 = fs::read(&file_path).unwrap();
|
||||
let bytes2 = fs::read(&file_path2).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
bytes1, bytes2,
|
||||
"Round-trip through files produced different content"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_empty_shapes() {
|
||||
round_trip_test(|snapshots, _metadata| {
|
||||
// Scalar (0-dimensional)
|
||||
let scalar = [42.0f32];
|
||||
let scalar_bytes: Vec<u8> = scalar.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let scalar_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(scalar_bytes, vec![], DType::F32),
|
||||
vec!["scalar".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(scalar_snapshot);
|
||||
|
||||
// Empty tensor
|
||||
let empty_snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
|
||||
vec!["empty".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(empty_snapshot);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_param_id_persistence() {
|
||||
use burn_core::module::ParamId;
|
||||
|
||||
// Create a specific ParamId with a known value
|
||||
let original_param_id = ParamId::from(123456789u64);
|
||||
|
||||
let data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),
|
||||
vec!["weights".to_string()],
|
||||
vec![],
|
||||
original_param_id,
|
||||
);
|
||||
|
||||
// Write to burnpack
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read back from burnpack
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap();
|
||||
|
||||
// Verify ParamId was preserved
|
||||
assert!(
|
||||
loaded_snapshot.tensor_id.is_some(),
|
||||
"ParamId should be present"
|
||||
);
|
||||
let loaded_param_id = loaded_snapshot.tensor_id.unwrap();
|
||||
assert_eq!(
|
||||
loaded_param_id.val(),
|
||||
original_param_id.val(),
|
||||
"ParamId value should be preserved: expected {}, got {}",
|
||||
original_param_id.val(),
|
||||
loaded_param_id.val()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_param_id_backward_compatibility() {
|
||||
use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};
|
||||
use alloc::collections::BTreeMap;
|
||||
|
||||
// Create metadata without param_id (simulating old burnpack format)
|
||||
let mut tensors = BTreeMap::new();
|
||||
tensors.insert(
|
||||
"old_tensor".to_string(),
|
||||
TensorDescriptor {
|
||||
dtype: DType::F32,
|
||||
shape: vec![2, 2],
|
||||
data_offsets: (0, 16),
|
||||
param_id: None, // No param_id stored (old format)
|
||||
},
|
||||
);
|
||||
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: BTreeMap::new(),
|
||||
};
|
||||
|
||||
// Serialize metadata
|
||||
let mut metadata_bytes = Vec::new();
|
||||
ciborium::ser::into_writer(&metadata, &mut metadata_bytes).unwrap();
|
||||
|
||||
// Create a complete burnpack with header and data
|
||||
use crate::burnpack::base::{BurnpackHeader, FORMAT_VERSION, MAGIC_NUMBER};
|
||||
|
||||
let metadata_size = metadata_bytes.len() as u32;
|
||||
let header = BurnpackHeader {
|
||||
magic: MAGIC_NUMBER,
|
||||
version: FORMAT_VERSION,
|
||||
metadata_size,
|
||||
};
|
||||
|
||||
let mut full_bytes = Vec::new();
|
||||
full_bytes.extend_from_slice(&header.into_bytes());
|
||||
full_bytes.extend_from_slice(&metadata_bytes);
|
||||
|
||||
// Add tensor data (4 f32 values = 16 bytes)
|
||||
let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
for value in tensor_data {
|
||||
full_bytes.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
// Read the old format burnpack
|
||||
let reader =
|
||||
BurnpackReader::from_bytes(burn_tensor::Bytes::from_bytes_vec(full_bytes)).unwrap();
|
||||
let loaded_snapshot = reader.get_tensor_snapshot("old_tensor").unwrap();
|
||||
|
||||
// Verify that a new ParamId was generated (backward compatibility)
|
||||
assert!(
|
||||
loaded_snapshot.tensor_id.is_some(),
|
||||
"ParamId should be generated for old format"
|
||||
);
|
||||
|
||||
// The generated ParamId should be different each time (it's new), but we can't test the exact value
|
||||
// We just verify it exists and has a valid u64 value
|
||||
let generated_param_id = loaded_snapshot.tensor_id.unwrap();
|
||||
assert!(
|
||||
generated_param_id.val() > 0,
|
||||
"Generated ParamId should have a valid value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_tensors_preserve_distinct_param_ids() {
|
||||
use burn_core::module::ParamId;
|
||||
|
||||
// Create multiple tensors with distinct ParamIds
|
||||
let param_id_1 = ParamId::from(111111u64);
|
||||
let param_id_2 = ParamId::from(222222u64);
|
||||
let param_id_3 = ParamId::from(333333u64);
|
||||
|
||||
let mut snapshots = Vec::new();
|
||||
|
||||
let data1 = [1.0f32, 2.0];
|
||||
let bytes1: Vec<u8> = data1.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
snapshots.push(TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes1, vec![2], DType::F32),
|
||||
vec!["tensor1".to_string()],
|
||||
vec![],
|
||||
param_id_1,
|
||||
));
|
||||
|
||||
let data2 = [3.0f32, 4.0];
|
||||
let bytes2: Vec<u8> = data2.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
snapshots.push(TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes2, vec![2], DType::F32),
|
||||
vec!["tensor2".to_string()],
|
||||
vec![],
|
||||
param_id_2,
|
||||
));
|
||||
|
||||
let data3 = [5.0f32, 6.0];
|
||||
let bytes3: Vec<u8> = data3.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
snapshots.push(TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes3, vec![2], DType::F32),
|
||||
vec!["tensor3".to_string()],
|
||||
vec![],
|
||||
param_id_3,
|
||||
));
|
||||
|
||||
// Write to burnpack
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read back
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
let snapshot1 = reader.get_tensor_snapshot("tensor1").unwrap();
|
||||
let snapshot2 = reader.get_tensor_snapshot("tensor2").unwrap();
|
||||
let snapshot3 = reader.get_tensor_snapshot("tensor3").unwrap();
|
||||
|
||||
// Verify each ParamId was preserved correctly
|
||||
assert_eq!(snapshot1.tensor_id.unwrap().val(), param_id_1.val());
|
||||
assert_eq!(snapshot2.tensor_id.unwrap().val(), param_id_2.val());
|
||||
assert_eq!(snapshot3.tensor_id.unwrap().val(), param_id_3.val());
|
||||
|
||||
// Verify they are distinct
|
||||
let id1 = snapshot1.tensor_id.unwrap().val();
|
||||
let id2 = snapshot2.tensor_id.unwrap().val();
|
||||
let id3 = snapshot3.tensor_id.unwrap().val();
|
||||
|
||||
assert_ne!(id1, id2, "ParamIds should be distinct");
|
||||
assert_ne!(id2, id3, "ParamIds should be distinct");
|
||||
assert_ne!(id1, id3, "ParamIds should be distinct");
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,744 @@
|
||||
use crate::burnpack::{
|
||||
base::{
|
||||
BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
|
||||
aligned_data_section_start, magic_range,
|
||||
},
|
||||
writer::BurnpackWriter,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::{DType, TensorData};
|
||||
use std::rc::Rc;
|
||||
|
||||
#[test]
|
||||
fn test_writer_new() {
|
||||
let writer = BurnpackWriter::new(vec![]);
|
||||
assert_eq!(writer.snapshots.len(), 0);
|
||||
assert!(writer.metadata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_add_metadata() {
|
||||
let writer = BurnpackWriter::new(vec![])
|
||||
.with_metadata("model_name", "test_model")
|
||||
.with_metadata("version", "1.0.0")
|
||||
.with_metadata("author", "test_author");
|
||||
|
||||
assert_eq!(writer.metadata.len(), 3);
|
||||
assert_eq!(
|
||||
writer.metadata.get("model_name"),
|
||||
Some(&"test_model".to_string())
|
||||
);
|
||||
assert_eq!(writer.metadata.get("version"), Some(&"1.0.0".to_string()));
|
||||
assert_eq!(
|
||||
writer.metadata.get("author"),
|
||||
Some(&"test_author".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_add_tensor_snapshot() {
|
||||
// Create test tensor snapshots
|
||||
let snapshot1 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["layer1".to_string(), "weights".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let snapshot2 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),
|
||||
vec!["layer1".to_string(), "bias".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);
|
||||
|
||||
assert_eq!(writer.snapshots.len(), 2);
|
||||
assert_eq!(writer.snapshots[0].full_path(), "layer1.weights");
|
||||
assert_eq!(writer.snapshots[1].full_path(), "layer1.bias");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_to_bytes_empty() {
|
||||
let writer = BurnpackWriter::new(vec![]);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Verify header
|
||||
assert!(bytes.len() >= HEADER_SIZE);
|
||||
assert_eq!(&bytes[magic_range()], &MAGIC_NUMBER.to_le_bytes());
|
||||
|
||||
// Parse header
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
assert_eq!(header.magic, MAGIC_NUMBER);
|
||||
assert_eq!(header.version, FORMAT_VERSION);
|
||||
|
||||
// Verify metadata
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata_bytes = &bytes[HEADER_SIZE..metadata_end];
|
||||
let metadata: BurnpackMetadata = ciborium::de::from_reader(metadata_bytes).unwrap();
|
||||
|
||||
assert_eq!(metadata.tensors.len(), 0);
|
||||
assert!(metadata.metadata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_to_bytes_with_tensors() {
|
||||
// Add tensors with different data types
|
||||
let f32_data = [1.0f32, 2.0, 3.0, 4.0];
|
||||
let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
let snapshot_f32 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(f32_bytes.clone(), vec![2, 2], DType::F32),
|
||||
vec!["weights".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let i64_data = [10i64, 20, 30];
|
||||
let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();
|
||||
let snapshot_i64 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(i64_bytes.clone(), vec![3], DType::I64),
|
||||
vec!["bias".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot_f32, snapshot_i64])
|
||||
.with_metadata("test_key", "test_value");
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Parse and verify
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
// Verify metadata
|
||||
assert_eq!(
|
||||
metadata.metadata.get("test_key"),
|
||||
Some(&"test_value".to_string())
|
||||
);
|
||||
|
||||
// Verify tensors
|
||||
assert_eq!(metadata.tensors.len(), 2);
|
||||
|
||||
let weights = metadata.tensors.get("weights").unwrap();
|
||||
assert_eq!(weights.dtype, DType::F32);
|
||||
assert_eq!(weights.shape, vec![2, 2]);
|
||||
assert_eq!(weights.data_offsets.1 - weights.data_offsets.0, 16); // 4 * 4 bytes
|
||||
|
||||
let bias = metadata.tensors.get("bias").unwrap();
|
||||
assert_eq!(bias.dtype, DType::I64);
|
||||
assert_eq!(bias.shape, vec![3]);
|
||||
assert_eq!(bias.data_offsets.1 - bias.data_offsets.0, 24); // 3 * 8 bytes
|
||||
|
||||
// Verify actual tensor data
|
||||
// Data section starts at aligned position after metadata
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
let weights = metadata.tensors.get("weights").unwrap();
|
||||
let bias = metadata.tensors.get("bias").unwrap();
|
||||
let weights_data = &bytes[data_section_start + weights.data_offsets.0 as usize
|
||||
..data_section_start + weights.data_offsets.1 as usize];
|
||||
assert_eq!(weights_data, f32_bytes);
|
||||
|
||||
let bias_data = &bytes[data_section_start + bias.data_offsets.0 as usize
|
||||
..data_section_start + bias.data_offsets.1 as usize];
|
||||
assert_eq!(bias_data, i64_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_all_dtypes() {
|
||||
use half::{bf16, f16};
|
||||
|
||||
// Test all supported data types (excluding QFloat which is tested separately)
|
||||
// Format: (DType, expected_size_per_element, sample_data_bytes)
|
||||
let test_cases = vec![
|
||||
// Floating point types
|
||||
(DType::F64, 8, 1.0f64.to_le_bytes().to_vec()),
|
||||
(DType::F32, 4, 1.0f32.to_le_bytes().to_vec()),
|
||||
(DType::F16, 2, f16::from_f32(1.0).to_le_bytes().to_vec()),
|
||||
(DType::BF16, 2, bf16::from_f32(1.0).to_le_bytes().to_vec()),
|
||||
// Signed integers
|
||||
(DType::I64, 8, 1i64.to_le_bytes().to_vec()),
|
||||
(DType::I32, 4, 1i32.to_le_bytes().to_vec()),
|
||||
(DType::I16, 2, 1i16.to_le_bytes().to_vec()),
|
||||
(DType::I8, 1, 1i8.to_le_bytes().to_vec()),
|
||||
// Unsigned integers
|
||||
(DType::U64, 8, 255u64.to_le_bytes().to_vec()),
|
||||
(DType::U32, 4, 255u32.to_le_bytes().to_vec()),
|
||||
(DType::U16, 2, 255u16.to_le_bytes().to_vec()),
|
||||
(DType::U8, 1, vec![255u8]),
|
||||
// Boolean
|
||||
(DType::Bool, 1, vec![1u8]),
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
let mut expected_data = vec![];
|
||||
for (i, (dtype, expected_size, data)) in test_cases.into_iter().enumerate() {
|
||||
let name = format!("tensor_{}", i);
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), vec![1], dtype),
|
||||
vec![name.clone()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
expected_data.push((name, dtype, expected_size, data));
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Parse and verify metadata
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
metadata.tensors.len(),
|
||||
13,
|
||||
"Expected 13 dtypes to be tested"
|
||||
);
|
||||
|
||||
// Verify each tensor's metadata and data
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
for (name, expected_dtype, expected_size, expected_bytes) in expected_data {
|
||||
let tensor = metadata
|
||||
.tensors
|
||||
.get(&name)
|
||||
.unwrap_or_else(|| panic!("Missing tensor: {}", name));
|
||||
assert_eq!(tensor.dtype, expected_dtype, "DType mismatch for {}", name);
|
||||
assert_eq!(tensor.shape, vec![1], "Shape mismatch for {}", name);
|
||||
|
||||
// Verify data size matches expected
|
||||
let data_size = (tensor.data_offsets.1 - tensor.data_offsets.0) as usize;
|
||||
assert_eq!(
|
||||
data_size, expected_size,
|
||||
"Data size mismatch for {} ({:?})",
|
||||
name, expected_dtype
|
||||
);
|
||||
|
||||
// Verify actual data bytes match
|
||||
let actual_bytes = &bytes[data_section_start + tensor.data_offsets.0 as usize
|
||||
..data_section_start + tensor.data_offsets.1 as usize];
|
||||
assert_eq!(
|
||||
actual_bytes,
|
||||
expected_bytes.as_slice(),
|
||||
"Data mismatch for {} ({:?})",
|
||||
name,
|
||||
expected_dtype
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_all_dtypes_round_trip() {
|
||||
use crate::burnpack::reader::BurnpackReader;
|
||||
use half::{bf16, f16};
|
||||
|
||||
// Test all dtypes can be written and read back correctly
|
||||
let test_cases = vec![
|
||||
// Floating point types - use multiple elements to better test
|
||||
(
|
||||
"f64_tensor",
|
||||
DType::F64,
|
||||
[1.0f64, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![4],
|
||||
),
|
||||
(
|
||||
"f32_tensor",
|
||||
DType::F32,
|
||||
[1.0f32, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2, 2],
|
||||
),
|
||||
(
|
||||
"f16_tensor",
|
||||
DType::F16,
|
||||
[f16::from_f32(1.0), f16::from_f32(2.0)]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2],
|
||||
),
|
||||
(
|
||||
"bf16_tensor",
|
||||
DType::BF16,
|
||||
[bf16::from_f32(1.0), bf16::from_f32(2.0)]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2],
|
||||
),
|
||||
// Signed integers
|
||||
(
|
||||
"i64_tensor",
|
||||
DType::I64,
|
||||
[1i64, -2, 3, -4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![4],
|
||||
),
|
||||
(
|
||||
"i32_tensor",
|
||||
DType::I32,
|
||||
[1i32, -2, 3, -4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2, 2],
|
||||
),
|
||||
(
|
||||
"i16_tensor",
|
||||
DType::I16,
|
||||
[1i16, -2, 3, -4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![4],
|
||||
),
|
||||
(
|
||||
"i8_tensor",
|
||||
DType::I8,
|
||||
[1i8, -2, 3, -4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2, 2],
|
||||
),
|
||||
// Unsigned integers
|
||||
(
|
||||
"u64_tensor",
|
||||
DType::U64,
|
||||
[1u64, 2, 3, 4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![4],
|
||||
),
|
||||
(
|
||||
"u32_tensor",
|
||||
DType::U32,
|
||||
[1u32, 2, 3, 4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![2, 2],
|
||||
),
|
||||
(
|
||||
"u16_tensor",
|
||||
DType::U16,
|
||||
[1u16, 2, 3, 4]
|
||||
.iter()
|
||||
.flat_map(|v| v.to_le_bytes())
|
||||
.collect::<Vec<u8>>(),
|
||||
vec![4],
|
||||
),
|
||||
("u8_tensor", DType::U8, vec![1u8, 2, 3, 4], vec![2, 2]),
|
||||
// Boolean
|
||||
("bool_tensor", DType::Bool, vec![1u8, 0, 1, 0], vec![4]),
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
let mut expected_results: Vec<(&str, DType, Vec<u8>, Vec<usize>)> = vec![];
|
||||
|
||||
for (name, dtype, data, shape) in test_cases.into_iter() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(data.clone(), shape.clone(), dtype),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
expected_results.push((name, dtype, data, shape));
|
||||
}
|
||||
|
||||
// Write to bytes
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Read back using BurnpackReader
|
||||
let reader = BurnpackReader::from_bytes(bytes).unwrap();
|
||||
|
||||
// Verify each tensor can be read back with correct data
|
||||
for (name, expected_dtype, expected_data, expected_shape) in expected_results {
|
||||
let snapshot = reader
|
||||
.get_tensor_snapshot(name)
|
||||
.unwrap_or_else(|e| panic!("Failed to get tensor snapshot {}: {}", name, e));
|
||||
let tensor_data = snapshot
|
||||
.to_data()
|
||||
.unwrap_or_else(|e| panic!("Failed to read tensor data {}: {}", name, e));
|
||||
|
||||
assert_eq!(
|
||||
tensor_data.dtype, expected_dtype,
|
||||
"DType mismatch for {}",
|
||||
name
|
||||
);
|
||||
assert_eq!(
|
||||
tensor_data.shape, expected_shape,
|
||||
"Shape mismatch for {}",
|
||||
name
|
||||
);
|
||||
assert_eq!(
|
||||
&tensor_data.bytes[..],
|
||||
expected_data.as_slice(),
|
||||
"Data mismatch for {}",
|
||||
name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_large_tensor() {
|
||||
// Create a large tensor (1MB)
|
||||
let size = 256 * 1024; // 256K floats = 1MB
|
||||
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
|
||||
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(bytes.clone(), vec![size], DType::F32),
|
||||
vec!["large_tensor".to_string()],
|
||||
vec![],
|
||||
burn_core::module::ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
|
||||
let result = writer.to_bytes().unwrap();
|
||||
|
||||
// Verify the large tensor is correctly stored
|
||||
let header = BurnpackHeader::from_bytes(&result[..HEADER_SIZE]).unwrap();
|
||||
let metadata: BurnpackMetadata = ciborium::de::from_reader(
|
||||
&result[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(metadata.tensors.len(), 1);
|
||||
let tensor = metadata.tensors.get("large_tensor").unwrap();
|
||||
assert_eq!(tensor.shape, vec![size as u64]);
|
||||
assert_eq!(
|
||||
tensor.data_offsets.1 - tensor.data_offsets.0,
|
||||
(size * 4) as u64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_empty_tensors() {
|
||||
// Add tensor with empty data
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![], vec![0], DType::F32),
|
||||
vec!["empty".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(metadata.tensors.len(), 1);
|
||||
let tensor = metadata.tensors.get("empty").unwrap();
|
||||
assert_eq!(tensor.shape, vec![0]);
|
||||
assert_eq!(tensor.data_offsets.1 - tensor.data_offsets.0, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_special_characters_in_names() {
|
||||
// Test various special characters in tensor names
|
||||
let special_names = vec![
|
||||
"layer.0.weight",
|
||||
"model/encoder/layer1",
|
||||
"model::layer::weight",
|
||||
"layer[0].bias",
|
||||
"layer_1_weight",
|
||||
"layer-1-bias",
|
||||
"layer@1#weight",
|
||||
"emoji_😀_tensor",
|
||||
"unicode_测试_tensor",
|
||||
"spaces in name",
|
||||
];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for name in &special_names {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(metadata.tensors.len(), 10);
|
||||
for (tensor_name, _tensor) in metadata.tensors.iter() {
|
||||
assert!(!tensor_name.is_empty());
|
||||
// Names should be preserved exactly
|
||||
assert!(
|
||||
tensor_name.contains("layer")
|
||||
|| tensor_name.contains("model")
|
||||
|| tensor_name.contains("emoji")
|
||||
|| tensor_name.contains("unicode")
|
||||
|| tensor_name.contains("spaces")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_metadata_overwrite() {
|
||||
let writer = BurnpackWriter::new(vec![])
|
||||
.with_metadata("key", "value1")
|
||||
.with_metadata("key", "value2");
|
||||
|
||||
assert_eq!(writer.metadata.get("key"), Some(&"value2".to_string()));
|
||||
assert_eq!(writer.metadata.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_tensor_order_preserved() {
|
||||
// Add tensors in specific order
|
||||
let names = vec!["z_tensor", "a_tensor", "m_tensor", "b_tensor"];
|
||||
|
||||
let mut snapshots = vec![];
|
||||
for name in &names {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),
|
||||
vec![name.to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
let writer = BurnpackWriter::new(snapshots);
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])
|
||||
.unwrap();
|
||||
|
||||
// Verify all tensors are present (BTreeMap stores in sorted order by key)
|
||||
let expected_sorted = vec!["a_tensor", "b_tensor", "m_tensor", "z_tensor"];
|
||||
let actual_names: Vec<_> = metadata.tensors.keys().collect();
|
||||
assert_eq!(actual_names, expected_sorted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_lazy_snapshot_evaluation() {
|
||||
// Create a lazy snapshot using closure
|
||||
let data = Rc::new(vec![1.0f32, 2.0, 3.0, 4.0]);
|
||||
let data_clone = data.clone();
|
||||
|
||||
let snapshot = TensorSnapshot::from_closure(
|
||||
Rc::new(move || {
|
||||
let bytes: Vec<u8> = data_clone.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||
Ok(TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32))
|
||||
}),
|
||||
DType::F32,
|
||||
vec![2, 2],
|
||||
vec!["lazy".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
|
||||
// The closure should only be evaluated when to_bytes is called
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();
|
||||
let metadata_end = HEADER_SIZE + header.metadata_size as usize;
|
||||
let metadata: BurnpackMetadata =
|
||||
ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();
|
||||
|
||||
assert_eq!(metadata.tensors.len(), 1);
|
||||
let tensor = metadata.tensors.get("lazy").unwrap();
|
||||
assert_eq!(tensor.dtype, DType::F32);
|
||||
assert_eq!(tensor.shape, vec![2, 2]);
|
||||
|
||||
// Verify the data was correctly written
|
||||
// Data section starts at aligned position after metadata
|
||||
let data_section_start = aligned_data_section_start(header.metadata_size as usize);
|
||||
let tensor_data = &bytes[data_section_start..data_section_start + 16];
|
||||
let expected: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.flat_map(|f| f.to_le_bytes())
|
||||
.collect();
|
||||
assert_eq!(tensor_data, expected.as_slice());
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_writer_write_to_file() {
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.bpk");
|
||||
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("file_test", "true");
|
||||
|
||||
writer.write_to_file(&file_path).unwrap();
|
||||
|
||||
// Verify file exists and has correct content
|
||||
assert!(file_path.exists());
|
||||
|
||||
let file_bytes = fs::read(&file_path).unwrap();
|
||||
let memory_bytes = writer.to_bytes().unwrap();
|
||||
|
||||
assert_eq!(file_bytes.as_slice(), &*memory_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_size() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
|
||||
|
||||
let size = writer.size().unwrap();
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
|
||||
// Size should match actual bytes length
|
||||
assert_eq!(size, bytes.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_write_into() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value");
|
||||
|
||||
// Get size and allocate buffer
|
||||
let size = writer.size().unwrap();
|
||||
let mut buffer = vec![0u8; size];
|
||||
|
||||
// Write into buffer
|
||||
writer.write_into(&mut buffer).unwrap();
|
||||
|
||||
// Compare with to_bytes()
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
assert_eq!(buffer.as_slice(), &*bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_write_into_buffer_too_small() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
|
||||
// Allocate a buffer that's too small
|
||||
let mut buffer = vec![0u8; 10];
|
||||
|
||||
// Should fail with buffer too small error
|
||||
let result = writer.write_into(&mut buffer);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Buffer too small"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_write_into_buffer_larger_than_needed() {
|
||||
let snapshot = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["test".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot]);
|
||||
|
||||
// Allocate a larger buffer
|
||||
let size = writer.size().unwrap();
|
||||
let mut buffer = vec![0u8; size + 100]; // Extra 100 bytes
|
||||
|
||||
// Should succeed and only write the necessary bytes
|
||||
writer.write_into(&mut buffer).unwrap();
|
||||
|
||||
// Compare the written portion with to_bytes()
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
assert_eq!(&buffer[..size], &*bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_write_into_multiple_tensors() {
|
||||
let snapshot1 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),
|
||||
vec!["tensor1".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let snapshot2 = TensorSnapshot::from_data(
|
||||
TensorData::from_bytes_vec(vec![5, 6, 7, 8, 9, 10], vec![2, 3], DType::U8),
|
||||
vec!["tensor2".to_string()],
|
||||
vec![],
|
||||
ParamId::new(),
|
||||
);
|
||||
|
||||
let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]).with_metadata("test", "multiple");
|
||||
|
||||
let size = writer.size().unwrap();
|
||||
let mut buffer = vec![0u8; size];
|
||||
writer.write_into(&mut buffer).unwrap();
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
assert_eq!(buffer.as_slice(), &*bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writer_write_into_empty() {
|
||||
let writer = BurnpackWriter::new(vec![]);
|
||||
|
||||
let size = writer.size().unwrap();
|
||||
let mut buffer = vec![0u8; size];
|
||||
writer.write_into(&mut buffer).unwrap();
|
||||
|
||||
let bytes = writer.to_bytes().unwrap();
|
||||
assert_eq!(buffer.as_slice(), &*bytes);
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
//! Tests for zero-copy tensor loading functionality.
|
||||
|
||||
use crate::ModuleStore;
|
||||
use crate::burnpack::store::BurnpackStore;
|
||||
|
||||
use burn_core as burn;
|
||||
use burn_core::module::{Module, Param};
|
||||
use burn_tensor::{AllocationProperty, Bytes, Tensor, backend::Backend};
|
||||
|
||||
type TestBackend = burn_ndarray::NdArray;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
struct SimpleModule<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
bias: Param<Tensor<B, 1>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> SimpleModule<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
weight: Param::from_data([[1.0f32, 2.0], [3.0, 4.0]], device),
|
||||
bias: Param::from_data([0.5f32, 1.5], device),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_zeros(device: &B::Device) -> Self {
|
||||
Self {
|
||||
weight: Param::from_tensor(Tensor::zeros([2, 2], device)),
|
||||
bias: Param::from_tensor(Tensor::zeros([2], device)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that from_static creates a store with zero_copy enabled by default.
|
||||
#[test]
|
||||
fn test_from_static_enables_zero_copy() {
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to bytes first
|
||||
let mut save_store = BurnpackStore::from_bytes(None);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
let bytes = save_store.get_bytes().unwrap();
|
||||
|
||||
// Convert to Vec<u8> and then leak to get &'static [u8]
|
||||
let bytes_vec: Vec<u8> = bytes.to_vec();
|
||||
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
|
||||
|
||||
// Create store from static - zero_copy should be enabled
|
||||
let mut load_store = BurnpackStore::from_static(static_bytes);
|
||||
|
||||
// Load into a new module
|
||||
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
|
||||
load_store.apply_to(&mut loaded_module).unwrap();
|
||||
|
||||
// Verify data is correct
|
||||
let loaded_weight = loaded_module.weight.val().to_data();
|
||||
let loaded_bias = loaded_module.bias.val().to_data();
|
||||
|
||||
assert_eq!(
|
||||
loaded_weight.to_vec::<f32>().unwrap(),
|
||||
vec![1.0, 2.0, 3.0, 4.0]
|
||||
);
|
||||
assert_eq!(loaded_bias.to_vec::<f32>().unwrap(), vec![0.5, 1.5]);
|
||||
}
|
||||
|
||||
/// Test that zero_copy builder method works.
|
||||
#[test]
|
||||
fn test_zero_copy_builder_method() {
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to bytes first
|
||||
let mut save_store = BurnpackStore::from_bytes(None);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
let bytes = save_store.get_bytes().unwrap();
|
||||
|
||||
// Create shared bytes for zero-copy
|
||||
let shared = bytes::Bytes::from(bytes.to_vec());
|
||||
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
|
||||
|
||||
// Create store with zero_copy enabled
|
||||
let mut load_store = BurnpackStore::from_bytes(Some(cubecl_bytes)).zero_copy(true);
|
||||
|
||||
// Load into a new module
|
||||
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
|
||||
load_store.apply_to(&mut loaded_module).unwrap();
|
||||
|
||||
// Verify data is correct
|
||||
let loaded_weight = loaded_module.weight.val().to_data();
|
||||
assert_eq!(
|
||||
loaded_weight.to_vec::<f32>().unwrap(),
|
||||
vec![1.0, 2.0, 3.0, 4.0]
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that zero_copy(false) uses copying even with shared bytes.
|
||||
#[test]
|
||||
fn test_zero_copy_disabled_uses_copy() {
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to bytes first
|
||||
let mut save_store = BurnpackStore::from_bytes(None);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
let bytes = save_store.get_bytes().unwrap();
|
||||
|
||||
// Convert to Vec<u8> and then leak to get &'static [u8]
|
||||
let bytes_vec: Vec<u8> = bytes.to_vec();
|
||||
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
|
||||
|
||||
// Create store from static but disable zero_copy
|
||||
let mut load_store = BurnpackStore::from_static(static_bytes).zero_copy(false);
|
||||
|
||||
// Load into a new module
|
||||
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
|
||||
load_store.apply_to(&mut loaded_module).unwrap();
|
||||
|
||||
// Verify data is correct (copied, not zero-copy)
|
||||
let loaded_weight = loaded_module.weight.val().to_data();
|
||||
assert_eq!(
|
||||
loaded_weight.to_vec::<f32>().unwrap(),
|
||||
vec![1.0, 2.0, 3.0, 4.0]
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that from_bytes with regular Bytes uses copying by default.
|
||||
#[test]
|
||||
fn test_from_bytes_uses_copy_by_default() {
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to bytes
|
||||
let mut save_store = BurnpackStore::from_bytes(None);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
let bytes = save_store.get_bytes().unwrap();
|
||||
|
||||
// Load from bytes (default: zero_copy = false)
|
||||
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
|
||||
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
|
||||
load_store.apply_to(&mut loaded_module).unwrap();
|
||||
|
||||
// Verify data is correct
|
||||
let loaded_weight = loaded_module.weight.val().to_data();
|
||||
assert_eq!(
|
||||
loaded_weight.to_vec::<f32>().unwrap(),
|
||||
vec![1.0, 2.0, 3.0, 4.0]
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that slice_bytes works correctly on StorageBackend.
|
||||
#[test]
|
||||
fn test_storage_backend_slice_bytes() {
|
||||
use crate::burnpack::reader::BurnpackReader;
|
||||
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to bytes first
|
||||
let mut save_store = BurnpackStore::from_bytes(None);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
let bytes = save_store.get_bytes().unwrap();
|
||||
|
||||
// Create shared bytes
|
||||
let shared = bytes::Bytes::from(bytes.to_vec());
|
||||
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
|
||||
|
||||
// Create reader and get snapshots with zero-copy
|
||||
let reader = BurnpackReader::from_bytes(cubecl_bytes).unwrap();
|
||||
let snapshots = reader.get_snapshots_zero_copy(true).unwrap();
|
||||
|
||||
// Verify we got the expected number of tensors
|
||||
assert_eq!(snapshots.len(), 2);
|
||||
|
||||
// Load the tensor data
|
||||
for snapshot in &snapshots {
|
||||
let data = snapshot.to_data().unwrap();
|
||||
// Just verify we can access the data - the actual content depends on tensor order
|
||||
assert!(!data.bytes.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that zero_copy=true with file-based loading works (via mmap + bytes::Bytes).
|
||||
#[test]
|
||||
fn test_zero_copy_file_based_works() {
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
let device = Default::default();
|
||||
let module = SimpleModule::<TestBackend>::new(&device);
|
||||
|
||||
// Save to a temporary file
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path();
|
||||
|
||||
let mut save_store = BurnpackStore::from_file(path).overwrite(true);
|
||||
save_store.collect_from(&module).unwrap();
|
||||
|
||||
// Load with zero_copy=true - should work because mmap is converted to bytes::Bytes
|
||||
let mut load_store = BurnpackStore::from_file(path).zero_copy(true);
|
||||
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
|
||||
|
||||
// The apply should succeed - mmap now supports zero-copy via bytes::Bytes::from_owner()
|
||||
load_store.apply_to(&mut loaded_module).unwrap();
|
||||
|
||||
// Verify data is correct
|
||||
let loaded_weight = loaded_module.weight.val().to_data();
|
||||
assert_eq!(
|
||||
loaded_weight.to_vec::<f32>().unwrap(),
|
||||
vec![1.0, 2.0, 3.0, 4.0]
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
use super::base::{
|
||||
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
|
||||
TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,
|
||||
};
|
||||
use crate::TensorSnapshot;
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::format;
|
||||
use alloc::string::{String, ToString};
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::Bytes;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::fs::File;
|
||||
#[cfg(feature = "std")]
|
||||
use std::io::Write;
|
||||
#[cfg(feature = "std")]
|
||||
use std::path::Path;
|
||||
|
||||
/// Align an offset to the specified alignment boundary.
|
||||
///
|
||||
/// Returns the smallest value >= `offset` that is a multiple of `alignment`.
|
||||
#[inline]
|
||||
const fn align_offset(offset: u64, alignment: u64) -> u64 {
|
||||
offset.div_ceil(alignment) * alignment
|
||||
}
|
||||
|
||||
/// Writer for creating Burnpack files
|
||||
pub struct BurnpackWriter {
|
||||
/// Tensors to write
|
||||
pub(crate) snapshots: Vec<TensorSnapshot>,
|
||||
/// Metadata key-value pairs
|
||||
pub(crate) metadata: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
impl BurnpackWriter {
|
||||
/// Create a new writer
|
||||
pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {
|
||||
Self {
|
||||
snapshots,
|
||||
metadata: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder pattern: add metadata and return self
|
||||
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
|
||||
self.metadata.insert(key.to_string(), value.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Build tensor descriptors and metadata
|
||||
fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {
|
||||
// Build tensor descriptors and calculate offsets with alignment
|
||||
let mut tensors = BTreeMap::new();
|
||||
let mut current_offset = 0u64;
|
||||
|
||||
for snapshot in &self.snapshots {
|
||||
let data_len = snapshot.data_len() as u64;
|
||||
|
||||
// Align the start offset for mmap zero-copy support
|
||||
let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);
|
||||
let end = aligned_start.checked_add(data_len).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Tensor offset overflow: {} + {} exceeds maximum",
|
||||
aligned_start, data_len
|
||||
))
|
||||
})?;
|
||||
|
||||
tensors.insert(
|
||||
snapshot.full_path(),
|
||||
TensorDescriptor {
|
||||
dtype: snapshot.dtype,
|
||||
shape: snapshot.shape.iter().map(|&s| s as u64).collect(),
|
||||
data_offsets: (aligned_start, end),
|
||||
param_id: snapshot.tensor_id.map(|id| id.val()),
|
||||
},
|
||||
);
|
||||
|
||||
current_offset = end;
|
||||
}
|
||||
|
||||
// Create metadata structure
|
||||
let metadata = BurnpackMetadata {
|
||||
tensors,
|
||||
metadata: self.metadata.clone(),
|
||||
};
|
||||
|
||||
// Serialize metadata with CBOR
|
||||
let mut metadata_bytes = Vec::new();
|
||||
ciborium::ser::into_writer(&metadata, &mut metadata_bytes)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
Ok((metadata, metadata_bytes))
|
||||
}
|
||||
|
||||
/// Calculate the total size needed for the burnpack data
|
||||
///
|
||||
/// This is useful when you want to pre-allocate a buffer for `write_into()`.
|
||||
/// The size includes padding bytes for both metadata alignment and tensor alignment.
|
||||
pub fn size(&self) -> Result<usize, BurnpackError> {
|
||||
let (metadata, metadata_bytes) = self.build_metadata()?;
|
||||
|
||||
// Data section starts at aligned position after header + metadata
|
||||
let data_section_start = aligned_data_section_start(metadata_bytes.len());
|
||||
|
||||
// Calculate total data section size from aligned offsets
|
||||
// The last tensor's end offset gives us the total data section size
|
||||
let data_size = metadata
|
||||
.tensors
|
||||
.values()
|
||||
.map(|t| t.data_offsets.1)
|
||||
.max()
|
||||
.unwrap_or(0) as usize;
|
||||
|
||||
Ok(data_section_start + data_size)
|
||||
}
|
||||
|
||||
/// Write burnpack data into a caller-provided buffer
|
||||
///
|
||||
/// The buffer must be large enough to hold all data. Use `size()` to determine
|
||||
/// the required buffer size. If the buffer is too small, this will return an error.
|
||||
///
|
||||
/// This allows the caller to control buffer allocation, enabling optimizations like:
|
||||
/// - Buffer reuse across multiple writes
|
||||
/// - Custom allocators
|
||||
/// - Pinned memory for GPU transfers
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes.
|
||||
pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {
|
||||
let (metadata, metadata_bytes) = self.build_metadata()?;
|
||||
|
||||
// Check metadata size fits in u32
|
||||
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Metadata size {} exceeds maximum of {} bytes",
|
||||
metadata_bytes.len(),
|
||||
u32::MAX
|
||||
))
|
||||
})?;
|
||||
|
||||
// Create header
|
||||
let header = BurnpackHeader {
|
||||
magic: MAGIC_NUMBER,
|
||||
version: FORMAT_VERSION,
|
||||
metadata_size,
|
||||
};
|
||||
|
||||
// Data section starts at aligned position after header + metadata
|
||||
let data_section_start = aligned_data_section_start(metadata_bytes.len());
|
||||
|
||||
// Calculate required size from aligned offsets
|
||||
let data_size = metadata
|
||||
.tensors
|
||||
.values()
|
||||
.map(|t| t.data_offsets.1)
|
||||
.max()
|
||||
.unwrap_or(0) as usize;
|
||||
let total_size = data_section_start + data_size;
|
||||
|
||||
// Check buffer size
|
||||
if buffer.len() < total_size {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Buffer too small: need {} bytes, got {} bytes",
|
||||
total_size,
|
||||
buffer.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
// Write header
|
||||
let header_bytes = header.into_bytes();
|
||||
buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);
|
||||
offset += HEADER_SIZE;
|
||||
|
||||
// Write metadata
|
||||
buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);
|
||||
offset += metadata_bytes.len();
|
||||
|
||||
// Write padding to align data section start
|
||||
if data_section_start > offset {
|
||||
buffer[offset..data_section_start].fill(0);
|
||||
offset = data_section_start;
|
||||
}
|
||||
|
||||
// Write tensor data with alignment padding
|
||||
for snapshot in &self.snapshots {
|
||||
// Get the aligned offset from metadata
|
||||
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Internal error: tensor '{}' not found in metadata",
|
||||
snapshot.full_path()
|
||||
))
|
||||
})?;
|
||||
let aligned_offset = descriptor.data_offsets.0 as usize;
|
||||
let target_offset = data_section_start + aligned_offset;
|
||||
|
||||
// Write padding zeros if needed
|
||||
if target_offset > offset {
|
||||
buffer[offset..target_offset].fill(0);
|
||||
offset = target_offset;
|
||||
}
|
||||
|
||||
let expected_len = snapshot.data_len();
|
||||
let data = snapshot.to_data().map_err(|e| {
|
||||
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
|
||||
})?;
|
||||
let actual_len = data.bytes.len();
|
||||
|
||||
// Validate data length consistency
|
||||
if actual_len != expected_len {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
|
||||
snapshot.full_path(),
|
||||
expected_len,
|
||||
actual_len
|
||||
)));
|
||||
}
|
||||
|
||||
buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);
|
||||
offset += actual_len;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write to a byte buffer (convenience method)
|
||||
///
|
||||
/// This allocates a buffer internally and writes the burnpack data.
|
||||
/// For more control over buffer allocation, use `size()` + `write_into()`.
|
||||
pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {
|
||||
let size = self.size()?;
|
||||
let mut buffer = vec![0u8; size];
|
||||
self.write_into(&mut buffer)?;
|
||||
Ok(Bytes::from_bytes_vec(buffer))
|
||||
}
|
||||
|
||||
/// Write directly to a file (more memory efficient for large models)
|
||||
#[cfg(feature = "std")]
|
||||
pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {
|
||||
let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
let (metadata, metadata_bytes) = self.build_metadata()?;
|
||||
|
||||
// Check metadata size fits in u32
|
||||
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Metadata size {} exceeds maximum of {} bytes",
|
||||
metadata_bytes.len(),
|
||||
u32::MAX
|
||||
))
|
||||
})?;
|
||||
|
||||
// Create and write header
|
||||
let header = BurnpackHeader {
|
||||
magic: MAGIC_NUMBER,
|
||||
version: FORMAT_VERSION,
|
||||
metadata_size,
|
||||
};
|
||||
|
||||
file.write_all(&header.into_bytes())
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
// Write metadata
|
||||
file.write_all(&metadata_bytes)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
// Data section starts at aligned position after header + metadata
|
||||
let data_section_start = aligned_data_section_start(metadata_bytes.len());
|
||||
let current_file_pos = HEADER_SIZE + metadata_bytes.len();
|
||||
|
||||
// Write padding to align data section start
|
||||
if data_section_start > current_file_pos {
|
||||
let padding_size = data_section_start - current_file_pos;
|
||||
let padding = vec![0u8; padding_size];
|
||||
file.write_all(&padding)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Track current position within data section (relative to data_section_start)
|
||||
let mut data_offset = 0usize;
|
||||
|
||||
// Stream tensor data directly to file with alignment padding
|
||||
for snapshot in &self.snapshots {
|
||||
// Get the aligned offset from metadata
|
||||
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
|
||||
BurnpackError::IoError(format!(
|
||||
"Internal error: tensor '{}' not found in metadata",
|
||||
snapshot.full_path()
|
||||
))
|
||||
})?;
|
||||
let aligned_offset = descriptor.data_offsets.0 as usize;
|
||||
|
||||
// Write padding zeros if needed
|
||||
if aligned_offset > data_offset {
|
||||
let padding_size = aligned_offset - data_offset;
|
||||
let padding = vec![0u8; padding_size];
|
||||
file.write_all(&padding)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
data_offset = aligned_offset;
|
||||
}
|
||||
|
||||
let expected_len = snapshot.data_len();
|
||||
let data = snapshot.to_data().map_err(|e| {
|
||||
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
|
||||
})?;
|
||||
let actual_len = data.bytes.len();
|
||||
|
||||
// Validate data length consistency
|
||||
if actual_len != expected_len {
|
||||
return Err(BurnpackError::IoError(format!(
|
||||
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
|
||||
snapshot.full_path(),
|
||||
expected_len,
|
||||
actual_len
|
||||
)));
|
||||
}
|
||||
|
||||
file.write_all(&data.bytes)
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
data_offset += actual_len;
|
||||
}
|
||||
|
||||
file.flush()
|
||||
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
1137
crates/stable-diffusion-burn/burn-crates/burn-store/src/collector.rs
Normal file
1137
crates/stable-diffusion-burn/burn-crates/burn-store/src/collector.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,625 @@
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use regex::Regex;
|
||||
|
||||
/// A sophisticated path filter that supports multiple matching strategies.
|
||||
///
|
||||
/// The filter uses an OR logic - a path is included if it matches ANY of the configured criteria.
|
||||
/// This allows for flexible and powerful filtering configurations.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PathFilter;
|
||||
/// // Create a filter that matches encoder paths or any weight path
|
||||
/// let filter = PathFilter::new()
|
||||
/// .with_regex(r"^encoder\..*")
|
||||
/// .with_regex(r".*\.weight$")
|
||||
/// .with_full_path("special_tensor");
|
||||
///
|
||||
/// // Check if a path should be included
|
||||
/// if filter.matches("encoder.layer1.weight") {
|
||||
/// // This will match due to both regex patterns
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PathFilter {
|
||||
/// Compiled regex patterns for matching paths
|
||||
#[cfg(feature = "std")]
|
||||
regex_patterns: Vec<Regex>,
|
||||
|
||||
/// Exact full paths to match
|
||||
exact_paths: Vec<String>,
|
||||
|
||||
/// Predicate functions for custom matching logic based on path and container path
|
||||
/// Note: These cannot be cloned, so we store them separately
|
||||
predicates: Vec<fn(&str, &str) -> bool>,
|
||||
|
||||
/// If true, matches all paths (overrides other filters)
|
||||
match_all: bool,
|
||||
}
|
||||
|
||||
impl PathFilter {
|
||||
/// Create a new empty filter (matches nothing by default)
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a filter that matches all paths
|
||||
pub fn all() -> Self {
|
||||
Self {
|
||||
match_all: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a filter that matches nothing
|
||||
pub fn none() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add a regex pattern for matching paths
|
||||
#[cfg(feature = "std")]
|
||||
pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
|
||||
if let Ok(regex) = Regex::new(pattern.as_ref()) {
|
||||
self.regex_patterns.push(regex);
|
||||
}
|
||||
// TODO: Consider returning Result to handle regex compilation errors
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple regex patterns
|
||||
#[cfg(feature = "std")]
|
||||
pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
for pattern in patterns {
|
||||
if let Ok(regex) = Regex::new(pattern.as_ref()) {
|
||||
self.regex_patterns.push(regex);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an exact full path to match
|
||||
pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
|
||||
self.exact_paths.push(path.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple exact full paths
|
||||
pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: Into<String>,
|
||||
{
|
||||
self.exact_paths.extend(paths.into_iter().map(|p| p.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a predicate function for custom matching based on path and container path
|
||||
pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
|
||||
self.predicates.push(predicate);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple predicates
|
||||
pub fn with_predicates<I>(mut self, predicates: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = fn(&str, &str) -> bool>,
|
||||
{
|
||||
self.predicates.extend(predicates);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set to match all paths
|
||||
pub fn match_all(mut self) -> Self {
|
||||
self.match_all = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if a path matches this filter (assumes empty container path for backward compatibility)
|
||||
pub fn matches(&self, path: &str) -> bool {
|
||||
self.matches_with_container_path_str(path, "")
|
||||
}
|
||||
|
||||
/// Check if a path and container type match this filter (for backward compatibility)
|
||||
pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool {
|
||||
// For backward compatibility, treat single container type as the full path
|
||||
self.matches_with_container_path_str(path, container_type)
|
||||
}
|
||||
|
||||
/// Check if a path and container path match this filter
|
||||
pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool {
|
||||
let path_str = path.join(".");
|
||||
let container_path = container_stack.join(".");
|
||||
self.matches_with_container_path_str(&path_str, &container_path)
|
||||
}
|
||||
|
||||
/// Check if a path and container path (dot-notated strings) match this filter
|
||||
pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool {
|
||||
// If match_all is set, always return true
|
||||
if self.match_all {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If no filters are configured, match nothing
|
||||
if self.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check exact path matches
|
||||
if self.exact_paths.iter().any(|p| p == path) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check regex patterns (on the path)
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
for regex in &self.regex_patterns {
|
||||
if regex.is_match(path) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check predicates with container path
|
||||
if self
|
||||
.predicates
|
||||
.iter()
|
||||
.any(|pred| pred(path, container_path))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if the filter is empty (matches nothing)
|
||||
pub fn is_empty(&self) -> bool {
|
||||
if self.match_all {
|
||||
return false;
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
let regex_empty = self.regex_patterns.is_empty();
|
||||
#[cfg(not(feature = "std"))]
|
||||
let regex_empty = true;
|
||||
|
||||
self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty
|
||||
}
|
||||
|
||||
/// Get the number of filter criteria configured
|
||||
pub fn criteria_count(&self) -> usize {
|
||||
if self.match_all {
|
||||
return 1;
|
||||
}
|
||||
|
||||
#[allow(unused_mut)]
|
||||
let mut count = self.exact_paths.len() + self.predicates.len();
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
count += self.regex_patterns.len();
|
||||
}
|
||||
|
||||
count
|
||||
}
|
||||
|
||||
/// Clear all regex patterns
|
||||
#[cfg(feature = "std")]
|
||||
pub fn clear_regex(&mut self) -> &mut Self {
|
||||
self.regex_patterns.clear();
|
||||
self
|
||||
}
|
||||
|
||||
/// Clear all exact paths
|
||||
pub fn clear_paths(&mut self) -> &mut Self {
|
||||
self.exact_paths.clear();
|
||||
self
|
||||
}
|
||||
|
||||
/// Clear all predicates
|
||||
pub fn clear_predicates(&mut self) -> &mut Self {
|
||||
self.predicates.clear();
|
||||
self
|
||||
}
|
||||
|
||||
/// Clear all filters
|
||||
pub fn clear(&mut self) -> &mut Self {
|
||||
#[cfg(feature = "std")]
|
||||
self.clear_regex();
|
||||
|
||||
self.clear_paths().clear_predicates();
|
||||
self.match_all = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a filter from regex patterns only
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_regex_patterns<I, S>(patterns: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
Self::new().with_regexes(patterns)
|
||||
}
|
||||
|
||||
/// Create a filter from exact paths only
|
||||
pub fn from_paths<I, S>(paths: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: Into<String>,
|
||||
{
|
||||
Self::new().with_full_paths(paths)
|
||||
}
|
||||
|
||||
/// Create a filter from a single predicate
|
||||
pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self {
|
||||
Self::new().with_predicate(predicate)
|
||||
}
|
||||
|
||||
/// Combine with another filter using OR logic
|
||||
pub fn or(mut self, other: Self) -> Self {
|
||||
if self.match_all || other.match_all {
|
||||
return Self::all();
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
self.regex_patterns.extend(other.regex_patterns);
|
||||
}
|
||||
|
||||
self.exact_paths.extend(other.exact_paths);
|
||||
self.predicates.extend(other.predicates);
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PathFilter {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.match_all {
|
||||
return write!(f, "PathFilter::all()");
|
||||
}
|
||||
|
||||
if self.is_empty() {
|
||||
return write!(f, "PathFilter::none()");
|
||||
}
|
||||
|
||||
write!(f, "PathFilter[")?;
|
||||
|
||||
let mut parts = Vec::new();
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
if !self.regex_patterns.is_empty() {
|
||||
parts.push(format!("regex: {:?}", self.regex_patterns));
|
||||
}
|
||||
|
||||
if !self.exact_paths.is_empty() {
|
||||
parts.push(format!("paths: {:?}", self.exact_paths));
|
||||
}
|
||||
|
||||
if !self.predicates.is_empty() {
|
||||
parts.push(format!("predicates: {}", self.predicates.len()));
|
||||
}
|
||||
|
||||
write!(f, "{}]", parts.join(", "))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn empty_filter() {
|
||||
let filter = PathFilter::new();
|
||||
assert!(filter.is_empty());
|
||||
assert!(!filter.matches("encoder.weight"));
|
||||
assert!(!filter.matches("decoder.bias"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn match_all() {
|
||||
let filter = PathFilter::all();
|
||||
assert!(!filter.is_empty());
|
||||
assert!(filter.matches("encoder.weight"));
|
||||
assert!(filter.matches("decoder.bias"));
|
||||
assert!(filter.matches("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_paths() {
|
||||
let filter = PathFilter::new()
|
||||
.with_full_path("encoder.weight")
|
||||
.with_full_path("decoder.bias");
|
||||
|
||||
assert!(filter.matches("encoder.weight"));
|
||||
assert!(filter.matches("decoder.bias"));
|
||||
assert!(!filter.matches("encoder.bias"));
|
||||
assert!(!filter.matches("decoder.weight"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn regex_patterns() {
|
||||
let filter = PathFilter::new()
|
||||
.with_regex(r"^encoder\..*")
|
||||
.with_regex(r".*\.weight$");
|
||||
|
||||
assert!(filter.matches("encoder.layer1.bias"));
|
||||
assert!(filter.matches("decoder.weight"));
|
||||
assert!(filter.matches("encoder.weight"));
|
||||
assert!(!filter.matches("decoder.bias"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn predicates() {
|
||||
fn contains_norm(path: &str, _container_path: &str) -> bool {
|
||||
path.contains("norm")
|
||||
}
|
||||
|
||||
fn is_short(path: &str, _container_path: &str) -> bool {
|
||||
path.len() < 10
|
||||
}
|
||||
|
||||
let filter = PathFilter::new()
|
||||
.with_predicate(contains_norm)
|
||||
.with_predicate(is_short);
|
||||
|
||||
assert!(filter.matches("norm.weight"));
|
||||
assert!(filter.matches("layer.norm.bias"));
|
||||
assert!(filter.matches("bias"));
|
||||
assert!(!filter.matches("encoder.decoder.weight.long.name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_filters() {
|
||||
let filter = PathFilter::new()
|
||||
.with_full_path("special.tensor")
|
||||
.with_predicate(|path, _container_path| path.contains("attention"));
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
let filter = filter.with_regex(r"^encoder\..*");
|
||||
|
||||
assert!(filter.matches("special.tensor"));
|
||||
assert!(filter.matches("self_attention.query"));
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
assert!(filter.matches("encoder.anything"));
|
||||
|
||||
assert!(!filter.matches("decoder.weight"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn or_combination() {
|
||||
let encoder_filter = PathFilter::new().with_full_path("encoder.weight");
|
||||
let decoder_filter = PathFilter::new().with_full_path("decoder.bias");
|
||||
|
||||
let combined = encoder_filter.or(decoder_filter);
|
||||
|
||||
assert!(combined.matches("encoder.weight"));
|
||||
assert!(combined.matches("decoder.bias"));
|
||||
assert!(!combined.matches("model.head.weight"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn common_patterns() {
|
||||
// Test encoder pattern
|
||||
let encoder = PathFilter::new().with_regex(r"^encoder\..*");
|
||||
assert!(encoder.matches("encoder.weight"));
|
||||
assert!(!encoder.matches("decoder.weight"));
|
||||
|
||||
// Test weights-only pattern
|
||||
let weights = PathFilter::new().with_regex(r".*\.weight$");
|
||||
assert!(weights.matches("encoder.weight"));
|
||||
assert!(weights.matches("decoder.weight"));
|
||||
assert!(!weights.matches("encoder.bias"));
|
||||
|
||||
// Test layer-specific patterns
|
||||
let layers = PathFilter::new()
|
||||
.with_regex(r"(^|.*\.)layers\.0\.")
|
||||
.with_regex(r"(^|.*\.)layers\.2\.")
|
||||
.with_regex(r"(^|.*\.)layers\.4\.");
|
||||
assert!(layers.matches("model.layers.0.weight"));
|
||||
assert!(layers.matches("layers.2.bias"));
|
||||
assert!(!layers.matches("layers.1.weight"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn criteria_count() {
|
||||
let filter = PathFilter::new()
|
||||
.with_full_path("path1")
|
||||
.with_full_path("path2")
|
||||
.with_predicate(|_, _| true);
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
let filter = filter.with_regex(".*");
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
assert_eq!(filter.criteria_count(), 4);
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
assert_eq!(filter.criteria_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_operations() {
|
||||
let mut filter = PathFilter::new().with_full_path("test");
|
||||
|
||||
filter.clear_paths();
|
||||
assert!(!filter.matches("test"));
|
||||
|
||||
filter.clear();
|
||||
assert!(filter.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn container_predicates() {
|
||||
// Filter that matches only Linear module weights
|
||||
let linear_weights = PathFilter::new().with_predicate(|path, container_path| {
|
||||
container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight")
|
||||
});
|
||||
|
||||
assert!(linear_weights.matches_with_container("layer1.weight", "Linear"));
|
||||
assert!(!linear_weights.matches_with_container("layer1.weight", "Conv2d"));
|
||||
assert!(!linear_weights.matches_with_container("layer1.bias", "Linear"));
|
||||
|
||||
// Filter for specific container types
|
||||
let conv_only = PathFilter::new().with_predicate(|_path, container_path| {
|
||||
let last = container_path.split('.').next_back();
|
||||
last == Some("Conv2d") || last == Some("ConvTranspose2d")
|
||||
});
|
||||
|
||||
assert!(conv_only.matches_with_container("encoder.weight", "Conv2d"));
|
||||
assert!(conv_only.matches_with_container("decoder.weight", "ConvTranspose2d"));
|
||||
assert!(!conv_only.matches_with_container("fc.weight", "Linear"));
|
||||
|
||||
// Combine path and container predicates
|
||||
let combined = PathFilter::new()
|
||||
.with_predicate(|path, _container_path| path.starts_with("encoder."))
|
||||
.with_predicate(|_path, container_path| {
|
||||
container_path.split('.').next_back() == Some("BatchNorm2d")
|
||||
});
|
||||
|
||||
// Should match either condition (OR logic)
|
||||
assert!(combined.matches_with_container("encoder.layer1", "Linear"));
|
||||
assert!(combined.matches_with_container("decoder.bn", "BatchNorm2d"));
|
||||
assert!(!combined.matches_with_container("decoder.layer", "Linear"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn container_predicate_with_regex() {
|
||||
// Combine regex patterns with container predicates
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
let filter = PathFilter::new()
|
||||
.with_regex(r"^encoder\..*")
|
||||
.with_predicate(|path, container_path| {
|
||||
container_path.split('.').next_back() == Some("Linear")
|
||||
&& path.contains(".bias")
|
||||
});
|
||||
|
||||
// Matches due to regex
|
||||
assert!(filter.matches_with_container("encoder.layer1.weight", "Conv2d"));
|
||||
// Matches due to container predicate
|
||||
assert!(filter.matches_with_container("decoder.fc.bias", "Linear"));
|
||||
// Doesn't match either
|
||||
assert!(!filter.matches_with_container("decoder.conv.weight", "Conv2d"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn container_stack_predicates() {
|
||||
// Filter using full container path - only tensors nested in a specific hierarchy
|
||||
let nested_filter = PathFilter::new().with_predicate(|_path, container_path| {
|
||||
// Check if tensor is nested within: Model -> TransformerBlock -> Linear
|
||||
let parts: Vec<&str> = container_path.split('.').collect();
|
||||
parts.len() >= 3
|
||||
&& parts[0] == "Model"
|
||||
&& parts[1] == "TransformerBlock"
|
||||
&& parts[2] == "Linear"
|
||||
});
|
||||
|
||||
assert!(nested_filter.matches_with_container_path_str(
|
||||
"encoder.weight",
|
||||
"Model.TransformerBlock.Linear.Param"
|
||||
));
|
||||
assert!(
|
||||
!nested_filter
|
||||
.matches_with_container_path_str("decoder.weight", "Model.Decoder.Linear.Param")
|
||||
);
|
||||
assert!(!nested_filter.matches_with_container_path_str(
|
||||
"encoder.weight",
|
||||
"Model.TransformerBlock.Conv2d.Param"
|
||||
));
|
||||
|
||||
// Filter that checks for specific depth in hierarchy
|
||||
let depth_filter = PathFilter::new().with_predicate(|_path, container_path| {
|
||||
let parts: Vec<&str> = container_path.split('.').collect();
|
||||
parts.len() == 4 && parts.get(2) == Some(&"Linear")
|
||||
});
|
||||
|
||||
assert!(depth_filter.matches_with_container_path_str(
|
||||
"model.layer.weight",
|
||||
"Model.TransformerBlock.Linear.Param"
|
||||
));
|
||||
assert!(
|
||||
!depth_filter
|
||||
.matches_with_container_path_str("model.weight", "Model.TransformerBlock.Conv2d")
|
||||
); // Too shallow
|
||||
|
||||
// Filter that checks any Linear in the path (not just the last)
|
||||
let any_linear = PathFilter::new()
|
||||
.with_predicate(|_path, container_path| container_path.contains("Linear"));
|
||||
|
||||
assert!(
|
||||
any_linear.matches_with_container_path_str(
|
||||
"some.path",
|
||||
"Model.TransformerBlock.Linear.Param"
|
||||
)
|
||||
);
|
||||
assert!(
|
||||
any_linear.matches_with_container_path_str("other.path", "Model.Decoder.Linear.Param")
|
||||
);
|
||||
assert!(
|
||||
!any_linear.matches_with_container_path_str(
|
||||
"conv.path",
|
||||
"Model.TransformerBlock.Conv2d.Param"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn container_path_dot_notation() {
|
||||
// Filter using dot-notated container path
|
||||
let dot_filter = PathFilter::new().with_predicate(|_path, container_path| {
|
||||
container_path.starts_with("Model.TransformerBlock")
|
||||
});
|
||||
|
||||
// Test with matches_with_container_path
|
||||
assert!(
|
||||
dot_filter.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
|
||||
);
|
||||
assert!(!dot_filter.matches_with_container_path_str("weight", "Model.Decoder.Linear"));
|
||||
|
||||
// Filter that checks for specific patterns in container path
|
||||
let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| {
|
||||
// Match any path that has Linear after a block
|
||||
container_path.contains("Block.Linear") || container_path.contains("Block.Conv")
|
||||
});
|
||||
|
||||
assert!(
|
||||
pattern_filter
|
||||
.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
|
||||
);
|
||||
assert!(pattern_filter.matches_with_container_path_str("weight", "Model.ResBlock.Conv2d"));
|
||||
assert!(!pattern_filter.matches_with_container_path_str("weight", "Model.Linear.Param"));
|
||||
|
||||
// Filter combining path and container path patterns
|
||||
let combined = PathFilter::new().with_predicate(|path, container_path| {
|
||||
// Only weights in Linear layers that are inside blocks
|
||||
path.ends_with(".weight")
|
||||
&& container_path.contains("Block")
|
||||
&& container_path.split('.').next_back() == Some("Linear")
|
||||
});
|
||||
|
||||
assert!(
|
||||
combined
|
||||
.matches_with_container_path_str("layer.weight", "Model.TransformerBlock.Linear")
|
||||
);
|
||||
assert!(
|
||||
!combined
|
||||
.matches_with_container_path_str("layer.bias", "Model.TransformerBlock.Linear")
|
||||
);
|
||||
assert!(!combined.matches_with_container_path_str("layer.weight", "Model.Decoder.Linear"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,674 @@
|
||||
use alloc::collections::BTreeMap;
|
||||
use alloc::string::{String, ToString};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use regex::{self, Regex};
|
||||
|
||||
use crate::TensorSnapshot;
|
||||
|
||||
/// Key remapper for transforming tensor names.
|
||||
///
|
||||
/// This allows mapping tensor names from one naming convention to another,
|
||||
/// which is useful for loading models from different frameworks or versions.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// # use burn_store::KeyRemapper;
|
||||
/// // Create a key remapper
|
||||
/// let remapper = KeyRemapper::new()
|
||||
/// .add_pattern(r"^pytorch\.(.*)", "burn.$1").expect("valid regex") // pytorch.layer -> burn.layer
|
||||
/// .add_pattern(r"\.gamma$", ".weight").expect("valid regex"); // layer.gamma -> layer.weight
|
||||
///
|
||||
/// // Use remapper with stores
|
||||
/// // store.remap(remapper)
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct KeyRemapper {
|
||||
/// Pattern-based remapping rules (regex pattern, replacement string)
|
||||
pub patterns: Vec<(Regex, String)>,
|
||||
}
|
||||
|
||||
impl KeyRemapper {
|
||||
/// Create a new empty key remapper
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add a remapping pattern (compiles regex)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `from` - Source pattern (regex string)
|
||||
/// * `to` - Replacement string (can include capture groups like `$1`)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Self)` - Updated remapping configuration
|
||||
/// * `Err(regex::Error)` - If regex compilation fails
|
||||
pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
|
||||
where
|
||||
S1: AsRef<str>,
|
||||
S2: Into<String>,
|
||||
{
|
||||
let regex = Regex::new(from.as_ref())?;
|
||||
self.patterns.push((regex, to.into()));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Create from a list of compiled regex patterns
|
||||
pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
|
||||
Self { patterns }
|
||||
}
|
||||
|
||||
/// Create from string patterns (will compile to regex)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `patterns` - Vector of (pattern, replacement) tuples
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Self)` - New remapping configuration
|
||||
/// * `Err(regex::Error)` - If any regex compilation fails
|
||||
pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
|
||||
where
|
||||
S1: AsRef<str>,
|
||||
S2: Into<String>,
|
||||
{
|
||||
let mut compiled_patterns = Vec::new();
|
||||
for (pattern, replacement) in patterns {
|
||||
let regex = Regex::new(pattern.as_ref())?;
|
||||
compiled_patterns.push((regex, replacement.into()));
|
||||
}
|
||||
Ok(Self {
|
||||
patterns: compiled_patterns,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from an iterator of patterns
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `iter` - Iterator yielding (pattern, replacement) tuples
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Self)` - New remapping configuration
|
||||
/// * `Err(regex::Error)` - If any regex compilation fails
|
||||
pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
|
||||
where
|
||||
I: IntoIterator<Item = (S1, S2)>,
|
||||
S1: AsRef<str>,
|
||||
S2: Into<String>,
|
||||
{
|
||||
let patterns: Result<Vec<_>, _> = iter
|
||||
.into_iter()
|
||||
.map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
|
||||
.collect();
|
||||
Ok(Self {
|
||||
patterns: patterns?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the remapping is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.patterns.is_empty()
|
||||
}
|
||||
|
||||
/// Convert to the format expected by remap_tensor_paths_with_patterns
|
||||
pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
|
||||
self.patterns.clone()
|
||||
}
|
||||
|
||||
/// Remap tensor paths using the configured patterns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - Vec of TensorSnapshots to remap
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple containing:
|
||||
/// * The remapped Vec of TensorSnapshots with updated paths
|
||||
/// * A vector of (new_path, original_path) showing the transformations
|
||||
pub fn remap(
|
||||
&self,
|
||||
mut tensors: Vec<TensorSnapshot>,
|
||||
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
|
||||
if self.patterns.is_empty() {
|
||||
let remapped_names = tensors
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let path = v.full_path();
|
||||
(path.clone(), path)
|
||||
})
|
||||
.collect();
|
||||
return (tensors, remapped_names);
|
||||
}
|
||||
|
||||
let mut remapped_snapshots = Vec::new();
|
||||
let mut remapped_names = Vec::new();
|
||||
|
||||
for mut snapshot in tensors.drain(..) {
|
||||
let original_path = snapshot.full_path();
|
||||
let mut new_path = original_path.clone();
|
||||
|
||||
// Apply all patterns to get the new path
|
||||
for (pattern, replacement) in &self.patterns {
|
||||
if pattern.is_match(&new_path) {
|
||||
new_path = pattern
|
||||
.replace_all(&new_path, replacement.as_str())
|
||||
.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Update the snapshot's internal path_stack if the path changed
|
||||
if new_path != original_path
|
||||
&& let Some(ref mut path_stack) = snapshot.path_stack
|
||||
{
|
||||
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
|
||||
}
|
||||
|
||||
remapped_names.push((new_path.clone(), original_path));
|
||||
remapped_snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
(remapped_snapshots, remapped_names)
|
||||
}
|
||||
}
|
||||
|
||||
/// Map tensor paths to have contiguous numeric indices.
|
||||
///
|
||||
/// This function detects numeric indices in tensor paths and renumbers them
|
||||
/// to be contiguous (0, 1, 2, ...) while preserving their relative order.
|
||||
/// It handles nested sequential structures by processing ALL numeric indices
|
||||
/// in each path independently based on their position context.
|
||||
///
|
||||
/// This is useful when loading PyTorch models that have gaps in layer numbering,
|
||||
/// such as when using `nn.Sequential` with mixed layer types (e.g., Conv2d + ReLU
|
||||
/// where only Conv2d has parameters).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// Simple case - input paths:
|
||||
/// - `fc.0.weight`, `fc.0.bias`
|
||||
/// - `fc.2.weight`, `fc.2.bias`
|
||||
/// - `fc.4.weight`, `fc.4.bias`
|
||||
///
|
||||
/// Output paths:
|
||||
/// - `fc.0.weight`, `fc.0.bias`
|
||||
/// - `fc.1.weight`, `fc.1.bias`
|
||||
/// - `fc.2.weight`, `fc.2.bias`
|
||||
///
|
||||
/// Nested case - input paths:
|
||||
/// - `feature.layers.0.conv_block.0.weight`
|
||||
/// - `feature.layers.0.conv_block.2.weight`
|
||||
/// - `feature.layers.2.conv_block.0.weight`
|
||||
/// - `feature.layers.2.conv_block.2.weight`
|
||||
///
|
||||
/// Output paths:
|
||||
/// - `feature.layers.0.conv_block.0.weight`
|
||||
/// - `feature.layers.0.conv_block.1.weight`
|
||||
/// - `feature.layers.1.conv_block.0.weight`
|
||||
/// - `feature.layers.1.conv_block.1.weight`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - Vec of TensorSnapshots to map
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple containing:
|
||||
/// * The mapped Vec of TensorSnapshots with updated paths
|
||||
/// * A vector of (new_path, original_path) showing the transformations
|
||||
pub fn map_indices_contiguous(
|
||||
mut tensors: Vec<TensorSnapshot>,
|
||||
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
|
||||
if tensors.is_empty() {
|
||||
return (tensors, Vec::new());
|
||||
}
|
||||
|
||||
// Step 1: Collect all paths and find all index positions
|
||||
// For each index position (identified by prefix using ORIGINAL indices),
|
||||
// collect all indices seen at that position.
|
||||
//
|
||||
// Key: prefix using original path (e.g., "feature.layers." or "feature.layers.0.conv_block.")
|
||||
// Value: BTreeMap of original_index -> new_index
|
||||
let mut index_maps: BTreeMap<String, BTreeMap<usize, usize>> = BTreeMap::new();
|
||||
|
||||
// First pass: collect all indices at each position using original prefixes
|
||||
for snapshot in &tensors {
|
||||
let path = snapshot.full_path();
|
||||
let parts: Vec<&str> = path.split('.').collect();
|
||||
|
||||
// Check each part for numeric indices
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if let Ok(index) = part.parse::<usize>() {
|
||||
// The prefix is everything before this index (using original path)
|
||||
let prefix = if i > 0 {
|
||||
format!("{}.", parts[..i].join("."))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
index_maps
|
||||
.entry(prefix)
|
||||
.or_default()
|
||||
.entry(index)
|
||||
.or_insert(usize::MAX); // Placeholder
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: assign contiguous indices for each position
|
||||
for indices in index_maps.values_mut() {
|
||||
let mut sorted_indices: Vec<usize> = indices.keys().cloned().collect();
|
||||
sorted_indices.sort();
|
||||
|
||||
for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() {
|
||||
indices.insert(old_idx, new_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Third pass: apply the remapping to all tensors
|
||||
// We use original prefixes for lookup since that's how we collected indices
|
||||
let mut mapped_snapshots = Vec::new();
|
||||
let mut transformations = Vec::new();
|
||||
|
||||
for mut snapshot in tensors.drain(..) {
|
||||
let original_path = snapshot.full_path();
|
||||
let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps);
|
||||
|
||||
// Update the snapshot's internal path_stack if the path changed
|
||||
if new_path != original_path
|
||||
&& let Some(ref mut path_stack) = snapshot.path_stack
|
||||
{
|
||||
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
|
||||
}
|
||||
|
||||
transformations.push((new_path, original_path));
|
||||
mapped_snapshots.push(snapshot);
|
||||
}
|
||||
|
||||
(mapped_snapshots, transformations)
|
||||
}
|
||||
|
||||
/// Remap all numeric indices in a path using the provided index maps.
|
||||
/// Uses original path prefixes for lookup.
|
||||
fn remap_all_indices_with_original_prefix(
|
||||
path: &str,
|
||||
index_maps: &BTreeMap<String, BTreeMap<usize, usize>>,
|
||||
) -> String {
|
||||
let parts: Vec<&str> = path.split('.').collect();
|
||||
let mut result_parts: Vec<String> = Vec::with_capacity(parts.len());
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if let Ok(index) = part.parse::<usize>() {
|
||||
// Build the prefix from ORIGINAL parts (not remapped)
|
||||
let prefix = if i > 0 {
|
||||
format!("{}.", parts[..i].join("."))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Look up the new index using original prefix
|
||||
if let Some(index_map) = index_maps.get(&prefix)
|
||||
&& let Some(&new_index) = index_map.get(&index)
|
||||
{
|
||||
result_parts.push(new_index.to_string());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Not a numeric index or no mapping found, keep as-is
|
||||
result_parts.push((*part).to_string());
|
||||
}
|
||||
|
||||
result_parts.join(".")
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_core::module::ParamId;
|
||||
use burn_tensor::TensorData;
|
||||
|
||||
fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
|
||||
let data = TensorData {
|
||||
bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
|
||||
shape: vec![2, 2],
|
||||
dtype: burn_tensor::DType::F32,
|
||||
};
|
||||
let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
|
||||
TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_remapper_basic() {
|
||||
let remapper = KeyRemapper::new()
|
||||
.add_pattern(r"^encoder\.", "transformer.encoder.")
|
||||
.expect("valid regex");
|
||||
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("encoder.layer1.weight"),
|
||||
create_test_tensor_snapshot("decoder.layer1.weight"),
|
||||
];
|
||||
|
||||
let (remapped, transformations) = remapper.remap(tensors);
|
||||
|
||||
// Check that remapped views exist with correct paths
|
||||
assert!(
|
||||
remapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
|
||||
);
|
||||
assert!(
|
||||
remapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "decoder.layer1.weight")
|
||||
);
|
||||
assert_eq!(remapped.len(), 2);
|
||||
|
||||
// Check transformations
|
||||
let encoder_transform = transformations
|
||||
.iter()
|
||||
.find(|(_new, old)| old == "encoder.layer1.weight")
|
||||
.expect("should find encoder transformation");
|
||||
assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_remapper_multiple_patterns() {
|
||||
let remapper = KeyRemapper::new()
|
||||
.add_pattern(r"^encoder\.", "transformer.encoder.")
|
||||
.expect("valid regex")
|
||||
.add_pattern(r"\.gamma$", ".weight")
|
||||
.expect("valid regex");
|
||||
|
||||
let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
|
||||
|
||||
let (remapped, _) = remapper.remap(tensors);
|
||||
|
||||
assert!(
|
||||
remapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
|
||||
);
|
||||
assert_eq!(remapped.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_remapper_from_patterns() {
|
||||
let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
|
||||
let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
|
||||
|
||||
let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
|
||||
|
||||
let (remapped, _) = remapper.remap(tensors);
|
||||
|
||||
assert!(
|
||||
remapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "burn.linear.bias_param")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_remapper_empty() {
|
||||
let remapper = KeyRemapper::new();
|
||||
assert!(remapper.is_empty());
|
||||
|
||||
let tensors = vec![create_test_tensor_snapshot("test.weight")];
|
||||
|
||||
let (remapped, transformations) = remapper.remap(tensors);
|
||||
|
||||
assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
|
||||
assert_eq!(remapped.len(), 1);
|
||||
assert_eq!(transformations.len(), 1);
|
||||
assert_eq!(
|
||||
transformations[0],
|
||||
("test.weight".to_string(), "test.weight".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_basic() {
|
||||
// Simulate PyTorch nn.Sequential with Conv2d (0, 2, 4) and ReLU (1, 3, 5)
|
||||
// Only Conv2d layers have parameters
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("fc.0.weight"),
|
||||
create_test_tensor_snapshot("fc.0.bias"),
|
||||
create_test_tensor_snapshot("fc.2.weight"),
|
||||
create_test_tensor_snapshot("fc.2.bias"),
|
||||
create_test_tensor_snapshot("fc.4.weight"),
|
||||
create_test_tensor_snapshot("fc.4.bias"),
|
||||
];
|
||||
|
||||
let (reindexed, transformations) = map_indices_contiguous(tensors);
|
||||
|
||||
// Check that indices are now contiguous
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.bias"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.bias"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.bias"));
|
||||
assert_eq!(reindexed.len(), 6);
|
||||
|
||||
// Check transformations
|
||||
let transform_2_to_1 = transformations
|
||||
.iter()
|
||||
.find(|(_, old)| old == "fc.2.weight")
|
||||
.expect("should find fc.2.weight transformation");
|
||||
assert_eq!(transform_2_to_1.0, "fc.1.weight");
|
||||
|
||||
let transform_4_to_2 = transformations
|
||||
.iter()
|
||||
.find(|(_, old)| old == "fc.4.weight")
|
||||
.expect("should find fc.4.weight transformation");
|
||||
assert_eq!(transform_4_to_2.0, "fc.2.weight");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_already_contiguous() {
|
||||
// Already contiguous indices should remain unchanged
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("fc.0.weight"),
|
||||
create_test_tensor_snapshot("fc.1.weight"),
|
||||
create_test_tensor_snapshot("fc.2.weight"),
|
||||
];
|
||||
|
||||
let (reindexed, transformations) = map_indices_contiguous(tensors);
|
||||
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
|
||||
assert_eq!(reindexed.len(), 3);
|
||||
|
||||
// All transformations should have same old and new paths
|
||||
for (new, old) in &transformations {
|
||||
assert_eq!(new, old);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_multiple_prefixes() {
|
||||
// Different prefixes should be mapped independently
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("encoder.0.weight"),
|
||||
create_test_tensor_snapshot("encoder.2.weight"),
|
||||
create_test_tensor_snapshot("decoder.1.weight"),
|
||||
create_test_tensor_snapshot("decoder.5.weight"),
|
||||
];
|
||||
|
||||
let (reindexed, _) = map_indices_contiguous(tensors);
|
||||
|
||||
// encoder: 0, 2 -> 0, 1
|
||||
assert!(
|
||||
reindexed
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "encoder.0.weight")
|
||||
);
|
||||
assert!(
|
||||
reindexed
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "encoder.1.weight")
|
||||
);
|
||||
|
||||
// decoder: 1, 5 -> 0, 1
|
||||
assert!(
|
||||
reindexed
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "decoder.0.weight")
|
||||
);
|
||||
assert!(
|
||||
reindexed
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "decoder.1.weight")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_no_indices() {
|
||||
// Paths without indices should remain unchanged
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("encoder.weight"),
|
||||
create_test_tensor_snapshot("decoder.bias"),
|
||||
];
|
||||
|
||||
let (reindexed, transformations) = map_indices_contiguous(tensors);
|
||||
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "encoder.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "decoder.bias"));
|
||||
|
||||
for (new, old) in &transformations {
|
||||
assert_eq!(new, old);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_empty() {
|
||||
let tensors: Vec<TensorSnapshot> = vec![];
|
||||
let (reindexed, transformations) = map_indices_contiguous(tensors);
|
||||
|
||||
assert!(reindexed.is_empty());
|
||||
assert!(transformations.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() {
|
||||
// Mix of indexed and non-indexed paths
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("fc.0.weight"),
|
||||
create_test_tensor_snapshot("fc.2.weight"),
|
||||
create_test_tensor_snapshot("output.weight"), // no index
|
||||
];
|
||||
|
||||
let (reindexed, _) = map_indices_contiguous(tensors);
|
||||
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); // 2 -> 1
|
||||
assert!(reindexed.iter().any(|v| v.full_path() == "output.weight")); // unchanged
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_nested_sequential() {
|
||||
// Test nested sequential structures like:
|
||||
// feature = nn.Sequential(ConvBlock, ReLU, ConvBlock, ReLU, ConvBlock)
|
||||
// where ConvBlock = nn.Sequential(Conv2d, ReLU, Conv2d)
|
||||
//
|
||||
// This produces paths like:
|
||||
// feature.layers.0.conv_block.0.weight (layer 0, conv 0)
|
||||
// feature.layers.0.conv_block.2.weight (layer 0, conv 2 - skipping ReLU at 1)
|
||||
// feature.layers.2.conv_block.0.weight (layer 2 - skipping ReLU at 1, conv 0)
|
||||
// feature.layers.2.conv_block.2.weight (layer 2, conv 2)
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("feature.layers.0.conv_block.0.weight"),
|
||||
create_test_tensor_snapshot("feature.layers.0.conv_block.2.weight"),
|
||||
create_test_tensor_snapshot("feature.layers.2.conv_block.0.weight"),
|
||||
create_test_tensor_snapshot("feature.layers.2.conv_block.2.weight"),
|
||||
];
|
||||
|
||||
let (mapped, transformations) = map_indices_contiguous(tensors);
|
||||
|
||||
// Expected mapping:
|
||||
// feature.layers: 0, 2 -> 0, 1
|
||||
// feature.layers.0.conv_block: 0, 2 -> 0, 1
|
||||
// feature.layers.2.conv_block: 0, 2 -> 0, 1
|
||||
//
|
||||
// Result:
|
||||
// feature.layers.0.conv_block.0.weight -> feature.layers.0.conv_block.0.weight
|
||||
// feature.layers.0.conv_block.2.weight -> feature.layers.0.conv_block.1.weight
|
||||
// feature.layers.2.conv_block.0.weight -> feature.layers.1.conv_block.0.weight
|
||||
// feature.layers.2.conv_block.2.weight -> feature.layers.1.conv_block.1.weight
|
||||
|
||||
assert!(
|
||||
mapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "feature.layers.0.conv_block.0.weight"),
|
||||
"0.0 should stay as 0.0"
|
||||
);
|
||||
assert!(
|
||||
mapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "feature.layers.0.conv_block.1.weight"),
|
||||
"0.2 should become 0.1"
|
||||
);
|
||||
assert!(
|
||||
mapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "feature.layers.1.conv_block.0.weight"),
|
||||
"2.0 should become 1.0"
|
||||
);
|
||||
assert!(
|
||||
mapped
|
||||
.iter()
|
||||
.any(|v| v.full_path() == "feature.layers.1.conv_block.1.weight"),
|
||||
"2.2 should become 1.1"
|
||||
);
|
||||
|
||||
// Verify specific transformations
|
||||
let t1 = transformations
|
||||
.iter()
|
||||
.find(|(_, old)| old == "feature.layers.2.conv_block.2.weight");
|
||||
assert_eq!(
|
||||
t1.map(|(new, _)| new.as_str()),
|
||||
Some("feature.layers.1.conv_block.1.weight"),
|
||||
"2.2 should map to 1.1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map_indices_contiguous_deeply_nested() {
|
||||
// Test with three levels of nesting
|
||||
let tensors = vec![
|
||||
create_test_tensor_snapshot("a.0.b.0.c.0.weight"),
|
||||
create_test_tensor_snapshot("a.0.b.0.c.2.weight"),
|
||||
create_test_tensor_snapshot("a.0.b.2.c.0.weight"),
|
||||
create_test_tensor_snapshot("a.2.b.0.c.0.weight"),
|
||||
];
|
||||
|
||||
let (mapped, _) = map_indices_contiguous(tensors);
|
||||
|
||||
// a: 0, 2 -> 0, 1
|
||||
// a.0.b: 0, 2 -> 0, 1
|
||||
// a.2.b: 0 -> 0
|
||||
// a.0.b.0.c: 0, 2 -> 0, 1
|
||||
// a.0.b.2.c: 0 -> 0
|
||||
// a.2.b.0.c: 0 -> 0
|
||||
|
||||
assert!(mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.0.weight"));
|
||||
assert!(
|
||||
mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.1.weight"),
|
||||
"a.0.b.0.c.2 should become a.0.b.0.c.1"
|
||||
);
|
||||
assert!(
|
||||
mapped.iter().any(|v| v.full_path() == "a.0.b.1.c.0.weight"),
|
||||
"a.0.b.2.c.0 should become a.0.b.1.c.0"
|
||||
);
|
||||
assert!(
|
||||
mapped.iter().any(|v| v.full_path() == "a.1.b.0.c.0.weight"),
|
||||
"a.2.b.0.c.0 should become a.1.b.0.c.0"
|
||||
);
|
||||
}
|
||||
}
|
||||
118
crates/stable-diffusion-burn/burn-crates/burn-store/src/lib.rs
Normal file
118
crates/stable-diffusion-burn/burn-crates/burn-store/src/lib.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
|
||||
//! # Burn Store
|
||||
//!
|
||||
//! Advanced model storage and serialization infrastructure for the Burn deep learning framework.
|
||||
//!
|
||||
//! This crate provides comprehensive functionality for storing and loading Burn modules
|
||||
//! and their tensor data, with support for cross-framework interoperability, flexible filtering,
|
||||
//! and efficient memory management through lazy materialization.
|
||||
//!
|
||||
//! ## Key Features
|
||||
//!
|
||||
//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support
|
||||
//! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization
|
||||
//! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation
|
||||
//! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance
|
||||
//! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates
|
||||
//! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility
|
||||
//! - **No-std Support**: Core functionality available in embedded and WASM environments
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ### Basic Save and Load
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use burn_store::{ModuleSnapshot, SafetensorsStore};
|
||||
//!
|
||||
//! // Save a model
|
||||
//! let mut store = SafetensorsStore::from_file("model.safetensors");
|
||||
//! model.save_into(&mut store)?;
|
||||
//!
|
||||
//! // Load a model
|
||||
//! let mut store = SafetensorsStore::from_file("model.safetensors");
|
||||
//! model.load_from(&mut store)?;
|
||||
//! ```
|
||||
//!
|
||||
//! ### Loading PyTorch Models
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use burn_store::PytorchStore;
|
||||
//!
|
||||
//! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter)
|
||||
//! let mut store = PytorchStore::from_file("pytorch_model.pth")
|
||||
//! .with_top_level_key("state_dict") // Access nested state dict if needed
|
||||
//! .allow_partial(true); // Skip unknown tensors
|
||||
//!
|
||||
//! model.load_from(&mut store)?;
|
||||
//! ```
|
||||
//!
|
||||
//! ### Filtering and Remapping
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! # use burn_store::SafetensorsStore;
|
||||
//! // Save only specific layers with renaming
|
||||
//! let mut store = SafetensorsStore::from_file("encoder.safetensors")
|
||||
//! .with_regex(r"^encoder\..*") // Filter: only encoder layers
|
||||
//! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X
|
||||
//! .metadata("subset", "encoder_only");
|
||||
//!
|
||||
//! // Use store with model.save_into(&mut store)?;
|
||||
//! ```
|
||||
//!
|
||||
//! ## Core Components
|
||||
//!
|
||||
//! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods
|
||||
//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows
|
||||
//! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format
|
||||
//! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files
|
||||
//! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving
|
||||
//! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns
|
||||
//! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility
|
||||
//!
|
||||
//! ## Feature Flags
|
||||
//!
|
||||
//! - `std`: Enables file I/O and other std-only features (default)
|
||||
//! - `safetensors`: Enables SafeTensors format support (default)
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
mod adapter;
|
||||
mod applier;
|
||||
mod apply_result;
|
||||
mod collector;
|
||||
mod filter;
|
||||
mod tensor_snapshot;
|
||||
mod traits;
|
||||
|
||||
pub use adapter::{
|
||||
BurnToPyTorchAdapter, ChainAdapter, IdentityAdapter, ModuleAdapter, PyTorchToBurnAdapter,
|
||||
};
|
||||
pub use applier::Applier;
|
||||
pub use apply_result::{ApplyError, ApplyResult};
|
||||
pub use collector::Collector;
|
||||
pub use filter::PathFilter;
|
||||
pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError};
|
||||
pub use traits::{ModuleSnapshot, ModuleStore};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod keyremapper;
|
||||
#[cfg(feature = "std")]
|
||||
pub use keyremapper::{KeyRemapper, map_indices_contiguous};
|
||||
|
||||
#[cfg(feature = "pytorch")]
|
||||
pub mod pytorch;
|
||||
#[cfg(feature = "pytorch")]
|
||||
pub use pytorch::{PytorchStore, PytorchStoreError};
|
||||
|
||||
#[cfg(feature = "safetensors")]
|
||||
mod safetensors;
|
||||
#[cfg(feature = "safetensors")]
|
||||
pub use safetensors::{SafetensorsStore, SafetensorsStoreError};
|
||||
|
||||
#[cfg(feature = "burnpack")]
|
||||
mod burnpack;
|
||||
#[cfg(feature = "burnpack")]
|
||||
pub use burnpack::writer::BurnpackWriter;
|
||||
#[cfg(feature = "burnpack")]
|
||||
pub use burnpack::{base::BurnpackError, store::BurnpackStore};
|
||||
@@ -0,0 +1,567 @@
|
||||
//! Lazy data loading support for PyTorch files.
|
||||
//!
|
||||
//! This module provides abstractions for lazy loading of tensor data from PyTorch files,
|
||||
//! avoiding the need to load all data into memory upfront.
|
||||
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Seek};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use zip::ZipArchive;
|
||||
|
||||
/// A data source that can lazily load tensor data.
|
||||
#[derive(Clone)]
|
||||
pub enum LazyDataSource {
|
||||
/// ZIP archive with lazy loading
|
||||
Zip(Arc<Mutex<ZipSource>>),
|
||||
/// TAR archive format (older torchvision models)
|
||||
Tar(Arc<Mutex<TarSource>>),
|
||||
/// Legacy format with multiple storages in single blob
|
||||
LegacyMultiStorage(Arc<Mutex<LegacyMultiStorageSource>>),
|
||||
}
|
||||
|
||||
/// ZIP archive source for lazy loading
|
||||
pub struct ZipSource {
|
||||
path: PathBuf,
|
||||
// Cache the file list to avoid reopening archive repeatedly
|
||||
file_list: Vec<(String, u64, u64)>, // (name, offset, compressed_size)
|
||||
}
|
||||
|
||||
/// TAR archive source for lazy loading (older torchvision models like AlexNet, SqueezeNet)
|
||||
///
|
||||
/// Older PyTorch/torchvision models (pre-1.6) use TAR format instead of ZIP.
|
||||
/// The TAR archive contains:
|
||||
/// - `sys_info`: System info pickle (endianness, type sizes)
|
||||
/// - `pickle`: OrderedDict mapping tensor names to storage keys
|
||||
/// - `tensors`: Tensor metadata pickles (unused, metadata is embedded in pickle)
|
||||
/// - `storages`: Storage count + sequential (metadata pickle, element count, raw data)
|
||||
pub struct TarSource {
|
||||
/// Cached storage map: storage_key -> (offset_in_storages, size_bytes)
|
||||
storage_map: HashMap<String, (usize, usize)>,
|
||||
/// The raw storages data (kept in memory for TAR format)
|
||||
storages_data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Legacy multi-storage source for old PyTorch format (0.1.10 - 1.5)
|
||||
///
|
||||
/// Legacy format stores tensor data as concatenated raw binary without explicit
|
||||
/// storage boundaries. This source tracks storage usage during tensor parsing
|
||||
/// to build a storage map for lazy loading.
|
||||
///
|
||||
/// ## Storage Layout
|
||||
/// - Pickle metadata with tensor definitions
|
||||
/// - List of storage keys (determines concatenation order)
|
||||
/// - Raw binary blob with all storages concatenated
|
||||
pub struct LegacyMultiStorageSource {
|
||||
path: PathBuf,
|
||||
data_offset: u64,
|
||||
#[allow(dead_code)]
|
||||
data_size: u64,
|
||||
// Map of storage_key -> (offset_in_blob, size)
|
||||
storage_map: RwLock<Option<HashMap<String, (u64, u64)>>>,
|
||||
// Storage keys in order (for boundary calculation)
|
||||
storage_keys: RwLock<Option<Vec<String>>>,
|
||||
// Track storage usage as tensors are accessed
|
||||
storage_usage: RwLock<HashMap<String, usize>>, // key -> max_bytes_needed
|
||||
}
|
||||
|
||||
impl ZipSource {
|
||||
/// Create a new ZIP source
|
||||
pub fn new(path: PathBuf) -> std::io::Result<Self> {
|
||||
let file = File::open(&path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut archive = ZipArchive::new(reader)?;
|
||||
|
||||
// Cache file metadata
|
||||
let mut file_list = Vec::new();
|
||||
for i in 0..archive.len() {
|
||||
let file = archive.by_index(i)?;
|
||||
let name = file.name().to_string();
|
||||
let offset = file.data_start();
|
||||
let compressed_size = file.compressed_size();
|
||||
file_list.push((
|
||||
name,
|
||||
offset.expect("should have an offset"),
|
||||
compressed_size,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self { path, file_list })
|
||||
}
|
||||
|
||||
/// Check if a file exists in the archive
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.file_list.iter().any(|(n, _, _)| n == name)
|
||||
}
|
||||
|
||||
/// Get list of data files (excluding pickle files)
|
||||
pub fn data_files(&self) -> Vec<String> {
|
||||
self.file_list
|
||||
.iter()
|
||||
.filter(|(name, _, _)| name.starts_with("data/") || name.contains("/data/"))
|
||||
.filter(|(name, _, _)| !name.ends_with(".pkl") && !name.ends_with("/"))
|
||||
.map(|(name, _, _)| name.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Read a specific file from the archive
|
||||
pub fn read_file(&self, name: &str) -> std::io::Result<Vec<u8>> {
|
||||
let file = File::open(&self.path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut archive = ZipArchive::new(reader)?;
|
||||
|
||||
let mut file = archive.by_name(name)?;
|
||||
let mut contents = Vec::with_capacity(file.size() as usize);
|
||||
file.read_to_end(&mut contents)?;
|
||||
Ok(contents)
|
||||
}
|
||||
|
||||
/// Read a portion of a file
|
||||
pub fn read_file_range(
|
||||
&self,
|
||||
name: &str,
|
||||
offset: usize,
|
||||
length: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
let file = File::open(&self.path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut archive = ZipArchive::new(reader)?;
|
||||
|
||||
let mut file = archive.by_name(name)?;
|
||||
let mut buffer = vec![0u8; length];
|
||||
|
||||
// Skip to offset
|
||||
let mut skip_buffer = vec![0u8; offset.min(8192)];
|
||||
let mut skipped = 0;
|
||||
while skipped < offset {
|
||||
let to_skip = (offset - skipped).min(skip_buffer.len());
|
||||
file.read_exact(&mut skip_buffer[..to_skip])?;
|
||||
skipped += to_skip;
|
||||
}
|
||||
|
||||
// Read the requested data
|
||||
file.read_exact(&mut buffer)?;
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl LegacyMultiStorageSource {
|
||||
/// Create a new legacy multi-storage source
|
||||
pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self {
|
||||
Self {
|
||||
path,
|
||||
data_offset,
|
||||
data_size,
|
||||
storage_map: RwLock::new(None),
|
||||
storage_keys: RwLock::new(None),
|
||||
storage_usage: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the ordered storage keys from the pickle
|
||||
pub fn set_storage_keys(&self, keys: Vec<String>) {
|
||||
let mut storage_keys = self
|
||||
.storage_keys
|
||||
.write()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
*storage_keys = Some(keys);
|
||||
}
|
||||
|
||||
/// Track storage usage from tensor access
|
||||
/// This is called from within tensor loading closures
|
||||
pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) {
|
||||
let mut usage = self
|
||||
.storage_usage
|
||||
.write()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let max_extent = offset + size;
|
||||
usage
|
||||
.entry(storage_key.to_string())
|
||||
.and_modify(|current| *current = (*current).max(max_extent))
|
||||
.or_insert(max_extent);
|
||||
|
||||
// Try to build storage map if we have enough information
|
||||
drop(usage);
|
||||
self.try_build_storage_map();
|
||||
}
|
||||
|
||||
/// Try to build the storage map from tracked usage
|
||||
fn try_build_storage_map(&self) {
|
||||
// Only build if we don't already have a map
|
||||
if self
|
||||
.storage_map
|
||||
.read()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if we have storage keys
|
||||
let keys_guard = self
|
||||
.storage_keys
|
||||
.read()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
if let Some(ref keys) = *keys_guard {
|
||||
let usage = self
|
||||
.storage_usage
|
||||
.read()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
|
||||
// Only build if we have usage info for all storages
|
||||
if keys.iter().all(|k| usage.contains_key(k)) {
|
||||
let mut map = HashMap::new();
|
||||
let mut current_offset = 0u64;
|
||||
|
||||
for key in keys {
|
||||
if let Some(&size) = usage.get(key) {
|
||||
map.insert(key.clone(), (current_offset, size as u64));
|
||||
current_offset += size as u64;
|
||||
}
|
||||
}
|
||||
|
||||
// Set the storage map
|
||||
drop(keys_guard);
|
||||
drop(usage);
|
||||
let mut storage_map = self
|
||||
.storage_map
|
||||
.write()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
*storage_map = Some(map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read data for a specific storage key
|
||||
/// Only loads the specific storage portion, never the entire blob
|
||||
pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
|
||||
// Extract numeric key from paths like "data/0" or just "0"
|
||||
let storage_key = key.split('/').next_back().unwrap_or(key);
|
||||
|
||||
// Get storage map - must be available for lazy loading to work
|
||||
let storage_map = self
|
||||
.storage_map
|
||||
.read()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
|
||||
if let Some(ref map) = *storage_map
|
||||
&& let Some(&(offset, size)) = map.get(storage_key)
|
||||
{
|
||||
// Load only this specific storage
|
||||
let mut file = File::open(&self.path)?;
|
||||
file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?;
|
||||
|
||||
let mut buffer = vec![0u8; size as usize];
|
||||
file.read_exact(&mut buffer)?;
|
||||
return Ok(buffer);
|
||||
}
|
||||
|
||||
// NO FALLBACK! If we don't have storage boundaries, we cannot load data lazily
|
||||
// The storage map MUST be built from tensor metadata for lazy loading to work
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
|
||||
storage_key
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl TarSource {
|
||||
/// Create a new TAR source by parsing storages data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `storages_data` - Raw storages blob with structure:
|
||||
/// - Count pickle (number of storages)
|
||||
/// - For each storage: metadata pickle + u64 num_elements + raw binary data
|
||||
pub fn new(storages_data: Vec<u8>) -> std::io::Result<Self> {
|
||||
use super::pickle_reader::{read_pickle, storage_type_to_element_size};
|
||||
use std::io::Cursor;
|
||||
|
||||
let mut storage_map = HashMap::new();
|
||||
let mut pos = 0usize;
|
||||
|
||||
// First, read the count of storages
|
||||
let mut cursor = Cursor::new(&storages_data[pos..]);
|
||||
let storage_count =
|
||||
if let Ok(super::pickle_reader::Object::Int(count)) = read_pickle(&mut cursor) {
|
||||
pos += cursor.position() as usize;
|
||||
count as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Parse each storage entry
|
||||
for _i in 0..storage_count {
|
||||
if pos >= storages_data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Read the storage metadata pickle: (storage_key, device, storage_type)
|
||||
let mut cursor = Cursor::new(&storages_data[pos..]);
|
||||
if let Ok(obj) = read_pickle(&mut cursor) {
|
||||
let pickle_size = cursor.position() as usize;
|
||||
pos += pickle_size;
|
||||
|
||||
// Extract storage info from pickle tuple
|
||||
let (storage_key, storage_type) = match obj {
|
||||
super::pickle_reader::Object::Tuple(tuple) if tuple.len() >= 3 => {
|
||||
let key = match &tuple[0] {
|
||||
super::pickle_reader::Object::Int(i) => i.to_string(),
|
||||
super::pickle_reader::Object::String(s) => s.clone(),
|
||||
_ => continue,
|
||||
};
|
||||
// tuple[1] is device (e.g., "cpu")
|
||||
// tuple[2] is storage type class
|
||||
let stype = match &tuple[2] {
|
||||
super::pickle_reader::Object::Class { name, .. } => name.clone(),
|
||||
other => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Expected Class for storage type, got {:?}", other),
|
||||
));
|
||||
}
|
||||
};
|
||||
(key, stype)
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
// Read the number of elements (u64 little-endian)
|
||||
if pos + 8 > storages_data.len() {
|
||||
break;
|
||||
}
|
||||
let num_elements = u64::from_le_bytes([
|
||||
storages_data[pos],
|
||||
storages_data[pos + 1],
|
||||
storages_data[pos + 2],
|
||||
storages_data[pos + 3],
|
||||
storages_data[pos + 4],
|
||||
storages_data[pos + 5],
|
||||
storages_data[pos + 6],
|
||||
storages_data[pos + 7],
|
||||
]) as usize;
|
||||
pos += 8;
|
||||
|
||||
// Determine element size from storage type
|
||||
let element_size = storage_type_to_element_size(&storage_type)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
let data_size = num_elements * element_size;
|
||||
|
||||
// Store the offset to raw data and its size
|
||||
storage_map.insert(storage_key, (pos, data_size));
|
||||
|
||||
// Skip the raw binary data
|
||||
pos += data_size;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
storage_map,
|
||||
storages_data,
|
||||
})
|
||||
}
|
||||
|
||||
/// Read data for a specific storage key
|
||||
pub fn read_file(&self, key: &str) -> std::io::Result<Vec<u8>> {
|
||||
// Extract the storage key from paths like "data/0"
|
||||
let storage_key = key.split('/').next_back().unwrap_or(key);
|
||||
|
||||
if let Some(&(offset, size)) = self.storage_map.get(storage_key)
|
||||
&& offset + size <= self.storages_data.len()
|
||||
{
|
||||
return Ok(self.storages_data[offset..offset + size].to_vec());
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("Storage key '{}' not found in TAR archive", storage_key),
|
||||
))
|
||||
}
|
||||
|
||||
/// Read a range of data for a specific storage key (avoids double allocation)
|
||||
pub fn read_file_range(
|
||||
&self,
|
||||
key: &str,
|
||||
offset: usize,
|
||||
length: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
let storage_key = key.split('/').next_back().unwrap_or(key);
|
||||
|
||||
if let Some(&(storage_offset, storage_size)) = self.storage_map.get(storage_key)
|
||||
&& storage_offset + storage_size <= self.storages_data.len()
|
||||
{
|
||||
let start = storage_offset + offset;
|
||||
let end = (storage_offset + offset + length).min(storage_offset + storage_size);
|
||||
return Ok(self.storages_data[start..end].to_vec());
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("Storage key '{}' not found in TAR archive", storage_key),
|
||||
))
|
||||
}
|
||||
|
||||
/// Check if a storage key exists
|
||||
pub fn contains(&self, key: &str) -> bool {
|
||||
let storage_key = key.split('/').next_back().unwrap_or(key);
|
||||
self.storage_map.contains_key(storage_key)
|
||||
}
|
||||
|
||||
/// Get list of storage keys
|
||||
pub fn keys(&self) -> Vec<String> {
|
||||
self.storage_map.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl LazyDataSource {
|
||||
/// Create from a ZIP file
|
||||
pub fn from_zip(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
||||
Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new(
|
||||
path.as_ref().to_path_buf(),
|
||||
)?))))
|
||||
}
|
||||
|
||||
/// Create from a TAR archive's storages data
|
||||
pub fn from_tar(storages_data: &[u8]) -> std::io::Result<Self> {
|
||||
Ok(Self::Tar(Arc::new(Mutex::new(TarSource::new(
|
||||
storages_data.to_vec(),
|
||||
)?))))
|
||||
}
|
||||
|
||||
/// Create from a legacy multi-storage file
|
||||
pub fn from_legacy_multi_storage(
|
||||
path: impl AsRef<Path>,
|
||||
data_offset: u64,
|
||||
data_size: u64,
|
||||
) -> Self {
|
||||
Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new(
|
||||
path.as_ref().to_path_buf(),
|
||||
data_offset,
|
||||
data_size,
|
||||
))))
|
||||
}
|
||||
|
||||
/// Read data for a specific key
|
||||
pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::Zip(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.read_file(key)
|
||||
}
|
||||
Self::Tar(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.read_file(key)
|
||||
}
|
||||
Self::LegacyMultiStorage(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.read(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a portion of data for a specific key
|
||||
pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::Zip(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.read_file_range(key, offset, length)
|
||||
}
|
||||
Self::Tar(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.read_file_range(key, offset, length)
|
||||
}
|
||||
Self::LegacyMultiStorage(source) => {
|
||||
// For legacy format, read only the requested range
|
||||
let storage_key = key.split('/').next_back().unwrap_or(key);
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
|
||||
// Get storage boundaries
|
||||
let storage_map = source
|
||||
.storage_map
|
||||
.read()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
if let Some(ref map) = *storage_map
|
||||
&& let Some(&(storage_offset, storage_size)) = map.get(storage_key)
|
||||
{
|
||||
// Calculate actual file position
|
||||
let file_offset = source.data_offset + storage_offset + offset as u64;
|
||||
let read_length = length.min((storage_size as usize).saturating_sub(offset));
|
||||
|
||||
// Read only the requested range
|
||||
let mut file = File::open(&source.path)?;
|
||||
file.seek(std::io::SeekFrom::Start(file_offset))?;
|
||||
|
||||
let mut buffer = vec![0u8; read_length];
|
||||
file.read_exact(&mut buffer)?;
|
||||
Ok(buffer)
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Storage boundaries not available for key '{}'. Cannot perform lazy loading.",
|
||||
storage_key
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a key exists
|
||||
pub fn contains(&self, key: &str) -> bool {
|
||||
match self {
|
||||
Self::Zip(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.contains(key)
|
||||
}
|
||||
Self::Tar(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.contains(key)
|
||||
}
|
||||
Self::LegacyMultiStorage(_) => true, // Legacy format has all data
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of available keys (for ZIP sources)
|
||||
pub fn keys(&self) -> Vec<String> {
|
||||
match self {
|
||||
Self::Zip(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.data_files()
|
||||
}
|
||||
Self::Tar(source) => {
|
||||
let source = source
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
source.keys()
|
||||
}
|
||||
Self::LegacyMultiStorage(_) => vec![], // Legacy format doesn't have distinct keys
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
//! PyTorch format support for burn-store.
|
||||
//!
|
||||
//! This module provides comprehensive support for loading PyTorch model files (.pth, .pt)
|
||||
//! into Burn, with automatic weight transformation and flexible configuration options.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Direct .pth/.pt file loading**: Load PyTorch checkpoint and state dict files
|
||||
//! - **Automatic weight transformation**: `PyTorchToBurnAdapter` is applied by default:
|
||||
//! - Linear layer weights are automatically transposed
|
||||
//! - Normalization parameters are renamed (gamma → weight, beta → bias)
|
||||
//! - Conv2d weights maintain their format
|
||||
//! - **Flexible filtering**: Load only specific layers or parameters
|
||||
//! - **Key remapping**: Rename tensors during loading to match your model structure
|
||||
//! - **Partial loading**: Continue even when some tensors are missing
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use burn_store::PytorchStore;
|
||||
//!
|
||||
//! // Load a PyTorch model (PyTorchToBurnAdapter is applied automatically)
|
||||
//! let mut store = PytorchStore::from_file("model.pth")
|
||||
//! .with_top_level_key("state_dict") // Access nested state dict
|
||||
//! .with_regex(r"^encoder\..*") // Only load encoder layers
|
||||
//! .with_key_remapping(r"^fc\.", "linear.") // Rename fc -> linear
|
||||
//! .allow_partial(true); // Skip missing tensors
|
||||
//!
|
||||
//! let mut model = MyModel::new(&device);
|
||||
//! let result = model.load_from(&mut store)?;
|
||||
//!
|
||||
//! println!("Loaded {} tensors", result.applied.len());
|
||||
//! if !result.missing.is_empty() {
|
||||
//! println!("Missing tensors: {:?}", result.missing);
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod lazy_data;
|
||||
pub mod pickle_reader;
|
||||
pub mod reader;
|
||||
pub mod store;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
// Main public interface
|
||||
pub use reader::{PytorchError, PytorchReader};
|
||||
pub use store::{PytorchStore, PytorchStoreError};
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,442 @@
|
||||
//! PyTorch store implementation for saving and loading models in PyTorch format.
|
||||
|
||||
use crate::{
|
||||
ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
|
||||
TensorSnapshot, map_indices_contiguous,
|
||||
};
|
||||
|
||||
use alloc::collections::BTreeMap;
|
||||
|
||||
use alloc::format;
|
||||
use alloc::string::{String, ToString};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::Backend;
|
||||
use core::fmt;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::reader::{PytorchError as ReaderError, PytorchReader};
|
||||
|
||||
/// Errors that can occur during PyTorch operations.
|
||||
#[derive(Debug)]
|
||||
pub enum PytorchStoreError {
|
||||
/// Reader error.
|
||||
Reader(ReaderError),
|
||||
|
||||
/// I/O error.
|
||||
Io(std::io::Error),
|
||||
|
||||
/// Tensor not found.
|
||||
TensorNotFound(String),
|
||||
|
||||
/// Validation failed.
|
||||
ValidationFailed(String),
|
||||
|
||||
/// Other error.
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for PytorchStoreError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
|
||||
Self::Io(e) => write!(f, "I/O error: {}", e),
|
||||
Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
|
||||
Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
|
||||
Self::Other(msg) => write!(f, "{}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PytorchStoreError {}
|
||||
|
||||
impl From<ReaderError> for PytorchStoreError {
|
||||
fn from(e: ReaderError) -> Self {
|
||||
PytorchStoreError::Reader(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for PytorchStoreError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
PytorchStoreError::Io(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// PyTorch store for file-based storage only.
|
||||
///
|
||||
/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)
|
||||
/// with automatic weight transformation using `PyTorchToBurnAdapter`.
|
||||
/// Linear weights are automatically transposed and normalization parameters
|
||||
/// are renamed (gamma -> weight, beta -> bias).
|
||||
///
|
||||
/// Note that saving to PyTorch format is not yet supported.
|
||||
pub struct PytorchStore {
|
||||
pub(crate) path: PathBuf,
|
||||
pub(crate) filter: PathFilter,
|
||||
pub(crate) remapper: KeyRemapper,
|
||||
pub(crate) validate: bool,
|
||||
pub(crate) allow_partial: bool,
|
||||
pub(crate) top_level_key: Option<String>,
|
||||
pub(crate) skip_enum_variants: bool,
|
||||
/// Enable contiguous mapping of layer indices (default: true)
|
||||
pub(crate) map_indices_contiguous: bool,
|
||||
/// Cached tensor snapshots (parsed once, reused)
|
||||
snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
|
||||
}
|
||||
|
||||
impl PytorchStore {
|
||||
/// Create a store for loading from a PyTorch file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// use burn_store::PytorchStore;
|
||||
///
|
||||
/// let store = PytorchStore::from_file("model.pth");
|
||||
/// ```
|
||||
pub fn from_file(path: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
path: path.into(),
|
||||
filter: PathFilter::new(),
|
||||
remapper: KeyRemapper::new(),
|
||||
validate: true,
|
||||
allow_partial: false,
|
||||
top_level_key: None,
|
||||
// PyTorch models never include enum variant names in paths
|
||||
skip_enum_variants: true,
|
||||
// Enable contiguous index mapping by default for PyTorch files
|
||||
// This handles nn.Sequential models with gaps in layer indices
|
||||
map_indices_contiguous: true,
|
||||
snapshots_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a top-level key to extract tensors from.
|
||||
///
|
||||
/// PyTorch files often contain nested dictionaries. Use this to extract
|
||||
/// tensors from a specific top-level key like "state_dict" or "model_state_dict".
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// let store = PytorchStore::from_file("checkpoint.pth")
|
||||
/// .with_top_level_key("model_state_dict");
|
||||
/// ```
|
||||
pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
|
||||
self.top_level_key = Some(key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Filter which tensors to load.
|
||||
pub fn filter(mut self, filter: PathFilter) -> Self {
|
||||
self.filter = filter;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a regex pattern to filter tensors.
|
||||
///
|
||||
/// Multiple patterns can be added and they work with OR logic.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .with_regex(r"^encoder\..*") // Match all encoder tensors
|
||||
/// .with_regex(r".*\.weight$"); // OR match any weight tensors
|
||||
/// ```
|
||||
pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
|
||||
self.filter = self.filter.with_regex(pattern);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple regex patterns to filter tensors.
|
||||
pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
self.filter = self.filter.with_regexes(patterns);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an exact full path to match.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .with_full_path("encoder.layer1.weight")
|
||||
/// .with_full_path("decoder.output.bias");
|
||||
/// ```
|
||||
pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
|
||||
self.filter = self.filter.with_full_path(path);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple exact full paths to match.
|
||||
pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: Into<String>,
|
||||
{
|
||||
self.filter = self.filter.with_full_paths(paths);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a predicate function for custom filtering logic.
|
||||
///
|
||||
/// The predicate receives the tensor path and container path.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
|
||||
/// ```
|
||||
pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
|
||||
self.filter = self.filter.with_predicate(predicate);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple predicate functions.
|
||||
pub fn with_predicates<I>(mut self, predicates: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = fn(&str, &str) -> bool>,
|
||||
{
|
||||
self.filter = self.filter.with_predicates(predicates);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the filter to match all paths (disables filtering).
|
||||
pub fn match_all(mut self) -> Self {
|
||||
self.filter = self.filter.match_all();
|
||||
self
|
||||
}
|
||||
|
||||
/// Remap tensor names during load.
|
||||
pub fn remap(mut self, remapper: KeyRemapper) -> Self {
|
||||
self.remapper = remapper;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a regex pattern to remap tensor names during load.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X
|
||||
/// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight
|
||||
/// ```
|
||||
pub fn with_key_remapping(
|
||||
mut self,
|
||||
from_pattern: impl AsRef<str>,
|
||||
to_pattern: impl Into<String>,
|
||||
) -> Self {
|
||||
self.remapper = self
|
||||
.remapper
|
||||
.add_pattern(from_pattern, to_pattern)
|
||||
.expect("Invalid regex pattern");
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to validate tensors during loading (default: true).
|
||||
pub fn validate(mut self, validate: bool) -> Self {
|
||||
self.validate = validate;
|
||||
self
|
||||
}
|
||||
|
||||
/// Allow partial loading of tensors (continue even if some tensors are missing).
|
||||
pub fn allow_partial(mut self, allow: bool) -> Self {
|
||||
self.allow_partial = allow;
|
||||
self
|
||||
}
|
||||
|
||||
/// Skip enum variant names when matching tensor paths (default: true).
|
||||
///
|
||||
/// When enabled, tensor paths from PyTorch that don't include enum variants
|
||||
/// can be matched against Burn module paths that do include them.
|
||||
/// For example, PyTorch path "feature.weight" can match Burn path "feature.BaseConv.weight".
|
||||
///
|
||||
/// This defaults to `true` for PytorchStore since PyTorch models never include
|
||||
/// enum variant names in their parameter paths.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// // Disable enum variant skipping (not typical)
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .skip_enum_variants(false);
|
||||
/// ```
|
||||
pub fn skip_enum_variants(mut self, skip: bool) -> Self {
|
||||
self.skip_enum_variants = skip;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable automatic contiguous mapping of layer indices (default: true).
|
||||
///
|
||||
/// When enabled, non-contiguous numeric indices in tensor paths are renumbered
|
||||
/// to be contiguous. This is useful when loading PyTorch models that have gaps
|
||||
/// in layer numbering, such as when using `nn.Sequential` with mixed layer types
|
||||
/// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// With index mapping enabled (default):
|
||||
/// - `fc.0.weight` → `fc.0.weight`
|
||||
/// - `fc.2.weight` → `fc.1.weight` (gap filled)
|
||||
/// - `fc.4.weight` → `fc.2.weight` (gap filled)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `map` - `true` to enable contiguous index mapping, `false` to disable
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # use burn_store::PytorchStore;
|
||||
/// // Disable contiguous index mapping if your model already has contiguous indices
|
||||
/// let store = PytorchStore::from_file("model.pth")
|
||||
/// .map_indices_contiguous(false);
|
||||
/// ```
|
||||
pub fn map_indices_contiguous(mut self, map: bool) -> Self {
|
||||
self.map_indices_contiguous = map;
|
||||
self
|
||||
}
|
||||
|
||||
/// Apply remapping to tensor snapshots.
|
||||
fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
|
||||
if self.remapper.is_empty() {
|
||||
return snapshots;
|
||||
}
|
||||
|
||||
let (remapped, _) = self.remapper.remap(snapshots);
|
||||
remapped
|
||||
}
|
||||
|
||||
/// Create a PytorchReader for the configured path and options.
|
||||
fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {
|
||||
let reader = if let Some(ref key) = self.top_level_key {
|
||||
PytorchReader::with_top_level_key(&self.path, key)?
|
||||
} else {
|
||||
PytorchReader::new(&self.path)?
|
||||
};
|
||||
Ok(reader)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleStore for PytorchStore {
|
||||
type Error = PytorchStoreError;
|
||||
|
||||
fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
|
||||
&mut self,
|
||||
_module: &M,
|
||||
) -> Result<(), Self::Error> {
|
||||
// Saving to PyTorch format is not yet supported
|
||||
Err(PytorchStoreError::Other(
|
||||
"Saving to PyTorch format is not yet supported. Use other formats for saving."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
|
||||
&mut self,
|
||||
module: &mut M,
|
||||
) -> Result<ApplyResult, Self::Error> {
|
||||
// Get snapshots from cache
|
||||
let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
|
||||
|
||||
// Get filter (convert to Option for apply)
|
||||
let filter_opt = if self.filter.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.filter.clone())
|
||||
};
|
||||
|
||||
// Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)
|
||||
// This adapter handles:
|
||||
// - Transposing linear weights from PyTorch format to Burn format
|
||||
// - Renaming normalization parameters (gamma -> weight, beta -> bias)
|
||||
// Filter is applied here during apply, not during cache population
|
||||
let result = module.apply(
|
||||
snapshots,
|
||||
filter_opt,
|
||||
Some(Box::new(PyTorchToBurnAdapter)),
|
||||
self.skip_enum_variants,
|
||||
);
|
||||
|
||||
// Validate if needed
|
||||
if self.validate && !result.errors.is_empty() {
|
||||
return Err(PytorchStoreError::ValidationFailed(format!(
|
||||
"Import errors:\n{}",
|
||||
result
|
||||
)));
|
||||
}
|
||||
|
||||
if !self.allow_partial && !result.missing.is_empty() {
|
||||
return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result)));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
|
||||
self.ensure_snapshots_cache()?;
|
||||
Ok(self.snapshots_cache.as_ref().unwrap().get(name))
|
||||
}
|
||||
|
||||
fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
|
||||
self.ensure_snapshots_cache()?;
|
||||
Ok(self.snapshots_cache.as_ref().unwrap())
|
||||
}
|
||||
|
||||
fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
|
||||
// Always use the cache to ensure remapping is applied consistently
|
||||
Ok(self.get_all_snapshots()?.keys().cloned().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl PytorchStore {
|
||||
/// Ensure the snapshots cache is populated
|
||||
fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {
|
||||
if self.snapshots_cache.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let reader = self.create_reader()?;
|
||||
|
||||
// Convert to tensor snapshots
|
||||
let mut snapshots: Vec<TensorSnapshot> = reader
|
||||
.into_tensors()
|
||||
.into_iter()
|
||||
.map(|(key, mut snapshot)| {
|
||||
// Parse the key into path parts (split by '.')
|
||||
let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
|
||||
|
||||
// Set the path stack from the key
|
||||
snapshot.path_stack = Some(path_parts);
|
||||
snapshot.container_stack = None;
|
||||
snapshot.tensor_id = None;
|
||||
|
||||
snapshot
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Apply remapping (but NOT filtering - that's done at apply time)
|
||||
snapshots = self.apply_remapping(snapshots);
|
||||
|
||||
// Apply contiguous index mapping if enabled
|
||||
// This must be done after remapping so that remapped paths are mapped
|
||||
if self.map_indices_contiguous {
|
||||
let (mapped, _) = map_indices_contiguous(snapshots);
|
||||
snapshots = mapped;
|
||||
}
|
||||
|
||||
// Build cache as BTreeMap
|
||||
let cache: BTreeMap<String, TensorSnapshot> =
|
||||
snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
|
||||
|
||||
self.snapshots_cache = Some(cache);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod reader;
|
||||
pub mod store;
|
||||
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# dependencies = ["torch"]
|
||||
# ///
|
||||
"""Create a legacy format PyTorch file with specific storage offsets to test offset handling."""
|
||||
|
||||
import torch
|
||||
|
||||
# Create tensors with known values at specific storage offsets
|
||||
# This will help us verify we're reading from the correct location
|
||||
|
||||
# Create a state dict with tensors that share storage
|
||||
# This is common in PyTorch models (e.g., weight and transposed weight views)
|
||||
state_dict = {}
|
||||
|
||||
# Create a base tensor with known pattern
|
||||
base_data = torch.arange(100, dtype=torch.float32)
|
||||
|
||||
# tensor1: uses elements 10-19 (offset 10*4 = 40 bytes)
|
||||
tensor1 = base_data[10:20].clone()
|
||||
tensor1[:] = torch.arange(1.0, 1.1, 0.01)[:10] # 1.00, 1.01, 1.02, ...
|
||||
|
||||
# tensor2: uses elements 30-35 (offset 30*4 = 120 bytes)
|
||||
tensor2 = base_data[30:35].clone()
|
||||
tensor2[:] = torch.arange(2.0, 2.5, 0.1)[:5] # 2.0, 2.1, 2.2, 2.3, 2.4
|
||||
|
||||
# tensor3: starts at beginning (offset 0)
|
||||
tensor3 = base_data[:5].clone()
|
||||
tensor3[:] = torch.arange(3.0, 3.5, 0.1)[:5] # 3.0, 3.1, 3.2, 3.3, 3.4
|
||||
|
||||
state_dict['tensor1'] = tensor1
|
||||
state_dict['tensor2'] = tensor2
|
||||
state_dict['tensor3'] = tensor3
|
||||
|
||||
# Save in legacy format
|
||||
output_file = 'test_data/legacy_with_offsets.pt'
|
||||
torch.save(state_dict, output_file, _use_new_zipfile_serialization=False)
|
||||
|
||||
print(f"Created {output_file}")
|
||||
|
||||
# Verify by loading
|
||||
loaded = torch.load(output_file, weights_only=False)
|
||||
print("\nVerification - expected values:")
|
||||
for key, tensor in loaded.items():
|
||||
print(f" {key}: {tensor.tolist()}")
|
||||
print(f" Storage offset: {tensor.storage_offset()}")
|
||||
print(f" Storage size: {len(tensor.storage())}")
|
||||
|
||||
# Also create a test with multiple tensors sharing the same storage
|
||||
# This is important for proper offset handling
|
||||
shared_storage = torch.randn(1000)
|
||||
|
||||
# Create views into the same storage at different offsets
|
||||
view1 = shared_storage[100:110] # offset 100
|
||||
view2 = shared_storage[500:520] # offset 500
|
||||
view3 = shared_storage[0:10] # offset 0
|
||||
|
||||
# Need to save these properly - PyTorch will handle the storage sharing
|
||||
shared_dict = {
|
||||
'view1': view1.clone(), # Clone to avoid view issues
|
||||
'view2': view2.clone(),
|
||||
'view3': view3.clone(),
|
||||
}
|
||||
|
||||
output_file2 = 'test_data/legacy_shared_storage.pt'
|
||||
torch.save(shared_dict, output_file2, _use_new_zipfile_serialization=False)
|
||||
print(f"\nCreated {output_file2}")
|
||||
|
||||
# Print exact values for test verification
|
||||
print("\nExact test values for legacy_with_offsets.pt:")
|
||||
print("tensor1 (10 elements starting at 1.0):")
|
||||
print(" First 3 values: [1.00, 1.01, 1.02]")
|
||||
print("tensor2 (5 elements starting at 2.0):")
|
||||
print(" All values: [2.0, 2.1, 2.2, 2.3, 2.4]")
|
||||
print("tensor3 (5 elements starting at 3.0):")
|
||||
print(" All values: [3.0, 3.1, 3.2, 3.3, 3.4]")
|
||||
@@ -0,0 +1,361 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Create TAR format test fixtures for burn-store integration tests.
|
||||
|
||||
The TAR format was used by very early versions of PyTorch (pre 0.1.10).
|
||||
Modern torch.save cannot create this format, so we construct it manually.
|
||||
|
||||
TAR format structure:
|
||||
- sys_info: pickle with {protocol_version, little_endian, type_sizes}
|
||||
- pickle: pickle with OrderedDict containing _rebuild_tensor_v2 REDUCE calls
|
||||
- storages: count_pickle + for each storage: (key, device, class) pickle + u64 num_elements + raw data
|
||||
"""
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import struct
|
||||
import tarfile
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_sys_info():
|
||||
"""Create sys_info pickle data."""
|
||||
sys_info = {
|
||||
"protocol_version": 1000,
|
||||
"little_endian": True,
|
||||
"type_sizes": {
|
||||
"short": 2,
|
||||
"int": 4,
|
||||
"long": 8,
|
||||
},
|
||||
}
|
||||
return pickle.dumps(sys_info, protocol=2)
|
||||
|
||||
|
||||
def encode_tensor_data(values: list, storage_type: str) -> tuple:
|
||||
"""Encode tensor values to bytes and return (bytes, element_size)."""
|
||||
fmt_map = {
|
||||
"FloatStorage": ("<f", 4),
|
||||
"DoubleStorage": ("<d", 8),
|
||||
"LongStorage": ("<q", 8),
|
||||
"IntStorage": ("<i", 4),
|
||||
"ShortStorage": ("<h", 2),
|
||||
"ByteStorage": ("<B", 1),
|
||||
"CharStorage": ("<b", 1),
|
||||
"BoolStorage": ("<B", 1),
|
||||
"HalfStorage": ("<e", 2),
|
||||
}
|
||||
fmt, size = fmt_map[storage_type]
|
||||
data = b"".join(struct.pack(fmt, v) for v in values)
|
||||
return data, size
|
||||
|
||||
|
||||
def write_int(buffer, value):
|
||||
"""Write an integer using appropriate pickle opcode."""
|
||||
if 0 <= value < 256:
|
||||
buffer.write(b'K') # BININT1
|
||||
buffer.write(bytes([value]))
|
||||
elif 0 <= value < 65536:
|
||||
buffer.write(b'M') # BININT2
|
||||
buffer.write(struct.pack('<H', value))
|
||||
else:
|
||||
buffer.write(b'J') # BININT
|
||||
buffer.write(struct.pack('<i', value))
|
||||
|
||||
|
||||
def write_string(buffer, s):
|
||||
"""Write a string using appropriate pickle opcode."""
|
||||
s_bytes = s.encode('utf-8')
|
||||
if len(s_bytes) < 256:
|
||||
buffer.write(b'U') # SHORT_BINSTRING
|
||||
buffer.write(bytes([len(s_bytes)]))
|
||||
buffer.write(s_bytes)
|
||||
else:
|
||||
buffer.write(b'T') # BINSTRING
|
||||
buffer.write(struct.pack('<I', len(s_bytes)))
|
||||
buffer.write(s_bytes)
|
||||
|
||||
|
||||
def create_storages_blob_manual(tensors: list) -> bytes:
|
||||
"""
|
||||
Create the storages binary blob manually.
|
||||
|
||||
Args:
|
||||
tensors: List of (key, storage_type, element_size, data_bytes) tuples
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
# Write storage count as pickle (simple integer)
|
||||
pickle.dump(len(tensors), buffer, protocol=2)
|
||||
|
||||
for key, storage_type, element_size, data_bytes in tensors:
|
||||
# Manually construct the tuple pickle with GLOBAL class reference
|
||||
# Format: (key, "cpu", <class 'torch.FloatStorage'>)
|
||||
|
||||
tuple_buffer = io.BytesIO()
|
||||
# Protocol 2 header
|
||||
tuple_buffer.write(b'\x80\x02')
|
||||
|
||||
# Build tuple with MARK + items + TUPLE
|
||||
tuple_buffer.write(b'(') # MARK
|
||||
|
||||
# First item: storage key (string)
|
||||
write_string(tuple_buffer, key)
|
||||
|
||||
# Second item: device "cpu"
|
||||
tuple_buffer.write(b'U\x03cpu')
|
||||
|
||||
# Third item: class reference using GLOBAL
|
||||
tuple_buffer.write(b'c') # GLOBAL opcode
|
||||
tuple_buffer.write(b'torch\n') # module
|
||||
tuple_buffer.write(storage_type.encode('ascii') + b'\n') # name
|
||||
|
||||
# End tuple
|
||||
tuple_buffer.write(b't') # TUPLE
|
||||
tuple_buffer.write(b'.') # STOP
|
||||
|
||||
buffer.write(tuple_buffer.getvalue())
|
||||
|
||||
# Write num_elements as u64 little-endian
|
||||
num_elements = len(data_bytes) // element_size
|
||||
buffer.write(struct.pack("<Q", num_elements))
|
||||
|
||||
# Write raw data
|
||||
buffer.write(data_bytes)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def create_main_pickle_manual(tensors_info: list) -> bytes:
|
||||
"""
|
||||
Create the main pickle containing _rebuild_tensor_v2 REDUCE calls.
|
||||
|
||||
For each tensor, we need:
|
||||
- GLOBAL torch._utils _rebuild_tensor_v2
|
||||
- MARK
|
||||
- args tuple: (persistent_id, offset, shape, stride, requires_grad, hooks)
|
||||
- TUPLE
|
||||
- REDUCE
|
||||
|
||||
The persistent_id is a PersistentTuple: ('storage', <class>, key, device, num_elements)
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
# Protocol 2 header
|
||||
buffer.write(b'\x80\x02')
|
||||
|
||||
# Build OrderedDict: GLOBAL + EMPTY_LIST + items + TUPLE + REDUCE
|
||||
# OrderedDict([('name1', tensor1), ('name2', tensor2)])
|
||||
|
||||
# GLOBAL collections OrderedDict
|
||||
buffer.write(b'ccollections\nOrderedDict\n')
|
||||
|
||||
# Start list for items
|
||||
buffer.write(b'(') # MARK
|
||||
buffer.write(b']') # EMPTY_LIST
|
||||
|
||||
# For each tensor, add (name, rebuilt_tensor) to the list
|
||||
for name, storage_key, storage_type, shape, num_elements in tensors_info:
|
||||
# Calculate stride for row-major (C) order
|
||||
stride = []
|
||||
s = 1
|
||||
for dim in reversed(shape):
|
||||
stride.insert(0, s)
|
||||
s *= dim
|
||||
|
||||
# Build inner tuple: (name, tensor_value)
|
||||
buffer.write(b'(') # MARK for (name, value) tuple
|
||||
|
||||
# Write name
|
||||
write_string(buffer, name)
|
||||
|
||||
# Now build the tensor using _rebuild_tensor_v2 REDUCE
|
||||
# GLOBAL torch._utils _rebuild_tensor_v2
|
||||
buffer.write(b'ctorch._utils\n_rebuild_tensor_v2\n')
|
||||
|
||||
# Build args tuple for _rebuild_tensor_v2
|
||||
# (persistent_id, offset, shape, stride, requires_grad, backward_hooks)
|
||||
buffer.write(b'(') # MARK for args tuple
|
||||
|
||||
# arg 0: persistent_id tuple: ('storage', class, key, device, num_elements)
|
||||
# This will be converted to PersistentTuple by the reader
|
||||
buffer.write(b'(') # MARK for persistent_id
|
||||
|
||||
write_string(buffer, 'storage')
|
||||
|
||||
# Class reference - GLOBAL torch FloatStorage
|
||||
buffer.write(b'c')
|
||||
buffer.write(b'torch\n')
|
||||
buffer.write(storage_type.encode('ascii') + b'\n')
|
||||
|
||||
# Storage key
|
||||
write_string(buffer, storage_key)
|
||||
|
||||
# Device
|
||||
buffer.write(b'U\x03cpu')
|
||||
|
||||
# num_elements
|
||||
write_int(buffer, num_elements)
|
||||
|
||||
buffer.write(b't') # TUPLE - end persistent_id
|
||||
|
||||
# arg 1: storage offset (0)
|
||||
buffer.write(b'K\x00')
|
||||
|
||||
# arg 2: shape tuple
|
||||
buffer.write(b'(')
|
||||
for dim in shape:
|
||||
write_int(buffer, dim)
|
||||
buffer.write(b't')
|
||||
|
||||
# arg 3: stride tuple
|
||||
buffer.write(b'(')
|
||||
for s_val in stride:
|
||||
write_int(buffer, s_val)
|
||||
buffer.write(b't')
|
||||
|
||||
# arg 4: requires_grad (False)
|
||||
buffer.write(b'\x89') # NEWFALSE
|
||||
|
||||
# arg 5: backward_hooks (empty OrderedDict)
|
||||
buffer.write(b'ccollections\nOrderedDict\n')
|
||||
buffer.write(b'(')
|
||||
buffer.write(b']')
|
||||
buffer.write(b't')
|
||||
buffer.write(b'R') # REDUCE to create empty OrderedDict
|
||||
|
||||
buffer.write(b't') # TUPLE - end args tuple
|
||||
|
||||
buffer.write(b'R') # REDUCE - call _rebuild_tensor_v2 with args
|
||||
|
||||
buffer.write(b't') # TUPLE - end (name, tensor) tuple
|
||||
|
||||
buffer.write(b'a') # APPEND to list
|
||||
|
||||
buffer.write(b't') # TUPLE - wrap list in tuple for REDUCE
|
||||
buffer.write(b'R') # REDUCE - call OrderedDict with the list
|
||||
buffer.write(b'.') # STOP
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def create_tar_pytorch_file(filename: str, tensors: dict, dtypes: dict):
|
||||
"""
|
||||
Create a TAR format PyTorch file.
|
||||
|
||||
Args:
|
||||
filename: Output file path
|
||||
tensors: Dict of tensor_name -> (values_list, shape)
|
||||
dtypes: Dict of tensor_name -> storage_type
|
||||
"""
|
||||
# Prepare storage data
|
||||
storage_list = [] # (key, storage_type, element_size, data_bytes)
|
||||
tensors_info = [] # (name, storage_key, storage_type, shape, num_elements)
|
||||
|
||||
for idx, (name, (values, shape)) in enumerate(tensors.items()):
|
||||
storage_key = str(idx)
|
||||
storage_type = dtypes[name]
|
||||
data_bytes, element_size = encode_tensor_data(values, storage_type)
|
||||
num_elements = len(values)
|
||||
|
||||
storage_list.append((storage_key, storage_type, element_size, data_bytes))
|
||||
tensors_info.append((name, storage_key, storage_type, shape, num_elements))
|
||||
|
||||
# Create the three main entries
|
||||
sys_info_data = create_sys_info()
|
||||
pickle_data = create_main_pickle_manual(tensors_info)
|
||||
storages_data = create_storages_blob_manual(storage_list)
|
||||
|
||||
# Write TAR archive
|
||||
os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)
|
||||
|
||||
with tarfile.open(filename, "w") as tar:
|
||||
# Add sys_info
|
||||
tarinfo = tarfile.TarInfo(name="sys_info")
|
||||
tarinfo.size = len(sys_info_data)
|
||||
tar.addfile(tarinfo, io.BytesIO(sys_info_data))
|
||||
|
||||
# Add pickle
|
||||
tarinfo = tarfile.TarInfo(name="pickle")
|
||||
tarinfo.size = len(pickle_data)
|
||||
tar.addfile(tarinfo, io.BytesIO(pickle_data))
|
||||
|
||||
# Add storages
|
||||
tarinfo = tarfile.TarInfo(name="storages")
|
||||
tarinfo.size = len(storages_data)
|
||||
tar.addfile(tarinfo, io.BytesIO(storages_data))
|
||||
|
||||
size = os.path.getsize(filename)
|
||||
print(f"Created {filename} ({size} bytes)")
|
||||
print(f" Tensors: {list(tensors.keys())}")
|
||||
|
||||
|
||||
def main():
|
||||
# Create test_data directory
|
||||
os.makedirs("test_data", exist_ok=True)
|
||||
|
||||
# Test 1: Single float32 tensor
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_float32.tar",
|
||||
{"tensor": ([1.0, 2.5, -3.7, 0.0], [4])},
|
||||
{"tensor": "FloatStorage"},
|
||||
)
|
||||
|
||||
# Test 2: Single float64 tensor
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_float64.tar",
|
||||
{"tensor": ([1.1, 2.2, 3.3], [3])},
|
||||
{"tensor": "DoubleStorage"},
|
||||
)
|
||||
|
||||
# Test 3: Single int64 tensor
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_int64.tar",
|
||||
{"tensor": ([100, -200, 300, 0], [4])},
|
||||
{"tensor": "LongStorage"},
|
||||
)
|
||||
|
||||
# Test 4: Multiple tensors (weight + bias)
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_weight_bias.tar",
|
||||
{
|
||||
"weight": ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [2, 3]),
|
||||
"bias": ([0.01, 0.02], [2]),
|
||||
},
|
||||
{
|
||||
"weight": "FloatStorage",
|
||||
"bias": "FloatStorage",
|
||||
},
|
||||
)
|
||||
|
||||
# Test 5: Different dtypes in one file
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_multi_dtype.tar",
|
||||
{
|
||||
"float_tensor": ([1.5, 2.5, 3.5], [3]),
|
||||
"double_tensor": ([1.111, 2.222], [2]),
|
||||
"int_tensor": ([10, 20, 30, 40], [4]),
|
||||
},
|
||||
{
|
||||
"float_tensor": "FloatStorage",
|
||||
"double_tensor": "DoubleStorage",
|
||||
"int_tensor": "LongStorage",
|
||||
},
|
||||
)
|
||||
|
||||
# Test 6: 2D tensor for shape verification
|
||||
create_tar_pytorch_file(
|
||||
"test_data/tar_2d_tensor.tar",
|
||||
{
|
||||
"matrix": ([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [3, 4]),
|
||||
},
|
||||
{"matrix": "FloatStorage"},
|
||||
)
|
||||
|
||||
print("\nAll TAR format test files created!")
|
||||
print("\nTo run tests: cargo test -p burn-store --features pytorch test_tar")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# dependencies = ["torch"]
|
||||
# ///
|
||||
"""Create a simple legacy format PyTorch file."""
|
||||
|
||||
import torch
|
||||
|
||||
# Create a simple state dict
|
||||
state_dict = {
|
||||
'weight': torch.randn(2, 3),
|
||||
'bias': torch.ones(2),
|
||||
'running_mean': torch.zeros(2),
|
||||
}
|
||||
|
||||
# Save without using zip format (legacy format)
|
||||
torch.save(state_dict, 'test_data/simple_legacy.pt', _use_new_zipfile_serialization=False)
|
||||
|
||||
print("Created simple_legacy.pt")
|
||||
|
||||
# Verify
|
||||
loaded = torch.load('test_data/simple_legacy.pt', weights_only=False)
|
||||
print(f"Loaded {len(loaded)} tensors")
|
||||
for key, val in loaded.items():
|
||||
print(f" {key}: shape {val.shape}, dtype {val.dtype}")
|
||||
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# dependencies = ["torch", "numpy"]
|
||||
# ///
|
||||
"""
|
||||
Generate test PyTorch .pt files for testing the burn-store PyTorch reader.
|
||||
Run with: uv run test_files.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Create test directory
|
||||
test_dir = Path(__file__).parent / "test_data"
|
||||
test_dir.mkdir(exist_ok=True)
|
||||
|
||||
def save_test_file(filename, data, description):
|
||||
"""Save a test file and print what was saved."""
|
||||
filepath = test_dir / filename
|
||||
torch.save(data, filepath)
|
||||
print(f"✓ {filename}: {description}")
|
||||
return filepath
|
||||
|
||||
# Test 1: Simple tensors of different types
|
||||
print("\n=== Generating Basic Tensor Tests ===")
|
||||
|
||||
# Float32 tensor (wrap in dict for compatibility)
|
||||
float32_tensor = torch.tensor([1.0, 2.5, -3.7, 0.0], dtype=torch.float32)
|
||||
save_test_file("float32.pt", {"tensor": float32_tensor}, "Float32 tensor [1.0, 2.5, -3.7, 0.0]")
|
||||
|
||||
# Float64 tensor
|
||||
float64_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float64)
|
||||
save_test_file("float64.pt", {"tensor": float64_tensor}, "Float64 tensor [1.1, 2.2, 3.3]")
|
||||
|
||||
# Int64 tensor
|
||||
int64_tensor = torch.tensor([100, -200, 300, 0], dtype=torch.int64)
|
||||
save_test_file("int64.pt", {"tensor": int64_tensor}, "Int64 tensor [100, -200, 300, 0]")
|
||||
|
||||
# Int32 tensor
|
||||
int32_tensor = torch.tensor([10, 20, -30], dtype=torch.int32)
|
||||
save_test_file("int32.pt", {"tensor": int32_tensor}, "Int32 tensor [10, 20, -30]")
|
||||
|
||||
# Int16 tensor
|
||||
int16_tensor = torch.tensor([1000, -2000, 3000], dtype=torch.int16)
|
||||
save_test_file("int16.pt", {"tensor": int16_tensor}, "Int16 tensor [1000, -2000, 3000]")
|
||||
|
||||
# Int8 tensor
|
||||
int8_tensor = torch.tensor([127, -128, 0, 50], dtype=torch.int8)
|
||||
save_test_file("int8.pt", {"tensor": int8_tensor}, "Int8 tensor [127, -128, 0, 50]")
|
||||
|
||||
# Boolean tensor
|
||||
bool_tensor = torch.tensor([True, False, True, True, False], dtype=torch.bool)
|
||||
save_test_file("bool.pt", {"tensor": bool_tensor}, "Bool tensor [True, False, True, True, False]")
|
||||
|
||||
# Float16 tensor (half precision)
|
||||
float16_tensor = torch.tensor([1.5, -2.25, 3.125], dtype=torch.float16)
|
||||
save_test_file("float16.pt", {"tensor": float16_tensor}, "Float16 tensor [1.5, -2.25, 3.125]")
|
||||
|
||||
# BFloat16 tensor
|
||||
bfloat16_tensor = torch.tensor([1.5, -2.5, 3.5], dtype=torch.bfloat16)
|
||||
save_test_file("bfloat16.pt", {"tensor": bfloat16_tensor}, "BFloat16 tensor [1.5, -2.5, 3.5]")
|
||||
|
||||
# UInt8 tensor
|
||||
uint8_tensor = torch.tensor([0, 128, 255, 42], dtype=torch.uint8)
|
||||
save_test_file("uint8.pt", {"tensor": uint8_tensor}, "UInt8 tensor [0, 128, 255, 42]")
|
||||
|
||||
# Test 2: Multi-dimensional tensors
|
||||
print("\n=== Generating Multi-dimensional Tensor Tests ===")
|
||||
|
||||
# 2D tensor
|
||||
tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
|
||||
save_test_file("tensor_2d.pt", {"tensor": tensor_2d}, "2D tensor shape (3, 2)")
|
||||
|
||||
# 3D tensor
|
||||
torch.manual_seed(42)
|
||||
tensor_3d = torch.randn(2, 3, 4) * 10
|
||||
save_test_file("tensor_3d.pt", {"tensor": tensor_3d}, "3D tensor shape (2, 3, 4)")
|
||||
|
||||
# 4D tensor (common for conv weights)
|
||||
tensor_4d = torch.randn(2, 3, 2, 2)
|
||||
save_test_file("tensor_4d.pt", {"tensor": tensor_4d}, "4D tensor shape (2, 3, 2, 2)")
|
||||
|
||||
# Test 3: State dict (multiple tensors)
|
||||
print("\n=== Generating State Dict Tests ===")
|
||||
|
||||
state_dict = {
|
||||
"weight": torch.randn(3, 4),
|
||||
"bias": torch.randn(3),
|
||||
"running_mean": torch.zeros(3),
|
||||
"running_var": torch.ones(3),
|
||||
}
|
||||
save_test_file("state_dict.pt", state_dict, "State dict with 4 tensors")
|
||||
|
||||
# Nested state dict
|
||||
nested_dict = {
|
||||
"layer1": {
|
||||
"weight": torch.randn(2, 3),
|
||||
"bias": torch.randn(2)
|
||||
},
|
||||
"layer2": {
|
||||
"weight": torch.randn(4, 2),
|
||||
"bias": torch.randn(4)
|
||||
}
|
||||
}
|
||||
save_test_file("nested_dict.pt", nested_dict, "Nested state dict")
|
||||
|
||||
# Test 4: Model checkpoint format
|
||||
print("\n=== Generating Model Checkpoint Tests ===")
|
||||
|
||||
# Typical checkpoint format (use string keys for compatibility)
|
||||
checkpoint = {
|
||||
"model_state_dict": {
|
||||
"fc1.weight": torch.randn(10, 5),
|
||||
"fc1.bias": torch.randn(10),
|
||||
"fc2.weight": torch.randn(3, 10),
|
||||
"fc2.bias": torch.randn(3),
|
||||
},
|
||||
"optimizer_state_dict": {
|
||||
"state": {
|
||||
"0": { # Use string key instead of integer
|
||||
"momentum_buffer": torch.randn(10, 5)
|
||||
}
|
||||
}
|
||||
},
|
||||
"epoch": 42,
|
||||
"loss": 0.123
|
||||
}
|
||||
save_test_file("checkpoint.pt", checkpoint, "Full checkpoint with model and optimizer state")
|
||||
|
||||
# Test 5: Edge cases
|
||||
print("\n=== Generating Edge Case Tests ===")
|
||||
|
||||
# Empty tensor (1D with 0 elements)
|
||||
empty_tensor = torch.zeros(0)
|
||||
save_test_file("empty.pt", {"tensor": empty_tensor}, "Empty tensor")
|
||||
|
||||
# Scalar tensor (0-dimensional)
|
||||
scalar_tensor = torch.tensor(42.0)
|
||||
save_test_file("scalar.pt", {"tensor": scalar_tensor}, "Scalar tensor (0-dim)")
|
||||
|
||||
# Large shape but small data (testing shape vs actual data)
|
||||
sparse_like = torch.zeros(100, 100)
|
||||
sparse_like[0, 0] = 1.0
|
||||
sparse_like[50, 50] = 2.0
|
||||
sparse_like[99, 99] = 3.0
|
||||
save_test_file("large_shape.pt", {"tensor": sparse_like}, "Large shape (100, 100) mostly zeros")
|
||||
|
||||
# Test 6: Mixed types in dict
|
||||
print("\n=== Generating Mixed Type Tests ===")
|
||||
|
||||
mixed_types = {
|
||||
"float32": torch.tensor([1.0, 2.0], dtype=torch.float32),
|
||||
"int64": torch.tensor([100, 200], dtype=torch.int64),
|
||||
"bool": torch.tensor([True, False], dtype=torch.bool),
|
||||
"float64": torch.tensor([1.1, 2.2], dtype=torch.float64),
|
||||
}
|
||||
save_test_file("mixed_types.pt", mixed_types, "Dict with mixed tensor types")
|
||||
|
||||
# Test 7: Special values
|
||||
print("\n=== Generating Special Value Tests ===")
|
||||
|
||||
# NaN and Inf values
|
||||
special_values = torch.tensor([float('nan'), float('inf'), float('-inf'), 0.0, 1.0])
|
||||
save_test_file("special_values.pt", {"tensor": special_values}, "Tensor with NaN and Inf")
|
||||
|
||||
# Very small and very large values
|
||||
extreme_values = torch.tensor([1e-30, 1e30, -1e-30, -1e30], dtype=torch.float32)
|
||||
save_test_file("extreme_values.pt", {"tensor": extreme_values}, "Tensor with extreme values")
|
||||
|
||||
# Test 8: Parameter wrapper (common in models)
|
||||
print("\n=== Generating Parameter Tests ===")
|
||||
|
||||
import torch.nn as nn
|
||||
param = nn.Parameter(torch.randn(3, 3))
|
||||
param_dict = {"param": param}
|
||||
save_test_file("parameter.pt", param_dict, "nn.Parameter wrapped tensor")
|
||||
|
||||
# Test 9: Buffer-style tensors
|
||||
print("\n=== Generating Buffer Tests ===")
|
||||
|
||||
# Simulate model buffers
|
||||
buffers = {
|
||||
"buffer1": torch.tensor([1, 2, 3], dtype=torch.int32),
|
||||
"buffer2": torch.tensor([True, False], dtype=torch.bool),
|
||||
}
|
||||
save_test_file("buffers.pt", buffers, "Model buffers")
|
||||
|
||||
# Test 10: Complex nested structure
|
||||
print("\n=== Generating Complex Structure Tests ===")
|
||||
|
||||
complex_structure = {
|
||||
"metadata": {
|
||||
"version": 1,
|
||||
"name": "test_model"
|
||||
},
|
||||
"state": {
|
||||
"encoder": {
|
||||
"layer_0": {
|
||||
"weight": torch.randn(4, 3),
|
||||
"bias": torch.randn(4)
|
||||
},
|
||||
"layer_1": {
|
||||
"weight": torch.randn(2, 4),
|
||||
"bias": torch.randn(2)
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"weight": torch.randn(3, 2),
|
||||
"bias": torch.randn(3)
|
||||
}
|
||||
},
|
||||
"config": {
|
||||
"hidden_size": 4,
|
||||
"num_layers": 2
|
||||
}
|
||||
}
|
||||
save_test_file("complex_structure.pt", complex_structure, "Complex nested structure")
|
||||
|
||||
print(f"\n✅ Generated {len(list(test_dir.glob('*.pt')))} test files in {test_dir}")
|
||||
print("\nTest files can be used to verify PyTorch reader functionality:")
|
||||
print("- Different data types (float32, int64, bool, etc.)")
|
||||
print("- Multi-dimensional tensors")
|
||||
print("- State dicts and nested structures")
|
||||
print("- Edge cases (empty, scalar, special values)")
|
||||
print("- Model checkpoints and parameters")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user