Replace helper functions with native burn functions
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
use std::error::Error;
|
||||
use burn::tensor::ElementConversion;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
|
||||
Ok(mlp)
|
||||
}
|
||||
|
||||
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||
pub fn load_multi_head_self_attention<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
|
||||
let n_head = load_usize::<B>("n_head", path, device)?;
|
||||
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
|
||||
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
|
||||
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
|
||||
Ok(mhsa)
|
||||
}
|
||||
|
||||
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||
pub fn load_residual_decoder_attention_block<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
|
||||
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
|
||||
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
|
||||
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
|
||||
@@ -64,15 +67,17 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
|
||||
|
||||
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
|
||||
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
|
||||
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
||||
let position_embedding =
|
||||
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
|
||||
|
||||
let n_layer = load_usize::<B>("n_layer", path, device)?;
|
||||
let mut blocks = (0..n_layer)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
|
||||
|
||||
let clip = CLIP {
|
||||
@@ -81,6 +86,6 @@ pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>,
|
||||
blocks: blocks,
|
||||
layer_norm: layer_norm,
|
||||
};
|
||||
|
||||
|
||||
Ok(clip)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user