From ccbf0625145b8f6424d23cb1e9b0c61a230fee5b Mon Sep 17 00:00:00 2001 From: Gadersd Date: Thu, 7 Sep 2023 12:54:27 -0400 Subject: [PATCH] Add custom backend to enable flash attention --- Cargo.toml | 10 +-- src/backend.rs | 136 +++++++++++++++++++++++++++++++ src/bin/sample/main.rs | 14 ++-- src/lib.rs | 1 + src/model/autoencoder/mod.rs | 28 ++++--- src/model/clip/mod.rs | 18 ++-- src/model/stablediffusion/mod.rs | 4 +- 7 files changed, 177 insertions(+), 34 deletions(-) create mode 100644 src/backend.rs diff --git a/Cargo.toml b/Cargo.toml index 89905f2..3737c0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,15 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["torch-backend"] -torch-backend = ["burn-tch"] wgpu-backend = ["burn-wgpu"] -[dependencies.burn-tch] -package = "burn-tch" -git = "https://github.com/burn-rs/burn.git" -optional = true - [dependencies.burn-wgpu] package = "burn-wgpu" git = "https://github.com/burn-rs/burn.git" @@ -23,6 +16,9 @@ optional = true [dependencies] burn = { git = "https://github.com/burn-rs/burn.git" } burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" } +burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" } +burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" } +tch = "0.13.0" serde = {version = "1.0.171", features = ["std", "derive"]} npy = "0.4.0" num-traits = "0.2.15" diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..c33710d --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,136 @@ +use burn::tensor::{activation::softmax, Tensor}; + +pub trait Backend: burn::tensor::backend::Backend { + fn qkv_attention( + q: Self::TensorPrimitive<3>, + k: Self::TensorPrimitive<3>, + v: Self::TensorPrimitive<3>, + mask: Option>, + n_head: usize, + ) -> Self::TensorPrimitive<3> { + qkv_attention( + Tensor::::from_primitive(q), + Tensor::from_primitive(k), + Tensor::from_primitive(v), + mask.map(|m| Tensor::from_primitive(m)), + n_head, + ) + .into_primitive() + } + + fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> { + attn_decoder_mask::(seq_length, device).into_primitive() + } +} + +use burn::tensor::ops::TensorOps; +use burn::tensor::Float; +use burn_tch::{self, TchElement, TchTensor}; +use tch; + +impl Backend for burn_tch::TchBackend { + fn qkv_attention( + q: Self::TensorPrimitive<3>, + k: Self::TensorPrimitive<3>, + v: Self::TensorPrimitive<3>, + mask: Option>, + n_head: usize, + ) -> Self::TensorPrimitive<3> { + let q = Tensor::from_primitive(q); + let k = Tensor::from_primitive(k); + let v = Tensor::from_primitive(v); + + let [n_batch, q_ctx, n_state] = q.dims(); + let [_, k_ctx, _] = k.dims(); + let n_hstate = n_state / n_head; + + let rearrange = |t: Tensor| { + let [_, n_ctx, _] = t.dims(); + t.reshape([n_batch, n_ctx, n_head, n_hstate]) + .swap_dims(1, 2) + }; + + let q = rearrange(q).into_primitive(); + let k = rearrange(k).into_primitive(); + let v = rearrange(v).into_primitive(); + + // for some reason torch crashes when mask is None + let mask = mask.unwrap_or_else(|| { + Tensor::::zeros_device([q_ctx, k_ctx], &Self::device(&v)) + .into_primitive() + }); + + Tensor::::from_primitive(TchTensor::new( + tch::Tensor::scaled_dot_product_attention( + &q.tensor, + &k.tensor, + &v.tensor, + Some(mask.tensor), + 0.0, + false, + ), + )) + .swap_dims(1, 2) + .flatten(2, 3) + .into_primitive() + } +} + +use burn_autodiff; + +impl Backend for burn_autodiff::ADBackendDecorator {} + +use std::f32::NEG_INFINITY; + +fn qkv_attention( + q: Tensor, + k: Tensor, + v: Tensor, + mask: Option>, + n_head: usize, +) -> Tensor { + let [n_batch, n_qctx, n_state] = q.dims(); + let [_, n_ctx, _] = k.dims(); + + let scale = (n_state as f64 / n_head as f64).powf(-0.25); + let n_hstate = n_state / n_head; + + let q = q + .reshape([n_batch, n_qctx, n_head, n_hstate]) + .swap_dims(1, 2) + * scale; + let k = k + .reshape([n_batch, n_ctx, n_head, n_hstate]) + .swap_dims(1, 2) + .transpose() + * scale; + let v = v + .reshape([n_batch, n_ctx, n_head, n_hstate]) + .swap_dims(1, 2); + + let qk = q.matmul(k); + + // apply mask + let qk = if let Some(mask) = mask { + qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>() + } else { + qk + }; + + // normalize value weightings + let w = softmax(qk, 3); + let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3); + + return o; +} + +fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { + let mut mask = Tensor::::zeros([seq_length, seq_length]); + + for i in 0..(seq_length - 1) { + let values = Tensor::::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); + mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); + } + + return mask.to_device(device); +} diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index d9db921..214dc5c 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -11,10 +11,10 @@ use burn::{ }; cfg_if::cfg_if! { - if #[cfg(feature = "torch-backend")] { - use burn_tch::{TchBackend, TchDevice}; - } else if #[cfg(feature = "wgpu-backend")] { + if #[cfg(feature = "wgpu-backend")] { use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; + } else { + use burn_tch::{TchBackend, TchDevice}; } } @@ -34,12 +34,12 @@ fn load_stable_diffusion_model_file( fn main() { cfg_if::cfg_if! { - if #[cfg(feature = "torch-backend")] { - type Backend = TchBackend; - let device = TchDevice::Cuda(0); - } else if #[cfg(feature = "wgpu-backend")] { + if #[cfg(feature = "wgpu-backend")] { type Backend = WgpuBackend; let device = WgpuDevice::BestAvailable; + } else { + type Backend = TchBackend; + let device = TchDevice::Cuda(0); } } diff --git a/src/lib.rs b/src/lib.rs index 0e2590d..88a3f7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,3 @@ +pub mod backend; pub mod model; pub mod tokenizer; diff --git a/src/model/autoencoder/mod.rs b/src/model/autoencoder/mod.rs index ae9e2b5..8a7a58b 100644 --- a/src/model/autoencoder/mod.rs +++ b/src/model/autoencoder/mod.rs @@ -16,9 +16,9 @@ use burn::{ }, }; -use super::attention::qkv_attention; use super::groupnorm::*; use super::silu::*; +use crate::backend::Backend as MyBackend; use std::iter; @@ -51,7 +51,7 @@ pub struct Autoencoder { post_quant_conv: Conv2d, } -impl Autoencoder { +impl Autoencoder { pub fn forward(&self, x: Tensor) -> Tensor { self.decode_latent(self.encode_image(x)) } @@ -128,7 +128,7 @@ pub struct Encoder { conv_out: Conv2d, } -impl Encoder { +impl Encoder { fn forward(&self, x: Tensor) -> Tensor { let x = self.conv_in.forward(x); @@ -200,7 +200,7 @@ pub struct Decoder { conv_out: Conv2d, } -impl Decoder { +impl Decoder { fn forward(&self, x: Tensor) -> Tensor { let x = self.conv_in.forward(x); let x = self.mid.forward(x); @@ -383,10 +383,6 @@ pub struct PaddedConv2d { impl PaddedConv2d { fn forward(&self, x: Tensor) -> Tensor { - println!( - "{} {} {:?} {:?}", - self.kernel_size, self.stride, self.padding, self.padding_actual - ); let [n_batch, n_channel, height, width] = x.dims(); let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height @@ -444,7 +440,7 @@ pub struct Mid { block_2: ResnetBlock, } -impl Mid { +impl Mid { fn forward(&self, x: Tensor) -> Tensor { let x = self.block_1.forward(x); let x = self.attn.forward(x); @@ -550,7 +546,7 @@ pub struct ConvSelfAttentionBlock { proj_out: Conv2d, } -impl ConvSelfAttentionBlock { +impl ConvSelfAttentionBlock { fn forward(&self, x: Tensor) -> Tensor { let [n_batch, n_channel, height, width] = x.dims(); @@ -572,9 +568,15 @@ impl ConvSelfAttentionBlock { .reshape([n_batch, n_channel, height * width]) .swap_dims(1, 2); - let wv = qkv_attention(q, k, v, None, 1) - .swap_dims(1, 2) - .reshape([n_batch, n_channel, height, width]); + let wv = Tensor::from_primitive(B::qkv_attention( + q.into_primitive(), + k.into_primitive(), + v.into_primitive(), + None, + 1, + )) + .swap_dims(1, 2) + .reshape([n_batch, n_channel, height, width]); let projected = self.proj_out.forward(wv); diff --git a/src/model/clip/mod.rs b/src/model/clip/mod.rs index 21c29fe..583e3ef 100644 --- a/src/model/clip/mod.rs +++ b/src/model/clip/mod.rs @@ -12,7 +12,7 @@ use burn::{ }, }; -use crate::model::attention::{attn_decoder_mask, qkv_attention}; +use crate::backend::Backend as MyBackend; #[derive(Config)] pub struct CLIPConfig { @@ -51,11 +51,11 @@ pub struct CLIP { layer_norm: nn::LayerNorm, } -impl CLIP { +impl CLIP { pub fn forward(&self, x: Tensor) -> Tensor { let [n_batch, seq_len] = x.dims(); - let mask = attn_decoder_mask(seq_len, &x.device()); + let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device())); let embedded = self.token_embedding.forward(x) + self @@ -104,7 +104,7 @@ pub struct ResidualDecoderAttentionBlock { mlp_ln: nn::LayerNorm, } -impl ResidualDecoderAttentionBlock { +impl ResidualDecoderAttentionBlock { fn forward(&self, x: Tensor, mask: Tensor) -> Tensor { let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask)); let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x)); @@ -152,13 +152,19 @@ pub struct MultiHeadSelfAttention { out: nn::Linear, } -impl MultiHeadSelfAttention { +impl MultiHeadSelfAttention { pub fn forward(&self, x: Tensor, mask: Option>) -> Tensor { let q = self.query.forward(x.clone()); let k = self.key.forward(x.clone()); let v = self.value.forward(x); - let wv = qkv_attention(q, k, v, mask, self.n_head); + let wv = Tensor::from_primitive(B::qkv_attention( + q.into_primitive(), + k.into_primitive(), + v.into_primitive(), + mask.map(|m| m.into_primitive()), + self.n_head, + )); return self.out.forward(wv); } diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index 3c9ffb3..32c708d 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -8,6 +8,8 @@ use burn::{ use num_traits::ToPrimitive; +use crate::backend::Backend as MyBackend; + use super::autoencoder::{Autoencoder, AutoencoderConfig}; use super::clip::{CLIPConfig, CLIP}; use super::unet::{UNet, UNetConfig}; @@ -44,7 +46,7 @@ pub struct StableDiffusion { clip: CLIP, } -impl StableDiffusion { +impl StableDiffusion { pub fn sample_image( &self, context: Tensor,