- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
168 lines
5.2 KiB
Rust
168 lines
5.2 KiB
Rust
use burn_core as burn;
|
|
|
|
use burn::module::{Module, Quantizer};
|
|
use burn::tensor::{
|
|
Device, Distribution, Tensor, Tolerance,
|
|
ops::{FloatElem, QuantizedTensor},
|
|
quantization::{
|
|
Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantScheme, QuantValue,
|
|
},
|
|
};
|
|
use burn_nn::{
|
|
Linear, LinearConfig,
|
|
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
|
};
|
|
|
|
#[cfg(all(
|
|
test,
|
|
not(feature = "test-wgpu"),
|
|
not(feature = "test-cuda"),
|
|
not(feature = "test-rocm")
|
|
))]
|
|
pub type B = burn_ndarray::NdArray<f32>;
|
|
|
|
#[cfg(all(test, feature = "test-wgpu"))]
|
|
/// Backend for test cases
|
|
pub type B = burn_wgpu::Wgpu;
|
|
|
|
#[cfg(all(test, feature = "test-cuda"))]
|
|
/// Backend for test cases
|
|
pub type B = burn_cuda::Cuda;
|
|
|
|
#[cfg(all(test, feature = "test-rocm"))]
|
|
/// Backend for test cases
|
|
pub type B = burn_rocm::Rocm;
|
|
|
|
fn should_quantize_module<M: Module<B>, const D: usize, F: Fn(&M) -> Tensor<B, D>>(
|
|
module: M,
|
|
scheme: QuantScheme,
|
|
func: F,
|
|
tolerance: Tolerance<FloatElem<B>>,
|
|
) {
|
|
let result = func(&module);
|
|
|
|
let calibration = Calibration::MinMax;
|
|
let mut quantizer = Quantizer {
|
|
calibration,
|
|
scheme,
|
|
};
|
|
let q_module = module.quantize_weights(&mut quantizer);
|
|
let q_result = func(&q_module);
|
|
|
|
result
|
|
.into_data()
|
|
.assert_approx_eq::<f32>(&q_result.into_data(), tolerance);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_transformer() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: TransformerEncoder<B> =
|
|
TransformerEncoderConfig::new(128, 256, 2, 2).init(&device);
|
|
let signal = Tensor::random([2, 32, 128], Distribution::Default, &device);
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::block([32]))
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.forward(TransformerEncoderInput::new(signal.clone())),
|
|
Tolerance::rel_abs(1e-2, 2e-2), // slightly higher abs tolerance (permissive: 1e-2)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_linear_128_256() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: Linear<B> = LinearConfig::new(128, 256).with_bias(false).init(&device);
|
|
let signal = Tensor::<B, 2>::random([1, 128], Distribution::Default, &device);
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::Tensor)
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.forward(signal.clone()),
|
|
Tolerance::permissive(),
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_linear() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
|
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
|
|
// Default scheme should select supported QuantStore default
|
|
// TODO: set native if dtype is supported by the test backend
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::Tensor)
|
|
// .with_store(QuantStore::Native)
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.forward(signal.clone()),
|
|
Tolerance::permissive(),
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_linear_weights() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::Tensor)
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.weight.val().dequantize(),
|
|
Tolerance::permissive(),
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_linear_blocks() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
|
let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::block([16]))
|
|
// .with_store(QuantStore::Native)
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.forward(signal.clone()),
|
|
Tolerance::permissive(),
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn should_quantize_linear_weights_blocks() {
|
|
let device: Device<B> = Default::default();
|
|
let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);
|
|
let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()
|
|
.with_value(QuantValue::Q8S)
|
|
.with_level(QuantLevel::block([16]))
|
|
// .with_store(QuantStore::Native)
|
|
.with_param(QuantParam::F32);
|
|
|
|
should_quantize_module(
|
|
transformer,
|
|
scheme,
|
|
|tr| tr.weight.val().dequantize(),
|
|
Tolerance::permissive(),
|
|
);
|
|
}
|