use super::*; use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions}; #[test] fn test_conv3d_basic() { let test = Conv3dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 4, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [ [536., 816., 816., 552.], [840., 1278., 1278., 864.], [840., 1278., 1278., 864.], [584., 888., 888., 600.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [680., 1032., 1032., 696.], [1056., 1602., 1602., 1080.], [1056., 1602., 1602., 1080.], [728., 1104., 1104., 744.], ], ], [ [ [968., 1464., 1464., 984.], [1488., 2250., 2250., 1512.], [1488., 2250., 2250., 1512.], [1016., 1536., 1536., 1032.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1112., 1680., 1680., 1128.], [1704., 2574., 2574., 1728.], [1704., 2574., 2574., 1728.], [1160., 1752., 1752., 1176.], ], ], ], [ [ [ [536., 816., 816., 552.], [840., 1278., 1278., 864.], [840., 1278., 1278., 864.], [584., 888., 888., 600.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [680., 1032., 1032., 696.], [1056., 1602., 1602., 1080.], [1056., 1602., 1602., 1080.], [728., 1104., 1104., 744.], ], ], [ [ [968., 1464., 1464., 984.], [1488., 2250., 2250., 1512.], [1488., 2250., 2250., 1512.], [1016., 1536., 1536., 1032.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1112., 1680., 1680., 1128.], [1704., 2574., 2574., 1728.], [1704., 2574., 2574., 1728.], [1160., 1752., 1752., 1176.], ], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [ [ [4590., 6156., 4644.], [6264., 8400., 6336.], [4806., 6444., 4860.], ], [ [6696., 8976., 6768.], [9120., 12224., 9216.], [6984., 9360., 7056.], ], [ [5454., 7308., 5508.], [7416., 9936., 7488.], [5670., 7596., 5724.], ], ], [ [ [8046., 10764., 8100.], [10872., 14544., 10944.], [8262., 11052., 8316.], ], [ [11304., 15120., 11376.], [15264., 20416., 15360.], [11592., 15504., 11664.], ], [ [8910., 11916., 8964.], [12024., 16080., 12096.], [9126., 12204., 9180.], ], ], ], [ [ [ [4590., 6156., 4644.], [6264., 8400., 6336.], [4806., 6444., 4860.], ], [ [6696., 8976., 6768.], [9120., 12224., 9216.], [6984., 9360., 7056.], ], [ [5454., 7308., 5508.], [7416., 9936., 7488.], [5670., 7596., 5724.], ], ], [ [ [8046., 10764., 8100.], [10872., 14544., 10944.], [8262., 11052., 8316.], ], [ [11304., 15120., 11376.], [15264., 20416., 15360.], [11592., 15504., 11664.], ], [ [8910., 11916., 8964.], [12024., 16080., 12096.], [9126., 12204., 9180.], ], ], ], ], &device, ), bias: TestTensor::from_floats([128., 128.], &device), }; test.assert_grads(grads); } #[test] fn test_conv3d_complex() { let test = Conv3dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 2, kernel_size_2: 3, kernel_size_3: 4, padding_1: 1, padding_2: 2, padding_3: 3, stride_1: 1, stride_2: 2, stride_3: 3, dilation_1: 2, dilation_2: 3, dilation_3: 4, groups: 1, depth: 5, height: 6, width: 7, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [ [0., 147., 0., 0., 0., 150., 0.], [0., 159., 0., 0., 0., 162., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 159., 0., 0., 0., 162., 0.], [0., 171., 0., 0., 0., 174., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 183., 0., 0., 0., 186., 0.], [0., 195., 0., 0., 0., 198., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 195., 0., 0., 0., 198., 0.], [0., 207., 0., 0., 0., 210., 0.], [0., 0., 0., 0., 0., 0., 0.], ], ], [ [ [0., 219., 0., 0., 0., 222., 0.], [0., 231., 0., 0., 0., 234., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 231., 0., 0., 0., 234., 0.], [0., 243., 0., 0., 0., 246., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 255., 0., 0., 0., 258., 0.], [0., 267., 0., 0., 0., 270., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 267., 0., 0., 0., 270., 0.], [0., 279., 0., 0., 0., 282., 0.], [0., 0., 0., 0., 0., 0., 0.], ], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], ], &device, ), bias: TestTensor::from_floats([10., 10., 10.], &device), }; test.assert_grads(grads); } #[test] fn test_conv3d_groups_stride_2_no_pad() { let test = Conv3dTestCase { batch_size: 1, channels_in: 4, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 0, padding_2: 0, padding_3: 0, stride_1: 2, stride_2: 2, stride_3: 2, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 2, depth: 4, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [ [0., 1., 2., 0.], [3., 4., 5., 0.], [6., 7., 8., 0.], [0., 0., 0., 0.], ], [ [9., 10., 11., 0.], [12., 13., 14., 0.], [15., 16., 17., 0.], [0., 0., 0., 0.], ], [ [18., 19., 20., 0.], [21., 22., 23., 0.], [24., 25., 26., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [27., 28., 29., 0.], [30., 31., 32., 0.], [33., 34., 35., 0.], [0., 0., 0., 0.], ], [ [36., 37., 38., 0.], [39., 40., 41., 0.], [42., 43., 44., 0.], [0., 0., 0., 0.], ], [ [45., 46., 47., 0.], [48., 49., 50., 0.], [51., 52., 53., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [54., 55., 56., 0.], [57., 58., 59., 0.], [60., 61., 62., 0.], [0., 0., 0., 0.], ], [ [63., 64., 65., 0.], [66., 67., 68., 0.], [69., 70., 71., 0.], [0., 0., 0., 0.], ], [ [72., 73., 74., 0.], [75., 76., 77., 0.], [78., 79., 80., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [81., 82., 83., 0.], [84., 85., 86., 0.], [87., 88., 89., 0.], [0., 0., 0., 0.], ], [ [90., 91., 92., 0.], [93., 94., 95., 0.], [96., 97., 98., 0.], [0., 0., 0., 0.], ], [ [99., 100., 101., 0.], [102., 103., 104., 0.], [105., 106., 107., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]], [[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]], [[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]], ], [ [[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]], [[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]], [[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]], ], ], [ [ [[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]], [[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]], [[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]], ], [ [[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]], [[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]], [[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]], ], ], ], &device, ), bias: TestTensor::from_floats([1., 1.], &device), }; test.assert_grads(grads); } struct Conv3dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, kernel_size_3: usize, padding_1: usize, padding_2: usize, padding_3: usize, stride_1: usize, stride_2: usize, stride_3: usize, dilation_1: usize, dilation_2: usize, dilation_3: usize, groups: usize, depth: usize, height: usize, width: usize, } struct Grads { x: TestTensor<5>, weight: TestTensor<5>, bias: TestTensor<1>, } impl Conv3dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([ self.batch_size, self.channels_in, self.depth, self.height, self.width, ]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, self.kernel_size_3, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<5, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<5, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv3d( x.clone(), weight.clone(), Some(bias.clone()), ConvOptions::new( [self.stride_1, self.stride_2, self.stride_3], [self.padding_1, self.padding_2, self.padding_3], [self.dilation_1, self.dilation_2, self.dilation_3], self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::default(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } }