Add custom backend to enable flash attention

This commit is contained in:
Gadersd
2023-09-07 12:54:27 -04:00
parent f4c58c1790
commit 01b1aea897
7 changed files with 177 additions and 34 deletions

View File

@@ -16,9 +16,9 @@ use burn::{
},
};
use super::attention::qkv_attention;
use super::groupnorm::*;
use super::silu::*;
use crate::backend::Backend as MyBackend;
use std::iter;
@@ -51,7 +51,7 @@ pub struct Autoencoder<B: Backend> {
post_quant_conv: Conv2d<B>,
}
impl<B: Backend> Autoencoder<B> {
impl<B: MyBackend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x))
}
@@ -128,7 +128,7 @@ pub struct Encoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: Backend> Encoder<B> {
impl<B: MyBackend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
@@ -200,7 +200,7 @@ pub struct Decoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: Backend> Decoder<B> {
impl<B: MyBackend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
@@ -383,10 +383,6 @@ pub struct PaddedConv2d<B: Backend> {
impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
println!(
"{} {} {:?} {:?}",
self.kernel_size, self.stride, self.padding, self.padding_actual
);
let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
@@ -444,7 +440,7 @@ pub struct Mid<B: Backend> {
block_2: ResnetBlock<B>,
}
impl<B: Backend> Mid<B> {
impl<B: MyBackend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
@@ -550,7 +546,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
proj_out: Conv2d<B>,
}
impl<B: Backend> ConvSelfAttentionBlock<B> {
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
@@ -572,9 +568,15 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = qkv_attention(q, k, v, None, 1)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
None,
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);