mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
@@ -18,7 +18,8 @@ use burn::{
|
||||
|
||||
use super::groupnorm::*;
|
||||
use super::silu::*;
|
||||
use crate::backend::Backend as MyBackend;
|
||||
//use crate::backend::Backend as MyBackend;
|
||||
use crate::backend::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
use std::iter;
|
||||
|
||||
@@ -26,13 +27,13 @@ use std::iter;
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
|
||||
let encoder =
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
|
||||
let decoder =
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Autoencoder<B> {
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent(self.encode_image(x))
|
||||
}
|
||||
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self
|
||||
.channels
|
||||
.first()
|
||||
@@ -88,7 +89,7 @@ impl EncoderConfig {
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let blocks = self
|
||||
.channels
|
||||
@@ -96,16 +97,16 @@ impl EncoderConfig {
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init(device);
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Encoder<B> {
|
||||
impl<B: Backend> Encoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
|
||||
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
|
||||
let n_expanded_channels = self
|
||||
.channels
|
||||
.first()
|
||||
@@ -160,8 +161,8 @@ impl DecoderConfig {
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
.init(device);
|
||||
let mid = MidConfig::new(n_expanded_channels).init(device);
|
||||
|
||||
let blocks = self
|
||||
.channels
|
||||
@@ -169,15 +170,15 @@ impl DecoderConfig {
|
||||
.enumerate()
|
||||
.map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Decoder<B> {
|
||||
impl<B: Backend> Decoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
let x = self.mid.forward(x);
|
||||
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
let padding = PaddingCfg::new(0, 1, 0, 1);
|
||||
Some(
|
||||
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
|
||||
.with_stride(2)
|
||||
.init(),
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
|
||||
let upsampler = if self.upsample {
|
||||
Some(
|
||||
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init(),
|
||||
.init(device),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.repeat(&[1, 1, 1, 2, 1, 2])
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding: PaddingCfg,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
|
||||
let calc_padding = |p_left, p_right| {
|
||||
let n = if p_left >= p_right {
|
||||
0
|
||||
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
|
||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||
.with_stride([self.stride, self.stride])
|
||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||
.init();
|
||||
.init(device);
|
||||
|
||||
let kernel_size = self.kernel_size;
|
||||
let stride = self.stride;
|
||||
|
||||
let padding = self.padding;
|
||||
let padding = Padding {
|
||||
pad_left: self.padding.pad_left,
|
||||
pad_right: self.padding.pad_right,
|
||||
pad_top: self.padding.pad_top,
|
||||
pad_bottom: self.padding.pad_bottom,
|
||||
};
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PaddingCfg {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
@@ -420,10 +433,10 @@ pub struct MidConfig {
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
fn init<B: Backend>(&self) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> Mid<B> {
|
||||
impl<B: Backend> Mid<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.block_1.forward(x);
|
||||
let x = self.attn.forward(x);
|
||||
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
.init(device);
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
|
||||
.with_padding(PaddingConfig2d::Explicit(1, 1))
|
||||
.init();
|
||||
.init(device);
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
|
||||
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
.reshape([n_batch, n_channel, height * width])
|
||||
.swap_dims(1, 2);
|
||||
|
||||
let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
/*let wv = Tensor::from_primitive(B::qkv_attention(
|
||||
q.into_primitive(),
|
||||
k.into_primitive(),
|
||||
v.into_primitive(),
|
||||
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
|
||||
1,
|
||||
))
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);*/
|
||||
|
||||
let wv = qkv_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
1,
|
||||
)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
Reference in New Issue
Block a user