- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
29 lines
717 B
Rust
29 lines
717 B
Rust
// Skip on metal - F64 not supported
|
|
#![cfg(all(feature = "std", not(feature = "metal")))]
|
|
|
|
use super::*;
|
|
use burn_backend_tests::might_panic;
|
|
use burn_tensor::{DType, Tensor, TensorData};
|
|
|
|
#[might_panic(reason = "Unsupported precision for fusion")]
|
|
#[test]
|
|
fn cast_keeps_gradient_flow() {
|
|
let device = Default::default();
|
|
|
|
let x = Tensor::<TestAutodiffBackend, 2>::from_data(
|
|
TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
|
|
&device,
|
|
)
|
|
.require_grad();
|
|
|
|
let y = x.clone().cast(DType::F64);
|
|
let z = y.sum();
|
|
|
|
let grads = z.backward();
|
|
let grad_x = x.grad(&grads).unwrap();
|
|
|
|
grad_x
|
|
.to_data()
|
|
.assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);
|
|
}
|