Add custom backend to enable flash attention

This commit is contained in:
Gadersd
2023-09-07 12:54:27 -04:00
committed by Ben_Kosytorz
parent 32a3ad9b3c
commit ccbf062514
7 changed files with 177 additions and 34 deletions

View File

@@ -12,7 +12,7 @@ use burn::{
},
};
use crate::model::attention::{attn_decoder_mask, qkv_attention};
use crate::backend::Backend as MyBackend;
#[derive(Config)]
pub struct CLIPConfig {
@@ -51,11 +51,11 @@ pub struct CLIP<B: Backend> {
layer_norm: nn::LayerNorm<B>,
}
impl<B: Backend> CLIP<B> {
impl<B: MyBackend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = attn_decoder_mask(seq_len, &x.device());
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
let embedded = self.token_embedding.forward(x)
+ self
@@ -104,7 +104,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
mlp_ln: nn::LayerNorm<B>,
}
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
@@ -152,13 +152,19 @@ pub struct MultiHeadSelfAttention<B: Backend> {
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
impl<B: MyBackend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone());
let v = self.value.forward(x);
let wv = qkv_attention(q, k, v, mask, self.n_head);
let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
mask.map(|m| m.into_primitive()),
self.n_head,
));
return self.out.forward(wv);
}