use super::GroupNorm; use crate::model::load::*; use std::error::Error; use burn::{ config::Config, module::{Module, Param}, nn, tensor::{backend::Backend, Tensor}, }; use super::*; use crate::model::groupnorm::load::load_group_norm; fn load_conv_self_attention_block( path: &str, device: &B::Device, ) -> Result, Box> { let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?; let q = load_conv2d(&format!("{}/{}", path, "q"), device)?; let k = load_conv2d(&format!("{}/{}", path, "k"), device)?; let v = load_conv2d(&format!("{}/{}", path, "v"), device)?; let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?; Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out, }) } fn load_resnet_block( path: &str, device: &B::Device, ) -> Result, Box> { let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?; let silu1 = SILU {}; let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?; let norm2 = load_group_norm(&format!("{}/{}", path, "norm2"), device)?; let silu2 = SILU {}; let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?; let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok(); Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut, }) } fn load_mid(path: &str, device: &B::Device) -> Result, Box> { let block_1 = load_resnet_block(&format!("{}/{}", path, "block_1"), device)?; let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?; let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?; Ok(Mid { block_1, attn, block_2, }) } fn load_padded_conv2d( path: &str, device: &B::Device, ) -> Result, Box> { let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; let channels = load_tensor::("channels", path, device)?; let channels = tensor_to_array_2(channels); let kernel_size = load_usize::("kernel_size", path, device)?; let stride = load_usize::("stride", path, device)?; let padding = load_tensor::("padding", path, device)?; let padding: [usize; 4] = tensor_to_array(padding); let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]); let mut record = conv.into_record(); let mut padded_conv: PaddedConv2d = PaddedConv2dConfig::new(channels, kernel_size, padding) .with_stride(stride) .init(); let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); record.padding = >::into_record(padding_actual); padded_conv.conv = padded_conv.conv.load_record(record); Ok(padded_conv) } fn load_decoder_block( path: &str, device: &B::Device, ) -> Result, Box> { let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?; let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok(); Ok(DecoderBlock { res1, res2, res3, upsampler, }) } fn load_encoder_block( path: &str, device: &B::Device, ) -> Result, Box> { let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok(); Ok(EncoderBlock { res1, res2, downsampler, }) } fn load_decoder(path: &str, device: &B::Device) -> Result, Box> { let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?; let mid = load_mid(&format!("{}/{}", path, "mid"), device)?; let n_block = load_usize::("n_block", path, device)?; let mut blocks = (0..n_block) .into_iter() .map(|i| load_decoder_block::(&format!("{}/blocks/{}", path, i), device)) .collect::, _>>()?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let silu = SILU {}; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out, }) } fn load_encoder(path: &str, device: &B::Device) -> Result, Box> { let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?; let mid = load_mid(&format!("{}/{}", path, "mid"), device)?; let n_block = load_usize::("n_block", path, device)?; let mut blocks = (0..n_block) .into_iter() .map(|i| load_encoder_block::(&format!("{}/blocks/{}", path, i), device)) .collect::, _>>()?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let silu = SILU {}; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out, }) } pub fn load_autoencoder( path: &str, device: &B::Device, ) -> Result, Box> { let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?; let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?; let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?; let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?; Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv, }) }