Add custom backend to enable flash attention

This commit is contained in:
Gadersd
2023-09-07 12:54:27 -04:00
parent f4c58c1790
commit 01b1aea897
7 changed files with 177 additions and 34 deletions

View File

@@ -8,6 +8,8 @@ use burn::{
use num_traits::ToPrimitive;
use crate::backend::Backend as MyBackend;
use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
use super::unet::{UNet, UNetConfig};
@@ -44,7 +46,7 @@ pub struct StableDiffusion<B: Backend> {
clip: CLIP<B>,
}
impl<B: Backend> StableDiffusion<B> {
impl<B: MyBackend> StableDiffusion<B> {
pub fn sample_image(
&self,
context: Tensor<B, 3>,