Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
use burn::tensor::{activation::softmax, Tensor};
|
||||
use burn::prelude::Backend;
|
||||
|
||||
/*pub type FloatTensor<B, const D: usize> = <B as burn::tensor::backend::Backend>::TensorPrimitive<D>;
|
||||
|
||||
pub trait Backend: burn::tensor::backend::Backend {
|
||||
fn qkv_attention(
|
||||
q: Self::TensorPrimitive<3>,
|
||||
k: Self::TensorPrimitive<3>,
|
||||
v: Self::TensorPrimitive<3>,
|
||||
mask: Option<Self::TensorPrimitive<2>>,
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> Self::TensorPrimitive<3> {
|
||||
) -> FloatTensor<Self, 3> {
|
||||
qkv_attention(
|
||||
Tensor::<Self, 3>::from_primitive(q),
|
||||
Tensor::from_primitive(k),
|
||||
@@ -18,24 +21,23 @@ pub trait Backend: burn::tensor::backend::Backend {
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> Self::TensorPrimitive<2> {
|
||||
fn attn_decoder_mask(seq_length: usize, device: &Self::Device) -> FloatTensor<Self, 2> {
|
||||
attn_decoder_mask::<Self>(seq_length, device).into_primitive()
|
||||
}
|
||||
}
|
||||
|
||||
use burn::tensor::ops::TensorOps;
|
||||
use burn::tensor::Float;
|
||||
use burn_tch::{self, TchElement, TchTensor};
|
||||
use tch;
|
||||
|
||||
impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
|
||||
fn qkv_attention(
|
||||
q: Self::TensorPrimitive<3>,
|
||||
k: Self::TensorPrimitive<3>,
|
||||
v: Self::TensorPrimitive<3>,
|
||||
mask: Option<Self::TensorPrimitive<2>>,
|
||||
q: FloatTensor<Self, 3>,
|
||||
k: FloatTensor<Self, 3>,
|
||||
v: FloatTensor<Self, 3>,
|
||||
mask: Option<FloatTensor<Self, 2>>,
|
||||
n_head: usize,
|
||||
) -> Self::TensorPrimitive<3> {
|
||||
) -> FloatTensor<Self, 2> {
|
||||
let q = Tensor::from_primitive(q);
|
||||
let k = Tensor::from_primitive(k);
|
||||
let v = Tensor::from_primitive(v);
|
||||
@@ -56,7 +58,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
|
||||
// for some reason torch crashes when mask is None
|
||||
let mask = mask.unwrap_or_else(|| {
|
||||
Tensor::<Self, 2, Float>::zeros_device([q_ctx, k_ctx], &Self::device(&v))
|
||||
Tensor::<Self, 2, Float>::zeros([q_ctx, k_ctx], &Self::device(&v))
|
||||
.into_primitive()
|
||||
});
|
||||
|
||||
@@ -68,6 +70,7 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
Some(mask.tensor),
|
||||
0.0,
|
||||
false,
|
||||
None,
|
||||
),
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
@@ -78,11 +81,11 @@ impl<E: TchElement> Backend for burn_tch::TchBackend<E> {
|
||||
|
||||
use burn_autodiff;
|
||||
|
||||
impl<B: Backend> Backend for burn_autodiff::ADBackendDecorator<B> {}
|
||||
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {}*/
|
||||
|
||||
use std::f32::NEG_INFINITY;
|
||||
|
||||
fn qkv_attention<B: Backend>(
|
||||
pub fn qkv_attention<B: Backend>(
|
||||
q: Tensor<B, 3>,
|
||||
k: Tensor<B, 3>,
|
||||
v: Tensor<B, 3>,
|
||||
@@ -124,13 +127,13 @@ fn qkv_attention<B: Backend>(
|
||||
return o;
|
||||
}
|
||||
|
||||
fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
|
||||
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask.to_device(device);
|
||||
return mask;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user