Add custom backend to enable flash attention

This commit is contained in:
Gadersd
2023-09-07 12:54:27 -04:00
parent f4c58c1790
commit 01b1aea897
7 changed files with 177 additions and 34 deletions

View File

@@ -11,10 +11,10 @@ use burn::{
};
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
use burn_tch::{TchBackend, TchDevice};
} else if #[cfg(feature = "wgpu-backend")] {
if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
} else {
use burn_tch::{TchBackend, TchDevice};
}
}
@@ -34,12 +34,12 @@ fn load_stable_diffusion_model_file<B: Backend>(
fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
} else if #[cfg(feature = "wgpu-backend")] {
if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
} else {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
}
}