Add first successful sampling implementation

This commit is contained in:
Gadersd
2023-08-04 17:01:44 -04:00
committed by Ben_Kosytorz
parent b794e9a9ec
commit 8e7a8d9be4
9 changed files with 42 additions and 34 deletions

View File

@@ -59,7 +59,7 @@ impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = attn_decoder_mask(seq_len);
let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x)
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();