Files
RustyUI/crates/stable-diffusion-burn/burn-crates/burn-backend-tests/tests/autodiff/cross.rs
Ben_Kosytorz 3a67c0979c feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
2026-03-05 19:39:14 +01:00

104 lines
3.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use super::*;
use burn_tensor::{TensorData, Tolerance};
#[cfg(feature = "std")]
use burn_backend_tests::might_panic;
#[test]
fn backward_basic() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Simple cross product; grad is a vector of ones.
let c = a.clone().cross(b.clone(), 1);
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
// For a: b×grad_out, where grad_out = [1,1,1]
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
// For b: grad_out×a
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[test]
fn backward_after_sum() {
let device = Default::default();
let a = TestAutodiffTensor::<2>::from_data(
TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
&device,
)
.require_grad();
let b = TestAutodiffTensor::<2>::from_data(
TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
&device,
)
.require_grad();
// Sum reduces to scalar, but the gradient should be the same.
let c = a.clone().cross(b.clone(), 1).sum();
let grads = c.backward();
let a_grad = a.grad(&grads).unwrap().to_data();
let b_grad = b.grad(&grads).unwrap().to_data();
let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);
let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);
a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());
b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());
}
#[cfg(feature = "std")]
#[might_panic(reason = "not implemented: Cross product on non-last dimension")]
#[test]
fn different_dim() {
// Also check when the cross is along a different dimension (e.g. dim 0).
let device = Default::default();
let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];
let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];
let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);
let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);
// Cross along dim 0. Some backends (for example CubeCL) may not support
// cross on non-last dimensions and will intentionally panic with a
// message like "Cross product on non-last dimension not yet implemented".
// In that case we treat the panic as a skipped test for that backend.
let out = a.cross(b.clone(), 0);
// Manually compute cross of each column vector using raw arrays
let expected = [
[
a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],
a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],
a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],
],
[
a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],
a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],
a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],
],
[
a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],
a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],
a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],
],
];
out.to_data()
.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());
}