Replace helper functions with native burn functions
This commit is contained in:
@@ -1,23 +1,32 @@
|
||||
use burn::{
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::softmax,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn::tensor::{activation::softmax, backend::Backend, Tensor};
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
|
||||
pub fn qkv_attention<B: Backend>(
|
||||
q: Tensor<B, 3>,
|
||||
k: Tensor<B, 3>,
|
||||
v: Tensor<B, 3>,
|
||||
mask: Option<Tensor<B, 2>>,
|
||||
n_head: usize,
|
||||
) -> Tensor<B, 3> {
|
||||
let [n_batch, n_qctx, n_state] = q.dims();
|
||||
let [_, n_ctx, _] = k.dims();
|
||||
|
||||
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
|
||||
let n_hstate = n_state / n_head;
|
||||
|
||||
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
|
||||
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
|
||||
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
|
||||
let q = q
|
||||
.reshape([n_batch, n_qctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
* scale;
|
||||
let k = k
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2)
|
||||
.transpose()
|
||||
* scale;
|
||||
let v = v
|
||||
.reshape([n_batch, n_ctx, n_head, n_hstate])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let qk = q.matmul(k);
|
||||
|
||||
@@ -44,4 +53,4 @@ pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> T
|
||||
}
|
||||
|
||||
return mask.to_device(device);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user