use burn::tensor::{activation::softmax, Tensor}; use burn::prelude::Backend; /*pub type FloatTensor = ::TensorPrimitive; pub trait Backend: burn::tensor::backend::Backend { fn qkv_attention( q: FloatTensor, k: FloatTensor, v: FloatTensor, mask: Option>, n_head: usize, ) -> FloatTensor { qkv_attention( Tensor::::from_primitive(q), Tensor::from_primitive(k), Tensor::from_primitive(v), mask.map(|m| Tensor::from_primitive(m)), n_head, ) .into_primitive() } fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor { attn_decoder_mask::(seq_length, device).into_primitive() } } use burn::tensor::Float; use burn_tch::{self, TchElement, TchTensor}; use tch; impl Backend for burn_tch::LibTorch { fn qkv_attention( q: FloatTensor, k: FloatTensor, v: FloatTensor, mask: Option>, n_head: usize, ) -> FloatTensor { let q = Tensor::from_primitive(q); let k = Tensor::from_primitive(k); let v = Tensor::from_primitive(v); let [n_batch, q_ctx, n_state] = q.dims(); let [_, k_ctx, _] = k.dims(); let n_hstate = n_state / n_head; let rearrange = |t: Tensor| { let [_, n_ctx, _] = t.dims(); t.reshape([n_batch, n_ctx, n_head, n_hstate]) .swap_dims(1, 2) }; let q = rearrange(q).into_primitive(); let k = rearrange(k).into_primitive(); let v = rearrange(v).into_primitive(); // for some reason torch crashes when mask is None let mask = mask.unwrap_or_else(|| { Tensor::::zeros([q_ctx, k_ctx], &Self::device(&v)) .into_primitive() }); Tensor::::from_primitive(TchTensor::new( tch::Tensor::scaled_dot_product_attention( &q.tensor, &k.tensor, &v.tensor, Some(mask.tensor), 0.0, false, None, ), )) .swap_dims(1, 2) .flatten(2, 3) .into_primitive() } } use burn_autodiff; impl Backend for burn_autodiff::Autodiff {}*/ use std::f32::NEG_INFINITY; pub fn qkv_attention( q: Tensor, k: Tensor, v: Tensor, mask: Option>, n_head: usize, ) -> Tensor { let [n_batch, n_qctx, n_state] = q.dims(); let [_, n_ctx, _] = k.dims(); let scale = (n_state as f64 / n_head as f64).powf(-0.25); let n_hstate = n_state / n_head; let q = q .reshape([n_batch, n_qctx, n_head, n_hstate]) .swap_dims(1, 2) * scale; let k = k .reshape([n_batch, n_ctx, n_head, n_hstate]) .swap_dims(1, 2) .transpose() * scale; let v = v .reshape([n_batch, n_ctx, n_head, n_hstate]) .swap_dims(1, 2); let qk = q.matmul(k); // apply mask let qk = if let Some(mask) = mask { qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>() } else { qk }; // normalize value weightings let w = softmax(qk, 3); let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3); return o; } pub fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> Tensor { let mut mask = Tensor::::zeros([seq_length, seq_length], device); for i in 0..(seq_length - 1) { let values = Tensor::::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY); mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); } return mask; }