mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Add custom backend to enable flash attention
This commit is contained in:
10
Cargo.toml
10
Cargo.toml
@@ -6,15 +6,8 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["torch-backend"]
|
|
||||||
torch-backend = ["burn-tch"]
|
|
||||||
wgpu-backend = ["burn-wgpu"]
|
wgpu-backend = ["burn-wgpu"]
|
||||||
|
|
||||||
[dependencies.burn-tch]
|
|
||||||
package = "burn-tch"
|
|
||||||
git = "https://github.com/burn-rs/burn.git"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies.burn-wgpu]
|
[dependencies.burn-wgpu]
|
||||||
package = "burn-wgpu"
|
package = "burn-wgpu"
|
||||||
git = "https://github.com/burn-rs/burn.git"
|
git = "https://github.com/burn-rs/burn.git"
|
||||||
@@ -23,6 +16,9 @@ optional = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
burn = { git = "https://github.com/burn-rs/burn.git" }
|
burn = { git = "https://github.com/burn-rs/burn.git" }
|
||||||
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
|
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
|
||||||
|
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" }
|
||||||
|
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" }
|
||||||
|
tch = "0.13.0"
|
||||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||||
npy = "0.4.0"
|
npy = "0.4.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
|
|||||||
136
src/backend.rs
Normal file
136
src/backend.rs
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
use burn::tensor::{activation::softmax, Tensor};
|
||||||
|
|
||||||
|
pub trait Backend: burn::tensor::backend::Backend {
|
||||||
|
fn qkv_attention(
|
||||||
|
q: Self::TensorPrimitive<3>,
|
||||||
|
k: Self::TensorPrimitive<3>,
|
||||||
|
v: Self::TensorPrimitive<3>,
|
||||||
|
mask: Option<Self::TensorPrimitive<2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> Self::TensorPrimitive<3> {
|
||||||
|
qkv_attention(
|
||||||
|
Tensor::<Self, 3>::from_primitive(q),
|
||||||
|
Tensor::from_primitive(k),
|
||||||
|
Tensor::from_primitive(v),
|
||||||
|
mask.map(|m| Tensor::from_primitive(m)),
|
||||||
|
n_head,
|
||||||
|
)
|
||||||
|
.into_primitive()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> {
|
||||||
|
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use burn::tensor::ops::TensorOps;
|
||||||
|
use burn::tensor::Float;
|
||||||
|
use burn_tch::{self, TchElement, TchTensor};
|
||||||
|
use tch;
|
||||||
|
|
||||||
|
impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||||
|
fn qkv_attention(
|
||||||
|
q: Self::TensorPrimitive<3>,
|
||||||
|
k: Self::TensorPrimitive<3>,
|
||||||
|
v: Self::TensorPrimitive<3>,
|
||||||
|
mask: Option<Self::TensorPrimitive<2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> Self::TensorPrimitive<3> {
|
||||||
|
let q = Tensor::from_primitive(q);
|
||||||
|
let k = Tensor::from_primitive(k);
|
||||||
|
let v = Tensor::from_primitive(v);
|
||||||
|
|
||||||
|
let [n_batch, q_ctx, n_state] = q.dims();
|
||||||
|
let [_, k_ctx, _] = k.dims();
|
||||||
|
let n_hstate = n_state / n_head;
|
||||||
|
|
||||||
|
let rearrange = |t: Tensor<Self, 3>| {
|
||||||
|
let [_, n_ctx, _] = t.dims();
|
||||||
|
t.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
};
|
||||||
|
|
||||||
|
let q = rearrange(q).into_primitive();
|
||||||
|
let k = rearrange(k).into_primitive();
|
||||||
|
let v = rearrange(v).into_primitive();
|
||||||
|
|
||||||
|
// for some reason torch crashes when mask is None
|
||||||
|
let mask = mask.unwrap_or_else(|| {
|
||||||
|
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v))
|
||||||
|
.into_primitive()
|
||||||
|
});
|
||||||
|
|
||||||
|
Tensor::<Self, 4>::from_primitive(TchTensor::new(
|
||||||
|
tch::Tensor::scaled_dot_product_attention(
|
||||||
|
&q.tensor,
|
||||||
|
&k.tensor,
|
||||||
|
&v.tensor,
|
||||||
|
Some(mask.tensor),
|
||||||
|
0.0,
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.flatten(2, 3)
|
||||||
|
.into_primitive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use burn_autodiff;
|
||||||
|
|
||||||
|
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {}
|
||||||
|
|
||||||
|
use std::f32::NEG_INFINITY;
|
||||||
|
|
||||||
|
fn qkv_attention<B: Backend>(
|
||||||
|
q: Tensor<B, 3>,
|
||||||
|
k: Tensor<B, 3>,
|
||||||
|
v: Tensor<B, 3>,
|
||||||
|
mask: Option<Tensor<B, 2>>,
|
||||||
|
n_head: usize,
|
||||||
|
) -> Tensor<B, 3> {
|
||||||
|
let [n_batch, n_qctx, n_state] = q.dims();
|
||||||
|
let [_, n_ctx, _] = k.dims();
|
||||||
|
|
||||||
|
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||||
|
let n_hstate = n_state / n_head;
|
||||||
|
|
||||||
|
let q = q
|
||||||
|
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
* scale;
|
||||||
|
let k = k
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2)
|
||||||
|
.transpose()
|
||||||
|
* scale;
|
||||||
|
let v = v
|
||||||
|
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||||
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
|
let qk = q.matmul(k);
|
||||||
|
|
||||||
|
// apply mask
|
||||||
|
let qk = if let Some(mask) = mask {
|
||||||
|
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
|
||||||
|
} else {
|
||||||
|
qk
|
||||||
|
};
|
||||||
|
|
||||||
|
// normalize value weightings
|
||||||
|
let w = softmax(qk, 3);
|
||||||
|
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
|
||||||
|
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||||
|
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||||
|
|
||||||
|
for i in 0..(seq_length - 1) {
|
||||||
|
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||||
|
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mask.to_device(device);
|
||||||
|
}
|
||||||
@@ -11,10 +11,10 @@ use burn::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "torch-backend")] {
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
|
||||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||||
|
} else {
|
||||||
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,12 +34,12 @@ fn load_stable_diffusion_model_file<B: Backend>(
|
|||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "torch-backend")] {
|
if #[cfg(feature = "wgpu-backend")] {
|
||||||
type Backend = TchBackend<f32>;
|
|
||||||
let device = TchDevice::Cuda(0);
|
|
||||||
} else if #[cfg(feature = "wgpu-backend")] {
|
|
||||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||||
let device = WgpuDevice::BestAvailable;
|
let device = WgpuDevice::BestAvailable;
|
||||||
|
} else {
|
||||||
|
type Backend = TchBackend<f32>;
|
||||||
|
let device = TchDevice::Cuda(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
|
pub mod backend;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ use burn::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::attention::qkv_attention;
|
|
||||||
use super::groupnorm::*;
|
use super::groupnorm::*;
|
||||||
use super::silu::*;
|
use super::silu::*;
|
||||||
|
use crate::backend::Backend as MyBackend;
|
||||||
|
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ pub struct Autoencoder<B: Backend> {
|
|||||||
post_quant_conv: Conv2d<B>,
|
post_quant_conv: Conv2d<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Autoencoder<B> {
|
impl<B: MyBackend> Autoencoder<B> {
|
||||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
self.decode_latent(self.encode_image(x))
|
self.decode_latent(self.encode_image(x))
|
||||||
}
|
}
|
||||||
@@ -128,7 +128,7 @@ pub struct Encoder<B: Backend> {
|
|||||||
conv_out: Conv2d<B>,
|
conv_out: Conv2d<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Encoder<B> {
|
impl<B: MyBackend> Encoder<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
let x = self.conv_in.forward(x);
|
let x = self.conv_in.forward(x);
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ pub struct Decoder<B: Backend> {
|
|||||||
conv_out: Conv2d<B>,
|
conv_out: Conv2d<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Decoder<B> {
|
impl<B: MyBackend> Decoder<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
let x = self.conv_in.forward(x);
|
let x = self.conv_in.forward(x);
|
||||||
let x = self.mid.forward(x);
|
let x = self.mid.forward(x);
|
||||||
@@ -383,10 +383,6 @@ pub struct PaddedConv2d<B: Backend> {
|
|||||||
|
|
||||||
impl<B: Backend> PaddedConv2d<B> {
|
impl<B: Backend> PaddedConv2d<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
println!(
|
|
||||||
"{} {} {:?} {:?}",
|
|
||||||
self.kernel_size, self.stride, self.padding, self.padding_actual
|
|
||||||
);
|
|
||||||
let [n_batch, n_channel, height, width] = x.dims();
|
let [n_batch, n_channel, height, width] = x.dims();
|
||||||
|
|
||||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
|
||||||
@@ -444,7 +440,7 @@ pub struct Mid<B: Backend> {
|
|||||||
block_2: ResnetBlock<B>,
|
block_2: ResnetBlock<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Mid<B> {
|
impl<B: MyBackend> Mid<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
let x = self.block_1.forward(x);
|
let x = self.block_1.forward(x);
|
||||||
let x = self.attn.forward(x);
|
let x = self.attn.forward(x);
|
||||||
@@ -550,7 +546,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
|
|||||||
proj_out: Conv2d<B>,
|
proj_out: Conv2d<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
let [n_batch, n_channel, height, width] = x.dims();
|
let [n_batch, n_channel, height, width] = x.dims();
|
||||||
|
|
||||||
@@ -572,7 +568,13 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
|
|||||||
.reshape([n_batch, n_channel, height * width])
|
.reshape([n_batch, n_channel, height * width])
|
||||||
.swap_dims(1, 2);
|
.swap_dims(1, 2);
|
||||||
|
|
||||||
let wv = qkv_attention(q, k, v, None, 1)
|
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||||
|
q.into_primitive(),
|
||||||
|
k.into_primitive(),
|
||||||
|
v.into_primitive(),
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
))
|
||||||
.swap_dims(1, 2)
|
.swap_dims(1, 2)
|
||||||
.reshape([n_batch, n_channel, height, width]);
|
.reshape([n_batch, n_channel, height, width]);
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use burn::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::model::attention::{attn_decoder_mask, qkv_attention};
|
use crate::backend::Backend as MyBackend;
|
||||||
|
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
pub struct CLIPConfig {
|
pub struct CLIPConfig {
|
||||||
@@ -51,11 +51,11 @@ pub struct CLIP<B: Backend> {
|
|||||||
layer_norm: nn::LayerNorm<B>,
|
layer_norm: nn::LayerNorm<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> CLIP<B> {
|
impl<B: MyBackend> CLIP<B> {
|
||||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||||
let [n_batch, seq_len] = x.dims();
|
let [n_batch, seq_len] = x.dims();
|
||||||
|
|
||||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||||
|
|
||||||
let embedded = self.token_embedding.forward(x)
|
let embedded = self.token_embedding.forward(x)
|
||||||
+ self
|
+ self
|
||||||
@@ -104,7 +104,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
|||||||
mlp_ln: nn::LayerNorm<B>,
|
mlp_ln: nn::LayerNorm<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
|
||||||
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
||||||
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
||||||
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
||||||
@@ -152,13 +152,19 @@ pub struct MultiHeadSelfAttention<B: Backend> {
|
|||||||
out: nn::Linear<B>,
|
out: nn::Linear<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
impl<B: MyBackend> MultiHeadSelfAttention<B> {
|
||||||
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
||||||
let q = self.query.forward(x.clone());
|
let q = self.query.forward(x.clone());
|
||||||
let k = self.key.forward(x.clone());
|
let k = self.key.forward(x.clone());
|
||||||
let v = self.value.forward(x);
|
let v = self.value.forward(x);
|
||||||
|
|
||||||
let wv = qkv_attention(q, k, v, mask, self.n_head);
|
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||||
|
q.into_primitive(),
|
||||||
|
k.into_primitive(),
|
||||||
|
v.into_primitive(),
|
||||||
|
mask.map(|m| m.into_primitive()),
|
||||||
|
self.n_head,
|
||||||
|
));
|
||||||
|
|
||||||
return self.out.forward(wv);
|
return self.out.forward(wv);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ use burn::{
|
|||||||
|
|
||||||
use num_traits::ToPrimitive;
|
use num_traits::ToPrimitive;
|
||||||
|
|
||||||
|
use crate::backend::Backend as MyBackend;
|
||||||
|
|
||||||
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
use super::autoencoder::{Autoencoder, AutoencoderConfig};
|
||||||
use super::clip::{CLIPConfig, CLIP};
|
use super::clip::{CLIPConfig, CLIP};
|
||||||
use super::unet::{UNet, UNetConfig};
|
use super::unet::{UNet, UNetConfig};
|
||||||
@@ -44,7 +46,7 @@ pub struct StableDiffusion<B: Backend> {
|
|||||||
clip: CLIP<B>,
|
clip: CLIP<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> StableDiffusion<B> {
|
impl<B: MyBackend> StableDiffusion<B> {
|
||||||
pub fn sample_image(
|
pub fn sample_image(
|
||||||
&self,
|
&self,
|
||||||
context: Tensor<B, 3>,
|
context: Tensor<B, 3>,
|
||||||
|
|||||||
Reference in New Issue
Block a user