Add first successful sampling implementation
This commit is contained in:
@@ -35,7 +35,7 @@ pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B,
|
||||
return o;
|
||||
}
|
||||
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
|
||||
pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
|
||||
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
|
||||
|
||||
for i in 0..(seq_length - 1) {
|
||||
@@ -43,5 +43,5 @@ pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
|
||||
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
return mask;
|
||||
return mask.to_device(device);
|
||||
}
|
||||
Reference in New Issue
Block a user