Add files via upload
Add initial project files
This commit is contained in:
530
src/model/autoencoder/mod.rs
Normal file
530
src/model/autoencoder/mod.rs
Normal file
@@ -0,0 +1,530 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::helper::div_roundup;
|
||||
|
||||
use super::silu::*;
|
||||
use super::groupnorm::*;
|
||||
use super::attention::qkv_attention;
|
||||
|
||||
use std::iter;
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct AutoencoderConfig {}
|
||||
|
||||
impl AutoencoderConfig {
|
||||
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
|
||||
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
|
||||
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
|
||||
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
|
||||
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
|
||||
|
||||
Autoencoder {
|
||||
encoder,
|
||||
decoder,
|
||||
quant_conv,
|
||||
post_quant_conv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn print_tensor<B: Backend>(x: Tensor<B, 4>) {
|
||||
let [_, channels, height, width] = x.dims();
|
||||
let channels = channels.min(10);
|
||||
let data = x.slice([0..1, 0..channels, 0..height, 0..width]).into_data();
|
||||
println!("{:?}", data);
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Autoencoder<B: Backend> {
|
||||
encoder: Encoder<B>,
|
||||
decoder: Decoder<B>,
|
||||
quant_conv: Conv2d<B>,
|
||||
post_quant_conv: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Autoencoder<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
self.decode_latent( self.encode_image(x) )
|
||||
}
|
||||
|
||||
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, _, _, _] = x.dims();
|
||||
let latent = self.encoder.forward(x);
|
||||
let latent = self.quant_conv.forward(latent);
|
||||
let latent = latent.slice([0..n_batch, 0..4]);
|
||||
latent
|
||||
}
|
||||
|
||||
pub fn decode_latent(&self, latent: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let latent = self.post_quant_conv.forward(latent);
|
||||
self.decoder.forward(latent)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
n_channels_out: usize,
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
fn init<B: Backend>(&self) -> Encoder<B> {
|
||||
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
|
||||
let n_expanded_channels_final = self.channels.first().unwrap().0;
|
||||
|
||||
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let downsample = i != self.channels.len() - 1;
|
||||
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
|
||||
}).collect();
|
||||
|
||||
let mid = MidConfig::new(n_expanded_channels_final).init();
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
Encoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Encoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<EncoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Encoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
|
||||
let mut x = x;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
let x = self.mid.forward(x);
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderConfig {
|
||||
channels: Vec<(usize, usize)>,
|
||||
n_group: usize,
|
||||
}
|
||||
|
||||
impl DecoderConfig {
|
||||
fn init<B: Backend>(&self) -> Decoder<B> {
|
||||
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
|
||||
let n_condensed_channels = self.channels.last().unwrap().1;
|
||||
|
||||
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let mid = MidConfig::new(n_expanded_channels).init();
|
||||
|
||||
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
|
||||
let upsample = i != self.channels.len() - 1;
|
||||
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
|
||||
}).collect();
|
||||
|
||||
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
|
||||
let silu = SILU::new();
|
||||
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
|
||||
Decoder {
|
||||
conv_in,
|
||||
mid,
|
||||
blocks,
|
||||
norm_out,
|
||||
silu,
|
||||
conv_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Decoder<B: Backend> {
|
||||
conv_in: Conv2d<B>,
|
||||
mid: Mid<B>,
|
||||
blocks: Vec<DecoderBlock<B>>,
|
||||
norm_out: GroupNorm<B>,
|
||||
silu: SILU,
|
||||
conv_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Decoder<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv_in.forward(x);
|
||||
let x = self.mid.forward(x);
|
||||
|
||||
let mut x = x;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x);
|
||||
}
|
||||
|
||||
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct EncoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
downsample: bool,
|
||||
}
|
||||
|
||||
impl EncoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> EncoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let downsampler = if self.downsample {
|
||||
let padding = Padding::new(0, 1, 0, 1);
|
||||
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EncoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
downsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct EncoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
downsampler: Option<PaddedConv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> EncoderBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.res1.forward(x);
|
||||
let x = self.res2.forward(x);
|
||||
if let Some(d) = self.downsampler.as_ref() {
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DecoderBlockConfig {
|
||||
n_channels_in: usize,
|
||||
n_channels_out: usize,
|
||||
upsample: bool,
|
||||
}
|
||||
|
||||
impl DecoderBlockConfig {
|
||||
fn init<B: Backend>(&self) -> DecoderBlock<B> {
|
||||
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
|
||||
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
|
||||
let upsampler = if self.upsample {
|
||||
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
DecoderBlock {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
upsampler,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DecoderBlock<B: Backend> {
|
||||
res1: ResnetBlock<B>,
|
||||
res2: ResnetBlock<B>,
|
||||
res3: ResnetBlock<B>,
|
||||
upsampler: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> DecoderBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.res1.forward(x);
|
||||
let x = self.res2.forward(x);
|
||||
let x = self.res3.forward(x);
|
||||
|
||||
if let Some(d) = self.upsampler.as_ref() {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
let x = x
|
||||
.reshape([n_batch, n_channel, height, 1, width, 1])
|
||||
.repeat(3, 2)
|
||||
.repeat(5, 2)
|
||||
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
|
||||
d.forward(x)
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct PaddedConv2dConfig {
|
||||
channels: [usize; 2],
|
||||
kernel_size: usize,
|
||||
#[config(default = 1)]
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
}
|
||||
|
||||
impl PaddedConv2dConfig {
|
||||
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
|
||||
let calc_padding = |p_left, p_right| {
|
||||
let n = if p_left >= p_right {
|
||||
0
|
||||
} else {
|
||||
div_roundup(p_right - p_left, self.stride)
|
||||
};
|
||||
|
||||
n * self.stride + p_left
|
||||
};
|
||||
|
||||
let pad_vertical = calc_padding(self.padding.pad_top, self.padding.pad_bottom);
|
||||
let pad_horizontal = calc_padding(self.padding.pad_left, self.padding.pad_right);
|
||||
let padding_actual = [pad_vertical, pad_horizontal];
|
||||
|
||||
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
|
||||
.with_stride([self.stride, self.stride])
|
||||
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
|
||||
.init();
|
||||
|
||||
let kernel_size = self.kernel_size;
|
||||
let stride = self.stride;
|
||||
|
||||
let padding = self.padding;
|
||||
|
||||
PaddedConv2d {
|
||||
conv,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
padding_actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PaddedConv2d<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: Padding,
|
||||
padding_actual: [usize; 2],
|
||||
}
|
||||
|
||||
impl<B: Backend> PaddedConv2d<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
|
||||
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
|
||||
|
||||
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
|
||||
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
|
||||
|
||||
self.conv
|
||||
.forward(x)
|
||||
.slice([
|
||||
0..n_batch,
|
||||
0..n_channel,
|
||||
skip_vert..(skip_vert + desired_height),
|
||||
skip_hor..(skip_hor + desired_width)
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Module, Copy, Debug)]
|
||||
pub struct Padding {
|
||||
pad_left: usize,
|
||||
pad_right: usize,
|
||||
pad_top: usize,
|
||||
pad_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct MidConfig {
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl MidConfig {
|
||||
fn init<B: Backend>(&self) -> Mid<B> {
|
||||
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
|
||||
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
|
||||
|
||||
Mid {
|
||||
block_1,
|
||||
attn,
|
||||
block_2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mid<B: Backend> {
|
||||
block_1: ResnetBlock<B>,
|
||||
attn: ConvSelfAttentionBlock<B>,
|
||||
block_2: ResnetBlock<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Mid<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.block_1.forward(x);
|
||||
let x = self.attn.forward(x);
|
||||
let x = self.block_2.forward(x);
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResnetBlockConfig {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl ResnetBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ResnetBlock<B> {
|
||||
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
|
||||
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
|
||||
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
|
||||
let nin_shortcut = if self.in_channels != self.out_channels {
|
||||
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let silu1 = SILU::new();
|
||||
let silu2 = SILU::new();
|
||||
|
||||
ResnetBlock {
|
||||
norm1,
|
||||
silu1,
|
||||
conv1,
|
||||
norm2,
|
||||
silu2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResnetBlock<B: Backend> {
|
||||
norm1: GroupNorm<B>,
|
||||
silu1: SILU,
|
||||
conv1: Conv2d<B>,
|
||||
norm2: GroupNorm<B>,
|
||||
silu2: SILU,
|
||||
conv2: Conv2d<B>,
|
||||
nin_shortcut: Option<Conv2d<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResnetBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
|
||||
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
|
||||
|
||||
|
||||
if let Some(ns) = self.nin_shortcut.as_ref() {
|
||||
ns.forward(x) + h
|
||||
} else {
|
||||
x + h
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ConvSelfAttentionBlockConfig {
|
||||
n_channel: usize,
|
||||
}
|
||||
|
||||
impl ConvSelfAttentionBlockConfig {
|
||||
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
|
||||
let norm = GroupNormConfig::new(32, self.n_channel).init();
|
||||
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
|
||||
|
||||
ConvSelfAttentionBlock {
|
||||
norm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvSelfAttentionBlock<B: Backend> {
|
||||
norm: GroupNorm<B>,
|
||||
q: Conv2d<B>,
|
||||
k: Conv2d<B>,
|
||||
v: Conv2d<B>,
|
||||
proj_out: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ConvSelfAttentionBlock<B> {
|
||||
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [n_batch, n_channel, height, width] = x.dims();
|
||||
|
||||
let h = self.norm.forward(x.clone());
|
||||
|
||||
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
|
||||
|
||||
let wv = qkv_attention(q, k, v, None, 1)
|
||||
.swap_dims(1, 2)
|
||||
.reshape([n_batch, n_channel, height, width]);
|
||||
|
||||
let projected = self.proj_out.forward(wv);
|
||||
|
||||
x + projected
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user