mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Add files via upload
Add initial project files
This commit is contained in:
134
src/model/autoencoder/load.rs
Normal file
134
src/model/autoencoder/load.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use super::GroupNorm;
|
||||
use crate::model::load::*;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::model::groupnorm::load::load_group_norm;
|
||||
|
||||
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
|
||||
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
|
||||
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
|
||||
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
|
||||
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
|
||||
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
|
||||
|
||||
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out })
|
||||
}
|
||||
|
||||
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> {
|
||||
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
|
||||
let silu1 = SILU {};
|
||||
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
|
||||
let norm2 = load_group_norm(&format!("{}/{}", path, "norm2"), device)?;
|
||||
let silu2 = SILU {};
|
||||
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
|
||||
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
|
||||
|
||||
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut })
|
||||
}
|
||||
|
||||
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
|
||||
let block_1 = load_resnet_block(&format!("{}/{}", path, "block_1"), device)?;
|
||||
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
|
||||
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
|
||||
|
||||
Ok(Mid { block_1, attn, block_2 })
|
||||
}
|
||||
|
||||
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
|
||||
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
|
||||
|
||||
let channels = load_tensor::<B, 1>("channels", path, device)?;
|
||||
let channels = tensor_to_array_2(channels);
|
||||
|
||||
let kernel_size = load_usize::<B>("kernel_size", path, device)?;
|
||||
let stride = load_usize::<B>("stride", path, device)?;
|
||||
|
||||
let padding = load_tensor::<B, 1>("padding", path, device)?;
|
||||
let padding: [usize; 4] = tensor_to_array(padding);
|
||||
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
|
||||
|
||||
let mut record = conv.into_record();
|
||||
|
||||
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init();
|
||||
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
|
||||
|
||||
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
|
||||
padded_conv.conv = padded_conv.conv.load_record(record);
|
||||
|
||||
|
||||
Ok(padded_conv)
|
||||
}
|
||||
|
||||
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
|
||||
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
|
||||
|
||||
Ok(DecoderBlock { res1, res2, res3, upsampler })
|
||||
}
|
||||
|
||||
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> {
|
||||
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
|
||||
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
|
||||
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
|
||||
|
||||
Ok(EncoderBlock { res1, res2, downsampler })
|
||||
}
|
||||
|
||||
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
|
||||
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
|
||||
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
}
|
||||
|
||||
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
|
||||
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
|
||||
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
|
||||
|
||||
let n_block = load_usize::<B>("n_block", path, device)?;
|
||||
let mut blocks = (0..n_block)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
|
||||
let silu = SILU {};
|
||||
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
|
||||
|
||||
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out })
|
||||
}
|
||||
|
||||
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> {
|
||||
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
|
||||
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
|
||||
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
|
||||
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
|
||||
|
||||
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv })
|
||||
}
|
||||
Reference in New Issue
Block a user