Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
parent 9e4d7bd310
commit 893fb0950d
19 changed files with 366 additions and 311 deletions

View File

@@ -18,7 +18,8 @@ use burn::{
use super::groupnorm::*;
use super::silu::*;
use crate::backend::Backend as MyBackend;
//use crate::backend::Backend as MyBackend;
use crate::backend::{qkv_attention, attn_decoder_mask};
use std::iter;
@@ -26,13 +27,13 @@ use std::iter;
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> Autoencoder<B> {
let encoder =
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(device);
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(device);
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(device);
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(device);
Autoencoder {
encoder,
@@ -51,7 +52,7 @@ pub struct Autoencoder<B: Backend> {
post_quant_conv: Conv2d<B>,
}
impl<B: MyBackend> Autoencoder<B> {
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent(self.encode_image(x))
}
@@ -78,7 +79,7 @@ pub struct EncoderConfig {
}
impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Encoder<B> {
let n_expanded_channels_initial = self
.channels
.first()
@@ -88,7 +89,7 @@ impl EncoderConfig {
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let blocks = self
.channels
@@ -96,16 +97,16 @@ impl EncoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init(device)
})
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let mid = MidConfig::new(n_expanded_channels_final).init(device);
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Encoder {
conv_in,
@@ -128,7 +129,7 @@ pub struct Encoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Encoder<B> {
impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
@@ -150,7 +151,7 @@ pub struct DecoderConfig {
}
impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> {
fn init<B: Backend>(&self, device: &B::Device) -> Decoder<B> {
let n_expanded_channels = self
.channels
.first()
@@ -160,8 +161,8 @@ impl DecoderConfig {
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let mid = MidConfig::new(n_expanded_channels).init();
.init(device);
let mid = MidConfig::new(n_expanded_channels).init(device);
let blocks = self
.channels
@@ -169,15 +170,15 @@ impl DecoderConfig {
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init(device)
})
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(device);
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
Decoder {
conv_in,
@@ -200,7 +201,7 @@ pub struct Decoder<B: Backend> {
conv_out: Conv2d<B>,
}
impl<B: MyBackend> Decoder<B> {
impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
@@ -223,15 +224,15 @@ pub struct EncoderBlockConfig {
}
impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1);
let padding = PaddingCfg::new(0, 1, 0, 1);
Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(),
.init(device),
)
} else {
None
@@ -272,15 +273,15 @@ pub struct DecoderBlockConfig {
}
impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
fn init<B: Backend>(&self, device: &B::Device) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init(device);
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(device);
let upsampler = if self.upsample {
Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
.init(device),
)
} else {
None
@@ -313,8 +314,7 @@ impl<B: Backend> DecoderBlock<B> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.repeat(&[1, 1, 1, 2, 1, 2])
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
@@ -329,11 +329,11 @@ pub struct PaddedConv2dConfig {
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: Padding,
padding: PaddingCfg,
}
impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
fn init<B: Backend>(&self, device: &B::Device) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right {
0
@@ -351,12 +351,17 @@ impl PaddedConv2dConfig {
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init();
.init(device);
let kernel_size = self.kernel_size;
let stride = self.stride;
let padding = self.padding;
let padding = Padding {
pad_left: self.padding.pad_left,
pad_right: self.padding.pad_right,
pad_top: self.padding.pad_top,
pad_bottom: self.padding.pad_bottom,
};
PaddedConv2d {
conv,
@@ -406,7 +411,15 @@ impl<B: Backend> PaddedConv2d<B> {
}
}
#[derive(Config, Module, Copy, Debug)]
#[derive(Config, Debug)]
pub struct PaddingCfg {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Module, Clone, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
@@ -420,10 +433,10 @@ pub struct MidConfig {
}
impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
fn init<B: Backend>(&self, device: &B::Device) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init(device);
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(device);
Mid {
block_1,
@@ -440,7 +453,7 @@ pub struct Mid<B: Backend> {
block_2: ResnetBlock<B>,
}
impl<B: MyBackend> Mid<B> {
impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
@@ -456,17 +469,17 @@ pub struct ResnetBlockConfig {
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
fn init<B: Backend>(&self, device: &B::Device) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(device);
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
.init(device);
let norm2 = GroupNormConfig::new(32, self.out_channels).init(device);
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
.init(device);
let nin_shortcut = if self.in_channels != self.out_channels {
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init(device))
} else {
None
};
@@ -520,12 +533,12 @@ pub struct ConvSelfAttentionBlockConfig {
}
impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init();
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
fn init<B: Backend>(&self, device: &B::Device) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init(device);
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(device);
ConvSelfAttentionBlock {
norm,
@@ -546,7 +559,7 @@ pub struct ConvSelfAttentionBlock<B: Backend> {
proj_out: Conv2d<B>,
}
impl<B: MyBackend> ConvSelfAttentionBlock<B> {
impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
@@ -568,7 +581,7 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = Tensor::from_primitive(B::qkv_attention(
/*let wv = Tensor::from_primitive(B::qkv_attention(
q.into_primitive(),
k.into_primitive(),
v.into_primitive(),
@@ -576,6 +589,16 @@ impl<B: MyBackend> ConvSelfAttentionBlock<B> {
1,
))
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);*/
let wv = qkv_attention(
q,
k,
v,
None,
1,
)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);