feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,13 @@
[package]
name = "pytorch-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-store = { path = "../", features = ["std", "pytorch"] }
serde = { workspace = true }
float-cmp = { workspace = true }

View File

@@ -0,0 +1 @@
pub type TestBackend = burn_ndarray::NdArray<f32>;

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.BatchNorm2d(5)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
# Condition batch norm (each forward will affect the running stats)
x1 = torch.ones(1, 5, 2, 2) - 0.5
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model after the first forward
torch.save(model.state_dict(), "batch_norm2d.pt")
x2 = torch.ones(1, 5, 2, 2) - 0.3
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,62 @@
use burn::{
module::Module,
nn::{BatchNorm, BatchNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: BatchNorm<B>,
}
impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
norm1: BatchNormConfig::new(5).init(device), // Python model uses BatchNorm2d(5)
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
#[test]
fn batch_norm2d() {
let device = Default::default();
let mut model = Net::<TestBackend>::new(&device);
let mut store = PytorchStore::from_file("tests/batch_norm/batch_norm2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 5, 2, 2], &device) - 0.3;
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([True, False, True])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "boolean.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,58 @@
use burn::{
module::{Module, Param, ParamId},
tensor::{Bool, Tensor, TensorData, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Bool>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::initialized(
ParamId::new(),
Tensor::from_bool(TensorData::from([false, false, false]), device),
),
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::TensorData;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
use crate::backend::TestBackend;
#[test]
fn boolean() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/boolean/boolean.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 1, Bool>::from_bool(
TensorData::from([true, false, true]),
&device,
);
assert_eq!(output.to_data(), expected.to_data());
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.ones(3, 3)
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer + x
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "buffer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,53 @@
use burn::{
module::{Module, Param},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 2>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::from_tensor(Tensor::zeros([3, 3], device)),
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
self.buffer.val() + x
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
#[test]
fn buffer() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/buffer/buffer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 2>::ones([3, 3], &device) * 2.0;
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.norm = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_blocks = nn.Sequential(
ConvBlock(2, 4, (3, 2)),
ConvBlock(4, 6, (3, 2)),
)
self.norm1 = nn.BatchNorm2d(6)
self.fc1 = nn.Linear(120, 12)
self.fc2 = nn.Linear(12, 10)
def forward(self, x):
x = self.conv_blocks(x)
x = self.norm1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(2)
model = Net().to(torch.device("cpu"))
# Condition the model (batch norm requires a forward pass to compute the mean and variance)
x1 = torch.ones(1, 2, 9, 6) - 0.1
x2 = torch.ones(1, 2, 9, 6) - 0.3
output = model(x1)
output = model(x2)
model.eval() # set to eval mode
torch.save(model.state_dict(), "complex_nested.pt")
# feed test data
x = torch.ones(1, 2, 9, 6) - 0.5
output = model(x)
print("Input shape: {}", x.shape)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,240 @@
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use burn::{
module::Module,
nn::{
BatchNorm, BatchNormConfig, Linear, LinearConfig,
conv::{Conv2d, Conv2dConfig},
},
tensor::{
Tensor,
activation::{log_softmax, relu},
backend::Backend,
},
};
use burn_autodiff::Autodiff;
use burn_store::{ModuleSnapshot, PytorchStore};
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv_blocks: Vec<ConvBlock<B>>,
norm1: BatchNorm<B>,
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let conv_blocks = vec![
ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
},
ConvBlock {
conv: Conv2dConfig::new([4, 6], [3, 2]).init(device),
norm: BatchNormConfig::new(6).init(device), // matches conv output channels
},
];
let norm1 = BatchNormConfig::new(6).init(device);
let fc1 = LinearConfig::new(120, 12).init(device);
let fc2 = LinearConfig::new(12, 10).init(device);
Self {
conv_blocks,
norm1,
fc1,
fc2,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv_blocks[0].forward(x);
let x = self.conv_blocks[1].forward(x);
let x = self.norm1.forward(x);
let x = x.reshape([0, -1]);
let x = self.fc1.forward(x);
let x = relu(x);
let x = self.fc2.forward(x);
log_softmax(x, 1)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(x);
self.norm.forward(x)
}
}
/// Partial model to test loading of partial records.
#[derive(Module, Debug)]
pub struct PartialNet<B: Backend> {
conv1: ConvBlock<B>,
}
impl<B: Backend> PartialNet<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
};
Self { conv1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
/// Model with extra fields to test loading of records (e.g. from a different model).
#[derive(Module, Debug)]
pub struct PartialWithExtraNet<B: Backend> {
conv1: ConvBlock<B>,
extra_field: bool, // This field is not present in the pytorch model
}
impl<B: Backend> PartialWithExtraNet<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvBlock {
conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),
norm: BatchNormConfig::new(4).init(device), // matches conv output channels
};
Self {
conv1,
extra_field: true,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.conv1.forward(x)
}
}
type TestBackend = burn_ndarray::NdArray<f32>;
fn model_test(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
let expected = Tensor::<TestBackend, 2>::from_data(
[[
-2.306_613,
-2.058_945_4,
-2.298_372_7,
-2.358_294,
-2.296_395_5,
-2.416_090_5,
-2.107_669,
-2.428_420_8,
-2.526_469,
-2.319_918_6,
]],
&device,
);
output.to_data().assert_approx_eq::<FloatElem<TestBackend>>(
&expected.to_data(),
Tolerance::absolute(precision),
);
}
#[test]
fn full_record() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
model_test(model, 1e-8);
}
#[test]
fn full_record_autodiff() {
let device = Default::default();
let mut model = Net::<Autodiff<TestBackend>>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn half_record() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
model_test(model, 1e-4);
}
#[test]
fn partial_model_loading() {
let device = Default::default();
let mut model = PartialNet::<TestBackend>::init(&device);
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
.allow_partial(true);
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
}
#[test]
fn extra_field_model_loading() {
let device = Default::default();
let mut model = PartialWithExtraNet::<TestBackend>::init(&device);
// Load the full model but rename "conv_blocks.0.*" to "conv1.*"
let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt")
.with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1")
.allow_partial(true);
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;
let output = model.forward(input);
// get the sum of all elements in the output tensor for quick check
let sum = output.sum();
assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);
assert!(model.extra_field);
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)
return x
CONFIG = {
"n_head": 2,
"n_layer": 3,
"d_model": 512,
"some_float": 0.1,
"some_int": 1,
"some_bool": True,
"some_str": "hello",
"some_list_int": [1, 2, 3],
"some_list_str": ["hello", "world"],
"some_list_float": [0.1, 0.2, 0.3],
"some_dict": {
"some_key": "some_value"
}
}
class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)
def forward(self, x):
x = self.fc1(x)
return x
def main():
model = Model().to(torch.device("cpu"))
weights_with_config = {
"my_model": model.state_dict(),
"my_config": CONFIG
}
torch.save(weights_with_config, "weights_with_config.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,53 @@
#![allow(clippy::too_many_arguments)] // To mute derive Config warning
use std::collections::HashMap;
use burn::config::Config;
#[allow(clippy::too_many_arguments)]
#[derive(Debug, PartialEq, Config)]
struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec<i32>,
some_list_str: Vec<String>,
some_list_float: Vec<f64>,
some_dict: HashMap<String, String>,
}
#[cfg(test)]
mod tests {
use burn_store::pytorch::PytorchReader;
use super::*;
#[test]
fn test_net_config() {
let config_expected = NetConfig {
n_head: 2,
n_layer: 3,
d_model: 512,
some_float: 0.1,
some_int: 1,
some_bool: true,
some_str: "hello".to_string(),
some_list_int: vec![1, 2, 3],
some_list_str: vec!["hello".to_string(), "world".to_string()],
some_list_float: vec![0.1, 0.2, 0.3],
some_dict: {
let mut map = HashMap::new();
map.insert("some_key".to_string(), "some_value".to_string());
map
},
};
let path = "tests/config/weights_with_config.pt";
let top_level_key = Some("my_config");
let config: NetConfig = PytorchReader::load_config(path, top_level_key).unwrap();
assert_eq!(config, config_expected);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv1d(2, 2, 2)
self.conv2 = nn.Conv1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv1d.pt")
input = torch.rand(1, 2, 6)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,97 @@
use burn::{
module::Module,
nn::conv::{Conv1d, Conv1dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv1d<B>,
conv2: Conv1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv1dConfig::new(2, 2, 2).init(device);
let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn conv1d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 3>::from_data(
[[
[
0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,
],
[
0.33977574,
0.523_940_8,
0.798_063_9,
0.77176833,
0.01122457,
0.80996025,
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[[
[0.02987457, 0.03134188, 0.04234261, -0.02437721],
[-0.03788019, -0.02972012, -0.00806090, -0.01981254],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv1d_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv1d(model, 1e-7);
}
#[test]
fn conv1d_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv1d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv2d.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,134 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv2d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv2d_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv2d(model, 1e-7);
}
#[test]
fn conv2d_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv2d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose1d(2, 2, 2)
self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose1d.pt")
input = torch.rand(1, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,87 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose1d<B>,
conv2: ConvTranspose1d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device);
let conv2 = ConvTranspose1dConfig::new([2, 2], 2)
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv_transpose1d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 3>::from_data(
[[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[[
[0.02935525, 0.01119324, -0.01356167, -0.00682688],
[0.01644749, -0.01429807, 0.00083987, 0.00279229],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv_transpose1d_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose1d(model, 1e-8);
}
#[test]
fn conv_transpose1d_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose1d(model, 1e-4);
}
}

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))
self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "conv_transpose2d.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,99 @@
use burn::{
module::Module,
nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: ConvTranspose2d<B>,
conv2: ConvTranspose2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn conv_transpose2d(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
[[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.04547675, 0.01879685, -0.01636661, 0.00310803],
[0.02090115, 0.01192738, -0.048_240_2, 0.02252235],
[0.03249975, -0.00460748, 0.05003899, 0.04029131],
[0.02185687, -0.10226749, -0.06508022, -0.01267705],
],
[
[0.00277598, -0.00513832, -0.059_048_3, 0.00567626],
[-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],
[-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],
[-0.03174751, 0.02963913, -0.02703723, -0.01860938],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn conv_transpose2d_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose2d(model, 1e-7);
}
#[test]
fn conv_transpose2d_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
conv_transpose2d(model, 1e-4);
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embed = nn.Embedding(10, 3)
def forward(self, x):
x = self.embed(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "embedding.pt")
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,86 @@
use burn::{
module::Module,
nn::{Embedding, EmbeddingConfig},
tensor::{Int, Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
embed: Embedding<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let embed = EmbeddingConfig::new(10, 3).init(device);
Self { embed }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
self.embed.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn embedding(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 3>::from_data(
[
[
[-1.609_484_9, -0.10016718, -0.609_188_9],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.22227049, 1.687_113_4, -0.32062083],
[-0.29934573, 1.879_345_7, -0.07213178],
],
[
[-0.22227049, 1.687_113_4, -0.32062083],
[0.303_722, -0.777_314_3, -0.25145486],
[-0.97977227, -1.609_096_3, -0.712_144_6],
[-0.02878714, 2.357_111, -1.037_338_7],
],
],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn embedding_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
embedding(model, 1e-3);
}
#[test]
fn embedding_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/embedding/embedding.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
embedding(model, 1e-3);
}
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class DwsConv(nn.Module):
"""Depthwise separable convolution."""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
# Depthwise conv
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
# Pointwise conv
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: Tensor) -> Tensor:
x = self.dconv(x)
return self.pconv(x)
class Model(nn.Module):
def __init__(self, depthwise: bool = False) -> None:
super().__init__()
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_false.pt")
input = torch.rand(1, 2, 5, 5)
print("Depthwise is False")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Depthwise is True")
model = Model(depthwise=True).to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_true.pt")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,197 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum Conv<B: Backend> {
DwsConv(DwsConv<B>),
Conv(Conv2d<B>),
}
#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
dconv: Conv2d<B>,
pconv: Conv2d<B>,
}
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv: Conv<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model with DwsConv variant.
pub fn init_dws_conv(device: &B::Device) -> Self {
let dconv = Conv2dConfig::new([2, 2], [3, 3])
.with_groups(2)
.init(device);
let pconv = Conv2dConfig::new([2, 2], [1, 1])
.with_groups(1)
.init(device);
Net {
conv: Conv::DwsConv(DwsConv { dconv, pconv }),
}
}
/// Create a new model with Conv variant.
pub fn init_conv(device: &B::Device) -> Self {
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);
Net {
conv: Conv::Conv(conv2d_config.init(device)),
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
match &self.conv {
Conv::DwsConv(dws_conv) => {
let x = dws_conv.dconv.forward(x);
dws_conv.pconv.forward(x)
}
Conv::Conv(conv) => conv.forward(x),
}
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn depthwise_false() {
let device = Default::default();
let mut model = Net::<TestBackend>::init_conv(&device);
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_false.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.35449377, -0.02832414, 0.490_976_1],
[0.29709217, 0.332_586_3, 0.30594018],
[0.18101373, 0.30932188, 0.30558896],
],
[
[-0.17683622, -0.13244139, -0.05608707],
[0.23467252, -0.07038684, 0.255_044_1],
[-0.241_931_3, -0.20476191, -0.14468731],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn depthwise_true() {
let device = Default::default();
let mut model = Net::<TestBackend>::init_dws_conv(&device);
let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_true.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],
[0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],
[0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],
[0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],
[
0.804_481_1,
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
],
],
[
[0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],
[0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],
[0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],
[
0.03694397,
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
],
[0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.77874625, 0.859_017_6, 0.834_283_5],
[0.773_056_4, 0.73817325, 0.78292674],
[0.710_775_2, 0.747_187_2, 0.733_264_4],
],
[
[-0.44891885, -0.49027523, -0.394_170_7],
[-0.43836114, -0.33961445, -0.387_311_5],
[-0.581_134_3, -0.34197026, -0.535_035_7],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.GroupNorm(2, 6)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "group_norm.pt")
x2 = torch.rand(1, 6, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,90 @@
use burn::{
module::Module,
nn::{GroupNorm, GroupNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: GroupNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn init(device: &B::Device) -> Self {
let norm1 = GroupNormConfig::new(2, 6).init(device);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::Tolerance;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn group_norm(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
[[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],
[[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],
[[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],
[[0.31379688, 0.19801933], [0.41619217, 0.28432965]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],
[[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],
[[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],
[[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],
[[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],
[[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn group_norm_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
group_norm(model, 1e-3);
}
#[test]
fn group_norm_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
group_norm(model, 1e-3);
}
}

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([1, 2, 3])
self.register_buffer("buffer", buffer, persistent=True)
def forward(self, x):
x = self.buffer
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "integer.pt")
input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,72 @@
use burn::{
module::{Module, Param, ParamId},
tensor::{Int, Tensor, TensorData, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Int>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
Self {
buffer: Param::initialized(
ParamId::new(),
Tensor::<B, 1, Int>::from_data(TensorData::from([0, 0, 0]), device),
),
}
}
/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {
self.buffer.val()
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::TensorData;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
fn integer(model: Net<TestBackend>) {
let device = Default::default();
let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);
let output = model.forward(input);
let expected =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 2, 3]), &device);
assert_eq!(output.to_data(), expected.to_data());
}
#[test]
fn integer_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
integer(model);
}
#[test]
fn integer_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/integer/integer.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
integer(model);
}
}

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvModule(nn.Module):
def __init__(self):
super(ConvModule, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = ConvModule()
def forward(self, x):
x = self.conv(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "key_remap.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,118 @@
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn key_remap() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/key_remap/key_remap.pt")
.with_key_remapping("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.024_595_8,
0.25883394,
0.93905586,
0.416_715_5,
0.713_979_7,
],
[0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],
[0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],
[0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],
[
0.99802476,
0.900_794_2,
0.476_588_2,
0.16625845,
0.804_481_1,
],
],
[
[
0.65517855,
0.17679012,
0.824_772_3,
0.803_550_9,
0.943_447_5,
],
[0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],
[0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],
[0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],
[
0.751_675_7,
0.148_438_4,
0.12274551,
0.530_407_2,
0.414_796_4,
],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.02502128, 0.00250649, 0.04841233],
[0.04589614, -0.00296854, 0.01991477],
[0.02920526, 0.059_497_3, 0.04326791],
],
[
[-0.04825336, 0.080_190_9, -0.02375088],
[0.02885434, 0.09638263, -0.07460806],
[0.02004079, 0.06244051, 0.035_887_1],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor
class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, 3, bias=False)
self.bn = nn.BatchNorm2d(6)
self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6))
def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.layer(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(42)
model = Model()
input = torch.rand(1, 3, 4, 4)
model(input) # condition batch norm
model.eval()
with torch.no_grad():
print(f"Input shape: {input.shape}")
print("Input type: {}", input.dtype)
print(f"Input: {input}")
output = model(input)
print(f"Output: {output}")
print(f"Output Shape: {output.shape}")
torch.save(model.state_dict(), "key_remap.pt")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,179 @@
use std::marker::PhantomData;
use burn::{
module::Module,
nn::{
BatchNorm, BatchNormConfig,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Device, Tensor, backend::Backend},
};
/// Some module that implements a specific method so it can be used in a sequential block.
pub trait ForwardModule<B: Backend> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
}
/// Conv2d + BatchNorm block.
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv2d<B>,
bn: BatchNorm<B>,
}
impl<B: Backend> ForwardModule<B> for ConvBlock<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
self.bn.forward(out)
}
}
impl<B: Backend> ConvBlock<B> {
pub fn new(in_channels: usize, out_channels: usize, device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self { conv, bn }
}
}
/// Collection of sequential blocks.
#[derive(Module, Debug)]
pub struct ModuleBlock<B: Backend, M> {
blocks: Vec<M>,
_backend: PhantomData<B>,
}
impl<B: Backend, M: ForwardModule<B>> ModuleBlock<B, M> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let mut out = input;
for block in &self.blocks {
out = block.forward(out);
}
out
}
}
impl<B: Backend> ModuleBlock<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)];
Self {
blocks,
_backend: PhantomData,
}
}
}
#[derive(Module, Debug)]
pub struct Model<B: Backend, M> {
conv: Conv2d<B>,
bn: BatchNorm<B>,
layer: ModuleBlock<B, M>,
}
impl<B: Backend> Model<B, ConvBlock<B>> {
pub fn new(device: &Device<B>) -> Self {
let conv = Conv2dConfig::new([3, 6], [3, 3])
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(6).init(device);
let layer = ModuleBlock::new(device);
Self { conv, bn, layer }
}
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let out = self.conv.forward(input);
let out = self.bn.forward(out);
self.layer.forward(out)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
#[should_panic]
fn key_remap_chained_missing_pattern() {
// Loading record should fail due to missing pattern to map the layer.blocks
let device = Default::default();
let mut model: Model<TestBackend, _> = Model::new(&device);
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
// Map *.block.0.* -> *.conv.*
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn key_remap_chained() {
let device = Default::default();
let mut model: Model<TestBackend, _> = Model::new(&device);
let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt")
// Map *.block.0.* -> *.conv.*
.with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2")
// Map *.block.1.* -> *.bn.*
.with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2")
// Map layer.[i].* -> layer.blocks.[i].*
.with_key_remapping("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.76193494, 0.626_546_1, 0.49510366, 0.11974698],
[0.07161391, 0.03232569, 0.704_681, 0.254_516],
[0.399_373_7, 0.21224737, 0.40888822, 0.14808255],
[0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6],
],
[
[0.33959562, 0.13321638, 0.41178054, 0.257_626_3],
[0.347_029_2, 0.02400219, 0.77974546, 0.15189773],
[0.75130886, 0.726_892_1, 0.85721636, 0.11647397],
[0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734],
],
[
[0.42948407, 0.49613327, 0.38488472, 0.08250773],
[0.73995143, 0.00364107, 0.81039995, 0.87411255],
[0.972_853_2, 0.38206023, 0.08917904, 0.61241513],
[0.77621365, 0.00234562, 0.38650817, 0.20027226],
],
]],
&device,
);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]],
[[0.17582723, 0.11344293], [0.05444185, 0.13307181]],
[[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]],
[[0.00230906, -0.02177845], [0.01129148, 0.00925517]],
[[0.14751078, 0.14433631], [0.05498439, 0.29049855]],
[[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]],
]],
&device,
);
let output = model.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.LayerNorm(2)
def forward(self, x):
x = self.norm1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "layer_norm.pt")
x2 = torch.rand(1, 2, 2, 2)
print("Input shape: {}", x2.shape)
print("Input: {}", x2)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,82 @@
use burn::{
module::Module,
nn::{LayerNorm, LayerNormConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: LayerNorm<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let norm1 = LayerNormConfig::new(2).init(device);
Self { norm1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn layer_norm(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],
[[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],
[[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn layer_norm_full() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
layer_norm(model, 1e-3);
}
#[test]
fn layer_norm_half() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
layer_norm(model, 1e-3);
}
}

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 4, bias=False)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2
x = self.fc2(x)
return x
class ModelWithBias(nn.Module):
def __init__(self):
super(ModelWithBias, self).__init__()
self.fc1 = nn.Linear(2, 3)
def forward(self, x):
x = self.fc1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
model_with_bias = ModelWithBias().to(torch.device("cpu"))
torch.save(model.state_dict(), "linear.pt")
torch.save(model_with_bias.state_dict(), "linear_with_bias.pt")
input = torch.rand(1, 2, 2, 2)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
print("Model with bias")
output = model_with_bias(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,154 @@
use burn::{
module::Module,
nn::{Linear, LinearConfig, Relu},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
relu: Relu,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let fc1 = LinearConfig::new(2, 3).init(device);
let fc2 = LinearConfig::new(3, 4).with_bias(false).init(device);
let relu = Relu;
Self { fc1, fc2, relu }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.fc1.forward(x);
let x = self.relu.forward(x);
self.fc2.forward(x)
}
}
#[derive(Module, Debug)]
struct NetWithBias<B: Backend> {
fc1: Linear<B>,
}
impl<B: Backend> NetWithBias<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let fc1 = LinearConfig::new(2, 3).init(device);
Self { fc1 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.fc1.forward(x)
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
fn linear_test(model: Net<TestBackend>, precision: f32) {
let device = Default::default();
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.09778349, -0.13756673, 0.04962806, 0.08856435],
[0.03163241, -0.02848549, 0.01437942, 0.11905234],
],
[
[0.07628226, -0.10757702, 0.03656857, 0.03824598],
[0.05443089, -0.06904714, 0.02744314, 0.09997337],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));
}
#[test]
fn linear_full_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
linear_test(model, 1e-7);
}
#[test]
fn linear_half_precision() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
linear_test(model, 1e-4);
}
#[test]
fn linear_with_bias() {
let device = Default::default();
let mut model = NetWithBias::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/linear/linear_with_bias.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],
[[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[-0.00432095, -1.107_101_2, 0.870_691_4],
[0.024_595_5, -0.954_462_9, 0.48518157],
],
[
[0.34315687, -0.757_384_2, 0.548_288],
[-0.06608963, -1.072_072_7, 0.645_800_5],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
def forward(self, x):
x = self.conv1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "missing_module_field.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,37 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
do_not_exist_in_pt: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::nn::conv::Conv2dConfig;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
Self {
do_not_exist_in_pt: Conv2dConfig::new([2, 2], [2, 2]).init(device),
}
}
}
#[test]
#[should_panic(expected = "do_not_exist_in_pt")]
fn should_fail_if_struct_field_is_missing() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store =
PytorchStore::from_file("tests/missing_module_field/missing_module_field.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
}

View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
num_layers = 5 # Number of repeated convolutional layers
# Create a list to store the layers
layers = []
for _ in range(num_layers):
layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True))
layers.append(nn.ReLU(inplace=True))
# Use nn.Sequential to create a single module from the layers
self.fc = nn.Sequential(*layers)
def forward(self, x):
x = self.fc(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "non_contiguous_indexes.pt")
input = torch.rand(1, 2, 5, 5)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,110 @@
use burn::{
module::Module,
nn::{
PaddingConfig2d,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Tensor, activation::relu, backend::Backend},
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
fc: Vec<Conv2d<B>>,
}
impl<B: Backend> Net<B> {
/// Create a new model with placeholder values.
pub fn init(device: &B::Device) -> Self {
let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same);
// The PyTorch file has 5 Conv2d layers at non-contiguous indices (0, 2, 4, 6, 8)
// in the Sequential (alternating with ReLU layers)
let fc = vec![
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
conv2d_config.init(device),
];
Net { fc }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i)))
}
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
use burn_store::{ModuleSnapshot, PytorchStore};
type FT = FloatElem<TestBackend>;
use super::*;
#[test]
fn non_contiguous_indexes() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store =
PytorchStore::from_file("tests/non_contiguous_indexes/non_contiguous_indexes.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
let input = Tensor::<TestBackend, 4>::from_data(
[[
[
[
0.67890584,
0.307_537_2,
0.265_156_2,
0.528_318_8,
0.86194897,
],
[0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455],
[0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637],
[0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804],
[0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932],
],
[
[0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266],
[0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439],
[0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786],
[0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244],
[0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938],
],
]],
&device,
);
let output = model.forward(input);
let expected = Tensor::<TestBackend, 4>::from_data(
[[
[
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3],
],
[
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],
],
]],
&device,
);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(1e-7));
}
}

View File

@@ -0,0 +1,22 @@
mod backend;
mod batch_norm;
mod boolean;
mod buffer;
mod complex_nested;
mod config;
mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod embedding;
mod enum_module;
mod group_norm;
mod integer;
mod key_remap;
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
mod non_contiguous_indexes;
mod top_level_key;

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
def forward(self, x):
x = self.conv1(x)
return x
def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,48 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::nn::conv::Conv2dConfig;
use burn_store::{ModuleSnapshot, PytorchStore};
use super::*;
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([2, 2], [2, 2]).init(device),
}
}
}
#[test]
#[should_panic]
fn should_fail_if_not_found() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
#[test]
fn should_load() {
let device = Default::default();
let mut model = Net::<TestBackend>::init(&device);
let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt")
.with_top_level_key("my_state_dict");
model
.load_from(&mut store)
.expect("Should decode state successfully");
}
}