feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,690 @@
use super::*;
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
#[test]
fn test_conv3d_basic() {
let test = Conv3dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 1,
padding_2: 1,
padding_3: 1,
stride_1: 1,
stride_2: 1,
stride_3: 1,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 1,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
[
[
[
[536., 816., 816., 552.],
[840., 1278., 1278., 864.],
[840., 1278., 1278., 864.],
[584., 888., 888., 600.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[912., 1386., 1386., 936.],
[1422., 2160., 2160., 1458.],
[1422., 2160., 2160., 1458.],
[984., 1494., 1494., 1008.],
],
[
[680., 1032., 1032., 696.],
[1056., 1602., 1602., 1080.],
[1056., 1602., 1602., 1080.],
[728., 1104., 1104., 744.],
],
],
[
[
[968., 1464., 1464., 984.],
[1488., 2250., 2250., 1512.],
[1488., 2250., 2250., 1512.],
[1016., 1536., 1536., 1032.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1560., 2358., 2358., 1584.],
[2394., 3618., 3618., 2430.],
[2394., 3618., 3618., 2430.],
[1632., 2466., 2466., 1656.],
],
[
[1112., 1680., 1680., 1128.],
[1704., 2574., 2574., 1728.],
[1704., 2574., 2574., 1728.],
[1160., 1752., 1752., 1176.],
],
],
],
],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
[
[
[
[4590., 6156., 4644.],
[6264., 8400., 6336.],
[4806., 6444., 4860.],
],
[
[6696., 8976., 6768.],
[9120., 12224., 9216.],
[6984., 9360., 7056.],
],
[
[5454., 7308., 5508.],
[7416., 9936., 7488.],
[5670., 7596., 5724.],
],
],
[
[
[8046., 10764., 8100.],
[10872., 14544., 10944.],
[8262., 11052., 8316.],
],
[
[11304., 15120., 11376.],
[15264., 20416., 15360.],
[11592., 15504., 11664.],
],
[
[8910., 11916., 8964.],
[12024., 16080., 12096.],
[9126., 12204., 9180.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([128., 128.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_complex() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 2,
kernel_size_2: 3,
kernel_size_3: 4,
padding_1: 1,
padding_2: 2,
padding_3: 3,
stride_1: 1,
stride_2: 2,
stride_3: 3,
dilation_1: 2,
dilation_2: 3,
dilation_3: 4,
groups: 1,
depth: 5,
height: 6,
width: 7,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 147., 0., 0., 0., 150., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 159., 0., 0., 0., 162., 0.],
[0., 171., 0., 0., 0., 174., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 330., 0., 0., 0., 336., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 354., 0., 0., 0., 360., 0.],
[0., 378., 0., 0., 0., 384., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 183., 0., 0., 0., 186., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 195., 0., 0., 0., 198., 0.],
[0., 207., 0., 0., 0., 210., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
[
[
[0., 219., 0., 0., 0., 222., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 231., 0., 0., 0., 234., 0.],
[0., 243., 0., 0., 0., 246., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 474., 0., 0., 0., 480., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 498., 0., 0., 0., 504., 0.],
[0., 522., 0., 0., 0., 528., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
[
[0., 255., 0., 0., 0., 258., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 267., 0., 0., 0., 270., 0.],
[0., 279., 0., 0., 0., 282., 0.],
[0., 0., 0., 0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
[
[
[
[0., 256., 272., 0.],
[0., 624., 656., 0.],
[0., 368., 384., 0.],
],
[
[0., 424., 440., 0.],
[0., 960., 992., 0.],
[0., 536., 552., 0.],
],
],
[
[
[0., 1096., 1112., 0.],
[0., 2304., 2336., 0.],
[0., 1208., 1224., 0.],
],
[
[0., 1264., 1280., 0.],
[0., 2640., 2672., 0.],
[0., 1376., 1392., 0.],
],
],
],
],
&device,
),
bias: TestTensor::from_floats([10., 10., 10.], &device),
};
test.assert_grads(grads);
}
#[test]
fn test_conv3d_groups_stride_2_no_pad() {
let test = Conv3dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
kernel_size_3: 3,
padding_1: 0,
padding_2: 0,
padding_3: 0,
stride_1: 2,
stride_2: 2,
stride_3: 2,
dilation_1: 1,
dilation_2: 1,
dilation_3: 1,
groups: 2,
depth: 4,
height: 4,
width: 4,
};
let device = Default::default();
let grads = Grads {
x: TestTensor::from_floats(
[[
[
[
[0., 1., 2., 0.],
[3., 4., 5., 0.],
[6., 7., 8., 0.],
[0., 0., 0., 0.],
],
[
[9., 10., 11., 0.],
[12., 13., 14., 0.],
[15., 16., 17., 0.],
[0., 0., 0., 0.],
],
[
[18., 19., 20., 0.],
[21., 22., 23., 0.],
[24., 25., 26., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[27., 28., 29., 0.],
[30., 31., 32., 0.],
[33., 34., 35., 0.],
[0., 0., 0., 0.],
],
[
[36., 37., 38., 0.],
[39., 40., 41., 0.],
[42., 43., 44., 0.],
[0., 0., 0., 0.],
],
[
[45., 46., 47., 0.],
[48., 49., 50., 0.],
[51., 52., 53., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[54., 55., 56., 0.],
[57., 58., 59., 0.],
[60., 61., 62., 0.],
[0., 0., 0., 0.],
],
[
[63., 64., 65., 0.],
[66., 67., 68., 0.],
[69., 70., 71., 0.],
[0., 0., 0., 0.],
],
[
[72., 73., 74., 0.],
[75., 76., 77., 0.],
[78., 79., 80., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
[
[
[81., 82., 83., 0.],
[84., 85., 86., 0.],
[87., 88., 89., 0.],
[0., 0., 0., 0.],
],
[
[90., 91., 92., 0.],
[93., 94., 95., 0.],
[96., 97., 98., 0.],
[0., 0., 0., 0.],
],
[
[99., 100., 101., 0.],
[102., 103., 104., 0.],
[105., 106., 107., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
],
]],
&device,
),
weight: TestTensor::from_floats(
[
[
[
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
],
[
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
],
],
[
[
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
],
[
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
],
],
],
&device,
),
bias: TestTensor::from_floats([1., 1.], &device),
};
test.assert_grads(grads);
}
struct Conv3dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
kernel_size_3: usize,
padding_1: usize,
padding_2: usize,
padding_3: usize,
stride_1: usize,
stride_2: usize,
stride_3: usize,
dilation_1: usize,
dilation_2: usize,
dilation_3: usize,
groups: usize,
depth: usize,
height: usize,
width: usize,
}
struct Grads {
x: TestTensor<5>,
weight: TestTensor<5>,
bias: TestTensor<1>,
}
impl Conv3dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let shape_x = Shape::new([
self.batch_size,
self.channels_in,
self.depth,
self.height,
self.width,
]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
self.kernel_size_3,
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<5, _>(shape_weight)
.into_data(),
&device,
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<5, _>(shape_x)
.into_data(),
&device,
)
.require_grad();
let output = conv3d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2, self.stride_3],
[self.padding_1, self.padding_2, self.padding_3],
[self.dilation_1, self.dilation_2, self.dilation_3],
self.groups,
),
);
let grads = output.backward();
// Assert
let x_grad_actual = x.grad(&grads).unwrap();
let weight_grad_actual = weight.grad(&grads).unwrap();
let bias_grad_actual = bias.grad(&grads).unwrap();
let tolerance = Tolerance::default();
expected_grads
.bias
.to_data()
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
expected_grads
.x
.to_data()
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
expected_grads
.weight
.to_data()
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
}
}