- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
57 lines
1.4 KiB
Rust
57 lines
1.4 KiB
Rust
use super::*;
|
|
|
|
#[test]
|
|
fn mul_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| x * y);
|
|
}
|
|
|
|
#[test]
|
|
fn div_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| x / y);
|
|
}
|
|
|
|
#[test]
|
|
fn sub_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| x - y);
|
|
}
|
|
|
|
#[test]
|
|
fn add_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| x + y);
|
|
}
|
|
|
|
#[test]
|
|
fn matmul_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| x.matmul(y));
|
|
}
|
|
|
|
#[test]
|
|
fn mask_where_broadcast() {
|
|
test_ops_broadcast_backward(|x, y| {
|
|
let cond = y.clone().equal_elem(4);
|
|
x.mask_where(cond, y)
|
|
});
|
|
}
|
|
|
|
fn test_ops_broadcast_backward<F>(func: F)
|
|
where
|
|
F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,
|
|
{
|
|
let device = Default::default();
|
|
let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();
|
|
let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();
|
|
|
|
// Slice isn't a broadcastable operation, so it will fail when the previous backward pass
|
|
// of an operation that support broadcast doesn't support it during the backward pass.
|
|
let y = func(w.clone().slice([0..1]), x.clone());
|
|
|
|
// Will panic if broadcast isn't supported!
|
|
let grads = y.backward();
|
|
|
|
let w_grad = w.grad(&grads).unwrap();
|
|
let x_grad = x.grad(&grads).unwrap();
|
|
|
|
assert_eq!(w_grad.shape(), w.shape());
|
|
assert_eq!(x_grad.shape(), x.shape());
|
|
}
|