use burn::tensor::ElementConversion; use std::error::Error; use burn::{ config::Config, module::{Module, Param}, nn, tensor::{backend::Backend, Tensor}, }; use super::*; use crate::model::load::*; pub fn load_mlp(path: &str, device: &B::Device) -> Result, Box> { let fc1 = load_linear(&format!("{}/{}", path, "fc1"), device)?; let gelu = QuickGELU::new(); let fc2 = load_linear(&format!("{}/{}", path, "fc2"), device)?; let mlp = MLP { fc1: fc1, gelu: gelu, fc2: fc2, }; Ok(mlp) } pub fn load_multi_head_self_attention( path: &str, device: &B::Device, ) -> Result, Box> { let n_head = load_usize::("n_head", path, device)?; let query = load_linear(&format!("{}/{}", path, "query"), device)?; let key = load_linear(&format!("{}/{}", path, "key"), device)?; let value = load_linear(&format!("{}/{}", path, "value"), device)?; let out = load_linear(&format!("{}/{}", path, "out"), device)?; let mhsa = MultiHeadSelfAttention { n_head: n_head, query: query, key: key, value: value, out: out, }; Ok(mhsa) } pub fn load_residual_decoder_attention_block( path: &str, device: &B::Device, ) -> Result, Box> { let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?; let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?; let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?; let mlp_ln = load_layer_norm(&format!("{}/{}", path, "mlp_ln"), device)?; let rdab = ResidualDecoderAttentionBlock { attn: attn, attn_ln: attn_ln, mlp: mlp, mlp_ln: mlp_ln, }; Ok(rdab) } pub fn load_clip(path: &str, device: &B::Device) -> Result, Box> { let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); let n_layer = load_usize::("n_layer", path, device)?; let mut blocks = (0..n_layer) .into_iter() .map(|i| { load_residual_decoder_attention_block::(&format!("{}/blocks/{}", path, i), device) }) .collect::, _>>()?; let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?; let clip = CLIP { token_embedding: token_embedding, position_embedding: position_embedding, blocks: blocks, layer_norm: layer_norm, }; Ok(clip) }