Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
committed by Ben_Kosytorz
parent 3c49b0a151
commit 75f0cedd9f
19 changed files with 366 additions and 311 deletions

View File

@@ -45,12 +45,12 @@ pub fn qkv_attention<B: Backend>(
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask.to_device(device);
return mask;
}