Add files via upload

Add initial project files
This commit is contained in:
Gadersd
2023-08-04 14:32:47 -04:00
committed by GitHub
parent 1aed8b655a
commit e4145441eb
31 changed files with 266571 additions and 0 deletions

47
src/model/attention.rs Normal file
View File

@@ -0,0 +1,47 @@
use burn::{
tensor::{
backend::Backend,
activation::softmax,
Tensor,
},
};
use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head;
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale;
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale;
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2);
let qk = q.matmul(k);
// apply mask
let qk = if let Some(mask) = mask {
qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
} else {
qk
};
// normalize value weightings
let w = softmax(qk, 3);
let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
return o;
}
pub fn attn_decoder_mask<B: Backend>(seq_length: usize) -> Tensor<B, 2> {
let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length]);
for i in 0..(seq_length - 1) {
let values = Tensor::<B, 2>::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY);
mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
}
return mask;
}

View File

@@ -0,0 +1,134 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use super::*;
use crate::model::groupnorm::load::load_group_norm;
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out })
}
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> {
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
let silu1 = SILU {};
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
let norm2 = load_group_norm(&format!("{}/{}", path, "norm2"), device)?;
let silu2 = SILU {};
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut })
}
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
let block_1 = load_resnet_block(&format!("{}/{}", path, "block_1"), device)?;
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
Ok(Mid { block_1, attn, block_2 })
}
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?;
let channels = tensor_to_array_2(channels);
let kernel_size = load_usize::<B>("kernel_size", path, device)?;
let stride = load_usize::<B>("stride", path, device)?;
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding: [usize; 4] = tensor_to_array(padding);
let padding = Padding::new(padding[0], padding[1], padding[2], padding[3]);
let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init();
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv)
}
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
Ok(DecoderBlock { res1, res2, res3, upsampler })
}
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
Ok(EncoderBlock { res1, res2, downsampler })
}
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block)
.into_iter()
.map(|i| {
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out })
}
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
let conv_in = load_conv2d(&format!("{}/{}", path, "conv_in"), device)?;
let mid = load_mid(&format!("{}/{}", path, "mid"), device)?;
let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block)
.into_iter()
.map(|i| {
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out })
}
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> {
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv })
}

View File

@@ -0,0 +1,530 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}},
tensor::{
backend::Backend,
activation::{softmax, sigmoid},
module::embedding,
Tensor,
Distribution,
Int,
},
};
use crate::helper::div_roundup;
use super::silu::*;
use super::groupnorm::*;
use super::attention::qkv_attention;
use std::iter;
#[derive(Config)]
pub struct AutoencoderConfig {}
impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> {
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
}
}
}
fn print_tensor<B: Backend>(x: Tensor<B, 4>) {
let [_, channels, height, width] = x.dims();
let channels = channels.min(10);
let data = x.slice([0..1, 0..channels, 0..height, 0..width]).into_data();
println!("{:?}", data);
}
#[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> {
encoder: Encoder<B>,
decoder: Decoder<B>,
quant_conv: Conv2d<B>,
post_quant_conv: Conv2d<B>,
}
impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent( self.encode_image(x) )
}
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, _, _, _] = x.dims();
let latent = self.encoder.forward(x);
let latent = self.quant_conv.forward(latent);
let latent = latent.slice([0..n_batch, 0..4]);
latent
}
pub fn decode_latent(&self, latent: Tensor<B, 4>) -> Tensor<B, 4> {
let latent = self.post_quant_conv.forward(latent);
self.decoder.forward(latent)
}
}
#[derive(Config)]
pub struct EncoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
n_channels_out: usize,
}
impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> {
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty.");
let n_expanded_channels_final = self.channels.first().unwrap().0;
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
}).collect();
let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Encoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<EncoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Encoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
let x = self.mid.forward(x);
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
}
}
#[derive(Config)]
pub struct DecoderConfig {
channels: Vec<(usize, usize)>,
n_group: usize,
}
impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> {
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty.");
let n_condensed_channels = self.channels.last().unwrap().1;
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let mid = MidConfig::new(n_expanded_channels).init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
}).collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct Decoder<B: Backend> {
conv_in: Conv2d<B>,
mid: Mid<B>,
blocks: Vec<DecoderBlock<B>>,
norm_out: GroupNorm<B>,
silu: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> Decoder<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv_in.forward(x);
let x = self.mid.forward(x);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) )
}
}
#[derive(Config)]
pub struct EncoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
downsample: bool,
}
impl EncoderBlockConfig {
fn init<B: Backend>(&self) -> EncoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1);
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() )
} else {
None
};
EncoderBlock {
res1,
res2,
downsampler,
}
}
}
#[derive(Module, Debug)]
pub struct EncoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
downsampler: Option<PaddedConv2d<B>>,
}
impl<B: Backend> EncoderBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.res1.forward(x);
let x = self.res2.forward(x);
if let Some(d) = self.downsampler.as_ref() {
d.forward(x)
} else {
x
}
}
}
#[derive(Config)]
pub struct DecoderBlockConfig {
n_channels_in: usize,
n_channels_out: usize,
upsample: bool,
}
impl DecoderBlockConfig {
fn init<B: Backend>(&self) -> DecoderBlock<B> {
let res1 = ResnetBlockConfig::new(self.n_channels_in, self.n_channels_out).init();
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let upsampler = if self.upsample {
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() )
} else {
None
};
DecoderBlock {
res1,
res2,
res3,
upsampler,
}
}
}
#[derive(Module, Debug)]
pub struct DecoderBlock<B: Backend> {
res1: ResnetBlock<B>,
res2: ResnetBlock<B>,
res3: ResnetBlock<B>,
upsampler: Option<Conv2d<B>>,
}
impl<B: Backend> DecoderBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.res1.forward(x);
let x = self.res2.forward(x);
let x = self.res3.forward(x);
if let Some(d) = self.upsampler.as_ref() {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
d.forward(x)
} else {
x
}
}
}
#[derive(Config)]
pub struct PaddedConv2dConfig {
channels: [usize; 2],
kernel_size: usize,
#[config(default = 1)]
stride: usize,
padding: Padding,
}
impl PaddedConv2dConfig {
fn init<B: Backend>(&self) -> PaddedConv2d<B> {
let calc_padding = |p_left, p_right| {
let n = if p_left >= p_right {
0
} else {
div_roundup(p_right - p_left, self.stride)
};
n * self.stride + p_left
};
let pad_vertical = calc_padding(self.padding.pad_top, self.padding.pad_bottom);
let pad_horizontal = calc_padding(self.padding.pad_left, self.padding.pad_right);
let padding_actual = [pad_vertical, pad_horizontal];
let conv = Conv2dConfig::new(self.channels, [self.kernel_size, self.kernel_size])
.with_stride([self.stride, self.stride])
.with_padding(PaddingConfig2d::Explicit(pad_vertical, pad_horizontal))
.init();
let kernel_size = self.kernel_size;
let stride = self.stride;
let padding = self.padding;
PaddedConv2d {
conv,
kernel_size,
stride,
padding,
padding_actual,
}
}
}
#[derive(Module, Debug)]
pub struct PaddedConv2d<B: Backend> {
conv: Conv2d<B>,
kernel_size: usize,
stride: usize,
padding: Padding,
padding_actual: [usize; 2],
}
impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1;
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
self.conv
.forward(x)
.slice([
0..n_batch,
0..n_channel,
skip_vert..(skip_vert + desired_height),
skip_hor..(skip_hor + desired_width)
])
}
}
#[derive(Config, Module, Copy, Debug)]
pub struct Padding {
pad_left: usize,
pad_right: usize,
pad_top: usize,
pad_bottom: usize,
}
#[derive(Config)]
pub struct MidConfig {
n_channel: usize,
}
impl MidConfig {
fn init<B: Backend>(&self) -> Mid<B> {
let block_1 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
let attn = ConvSelfAttentionBlockConfig::new(self.n_channel).init();
let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init();
Mid {
block_1,
attn,
block_2,
}
}
}
#[derive(Module, Debug)]
pub struct Mid<B: Backend> {
block_1: ResnetBlock<B>,
attn: ConvSelfAttentionBlock<B>,
block_2: ResnetBlock<B>,
}
impl<B: Backend> Mid<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.block_1.forward(x);
let x = self.attn.forward(x);
let x = self.block_2.forward(x);
x
}
}
#[derive(Config)]
pub struct ResnetBlockConfig {
in_channels: usize,
out_channels: usize,
}
impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init();
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init();
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let nin_shortcut = if self.in_channels != self.out_channels {
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() )
} else {
None
};
let silu1 = SILU::new();
let silu2 = SILU::new();
ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
}
}
}
#[derive(Module, Debug)]
pub struct ResnetBlock<B: Backend> {
norm1: GroupNorm<B>,
silu1: SILU,
conv1: Conv2d<B>,
norm2: GroupNorm<B>,
silu2: SILU,
conv2: Conv2d<B>,
nin_shortcut: Option<Conv2d<B>>,
}
impl<B: Backend> ResnetBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) );
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) );
if let Some(ns) = self.nin_shortcut.as_ref() {
ns.forward(x) + h
} else {
x + h
}
}
}
#[derive(Config)]
pub struct ConvSelfAttentionBlockConfig {
n_channel: usize,
}
impl ConvSelfAttentionBlockConfig {
fn init<B: Backend>(&self) -> ConvSelfAttentionBlock<B> {
let norm = GroupNormConfig::new(32, self.n_channel).init();
let q = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let k = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let v = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init();
ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
}
}
}
#[derive(Module, Debug)]
pub struct ConvSelfAttentionBlock<B: Backend> {
norm: GroupNorm<B>,
q: Conv2d<B>,
k: Conv2d<B>,
v: Conv2d<B>,
proj_out: Conv2d<B>,
}
impl<B: Backend> ConvSelfAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let h = self.norm.forward(x.clone());
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let wv = qkv_attention(q, k, v, None, 1)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
let projected = self.proj_out.forward(wv);
x + projected
}
}

86
src/model/clip/load.rs Normal file
View File

@@ -0,0 +1,86 @@
use std::error::Error;
use burn::tensor::ElementConversion;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use super::*;
use crate::model::load::*;
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let fc1 = load_linear(&format!("{}/{}", path, "fc1"), device)?;
let gelu = QuickGELU::new();
let fc2 = load_linear(&format!("{}/{}", path, "fc2"), device)?;
let mlp = MLP {
fc1: fc1,
gelu: gelu,
fc2: fc2,
};
Ok(mlp)
}
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear(&format!("{}/{}", path, "query"), device)?;
let key = load_linear(&format!("{}/{}", path, "key"), device)?;
let value = load_linear(&format!("{}/{}", path, "value"), device)?;
let out = load_linear(&format!("{}/{}", path, "out"), device)?;
let mhsa = MultiHeadSelfAttention {
n_head: n_head,
query: query,
key: key,
value: value,
out: out,
};
Ok(mhsa)
}
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
let mlp_ln = load_layer_norm(&format!("{}/{}", path, "mlp_ln"), device)?;
let rdab = ResidualDecoderAttentionBlock {
attn: attn,
attn_ln: attn_ln,
mlp: mlp,
mlp_ln: mlp_ln,
};
Ok(rdab)
}
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer)
.into_iter()
.map(|i| {
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?;
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;
let clip = CLIP {
token_embedding: token_embedding,
position_embedding: position_embedding,
blocks: blocks,
layer_norm: layer_norm,
};
Ok(clip)
}

220
src/model/clip/mod.rs Normal file
View File

@@ -0,0 +1,220 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
activation::{softmax, sigmoid},
module::embedding,
Tensor,
Distribution,
Int,
},
};
use crate::model::attention::{qkv_attention, attn_decoder_mask};
#[derive(Config)]
pub struct CLIPConfig {
n_vocab: usize,
n_state: usize,
n_head: usize,
n_ctx: usize,
n_layer: usize,
}
impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
let blocks = (0..self.n_layer)
.into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
.collect();
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
CLIP {
token_embedding,
position_embedding,
blocks,
layer_norm,
}
}
}
#[derive(Module, Debug)]
pub struct CLIP<B: Backend> {
token_embedding: nn::Embedding<B>,
position_embedding: Param<Tensor<B, 2>>,
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
layer_norm: nn::LayerNorm<B>,
}
impl<B: Backend> CLIP<B> {
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let mask = attn_decoder_mask(seq_len);
let embedded = self.token_embedding.forward(x)
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
let mut x = embedded;
for block in &self.blocks {
x = block.forward(x, mask.clone());
}
self.layer_norm.forward(x)
}
}
#[derive(Config)]
pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize,
n_head: usize,
}
impl ResidualDecoderAttentionBlockConfig {
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
ResidualDecoderAttentionBlock {
attn,
attn_ln,
mlp,
mlp_ln,
}
}
}
#[derive(Module, Debug)]
pub struct ResidualDecoderAttentionBlock<B: Backend> {
attn: MultiHeadSelfAttention<B>,
attn_ln: nn::LayerNorm<B>,
mlp: MLP<B>,
mlp_ln: nn::LayerNorm<B>,
}
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
fn forward(&self, x: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
return x;
}
}
#[derive(Config)]
pub struct MultiHeadSelfAttentionConfig {
n_state: usize,
n_head: usize,
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
let key = nn::LinearConfig::new(self.n_state, self.n_state).init();
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
MultiHeadSelfAttention {
n_head,
query,
key,
value,
out
}
}
}
#[derive(Module, Debug)]
pub struct MultiHeadSelfAttention<B: Backend> {
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
let q = self.query.forward(x.clone());
let k = self.key.forward(x.clone());
let v = self.value.forward(x);
let wv = qkv_attention(q, k, v, mask, self.n_head);
return self.out.forward(wv);
}
}
#[derive(Config, Debug)]
pub struct MLPConfig {
input_size: usize,
hidden_size: usize,
}
impl MLPConfig {
fn init<B: Backend>(&self) -> MLP<B> {
let fc1 = nn::LinearConfig::new(self.input_size, self.hidden_size).init();
let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
MLP {
fc1,
gelu,
fc2,
}
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
fc1: nn::Linear<B>,
gelu: QuickGELU,
fc2: nn::Linear<B>,
}
impl<B: Backend> MLP<B> {
fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let x = self.fc1.forward(x);
let x = self.gelu.forward(x);
let x = self.fc2.forward(x);
x
}
}
#[derive(Module, Clone, Debug)]
pub struct QuickGELU {}
impl QuickGELU {
fn new() -> Self {
Self {}
}
fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
x.clone() * sigmoid(x * 1.702)
}
}

View File

@@ -0,0 +1,33 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
Ok(
GroupNorm {
n_group,
n_channel,
gamma,
beta,
eps,
}
)
}

View File

@@ -0,0 +1,72 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
tensor::{
backend::Backend,
Tensor,
},
};
#[derive(Config)]
pub struct GroupNormConfig {
n_group: usize,
n_channel: usize,
#[config(default = 1e-5)]
eps: f64,
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into();
let beta = Tensor::zeros([self.n_channel]).into();
let eps = self.eps;
GroupNorm {
n_group: self.n_group,
n_channel: self.n_channel,
gamma,
beta,
eps,
}
}
}
#[derive(Module, Debug)]
pub struct GroupNorm<B: Backend> {
n_group: usize,
n_channel: usize,
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
eps: f64,
}
impl<B: Backend> GroupNorm<B> {
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let shape = x.shape();
let n_batch = shape.dims[0];
let num_elements = shape.num_elements();
let mut affine_shape = [1; D];
affine_shape[1] = self.n_channel;
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
}
}
pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tensor<B, D> {
//let (var, mean) = x.clone().var_mean_bias(D - 1);
//x.sub(mean).div(var.sqrt().add_scalar(eps))
let u = x.clone() - x.mean_dim(D - 1);
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
}

167
src/model/load.rs Normal file
View File

@@ -0,0 +1,167 @@
use std::error::Error;
use std::io::Read;
use npy::{self, NpyData};
use num_traits::cast::ToPrimitive;
use burn::{
config::Config,
module::{Module, Param},
nn::{self, conv},
tensor::{
backend::Backend,
Tensor,
Data,
},
};
use burn::tensor::ElementConversion;
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> {
let mut v = numpy_data.to_vec();
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
let data: Vec<B::FloatElem> = v[D..].into_iter().map(|e| e.elem()).collect();
Tensor::from_data_device(Data::new(data, shape.into()), device)
}
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> {
let tensor_path = format!("{}/{}.npy", path, name);
let mut buf = vec![];
std::fs::File::open(&tensor_path)?
.read_to_end(&mut buf)?;
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
let tensor = numpy_to_tensor(tensor_numpy, device);
println!("{}", tensor_path);
Ok(tensor)
}
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
}
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
}
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
let record = nn::LinearRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
};
let linear: nn::Linear<B> = nn::LinearConfig::new(3, 3).init_with(record);
Ok(linear)
}
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims();
let record = nn::EmbeddingRecord {
weight: weight.into(),
};
let embedding = nn::EmbeddingConfig::new(n_vocab, n_state).init_with(record);
Ok(embedding)
}
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device)?;
let eps = load_f32::<B>("eps", path, device)? as f64;
let [n_state] = weight.dims();
let record = nn::LayerNormRecord {
gamma: weight.into(),
beta: bias.into(),
epsilon: <f64 as Module<B>>::into_record(eps),
};
let layer_norm: nn::LayerNorm<B> = nn::LayerNormConfig::new(n_state).init_with(record);
Ok(layer_norm)
}
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?;
let eps = load_f32::<B>("eps", path, device)?.into();
let rmsnorm = RMSNorm {
weight: weight.into(),
eps: eps
};
Ok(rmsnorm)
}*/
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 4>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok();
let has_bias = bias.is_some();
let stride = load_tensor::<B, 1>("stride", path, device)?;
let stride = tensor_to_array_2(stride);
let kernel_size = load_tensor::<B, 1>("kernel_size", path, device)?;
let kernel_size = tensor_to_array_2(kernel_size);
let dilation = load_tensor::<B, 1>("dilation", path, device)?;
let dilation = tensor_to_array_2(dilation);
let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channels_in = load_usize::<B>("n_channels_in", path, device)?.into();
let n_channels_out = load_usize::<B>("n_channels_out", path, device)?.into();
let padding = load_tensor::<B, 1>("padding", path, device)?;
let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let record = conv::Conv2dRecord {
weight: weight.into(),
bias: bias.map(|t| t.into()),
stride: <[usize; 2] as Module<B>>::into_record(stride),
kernel_size: <[usize; 2] as Module<B>>::into_record(kernel_size),
dilation: <[usize; 2] as Module<B>>::into_record(dilation),
groups: <usize as Module<B>>::into_record(n_group),
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
};
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride)
.with_dilation(dilation)
.with_groups(n_group)
.with_padding(padding)
.with_bias(has_bias)
.init_with(record);
Ok(conv2d)
}
pub fn tensor_to_array_2<B: Backend>(x: Tensor<B, 1>) -> [usize; 2] {
let vec = x.into_data().value;
assert!(vec.len() == 2, "Tensor length must be 2.");
[vec[0].to_usize().unwrap(), vec[1].to_usize().unwrap()]
}
pub fn tensor_to_array<const N: usize, B: Backend>(x: Tensor<B, 1>) -> [usize; N] {
let vec = x.into_data().value;
assert!(vec.len() == N, "Tensor length must be {}.", N);
let mut arr = [0; N];
for (a, t) in arr.iter_mut().zip(vec) {
*a = t.to_usize().unwrap();
}
arr
}

11
src/model/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
pub mod stablediffusion;
pub mod autoencoder;
pub mod unet;
pub mod clip;
pub mod silu;
pub mod groupnorm;
pub mod attention;
pub mod load;

22
src/model/silu.rs Normal file
View File

@@ -0,0 +1,22 @@
use burn::{
module::Module,
tensor::{
backend::Backend,
activation::sigmoid,
Tensor,
},
};
#[derive(Module, Clone, Debug)]
pub struct SILU {}
impl SILU {
pub fn new() -> Self {
Self {}
}
pub fn forward<B: Backend, const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
x.clone() * sigmoid(x)
}
}

View File

@@ -0,0 +1,34 @@
use std::error::Error;
use burn::tensor::ElementConversion;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use super::*;
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip};
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products: Vec<_> = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into_data().value.into_iter()
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap())
.collect();
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
Ok(StableDiffusion {
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
})
}

View File

@@ -0,0 +1,181 @@
pub mod load;
use burn::{
config::Config,
module::Module,
tensor::{
backend::Backend,
Tensor,
Int,
Float,
BasicOps,
Data,
Distribution,
},
};
use num_traits::ToPrimitive;
use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::unet::{UNet, UNetConfig};
use super::clip::{CLIP, CLIPConfig};
use crate::tokenizer::SimpleTokenizer;
#[derive(Config)]
pub struct StableDiffusionConfig {
}
impl StableDiffusionConfig {
fn init<B: Backend>(&self) -> StableDiffusion<B> {
let n_steps = 1000;
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps)
.into_data().value
.into_iter()
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap()).collect();
let autoencoder = AutoencoderConfig::new().init();
let diffusion = UNetConfig::new().init();
let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init();
StableDiffusion {
n_steps,
alpha_cumulative_products,
autoencoder,
diffusion,
clip,
}
}
}
#[derive(Module, Debug)]
pub struct StableDiffusion<B: Backend> {
n_steps: usize,
alpha_cumulative_products: Vec<f64>,
autoencoder: Autoencoder<B>,
diffusion: UNet<B>,
clip: CLIP<B>,
}
impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> {
let [n_batch, _, _] = context.dims();
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
let n_channel = 3;
let height = 512;
let width = 512;
let num_elements_per_image = n_channel * height * width;
// correct size and scale and reorder to
let image = (image + 1.0) / 2.0;
let image = image
.reshape([n_batch, n_channel, height, width])
.swap_dims(1, 2)
.swap_dims(2, 3)
.mul_scalar(255.0);
let flattened: Vec<_> = image.
into_data().
value;
(0..n_batch).into_iter().map(|b| {
let start = b * num_elements_per_image;
let end = start + num_elements_per_image;
flattened[start..end].into_iter().map(|v| v.to_u8().unwrap()).collect()
}).collect()
}
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
assert!(self.n_steps % n_steps == 0);
let step_size = self.n_steps / n_steps;
let [n_batches, _, _] = context.dims();
let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0) )
};
let sigma = 0.0; // Use deterministic diffusion
let mut latent = gen_noise();
for t in (0..self.n_steps).rev().step_by(step_size) {
let current_alpha = self.alpha_cumulative_products[t];
let prev_alpha = if t >= step_size {
self.alpha_cumulative_products[t - step_size]
} else {
1.0
};
let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]);
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale);
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
let prev_latent = predx0 * prev_alpha.sqrt() + dir_latent + gen_noise() * sigma;
latent = prev_latent;
}
latent
}
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = latent.dims();
let latent = latent.repeat(0, 2);
let latent = self.diffusion.forward(
latent.repeat(0, 2),
timestep.repeat(0, 2),
Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0)
);
let unconditional_latent = latent.clone().slice([0..n_batch]);
let conditional_latent = latent.slice([n_batch..2 * n_batch]);
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale
}
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
self.context(tokenizer, "").squeeze(0)
}
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
let text = format!("<|startoftext|>{}<|endoftext|>", text);
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();
self.clip.forward(Tensor::from_ints(&tokenized[..]).unsqueeze())
}
}
use crate::helper::to_float;
use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
to_float(Tensor::arange(1..n_steps + 1))
.mul_scalar(PI * 0.5 / n_steps as f64)
.cos()
}
fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
let min_signal_rate: f64 = 0.02;
let max_signal_rate: f64 = 0.95;
let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1);
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos()
}
fn offset_cosine_schedule_cumprod<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
offset_cosine_schedule::<B>(n_steps).powf(2.0)
}

278
src/model/unet/load.rs Normal file
View File

@@ -0,0 +1,278 @@
use super::GroupNorm;
use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use super::*;
use crate::model::groupnorm::load::load_group_norm;
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> {
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
let skip_connection = load_conv2d::<B>(&format!("{}/{}", path, "skip_connection"), device).ok();
let res_block = ResBlock {
norm_in: norm_in,
silu_in: SILU::new(),
conv_in: conv_in,
silu_embed: SILU::new(),
lin_embed: lin_embed,
norm_out: norm_out,
silu_out: SILU::new(),
conv_out: conv_out,
skip_connection: skip_connection,
};
Ok(res_block)
}
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
let value = load_linear::<B>(&format!("{}/{}", path, "value"), device)?;
let out = load_linear::<B>(&format!("{}/{}", path, "out"), device)?;
let multi_head_attention = MultiHeadAttention {
n_head: n_head,
query: query,
key: key,
value: value,
out: out,
};
Ok(multi_head_attention)
}
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
let geglue = GEGLU {
proj: proj,
gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct
};
Ok(geglue)
}
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
let mlp = MLP {
geglu: geglu,
lin: lin,
};
Ok(mlp)
}
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> {
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
let attn2 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn2"), device)?;
let norm3 = load_layer_norm::<B>(&format!("{}/{}", path, "norm3"), device)?;
let mlp = load_mlp::<B>(&format!("{}/{}", path, "mlp"), device)?;
let transformer_block = TransformerBlock {
norm1: norm1,
attn1: attn1,
norm2: norm2,
attn2: attn2,
norm3: norm3,
mlp: mlp,
};
Ok(transformer_block)
}
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
let proj_out = load_conv2d::<B>(&format!("{}/{}", path, "proj_out"), device)?;
let spatial_transformer = SpatialTransformer {
norm: norm,
proj_in: proj_in,
transformer: transformer,
proj_out: proj_out,
};
Ok(spatial_transformer)
}
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let upsample = Upsample {
conv: conv,
};
Ok(upsample)
}
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> {
load_conv2d(path, device)
}
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
let res_transformer_res = ResTransformerRes {
res1: res1,
transformer: transformer,
res2: res2,
};
Ok(res_transformer_res)
}
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_transformer_upsample = ResTransformerUpsample {
res: res,
transformer: transformer,
upsample: upsample,
};
Ok(res_transformer_upsample)
}
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_upsample = ResUpSample {
res: res,
upsample: upsample,
};
Ok(res_upsample)
}
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res_transformer = ResTransformer {
res: res,
transformer: transformer,
};
Ok(res_transformer)
}
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
let d1 = load_downsample::<B>(&format!("{}/{}", path, "d1"), device)?;
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
let d2 = load_downsample::<B>(&format!("{}/{}", path, "d2"), device)?;
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
let d3 = load_downsample::<B>(&format!("{}/{}", path, "d3"), device)?;
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let unet_input_blocks = UNetInputBlocks {
conv: conv,
rt1: rt1,
rt2: rt2,
d1: d1,
rt3: rt3,
rt4: rt4,
d2: d2,
rt5: rt5,
rt6: rt6,
d3: d3,
r1: r1,
r2: r2,
};
Ok(unet_input_blocks)
}
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
let rtu1 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu1"), device)?;
let rt3 = load_res_transformer::<B>(&format!("{}/{}", path, "rt3"), device)?;
let rt4 = load_res_transformer::<B>(&format!("{}/{}", path, "rt4"), device)?;
let rtu2 = load_res_transformer_upsample::<B>(&format!("{}/{}", path, "rtu2"), device)?;
let rt5 = load_res_transformer::<B>(&format!("{}/{}", path, "rt5"), device)?;
let rt6 = load_res_transformer::<B>(&format!("{}/{}", path, "rt6"), device)?;
let rt7 = load_res_transformer::<B>(&format!("{}/{}", path, "rt7"), device)?;
Ok(UNetOutputBlocks {
r1,
r2,
ru,
rt1,
rt2,
rtu1,
rt3,
rt4,
rtu2,
rt5,
rt6,
rt7,
})
}
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;
Ok(UNet {
lin1_time_embed,
silu_time_embed,
lin2_time_embed,
input_blocks,
middle_block,
output_blocks,
norm_out,
silu_out,
conv_out,
})
}

757
src/model/unet/mod.rs Normal file
View File

@@ -0,0 +1,757 @@
pub mod load;
use burn::{
config::Config,
module::{Module, Param},
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}},
tensor::{
backend::Backend,
activation::softmax,
module::embedding,
Tensor,
Distribution,
Int,
},
};
use super::silu::*;
use super::groupnorm::*;
use crate::helper::to_float;
use super::attention::qkv_attention;
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> {
let half = dim / 2;
let freqs = ( to_float(Tensor::arange_device(0..half, &timesteps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp();
let args = to_float(timesteps) * freqs;
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
}
#[derive(Config)]
pub struct UNetConfig {}
impl UNetConfig {
pub fn init<B: Backend>(&self) -> UNet<B> {
let lin1_time_embed = nn::LinearConfig::new(320, 1280).init();
let silu_time_embed = SILU::new();
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
d1: DownsampleConfig::new(320).init(),
rt3: ResTransformerConfig::new(320, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(640, 1280, 640, 768, 8).init(),
d2: DownsampleConfig::new(640).init(),
rt5: ResTransformerConfig::new(640, 1280, 1280, 768, 8).init(),
rt6: ResTransformerConfig::new(1280, 1280, 1280, 768, 8).init(),
d3: DownsampleConfig::new(1280).init(),
r1: ResBlockConfig::new(1280, 1280, 1280).init(),
r2: ResBlockConfig::new(1280, 1280, 1280).init(),
};
let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init();
let output_blocks = UNetOutputBlocks {
r1: ResBlockConfig::new(2560, 1280, 1280).init(),
r2: ResBlockConfig::new(2560, 1280, 1280).init(),
ru: ResUpSampleConfig::new(2560, 1280, 1280).init(),
rt1: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rt2: ResTransformerConfig::new(2560, 1280, 1280, 768, 8).init(),
rtu1: ResTransformerUpsampleConfig::new(1920, 1280, 1280, 768, 8).init(),
rt3: ResTransformerConfig::new(1920, 1280, 640, 768, 8).init(),
rt4: ResTransformerConfig::new(1280, 1280, 640, 768, 8).init(),
rtu2: ResTransformerUpsampleConfig::new(960, 1280, 640, 768, 8).init(),
rt5: ResTransformerConfig::new(960, 1280, 320, 768, 8).init(),
rt6: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
rt7: ResTransformerConfig::new(640, 1280, 320, 768, 8).init(),
};
let norm_out = GroupNormConfig::new(32, 320).init();
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
UNet {
lin1_time_embed,
silu_time_embed,
lin2_time_embed,
input_blocks,
middle_block,
output_blocks,
norm_out,
silu_out,
conv_out,
}
}
}
#[derive(Module, Debug)]
pub struct UNet<B: Backend> {
lin1_time_embed: nn::Linear<B>,
silu_time_embed: SILU,
lin2_time_embed: nn::Linear<B>,
input_blocks: UNetInputBlocks<B>,
middle_block: ResTransformerRes<B>,
output_blocks: UNetOutputBlocks<B>,
norm_out: GroupNorm<B>,
silu_out: SILU,
conv_out: Conv2d<B>,
}
impl<B: Backend> UNet<B> {
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let t_emb = timestep_embedding(timesteps, 320, 10000);
let emb = self.lin1_time_embed.forward(t_emb);
let emb = self.silu_time_embed.forward(emb);
let emb = self.lin2_time_embed.forward(emb);
let mut saved_inputs = Vec::new();
let mut x = x;
// input blocks
for block in self.input_blocks.as_array() {
println!("{:?}", x.clone().flatten::<1>(0, 3).slice([0..100]).into_data());
x = block.forward(x, emb.clone(), context.clone());
saved_inputs.push(x.clone())
}
// middle block
x = self.middle_block.forward(x, emb.clone(), context.clone());
// output blocks
for block in self.output_blocks.as_array() {
x = Tensor::cat(vec![x, saved_inputs.pop().unwrap()], 1);
x = block.forward(x, emb.clone(), context.clone());
}
let x = self.norm_out.forward(x);
let x = self.silu_out.forward(x);
let x = self.conv_out.forward(x);
x
}
}
#[derive(Module, Debug)]
pub struct UNetInputBlocks<B: Backend> {
conv: Conv2d<B>,
rt1: ResTransformer<B>,
rt2: ResTransformer<B>,
d1: Downsample<B>,
rt3: ResTransformer<B>,
rt4: ResTransformer<B>,
d2: Downsample<B>,
rt5: ResTransformer<B>,
rt6: ResTransformer<B>,
d3: Downsample<B>,
r1: ResBlock<B>,
r2: ResBlock<B>,
}
impl<B: Backend> UNetInputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[
&self.conv,
&self.rt1,
&self.rt2,
&self.d1,
&self.rt3,
&self.rt4,
&self.d2,
&self.rt5,
&self.rt6,
&self.d3,
&self.r1,
&self.r2,
]
}
}
#[derive(Module, Debug)]
pub struct UNetOutputBlocks<B: Backend> {
r1: ResBlock<B>,
r2: ResBlock<B>,
ru: ResUpSample<B>,
rt1: ResTransformer<B>,
rt2: ResTransformer<B>,
rtu1: ResTransformerUpsample<B>,
rt3: ResTransformer<B>,
rt4: ResTransformer<B>,
rtu2: ResTransformerUpsample<B>,
rt5: ResTransformer<B>,
rt6: ResTransformer<B>,
rt7: ResTransformer<B>,
}
impl<B: Backend> UNetOutputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[
&self.r1,
&self.r2,
&self.ru,
&self.rt1,
&self.rt2,
&self.rtu1,
&self.rt3,
&self.rt4,
&self.rtu2,
&self.rt5,
&self.rt6,
&self.rt7,
]
}
}
trait UNetBlock<B: Backend> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
}
#[derive(Config)]
pub struct ResTransformerConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
ResTransformer {
res,
transformer,
}
}
}
#[derive(Module, Debug)]
pub struct ResTransformer<B: Backend> {
res: ResBlock<B>,
transformer: SpatialTransformer<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformer<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.transformer.forward(x, context);
x
}
}
#[derive(Config)]
pub struct ResUpSampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
}
impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
ResUpSample {
res,
upsample,
}
}
}
#[derive(Module, Debug)]
pub struct ResUpSample<B: Backend> {
res: ResBlock<B>,
upsample: Upsample<B>,
}
impl<B: Backend> UNetBlock<B> for ResUpSample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.upsample.forward(x);
x
}
}
#[derive(Config)]
pub struct ResTransformerUpsampleConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
let upsample = UpsampleConfig::new(self.n_channels_out).init();
ResTransformerUpsample {
res,
transformer,
upsample,
}
}
}
#[derive(Module, Debug)]
pub struct ResTransformerUpsample<B: Backend> {
res: ResBlock<B>,
transformer: SpatialTransformer<B>,
upsample: Upsample<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformerUpsample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res.forward(x, emb);
let x = self.transformer.forward(x, context);
let x = self.upsample.forward(x);
x
}
}
#[derive(Config)]
pub struct ResTransformerResConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
n_context_state: usize,
n_head: usize,
}
impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init();
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init();
ResTransformerRes {
res1,
transformer,
res2,
}
}
}
#[derive(Module, Debug)]
pub struct ResTransformerRes<B: Backend> {
res1: ResBlock<B>,
transformer: SpatialTransformer<B>,
res2: ResBlock<B>,
}
impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let x = self.res1.forward(x, emb.clone());
let x = self.transformer.forward(x, context);
let x = self.res2.forward(x, emb);
x
}
}
#[derive(Config)]
pub struct UpsampleConfig {
n_channels: usize,
}
impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
Upsample {
conv,
}
}
}
#[derive(Module, Debug)]
pub struct Upsample<B: Backend> {
conv: Conv2d<B>,
}
impl<B: Backend> Upsample<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let x = x
.reshape([n_batch, n_channel, height, 1, width, 1])
.repeat(3, 2)
.repeat(5, 2)
.reshape([n_batch, n_channel, 2 * height, 2 * width]);
self.conv.forward(x)
}
}
impl<B: Backend> UNetBlock<B> for Upsample<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x)
}
}
#[derive(Config)]
pub struct DownsampleConfig {
n_channels: usize,
}
impl DownsampleConfig {
fn init<B: Backend>(&self) -> Conv2d<B> {
Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init()
}
}
type Downsample<B> = Conv2d<B>;
impl<B: Backend> UNetBlock<B> for Conv2d<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x)
}
}
#[derive(Config)]
pub struct SpatialTransformerConfig {
n_channels: usize,
n_context_state: usize,
n_head: usize,
}
impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init();
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
SpatialTransformer {
norm,
proj_in,
transformer,
proj_out,
}
}
}
#[derive(Module, Debug)]
pub struct SpatialTransformer<B: Backend> {
norm: GroupNorm<B>,
proj_in: Conv2d<B>,
transformer: TransformerBlock<B>,
proj_out: Conv2d<B>,
}
impl<B: Backend> SpatialTransformer<B> {
fn forward(&self, x: Tensor<B, 4>, context: Tensor<B, 3>) -> Tensor<B, 4> {
let [n_batch, n_channel, height, width] = x.dims();
let x_in = x.clone();
let x = self.norm.forward(x);
let x = self.proj_in.forward(x);
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2);
let x = self.transformer.forward(x, context)
.swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]);
x_in + self.proj_out.forward(x)
}
}
#[derive(Config)]
pub struct TransformerBlockConfig {
n_state: usize,
n_context_state: usize,
n_head: usize,
}
impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm3 = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4).init();
TransformerBlock {
norm1,
attn1,
norm2,
attn2,
norm3,
mlp,
}
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
norm1: nn::LayerNorm<B>,
attn1: MultiHeadAttention<B>,
norm2: nn::LayerNorm<B>,
attn2: MultiHeadAttention<B>,
norm3: nn::LayerNorm<B>,
mlp: MLP<B>,
}
impl<B: Backend> TransformerBlock<B> {
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None);
let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context));
x.clone() + self.mlp.forward( self.norm3.forward(x) )
}
}
#[derive(Config)]
pub struct MLPConfig {
n_state: usize,
mult: usize,
}
impl MLPConfig {
pub fn init<B: Backend>(&self) -> MLP<B> {
let n_state_hidden = self.n_state * self.mult;
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
MLP {
geglu,
lin,
}
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
geglu: GEGLU<B>,
lin: nn::Linear<B>,
}
impl<B: Backend> MLP<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.lin.forward( self.geglu.forward(x) )
}
}
#[derive(Config)]
pub struct GEGLUConfig {
n_state_in: usize,
n_state_out: usize,
}
impl GEGLUConfig {
fn init<B: Backend>(&self) -> GEGLU<B> {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
let gelu = GELU::new();
GEGLU {
proj,
gelu,
}
}
}
#[derive(Module, Debug)]
pub struct GEGLU<B: Backend> {
proj: nn::Linear<B>,
gelu: GELU,
}
impl<B: Backend> GEGLU<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let projected = self.proj.forward(x);
let [n_batch, n_ctx, n_state] = projected.dims();
let n_state_out = n_state / 2;
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
x * self.gelu.forward(gate)
}
}
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
n_state: usize,
n_context_state: usize,
n_head: usize,
}
impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init();
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
MultiHeadAttention {
n_head,
query,
key,
value,
out
}
}
}
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
n_head: usize,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
out: nn::Linear<B>,
}
impl<B: Backend> MultiHeadAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>, context: Option<Tensor<B, 3>>) -> Tensor<B, 3> {
let xa = context.unwrap_or_else(|| x.clone());
let q = self.query.forward(x);
let k = self.key.forward(xa.clone());
let v = self.value.forward(xa);
let wv = qkv_attention(q, k, v, None, self.n_head);
self.out.forward(wv)
}
}
#[derive(Config)]
pub struct ResBlockConfig {
n_channels_in: usize,
n_channels_embed: usize,
n_channels_out: usize,
}
impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init();
let skip_connection = if self.n_channels_in != self.n_channels_out {
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() )
} else {
None
};
ResBlock {
norm_in,
silu_in,
conv_in,
silu_embed,
lin_embed,
norm_out,
silu_out,
conv_out,
skip_connection,
}
}
}
#[derive(Module, Debug)]
pub struct ResBlock<B: Backend> {
norm_in: GroupNorm<B>,
silu_in: SILU,
conv_in: Conv2d<B>,
silu_embed: SILU,
lin_embed: nn::Linear<B>,
norm_out: GroupNorm<B>,
silu_out: SILU,
conv_out: Conv2d<B>,
skip_connection: Option<Conv2d<B>>,
}
impl<B: Backend> ResBlock<B> {
fn forward(&self, x: Tensor<B, 4>, embed: Tensor<B, 2>) -> Tensor<B, 4> {
let h = self.norm_in.forward(x.clone());
let h = self.silu_in.forward(h);
let h = self.conv_in.forward(h);
let embed_out = self.silu_embed.forward(embed);
let embed_out = self.lin_embed.forward(embed_out);
let [n_batch_embed, n_state_embed] = embed_out.dims();
let h = h + embed_out.reshape([n_batch_embed, n_state_embed, 1, 1]);
let h = self.norm_out.forward(h);
let h = self.silu_out.forward(h);
let h = self.conv_out.forward(h);
if let Some(skipc) = self.skip_connection.as_ref() {
skipc.forward(x) + h
} else {
x + h
}
}
}
impl<B: Backend> UNetBlock<B> for ResBlock<B> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4> {
self.forward(x, emb)
}
}