- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
691 lines
24 KiB
Rust
691 lines
24 KiB
Rust
use super::*;
|
|
use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};
|
|
|
|
#[test]
|
|
fn test_conv3d_basic() {
|
|
let test = Conv3dTestCase {
|
|
batch_size: 2,
|
|
channels_in: 2,
|
|
channels_out: 2,
|
|
kernel_size_1: 3,
|
|
kernel_size_2: 3,
|
|
kernel_size_3: 3,
|
|
padding_1: 1,
|
|
padding_2: 1,
|
|
padding_3: 1,
|
|
stride_1: 1,
|
|
stride_2: 1,
|
|
stride_3: 1,
|
|
dilation_1: 1,
|
|
dilation_2: 1,
|
|
dilation_3: 1,
|
|
groups: 1,
|
|
depth: 4,
|
|
height: 4,
|
|
width: 4,
|
|
};
|
|
let device = Default::default();
|
|
let grads = Grads {
|
|
x: TestTensor::from_floats(
|
|
[
|
|
[
|
|
[
|
|
[
|
|
[536., 816., 816., 552.],
|
|
[840., 1278., 1278., 864.],
|
|
[840., 1278., 1278., 864.],
|
|
[584., 888., 888., 600.],
|
|
],
|
|
[
|
|
[912., 1386., 1386., 936.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[984., 1494., 1494., 1008.],
|
|
],
|
|
[
|
|
[912., 1386., 1386., 936.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[984., 1494., 1494., 1008.],
|
|
],
|
|
[
|
|
[680., 1032., 1032., 696.],
|
|
[1056., 1602., 1602., 1080.],
|
|
[1056., 1602., 1602., 1080.],
|
|
[728., 1104., 1104., 744.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[968., 1464., 1464., 984.],
|
|
[1488., 2250., 2250., 1512.],
|
|
[1488., 2250., 2250., 1512.],
|
|
[1016., 1536., 1536., 1032.],
|
|
],
|
|
[
|
|
[1560., 2358., 2358., 1584.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[1632., 2466., 2466., 1656.],
|
|
],
|
|
[
|
|
[1560., 2358., 2358., 1584.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[1632., 2466., 2466., 1656.],
|
|
],
|
|
[
|
|
[1112., 1680., 1680., 1128.],
|
|
[1704., 2574., 2574., 1728.],
|
|
[1704., 2574., 2574., 1728.],
|
|
[1160., 1752., 1752., 1176.],
|
|
],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[
|
|
[536., 816., 816., 552.],
|
|
[840., 1278., 1278., 864.],
|
|
[840., 1278., 1278., 864.],
|
|
[584., 888., 888., 600.],
|
|
],
|
|
[
|
|
[912., 1386., 1386., 936.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[984., 1494., 1494., 1008.],
|
|
],
|
|
[
|
|
[912., 1386., 1386., 936.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[1422., 2160., 2160., 1458.],
|
|
[984., 1494., 1494., 1008.],
|
|
],
|
|
[
|
|
[680., 1032., 1032., 696.],
|
|
[1056., 1602., 1602., 1080.],
|
|
[1056., 1602., 1602., 1080.],
|
|
[728., 1104., 1104., 744.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[968., 1464., 1464., 984.],
|
|
[1488., 2250., 2250., 1512.],
|
|
[1488., 2250., 2250., 1512.],
|
|
[1016., 1536., 1536., 1032.],
|
|
],
|
|
[
|
|
[1560., 2358., 2358., 1584.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[1632., 2466., 2466., 1656.],
|
|
],
|
|
[
|
|
[1560., 2358., 2358., 1584.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[2394., 3618., 3618., 2430.],
|
|
[1632., 2466., 2466., 1656.],
|
|
],
|
|
[
|
|
[1112., 1680., 1680., 1128.],
|
|
[1704., 2574., 2574., 1728.],
|
|
[1704., 2574., 2574., 1728.],
|
|
[1160., 1752., 1752., 1176.],
|
|
],
|
|
],
|
|
],
|
|
],
|
|
&device,
|
|
),
|
|
weight: TestTensor::from_floats(
|
|
[
|
|
[
|
|
[
|
|
[
|
|
[4590., 6156., 4644.],
|
|
[6264., 8400., 6336.],
|
|
[4806., 6444., 4860.],
|
|
],
|
|
[
|
|
[6696., 8976., 6768.],
|
|
[9120., 12224., 9216.],
|
|
[6984., 9360., 7056.],
|
|
],
|
|
[
|
|
[5454., 7308., 5508.],
|
|
[7416., 9936., 7488.],
|
|
[5670., 7596., 5724.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[8046., 10764., 8100.],
|
|
[10872., 14544., 10944.],
|
|
[8262., 11052., 8316.],
|
|
],
|
|
[
|
|
[11304., 15120., 11376.],
|
|
[15264., 20416., 15360.],
|
|
[11592., 15504., 11664.],
|
|
],
|
|
[
|
|
[8910., 11916., 8964.],
|
|
[12024., 16080., 12096.],
|
|
[9126., 12204., 9180.],
|
|
],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[
|
|
[4590., 6156., 4644.],
|
|
[6264., 8400., 6336.],
|
|
[4806., 6444., 4860.],
|
|
],
|
|
[
|
|
[6696., 8976., 6768.],
|
|
[9120., 12224., 9216.],
|
|
[6984., 9360., 7056.],
|
|
],
|
|
[
|
|
[5454., 7308., 5508.],
|
|
[7416., 9936., 7488.],
|
|
[5670., 7596., 5724.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[8046., 10764., 8100.],
|
|
[10872., 14544., 10944.],
|
|
[8262., 11052., 8316.],
|
|
],
|
|
[
|
|
[11304., 15120., 11376.],
|
|
[15264., 20416., 15360.],
|
|
[11592., 15504., 11664.],
|
|
],
|
|
[
|
|
[8910., 11916., 8964.],
|
|
[12024., 16080., 12096.],
|
|
[9126., 12204., 9180.],
|
|
],
|
|
],
|
|
],
|
|
],
|
|
&device,
|
|
),
|
|
bias: TestTensor::from_floats([128., 128.], &device),
|
|
};
|
|
test.assert_grads(grads);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conv3d_complex() {
|
|
let test = Conv3dTestCase {
|
|
batch_size: 1,
|
|
channels_in: 2,
|
|
channels_out: 3,
|
|
kernel_size_1: 2,
|
|
kernel_size_2: 3,
|
|
kernel_size_3: 4,
|
|
padding_1: 1,
|
|
padding_2: 2,
|
|
padding_3: 3,
|
|
stride_1: 1,
|
|
stride_2: 2,
|
|
stride_3: 3,
|
|
dilation_1: 2,
|
|
dilation_2: 3,
|
|
dilation_3: 4,
|
|
groups: 1,
|
|
depth: 5,
|
|
height: 6,
|
|
width: 7,
|
|
};
|
|
let device = Default::default();
|
|
let grads = Grads {
|
|
x: TestTensor::from_floats(
|
|
[[
|
|
[
|
|
[
|
|
[0., 147., 0., 0., 0., 150., 0.],
|
|
[0., 159., 0., 0., 0., 162., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 159., 0., 0., 0., 162., 0.],
|
|
[0., 171., 0., 0., 0., 174., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 330., 0., 0., 0., 336., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 378., 0., 0., 0., 384., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 330., 0., 0., 0., 336., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 378., 0., 0., 0., 384., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 330., 0., 0., 0., 336., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 354., 0., 0., 0., 360., 0.],
|
|
[0., 378., 0., 0., 0., 384., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 183., 0., 0., 0., 186., 0.],
|
|
[0., 195., 0., 0., 0., 198., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 195., 0., 0., 0., 198., 0.],
|
|
[0., 207., 0., 0., 0., 210., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[0., 219., 0., 0., 0., 222., 0.],
|
|
[0., 231., 0., 0., 0., 234., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 231., 0., 0., 0., 234., 0.],
|
|
[0., 243., 0., 0., 0., 246., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 474., 0., 0., 0., 480., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 522., 0., 0., 0., 528., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 474., 0., 0., 0., 480., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 522., 0., 0., 0., 528., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 474., 0., 0., 0., 480., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 498., 0., 0., 0., 504., 0.],
|
|
[0., 522., 0., 0., 0., 528., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 255., 0., 0., 0., 258., 0.],
|
|
[0., 267., 0., 0., 0., 270., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
[0., 267., 0., 0., 0., 270., 0.],
|
|
[0., 279., 0., 0., 0., 282., 0.],
|
|
[0., 0., 0., 0., 0., 0., 0.],
|
|
],
|
|
],
|
|
]],
|
|
&device,
|
|
),
|
|
weight: TestTensor::from_floats(
|
|
[
|
|
[
|
|
[
|
|
[
|
|
[0., 256., 272., 0.],
|
|
[0., 624., 656., 0.],
|
|
[0., 368., 384., 0.],
|
|
],
|
|
[
|
|
[0., 424., 440., 0.],
|
|
[0., 960., 992., 0.],
|
|
[0., 536., 552., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[0., 1096., 1112., 0.],
|
|
[0., 2304., 2336., 0.],
|
|
[0., 1208., 1224., 0.],
|
|
],
|
|
[
|
|
[0., 1264., 1280., 0.],
|
|
[0., 2640., 2672., 0.],
|
|
[0., 1376., 1392., 0.],
|
|
],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[
|
|
[0., 256., 272., 0.],
|
|
[0., 624., 656., 0.],
|
|
[0., 368., 384., 0.],
|
|
],
|
|
[
|
|
[0., 424., 440., 0.],
|
|
[0., 960., 992., 0.],
|
|
[0., 536., 552., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[0., 1096., 1112., 0.],
|
|
[0., 2304., 2336., 0.],
|
|
[0., 1208., 1224., 0.],
|
|
],
|
|
[
|
|
[0., 1264., 1280., 0.],
|
|
[0., 2640., 2672., 0.],
|
|
[0., 1376., 1392., 0.],
|
|
],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[
|
|
[0., 256., 272., 0.],
|
|
[0., 624., 656., 0.],
|
|
[0., 368., 384., 0.],
|
|
],
|
|
[
|
|
[0., 424., 440., 0.],
|
|
[0., 960., 992., 0.],
|
|
[0., 536., 552., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[0., 1096., 1112., 0.],
|
|
[0., 2304., 2336., 0.],
|
|
[0., 1208., 1224., 0.],
|
|
],
|
|
[
|
|
[0., 1264., 1280., 0.],
|
|
[0., 2640., 2672., 0.],
|
|
[0., 1376., 1392., 0.],
|
|
],
|
|
],
|
|
],
|
|
],
|
|
&device,
|
|
),
|
|
bias: TestTensor::from_floats([10., 10., 10.], &device),
|
|
};
|
|
test.assert_grads(grads);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conv3d_groups_stride_2_no_pad() {
|
|
let test = Conv3dTestCase {
|
|
batch_size: 1,
|
|
channels_in: 4,
|
|
channels_out: 2,
|
|
kernel_size_1: 3,
|
|
kernel_size_2: 3,
|
|
kernel_size_3: 3,
|
|
padding_1: 0,
|
|
padding_2: 0,
|
|
padding_3: 0,
|
|
stride_1: 2,
|
|
stride_2: 2,
|
|
stride_3: 2,
|
|
dilation_1: 1,
|
|
dilation_2: 1,
|
|
dilation_3: 1,
|
|
groups: 2,
|
|
depth: 4,
|
|
height: 4,
|
|
width: 4,
|
|
};
|
|
let device = Default::default();
|
|
let grads = Grads {
|
|
x: TestTensor::from_floats(
|
|
[[
|
|
[
|
|
[
|
|
[0., 1., 2., 0.],
|
|
[3., 4., 5., 0.],
|
|
[6., 7., 8., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[9., 10., 11., 0.],
|
|
[12., 13., 14., 0.],
|
|
[15., 16., 17., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[18., 19., 20., 0.],
|
|
[21., 22., 23., 0.],
|
|
[24., 25., 26., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[27., 28., 29., 0.],
|
|
[30., 31., 32., 0.],
|
|
[33., 34., 35., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[36., 37., 38., 0.],
|
|
[39., 40., 41., 0.],
|
|
[42., 43., 44., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[45., 46., 47., 0.],
|
|
[48., 49., 50., 0.],
|
|
[51., 52., 53., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[54., 55., 56., 0.],
|
|
[57., 58., 59., 0.],
|
|
[60., 61., 62., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[63., 64., 65., 0.],
|
|
[66., 67., 68., 0.],
|
|
[69., 70., 71., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[72., 73., 74., 0.],
|
|
[75., 76., 77., 0.],
|
|
[78., 79., 80., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[81., 82., 83., 0.],
|
|
[84., 85., 86., 0.],
|
|
[87., 88., 89., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[90., 91., 92., 0.],
|
|
[93., 94., 95., 0.],
|
|
[96., 97., 98., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[99., 100., 101., 0.],
|
|
[102., 103., 104., 0.],
|
|
[105., 106., 107., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
[
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
[0., 0., 0., 0.],
|
|
],
|
|
],
|
|
]],
|
|
&device,
|
|
),
|
|
weight: TestTensor::from_floats(
|
|
[
|
|
[
|
|
[
|
|
[[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],
|
|
[[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],
|
|
[[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],
|
|
],
|
|
[
|
|
[[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],
|
|
[[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],
|
|
[[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],
|
|
[[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],
|
|
[[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],
|
|
],
|
|
[
|
|
[[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],
|
|
[[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],
|
|
[[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],
|
|
],
|
|
],
|
|
],
|
|
&device,
|
|
),
|
|
bias: TestTensor::from_floats([1., 1.], &device),
|
|
};
|
|
test.assert_grads(grads);
|
|
}
|
|
|
|
struct Conv3dTestCase {
|
|
batch_size: usize,
|
|
channels_in: usize,
|
|
channels_out: usize,
|
|
kernel_size_1: usize,
|
|
kernel_size_2: usize,
|
|
kernel_size_3: usize,
|
|
padding_1: usize,
|
|
padding_2: usize,
|
|
padding_3: usize,
|
|
stride_1: usize,
|
|
stride_2: usize,
|
|
stride_3: usize,
|
|
dilation_1: usize,
|
|
dilation_2: usize,
|
|
dilation_3: usize,
|
|
groups: usize,
|
|
depth: usize,
|
|
height: usize,
|
|
width: usize,
|
|
}
|
|
|
|
struct Grads {
|
|
x: TestTensor<5>,
|
|
weight: TestTensor<5>,
|
|
bias: TestTensor<1>,
|
|
}
|
|
|
|
impl Conv3dTestCase {
|
|
fn assert_grads(self, expected_grads: Grads) {
|
|
let shape_x = Shape::new([
|
|
self.batch_size,
|
|
self.channels_in,
|
|
self.depth,
|
|
self.height,
|
|
self.width,
|
|
]);
|
|
let shape_weight = Shape::new([
|
|
self.channels_out,
|
|
self.channels_in / self.groups,
|
|
self.kernel_size_1,
|
|
self.kernel_size_2,
|
|
self.kernel_size_3,
|
|
]);
|
|
let device = Default::default();
|
|
let weight = TestAutodiffTensor::from_data(
|
|
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
|
.reshape::<5, _>(shape_weight)
|
|
.into_data(),
|
|
&device,
|
|
)
|
|
.require_grad();
|
|
let bias = TestAutodiffTensor::from_data(
|
|
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
|
&device,
|
|
)
|
|
.require_grad();
|
|
let x = TestAutodiffTensor::from_data(
|
|
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
|
.reshape::<5, _>(shape_x)
|
|
.into_data(),
|
|
&device,
|
|
)
|
|
.require_grad();
|
|
let output = conv3d(
|
|
x.clone(),
|
|
weight.clone(),
|
|
Some(bias.clone()),
|
|
ConvOptions::new(
|
|
[self.stride_1, self.stride_2, self.stride_3],
|
|
[self.padding_1, self.padding_2, self.padding_3],
|
|
[self.dilation_1, self.dilation_2, self.dilation_3],
|
|
self.groups,
|
|
),
|
|
);
|
|
let grads = output.backward();
|
|
|
|
// Assert
|
|
let x_grad_actual = x.grad(&grads).unwrap();
|
|
let weight_grad_actual = weight.grad(&grads).unwrap();
|
|
let bias_grad_actual = bias.grad(&grads).unwrap();
|
|
|
|
let tolerance = Tolerance::default();
|
|
expected_grads
|
|
.bias
|
|
.to_data()
|
|
.assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);
|
|
expected_grads
|
|
.x
|
|
.to_data()
|
|
.assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);
|
|
expected_grads
|
|
.weight
|
|
.to_data()
|
|
.assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);
|
|
}
|
|
}
|