Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
@@ -12,7 +12,8 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::backend::Backend as MyBackend;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct CLIPConfig {
|
||||
@@ -24,15 +25,15 @@ pub struct CLIPConfig {
|
||||
}
|
||||
|
||||
impl CLIPConfig {
|
||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(device);
|
||||
let position_embedding =
|
||||
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
Param::from_tensor(Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0), device));
|
||||
let blocks = (0..self.n_layer)
|
||||
.into_iter()
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init(device))
|
||||
.collect();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
CLIP {
|
||||
token_embedding,
|
||||
@@ -51,11 +52,12 @@ pub struct CLIP<B: Backend> {
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> CLIP<B> {
|
||||
impl<B: Backend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
//let mask = Tensor::from_primitive(B::attn_decoder_mask(seq_len, &x.device()));
|
||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self
|
||||
@@ -80,12 +82,12 @@ pub struct ResidualDecoderAttentionBlockConfig {
|
||||
}
|
||||
|
||||
impl ResidualDecoderAttentionBlockConfig {
|
||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(device);
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
|
||||
|
||||
ResidualDecoderAttentionBlock {
|
||||
attn,
|
||||
@@ -104,7 +106,7 @@ pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> ResidualDecoderAttentionBlock<B> {
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
|
||||
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
|
||||
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
|
||||
@@ -119,7 +121,7 @@ pub struct MultiHeadSelfAttentionConfig {
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
@@ -128,10 +130,10 @@ impl MultiHeadSelfAttentionConfig {
|
||||
);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let key = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
@@ -152,19 +154,27 @@ pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> MultiHeadSelfAttention<B> {
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
|
||||
let q = self.query.forward(x.clone());
|
||||
let k = self.key.forward(x.clone());
|
||||
let v = self.value.forward(x);
|
||||
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
mask.map(|m| m.into_primitive()),
|
||||
self.n_head,
|
||||
));
|
||||
));*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mask,
|
||||
self.n_head,
|
||||
);
|
||||
|
||||
return self.out.forward(wv);
|
||||
}
|
||||
@@ -177,10 +187,10 @@ pub struct MLPConfig {
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
fn init<B: Backend>(&self) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
|
||||
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init(device);
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(device);
|
||||
|
||||
MLP { fc1, gelu, fc2 }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user