Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
committed by Ben_Kosytorz
parent 167e45fc30
commit 32a3ad9b3c
20 changed files with 1091 additions and 950 deletions

BIN
img0.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 671 KiB

After

Width:  |  Height:  |  Size: 677 KiB

View File

@@ -1,26 +1,27 @@
use std::env; use std::env;
use std::process;
use std::error::Error; use std::error::Error;
use std::process;
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion}; use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use burn_ndarray::{NdArrayBackend, NdArrayDevice}; use burn_ndarray::{NdArrayBackend, NdArrayDevice};
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> { fn convert_dump_to_model<B: Backend>(
dump_path: &str,
model_name: &str,
device: &B::Device,
) -> Result<(), Box<dyn Error>> {
println!("Loading dump..."); println!("Loading dump...");
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?; let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model..."); println!("Saving model...");
save_model_file(model, model_name)?; save_model_file(model, model_name)?;
@@ -28,12 +29,11 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
Ok(()) Ok(())
} }
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> { fn save_model_file<B: Backend>(
BinFileRecorder::<FullPrecisionSettings>::new() model: StableDiffusion<B>,
.record( name: &str,
model.into_record(), ) -> Result<(), record::RecorderError> {
name.into(), BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
)
} }
fn main() { fn main() {

View File

@@ -1,13 +1,13 @@
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}}; use stablediffusion::{
model::stablediffusion::{load::load_stable_diffusion, *},
tokenizer::SimpleTokenizer,
};
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
cfg_if::cfg_if! { cfg_if::cfg_if! {
@@ -22,9 +22,11 @@ use std::env;
use std::io; use std::io;
use std::process; use std::process;
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> { fn load_stable_diffusion_model_file<B: Backend>(
filename: &str,
) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new() BinFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into()) .load(filename.into())
.map(|record| StableDiffusionConfig::new().init().load_record(record)) .map(|record| StableDiffusionConfig::new().init().load_record(record))
@@ -78,17 +80,22 @@ fn main() {
let sd = sd.to_device(&device); let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer); let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
println!("Sampling image..."); println!("Sampling image...");
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); let images = sd.sample_image(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| { save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
eprintln!("Error saving image: {}", err); eprintln!("Error saving image: {}", err);
process::exit(1); process::exit(1);
}); });
} }
use image::{self, ImageResult, ColorType::Rgb8}; use image::{self, ColorType::Rgb8, ImageResult};
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> { fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
for (index, img_data) in images.iter().enumerate() { for (index, img_data) in images.iter().enumerate() {
@@ -103,12 +110,15 @@ fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -
fn save_test_image() -> ImageResult<()> { fn save_test_image() -> ImageResult<()> {
let width = 256; let width = 256;
let height = 256; let height = 256;
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| { let raw: Vec<_> = (0..width * height)
.into_iter()
.flat_map(|i| {
let row = i / width; let row = i / width;
let red = (255.0 * row as f64 / height as f64) as u8; let red = (255.0 * row as f64 / height as f64) as u8;
[red, 0, 0] [red, 0, 0]
}).collect(); })
.collect();
image::save_buffer("red.png", &raw[..], width, height, Rgb8) image::save_buffer("red.png", &raw[..], width, height, Rgb8)
} }

View File

@@ -1,87 +0,0 @@
use burn::{
tensor::{
backend::Backend,
activation::relu,
Tensor,
Int,
Bool,
Float,
TensorKind,
BasicOps,
Numeric,
Element,
},
};
use num_traits::ToPrimitive;
pub fn tensor_max_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, max: f64) -> Tensor<B, D> {
relu(x.sub_scalar(max)).add_scalar(max)
}
pub fn tensor_min_scalar<B: Backend, const D: usize>(x: Tensor<B, D>, min: f64) -> Tensor<B, D> {
-tensor_max_scalar(-x, -min)
}
pub fn tensor_max<B: Backend, const D: usize>(x: Tensor<B, D>, max: Tensor<B, D>) -> Tensor<B, D> {
relu(x - max.clone()) + max
}
pub fn tensor_min<B: Backend, const D: usize>(x: Tensor<B, D>, min: Tensor<B, D>) -> Tensor<B, D> {
-tensor_max(-x, -min)
}
pub fn tensor_log10<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let ln10 = (10.0f64).ln();
x.log() / ln10
}
pub fn tensor_max_element<B: Backend, const D: usize>(x: Tensor<B, D>) -> f64 {
let flat: Tensor<B, 1> = x.flatten(0, D - 1);
let max_index = flat.clone().argmax(0);
flat.select(0, max_index).into_scalar().to_f64().unwrap()
}
pub fn all_zeros<B: Backend, const D: usize>(x: Tensor<B, D>) -> bool {
x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0
}
pub fn max_dim<B: Backend>(x: Tensor<B, 2>, dim: usize) -> Tensor<B, 2> {
let indices = x.clone().argmax(dim).flatten(0, 1);
x.select(dim, indices)
}
pub fn _10pow<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let log10 = (10.0f64).ln();
(x * log10).exp()
}
pub fn to_float<B: Backend, const D: usize>(x: Tensor<B, D, Int>) -> Tensor<B, D, Float> {
let device = x.device();
Tensor::from_data(
x
.into_data()
.convert()
).to_device(&device)
}
pub fn to_float_bool<B: Backend, const D: usize>(x: Tensor<B, D, Bool>) -> Tensor<B, D, Float> {
let device = x.device();
Tensor::from_data(
x
.into_data()
.convert()
).to_device(&device)
}
pub fn reverse<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B> + Numeric<B>>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> where <K as BasicOps<B>>::Elem: Element {
let len = x.dims()[dim];
let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64;
x.select(dim, indices)
}
pub fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}

View File

@@ -1,3 +1,2 @@
pub mod model; pub mod model;
pub mod tokenizer; pub mod tokenizer;
pub mod helper;

View File

@@ -1,23 +1,32 @@
use burn::{ use burn::tensor::{activation::softmax, backend::Backend, Tensor};
tensor::{
backend::Backend,
activation::softmax,
Tensor,
},
};
use std::f32::NEG_INFINITY; use std::f32::NEG_INFINITY;
pub fn qkv_attention<B: Backend>(q: Tensor<B, 3>, k: Tensor<B, 3>, v: Tensor<B, 3>, mask: Option<Tensor<B, 2>>, n_head: usize) -> Tensor<B, 3> { pub fn qkv_attention<B: Backend>(
q: Tensor<B, 3>,
k: Tensor<B, 3>,
v: Tensor<B, 3>,
mask: Option<Tensor<B, 2>>,
n_head: usize,
) -> Tensor<B, 3> {
let [n_batch, n_qctx, n_state] = q.dims(); let [n_batch, n_qctx, n_state] = q.dims();
let [_, n_ctx, _] = k.dims(); let [_, n_ctx, _] = k.dims();
let scale = (n_state as f64 / n_head as f64).powf(-0.25); let scale = (n_state as f64 / n_head as f64).powf(-0.25);
let n_hstate = n_state / n_head; let n_hstate = n_state / n_head;
let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale; let q = q
let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale; .reshape([n_batch, n_qctx, n_head, n_hstate])
let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2); .swap_dims(1, 2)
* scale;
let k = k
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2)
.transpose()
* scale;
let v = v
.reshape([n_batch, n_ctx, n_head, n_hstate])
.swap_dims(1, 2);
let qk = q.matmul(k); let qk = q.matmul(k);

View File

@@ -7,26 +7,35 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::groupnorm::load::load_group_norm; use crate::model::groupnorm::load::load_group_norm;
fn load_conv_self_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> { fn load_conv_self_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ConvSelfAttentionBlock<B>, Box<dyn Error>> {
let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?; let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?;
let q = load_conv2d(&format!("{}/{}", path, "q"), device)?; let q = load_conv2d(&format!("{}/{}", path, "q"), device)?;
let k = load_conv2d(&format!("{}/{}", path, "k"), device)?; let k = load_conv2d(&format!("{}/{}", path, "k"), device)?;
let v = load_conv2d(&format!("{}/{}", path, "v"), device)?; let v = load_conv2d(&format!("{}/{}", path, "v"), device)?;
let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?; let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?;
Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out }) Ok(ConvSelfAttentionBlock {
norm,
q,
k,
v,
proj_out,
})
} }
fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResnetBlock<B>, Box<dyn Error>> { fn load_resnet_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResnetBlock<B>, Box<dyn Error>> {
let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?; let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?;
let silu1 = SILU {}; let silu1 = SILU {};
let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?; let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?;
@@ -35,7 +44,15 @@ fn load_resnet_block<B: Backend>(path: &str, device: &B::Device) -> Result<Resne
let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?; let conv2 = load_conv2d(&format!("{}/{}", path, "conv2"), device)?;
let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok(); let nin_shortcut = load_conv2d(&format!("{}/{}", path, "nin_shortcut"), device).ok();
Ok(ResnetBlock { norm1, silu1, conv1, norm2, silu2, conv2, nin_shortcut }) Ok(ResnetBlock {
norm1,
silu1,
conv1,
norm2,
silu2,
conv2,
nin_shortcut,
})
} }
fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> { fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dyn Error>> {
@@ -43,10 +60,17 @@ fn load_mid<B: Backend>(path: &str, device: &B::Device) -> Result<Mid<B>, Box<dy
let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?; let attn = load_conv_self_attention_block(&format!("{}/{}", path, "attn"), device)?;
let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?; let block_2 = load_resnet_block(&format!("{}/{}", path, "block_2"), device)?;
Ok(Mid { block_1, attn, block_2 }) Ok(Mid {
block_1,
attn,
block_2,
})
} }
fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<PaddedConv2d<B>, Box<dyn Error>> { fn load_padded_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<PaddedConv2d<B>, Box<dyn Error>> {
let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?;
let channels = load_tensor::<B, 1>("channels", path, device)?; let channels = load_tensor::<B, 1>("channels", path, device)?;
@@ -61,31 +85,48 @@ fn load_padded_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<Padd
let mut record = conv.into_record(); let mut record = conv.into_record();
let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init(); let mut padded_conv: PaddedConv2d<B> = PaddedConv2dConfig::new(channels, kernel_size, padding)
let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); .with_stride(stride)
.init();
let padding_actual =
PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]);
record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual); record.padding = <PaddingConfig2d as Module<B>>::into_record(padding_actual);
padded_conv.conv = padded_conv.conv.load_record(record); padded_conv.conv = padded_conv.conv.load_record(record);
Ok(padded_conv) Ok(padded_conv)
} }
fn load_decoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<DecoderBlock<B>, Box<dyn Error>> { fn load_decoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<DecoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?; let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?;
let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok(); let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok();
Ok(DecoderBlock { res1, res2, res3, upsampler }) Ok(DecoderBlock {
res1,
res2,
res3,
upsampler,
})
} }
fn load_encoder_block<B: Backend>(path: &str, device: &B::Device) -> Result<EncoderBlock<B>, Box<dyn Error>> { fn load_encoder_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<EncoderBlock<B>, Box<dyn Error>> {
let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?;
let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?;
let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok(); let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok();
Ok(EncoderBlock { res1, res2, downsampler }) Ok(EncoderBlock {
res1,
res2,
downsampler,
})
} }
fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> { fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>, Box<dyn Error>> {
@@ -95,15 +136,21 @@ fn load_decoder<B: Backend>(path: &str, device: &B::Device) -> Result<Decoder<B>
let n_block = load_usize::<B>("n_block", path, device)?; let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block) let mut blocks = (0..n_block)
.into_iter() .into_iter()
.map(|i| { .map(|i| load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
load_decoder_block::<B>(&format!("{}/blocks/{}", path, i), device) .collect::<Result<Vec<_>, _>>()?;
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {}; let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out }) Ok(Decoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
} }
fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> { fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>, Box<dyn Error>> {
@@ -113,22 +160,36 @@ fn load_encoder<B: Backend>(path: &str, device: &B::Device) -> Result<Encoder<B>
let n_block = load_usize::<B>("n_block", path, device)?; let n_block = load_usize::<B>("n_block", path, device)?;
let mut blocks = (0..n_block) let mut blocks = (0..n_block)
.into_iter() .into_iter()
.map(|i| { .map(|i| load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device))
load_encoder_block::<B>(&format!("{}/blocks/{}", path, i), device) .collect::<Result<Vec<_>, _>>()?;
}).collect::<Result<Vec<_>, _>>()?;
let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?;
let silu = SILU {}; let silu = SILU {};
let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?;
Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out }) Ok(Encoder {
conv_in,
mid,
blocks,
norm_out,
silu,
conv_out,
})
} }
pub fn load_autoencoder<B: Backend>(path: &str, device: &B::Device) -> Result<Autoencoder<B>, Box<dyn Error>> { pub fn load_autoencoder<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Autoencoder<B>, Box<dyn Error>> {
let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?; let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?;
let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?; let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?;
let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?; let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?;
let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?; let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?;
Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv }) Ok(Autoencoder {
encoder,
decoder,
quant_conv,
post_quant_conv,
})
} }

View File

@@ -3,33 +3,34 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}}, nn::{
self,
conv::{Conv2d, Conv2dConfig, Conv2dRecord},
PaddingConfig2d,
},
tensor::{ tensor::{
activation::{sigmoid, softmax},
backend::Backend, backend::Backend,
activation::{softmax, sigmoid},
module::embedding, module::embedding,
Tensor, Distribution, Int, Tensor,
Distribution,
Int,
}, },
}; };
use crate::helper::div_roundup;
use super::silu::*;
use super::groupnorm::*;
use super::attention::qkv_attention; use super::attention::qkv_attention;
use super::groupnorm::*;
use super::silu::*;
use std::iter; use std::iter;
#[derive(Config)] #[derive(Config)]
pub struct AutoencoderConfig {} pub struct AutoencoderConfig {}
impl AutoencoderConfig { impl AutoencoderConfig {
pub fn init<B: Backend>(&self) -> Autoencoder<B> { pub fn init<B: Backend>(&self) -> Autoencoder<B> {
let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); let encoder =
let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init();
let decoder =
DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init();
let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(); let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init();
let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(); let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init();
@@ -42,7 +43,6 @@ impl AutoencoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Autoencoder<B: Backend> { pub struct Autoencoder<B: Backend> {
encoder: Encoder<B>, encoder: Encoder<B>,
@@ -53,7 +53,7 @@ pub struct Autoencoder<B: Backend> {
impl<B: Backend> Autoencoder<B> { impl<B: Backend> Autoencoder<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.decode_latent( self.encode_image(x) ) self.decode_latent(self.encode_image(x))
} }
pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { pub fn encode_image(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
@@ -79,20 +79,33 @@ pub struct EncoderConfig {
impl EncoderConfig { impl EncoderConfig {
fn init<B: Backend>(&self) -> Encoder<B> { fn init<B: Backend>(&self) -> Encoder<B> {
let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty."); let n_expanded_channels_initial = self
.channels
.first()
.map(|f| f.1)
.expect("Channels must not be empty.");
let n_expanded_channels_final = self.channels.first().unwrap().0; let n_expanded_channels_final = self.channels.first().unwrap().0;
let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let downsample = i != self.channels.len() - 1; let downsample = i != self.channels.len() - 1;
EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init()
}).collect(); })
.collect();
let mid = MidConfig::new(n_expanded_channels_final).init(); let mid = MidConfig::new(n_expanded_channels_final).init();
let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(); let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init();
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
Encoder { Encoder {
conv_in, conv_in,
@@ -105,7 +118,6 @@ impl EncoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Encoder<B: Backend> { pub struct Encoder<B: Backend> {
conv_in: Conv2d<B>, conv_in: Conv2d<B>,
@@ -126,12 +138,11 @@ impl<B: Backend> Encoder<B> {
} }
let x = self.mid.forward(x); let x = self.mid.forward(x);
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct DecoderConfig { pub struct DecoderConfig {
channels: Vec<(usize, usize)>, channels: Vec<(usize, usize)>,
@@ -140,20 +151,33 @@ pub struct DecoderConfig {
impl DecoderConfig { impl DecoderConfig {
fn init<B: Backend>(&self) -> Decoder<B> { fn init<B: Backend>(&self) -> Decoder<B> {
let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty."); let n_expanded_channels = self
.channels
.first()
.map(|f| f.0)
.expect("Channels must not be empty.");
let n_condensed_channels = self.channels.last().unwrap().1; let n_condensed_channels = self.channels.last().unwrap().1;
let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let mid = MidConfig::new(n_expanded_channels).init(); let mid = MidConfig::new(n_expanded_channels).init();
let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { let blocks = self
.channels
.iter()
.enumerate()
.map(|(i, &(n_channel_in, n_channel_out))| {
let upsample = i != self.channels.len() - 1; let upsample = i != self.channels.len() - 1;
DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init()
}).collect(); })
.collect();
let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(); let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init();
let silu = SILU::new(); let silu = SILU::new();
let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
Decoder { Decoder {
conv_in, conv_in,
@@ -166,7 +190,6 @@ impl DecoderConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct Decoder<B: Backend> { pub struct Decoder<B: Backend> {
conv_in: Conv2d<B>, conv_in: Conv2d<B>,
@@ -187,7 +210,8 @@ impl<B: Backend> Decoder<B> {
x = block.forward(x); x = block.forward(x);
} }
self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) self.conv_out
.forward(self.silu.forward(self.norm_out.forward(x)))
} }
} }
@@ -204,7 +228,11 @@ impl EncoderBlockConfig {
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let downsampler = if self.downsample { let downsampler = if self.downsample {
let padding = Padding::new(0, 1, 0, 1); let padding = Padding::new(0, 1, 0, 1);
Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() ) Some(
PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding)
.with_stride(2)
.init(),
)
} else { } else {
None None
}; };
@@ -249,7 +277,11 @@ impl DecoderBlockConfig {
let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init();
let upsampler = if self.upsample { let upsampler = if self.upsample {
Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() ) Some(
Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
)
} else { } else {
None None
}; };
@@ -291,7 +323,6 @@ impl<B: Backend> DecoderBlock<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct PaddedConv2dConfig { pub struct PaddedConv2dConfig {
channels: [usize; 2], channels: [usize; 2],
@@ -337,6 +368,10 @@ impl PaddedConv2dConfig {
} }
} }
fn div_roundup(x: usize, y: usize) -> usize {
(x + y - 1) / y
}
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct PaddedConv2d<B: Backend> { pub struct PaddedConv2d<B: Backend> {
conv: Conv2d<B>, conv: Conv2d<B>,
@@ -348,22 +383,29 @@ pub struct PaddedConv2d<B: Backend> {
impl<B: Backend> PaddedConv2d<B> { impl<B: Backend> PaddedConv2d<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual); println!(
"{} {} {:?} {:?}",
self.kernel_size, self.stride, self.padding, self.padding_actual
);
let [n_batch, n_channel, height, width] = x.dims(); let [n_batch, n_channel, height, width] = x.dims();
let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1; let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height
let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1; - self.kernel_size)
/ self.stride
+ 1;
let desired_width = (self.padding.pad_left + self.padding.pad_right + width
- self.kernel_size)
/ self.stride
+ 1;
let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride; let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride;
let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride; let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride;
self.conv self.conv.forward(x).slice([
.forward(x)
.slice([
0..n_batch, 0..n_batch,
0..n_channel, 0..n_channel,
skip_vert..(skip_vert + desired_height), skip_vert..(skip_vert + desired_height),
skip_hor..(skip_hor + desired_width) skip_hor..(skip_hor + desired_width),
]) ])
} }
} }
@@ -411,7 +453,6 @@ impl<B: Backend> Mid<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct ResnetBlockConfig { pub struct ResnetBlockConfig {
in_channels: usize, in_channels: usize,
@@ -421,11 +462,15 @@ pub struct ResnetBlockConfig {
impl ResnetBlockConfig { impl ResnetBlockConfig {
fn init<B: Backend>(&self) -> ResnetBlock<B> { fn init<B: Backend>(&self) -> ResnetBlock<B> {
let norm1 = GroupNormConfig::new(32, self.in_channels).init(); let norm1 = GroupNormConfig::new(32, self.in_channels).init();
let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let norm2 = GroupNormConfig::new(32, self.out_channels).init(); let norm2 = GroupNormConfig::new(32, self.out_channels).init();
let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let nin_shortcut = if self.in_channels != self.out_channels { let nin_shortcut = if self.in_channels != self.out_channels {
Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() ) Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init())
} else { } else {
None None
}; };
@@ -458,9 +503,12 @@ pub struct ResnetBlock<B: Backend> {
impl<B: Backend> ResnetBlock<B> { impl<B: Backend> ResnetBlock<B> {
fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) ); let h = self
let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) ); .conv1
.forward(self.silu1.forward(self.norm1.forward(x.clone())));
let h = self
.conv2
.forward(self.silu2.forward(self.norm2.forward(h)));
if let Some(ns) = self.nin_shortcut.as_ref() { if let Some(ns) = self.nin_shortcut.as_ref() {
ns.forward(x) + h ns.forward(x) + h
@@ -508,9 +556,21 @@ impl<B: Backend> ConvSelfAttentionBlock<B> {
let h = self.norm.forward(x.clone()); let h = self.norm.forward(x.clone());
let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); let q = self
let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); .q
let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); .forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let k = self
.k
.forward(h.clone())
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let v = self
.v
.forward(h)
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let wv = qkv_attention(q, k, v, None, 1) let wv = qkv_attention(q, k, v, None, 1)
.swap_dims(1, 2) .swap_dims(1, 2)

View File

@@ -1,14 +1,11 @@
use std::error::Error;
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
@@ -28,7 +25,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
Ok(mlp) Ok(mlp)
} }
pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> { pub fn load_multi_head_self_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadSelfAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?; let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear(&format!("{}/{}", path, "query"), device)?; let query = load_linear(&format!("{}/{}", path, "query"), device)?;
let key = load_linear(&format!("{}/{}", path, "key"), device)?; let key = load_linear(&format!("{}/{}", path, "key"), device)?;
@@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention<B: Backend>(path: &str, device: &B::Device
Ok(mhsa) Ok(mhsa)
} }
pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> { pub fn load_residual_decoder_attention_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResidualDecoderAttentionBlock<B>, Box<dyn Error>> {
let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?; let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?;
let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?; let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?;
let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?; let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?;
@@ -64,14 +67,16 @@ pub fn load_residual_decoder_attention_block<B: Backend>(path: &str, device: &B:
pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> { pub fn load_clip<B: Backend>(path: &str, device: &B::Device) -> Result<CLIP<B>, Box<dyn Error>> {
let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?;
let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); let position_embedding =
load_tensor("weight", &format!("{}/position_embedding", path), device)?.into();
let n_layer = load_usize::<B>("n_layer", path, device)?; let n_layer = load_usize::<B>("n_layer", path, device)?;
let mut blocks = (0..n_layer) let mut blocks = (0..n_layer)
.into_iter() .into_iter()
.map(|i| { .map(|i| {
load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device) load_residual_decoder_attention_block::<B>(&format!("{}/blocks/{}", path, i), device)
}).collect::<Result<Vec<_>, _>>()?; })
.collect::<Result<Vec<_>, _>>()?;
let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?; let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?;

View File

@@ -5,17 +5,14 @@ use burn::{
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{
activation::{sigmoid, softmax},
backend::Backend, backend::Backend,
activation::{softmax, sigmoid},
module::embedding, module::embedding,
Tensor, Distribution, Int, Tensor,
Distribution,
Int,
}, },
}; };
use crate::model::attention::{qkv_attention, attn_decoder_mask}; use crate::model::attention::{attn_decoder_mask, qkv_attention};
#[derive(Config)] #[derive(Config)]
pub struct CLIPConfig { pub struct CLIPConfig {
@@ -29,7 +26,8 @@ pub struct CLIPConfig {
impl CLIPConfig { impl CLIPConfig {
pub fn init<B: Backend>(&self) -> CLIP<B> { pub fn init<B: Backend>(&self) -> CLIP<B> {
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); let position_embedding =
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
let blocks = (0..self.n_layer) let blocks = (0..self.n_layer)
.into_iter() .into_iter()
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init()) .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
@@ -45,8 +43,6 @@ impl CLIPConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct CLIP<B: Backend> { pub struct CLIP<B: Backend> {
token_embedding: nn::Embedding<B>, token_embedding: nn::Embedding<B>,
@@ -62,7 +58,11 @@ impl<B: Backend> CLIP<B> {
let mask = attn_decoder_mask(seq_len, &x.device()); let mask = attn_decoder_mask(seq_len, &x.device());
let embedded = self.token_embedding.forward(x) let embedded = self.token_embedding.forward(x)
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze(); + self
.position_embedding
.val()
.slice([0..seq_len])
.unsqueeze();
let mut x = embedded; let mut x = embedded;
for block in &self.blocks { for block in &self.blocks {
@@ -73,8 +73,6 @@ impl<B: Backend> CLIP<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct ResidualDecoderAttentionBlockConfig { pub struct ResidualDecoderAttentionBlockConfig {
n_state: usize, n_state: usize,
@@ -122,7 +120,12 @@ pub struct MultiHeadSelfAttentionConfig {
impl MultiHeadSelfAttentionConfig { impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> { fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).init(); let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
@@ -135,7 +138,7 @@ impl MultiHeadSelfAttentionConfig {
query, query,
key, key,
value, value,
out out,
} }
} }
} }
@@ -161,13 +164,6 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
} }
} }
#[derive(Config, Debug)] #[derive(Config, Debug)]
pub struct MLPConfig { pub struct MLPConfig {
input_size: usize, input_size: usize,
@@ -180,11 +176,7 @@ impl MLPConfig {
let gelu = QuickGELU::new(); let gelu = QuickGELU::new();
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(); let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
MLP { MLP { fc1, gelu, fc2 }
fc1,
gelu,
fc2,
}
} }
} }
@@ -217,4 +209,3 @@ impl QuickGELU {
x.clone() * sigmoid(x * 1.702) x.clone() * sigmoid(x * 1.702)
} }
} }

View File

@@ -7,27 +7,31 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> { pub fn load_group_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<GroupNorm<B>, Box<dyn Error>> {
let n_group = load_usize::<B>("n_group", path, device)?.into(); let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channel = load_usize::<B>("n_channel", path, device)?.into(); let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into(); let gamma = load_tensor::<B, 1>("weight", path, device)
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into(); .ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
.into();
let beta = load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
.into();
Ok( Ok(GroupNorm {
GroupNorm {
n_group, n_group,
n_channel, n_channel,
gamma, gamma,
beta, beta,
eps, eps,
} })
)
} }

View File

@@ -3,10 +3,7 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
#[derive(Config)] #[derive(Config)]
@@ -19,7 +16,12 @@ pub struct GroupNormConfig {
impl GroupNormConfig { impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> { pub fn init<B: Backend>(&self) -> GroupNorm<B> {
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group); assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
self.n_channel,
self.n_group
);
let n_per_group = self.n_channel / self.n_group; let n_per_group = self.n_channel / self.n_group;
@@ -56,7 +58,14 @@ impl<B: Backend> GroupNorm<B> {
let mut affine_shape = [1; D]; let mut affine_shape = [1; D];
affine_shape[1] = self.n_channel; affine_shape[1] = self.n_channel;
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps ) layernorm(
x.reshape([
n_batch,
self.n_group,
num_elements / (n_batch * self.n_group),
]),
self.eps,
)
.reshape(shape) .reshape(shape)
.mul(self.gamma.val().reshape(affine_shape)) .mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape)) .add(self.beta.val().reshape(affine_shape))
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
//x.sub(mean).div(var.sqrt().add_scalar(eps)) //x.sub(mean).div(var.sqrt().add_scalar(eps))
let u = x.clone() - x.mean_dim(D - 1); let u = x.clone() - x.mean_dim(D - 1);
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() ) u.clone()
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
} }

View File

@@ -1,22 +1,21 @@
use std::error::Error;
use std::io::Read;
use npy::{self, NpyData}; use npy::{self, NpyData};
use num_traits::cast::ToPrimitive; use num_traits::cast::ToPrimitive;
use std::error::Error;
use std::io::Read;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, conv}, nn::{self, conv},
tensor::{ tensor::{backend::Backend, Data, Tensor},
backend::Backend,
Tensor,
Data,
},
}; };
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, device: &B::Device) -> Tensor<B, D> { pub fn numpy_to_tensor<B: Backend, const D: usize>(
numpy_data: NpyData<f32>,
device: &B::Device,
) -> Tensor<B, D> {
let mut v = numpy_data.to_vec(); let mut v = numpy_data.to_vec();
let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect();
@@ -25,12 +24,15 @@ pub fn numpy_to_tensor<B: Backend, const D: usize>(numpy_data: NpyData<f32>, dev
Tensor::from_data_device(Data::new(data, shape.into()), device) Tensor::from_data_device(Data::new(data, shape.into()), device)
} }
pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &B::Device) -> Result<Tensor<B, D>, Box<dyn Error>> { pub fn load_tensor<B: Backend, const D: usize>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<Tensor<B, D>, Box<dyn Error>> {
let tensor_path = format!("{}/{}.npy", path, name); let tensor_path = format!("{}/{}.npy", path, name);
let mut buf = vec![]; let mut buf = vec![];
std::fs::File::open(&tensor_path)? std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?;
.read_to_end(&mut buf)?;
let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?; let tensor_numpy: NpyData<f32> = NpyData::from_bytes(&buf)?;
@@ -41,15 +43,26 @@ pub fn load_tensor<B: Backend, const D: usize>(name: &str, path: &str, device: &
Ok(tensor) Ok(tensor)
} }
pub fn load_f32<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<f32, Box<dyn Error>> { pub fn load_f32<B: Backend>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<f32, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_f32().unwrap())
} }
pub fn load_usize<B: Backend>(name: &str, path: &str, device: &B::Device) -> Result<usize, Box<dyn Error>> { pub fn load_usize<B: Backend>(
name: &str,
path: &str,
device: &B::Device,
) -> Result<usize, Box<dyn Error>> {
load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) load_tensor::<B, 1>(name, path, device).map(|t| t.into_scalar().to_usize().unwrap())
} }
pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Linear<B>, Box<dyn Error>> { pub fn load_linear<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Linear<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok(); let bias = load_tensor::<B, 1>("bias", path, device).ok();
@@ -62,7 +75,10 @@ pub fn load_linear<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Lin
Ok(linear) Ok(linear)
} }
pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::Embedding<B>, Box<dyn Error>> { pub fn load_embedding<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::Embedding<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 2>("weight", path, device)?; let weight = load_tensor::<B, 2>("weight", path, device)?;
let [n_vocab, n_state] = weight.dims(); let [n_vocab, n_state] = weight.dims();
@@ -74,7 +90,10 @@ pub fn load_embedding<B: Backend>(path: &str, device: &B::Device) -> Result<nn::
Ok(embedding) Ok(embedding)
} }
pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn::LayerNorm<B>, Box<dyn Error>> { pub fn load_layer_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<nn::LayerNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?; let weight = load_tensor::<B, 1>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device)?; let bias = load_tensor::<B, 1>("bias", path, device)?;
let eps = load_f32::<B>("eps", path, device)? as f64; let eps = load_f32::<B>("eps", path, device)? as f64;
@@ -92,7 +111,6 @@ pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn:
Ok(layer_norm) Ok(layer_norm)
} }
/*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> { /*pub fn load_rmsnorm<B: Backend>(path: &str, device: &B::Device) -> Result<RMSNorm<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 1>("weight", path, device)?; let weight = load_tensor::<B, 1>("weight", path, device)?;
let eps = load_f32::<B>("eps", path, device)?.into(); let eps = load_f32::<B>("eps", path, device)?.into();
@@ -105,7 +123,10 @@ pub fn load_layer_norm<B: Backend>(path: &str, device: &B::Device) -> Result<nn:
Ok(rmsnorm) Ok(rmsnorm)
}*/ }*/
pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::Conv2d<B>, Box<dyn Error>> { pub fn load_conv2d<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<conv::Conv2d<B>, Box<dyn Error>> {
let weight = load_tensor::<B, 4>("weight", path, device)?; let weight = load_tensor::<B, 4>("weight", path, device)?;
let bias = load_tensor::<B, 1>("bias", path, device).ok(); let bias = load_tensor::<B, 1>("bias", path, device).ok();
let has_bias = bias.is_some(); let has_bias = bias.is_some();
@@ -127,7 +148,6 @@ pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::C
let padding = tensor_to_array_2(padding); let padding = tensor_to_array_2(padding);
let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]); let padding = nn::PaddingConfig2d::Explicit(padding[0], padding[1]);
let record = conv::Conv2dRecord { let record = conv::Conv2dRecord {
weight: weight.into(), weight: weight.into(),
bias: bias.map(|t| t.into()), bias: bias.map(|t| t.into()),
@@ -138,7 +158,8 @@ pub fn load_conv2d<B: Backend>(path: &str, device: &B::Device) -> Result<conv::C
padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()), padding: <nn::PaddingConfig2d as Module<B>>::into_record(padding.clone()),
}; };
let conv2d: conv::Conv2d<B> = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size) let conv2d: conv::Conv2d<B> =
conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size)
.with_stride(stride) .with_stride(stride)
.with_dilation(dilation) .with_dilation(dilation)
.with_groups(n_group) .with_groups(n_group)

View File

@@ -1,11 +1,11 @@
pub mod stablediffusion; pub mod stablediffusion;
pub mod autoencoder; pub mod autoencoder;
pub mod unet;
pub mod clip; pub mod clip;
pub mod unet;
pub mod silu;
pub mod groupnorm;
pub mod attention; pub mod attention;
pub mod groupnorm;
pub mod silu;
pub mod load; pub mod load;

View File

@@ -1,13 +1,8 @@
use burn::{ use burn::{
module::Module, module::Module,
tensor::{ tensor::{activation::sigmoid, backend::Backend, Tensor},
backend::Backend,
activation::sigmoid,
Tensor,
},
}; };
#[derive(Module, Clone, Debug)] #[derive(Module, Clone, Debug)]
pub struct SILU {} pub struct SILU {}

View File

@@ -1,20 +1,22 @@
use std::error::Error;
use burn::tensor::ElementConversion; use burn::tensor::ElementConversion;
use std::error::Error;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip}; use crate::model::{
autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
};
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> { pub fn load_stable_diffusion<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
let n_steps = load_usize::<B>("n_steps", path, device)?; let n_steps = load_usize::<B>("n_steps", path, device)?;
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into(); let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
@@ -29,4 +31,3 @@ pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Resu
clip, clip,
}) })
} }

View File

@@ -3,28 +3,18 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
tensor::{ tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor},
backend::Backend,
Tensor,
Int,
Float,
BasicOps,
Data,
Distribution,
},
}; };
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use super::autoencoder::{Autoencoder, AutoencoderConfig}; use super::autoencoder::{Autoencoder, AutoencoderConfig};
use super::clip::{CLIPConfig, CLIP};
use super::unet::{UNet, UNetConfig}; use super::unet::{UNet, UNetConfig};
use super::clip::{CLIP, CLIPConfig};
use crate::tokenizer::SimpleTokenizer; use crate::tokenizer::SimpleTokenizer;
#[derive(Config)] #[derive(Config)]
pub struct StableDiffusionConfig { pub struct StableDiffusionConfig {}
}
impl StableDiffusionConfig { impl StableDiffusionConfig {
pub fn init<B: Backend>(&self) -> StableDiffusion<B> { pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
@@ -55,10 +45,21 @@ pub struct StableDiffusion<B: Backend> {
} }
impl<B: Backend> StableDiffusion<B> { impl<B: Backend> StableDiffusion<B> {
pub fn sample_image(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Vec<Vec<u8>> { pub fn sample_image(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Vec<Vec<u8>> {
let [n_batch, _, _] = context.dims(); let [n_batch, _, _] = context.dims();
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps); let latent = self.sample_latent(
context,
unconditional_context,
unconditional_guidance_scale,
n_steps,
);
self.latent_to_image(latent) self.latent_to_image(latent)
} }
@@ -79,19 +80,29 @@ impl<B: Backend> StableDiffusion<B> {
.swap_dims(2, 3) .swap_dims(2, 3)
.mul_scalar(255.0); .mul_scalar(255.0);
let flattened: Vec<_> = image. let flattened: Vec<_> = image.into_data().value;
into_data().
value;
(0..n_batch).into_iter().map(|b| { (0..n_batch)
.into_iter()
.map(|b| {
let start = b * num_elements_per_image; let start = b * num_elements_per_image;
let end = start + num_elements_per_image; let end = start + num_elements_per_image;
flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect() flattened[start..end]
}).collect() .into_iter()
.map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap())
.collect()
})
.collect()
} }
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> { pub fn sample_latent(
&self,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
n_steps: usize,
) -> Tensor<B, 4> {
let device = context.device(); let device = context.device();
let step_size = self.n_steps / n_steps; let step_size = self.n_steps / n_steps;
@@ -99,7 +110,8 @@ impl<B: Backend> StableDiffusion<B> {
let [n_batches, _, _] = context.dims(); let [n_batches, _, _] = context.dims();
let gen_noise = || { let gen_noise = || {
Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device) Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0))
.to_device(&device)
}; };
let sigma = 0.0; // Use deterministic diffusion let sigma = 0.0; // Use deterministic diffusion
@@ -107,10 +119,21 @@ impl<B: Backend> StableDiffusion<B> {
let mut latent = gen_noise(); let mut latent = gen_noise();
for t in (0..self.n_steps).rev().step_by(step_size) { for t in (0..self.n_steps).rev().step_by(step_size) {
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap(); let current_alpha: f64 = self
.alpha_cumulative_products
.val()
.slice([t..t + 1])
.into_scalar()
.to_f64()
.unwrap();
let prev_alpha: f64 = if t >= step_size { let prev_alpha: f64 = if t >= step_size {
let i = t - step_size; let i = t - step_size;
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap() self.alpha_cumulative_products
.val()
.slice([i..i + 1])
.into_scalar()
.to_f64()
.unwrap()
} else { } else {
1.0 1.0
}; };
@@ -118,7 +141,13 @@ impl<B: Backend> StableDiffusion<B> {
let sqrt_noise = (1.0 - current_alpha).sqrt(); let sqrt_noise = (1.0 - current_alpha).sqrt();
let timestep = Tensor::from_ints([t as i32]).to_device(&device); let timestep = Tensor::from_ints([t as i32]).to_device(&device);
let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale); let pred_noise = self.forward_diffuser(
latent.clone(),
timestep,
context.clone(),
unconditional_context.clone(),
unconditional_guidance_scale,
);
let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt(); let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt();
let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt(); let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt();
@@ -129,21 +158,24 @@ impl<B: Backend> StableDiffusion<B> {
latent latent
} }
fn forward_diffuser(&self, latent: Tensor<B, 4>, timestep: Tensor<B, 1, Int>, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64) -> Tensor<B, 4> { fn forward_diffuser(
&self,
latent: Tensor<B, 4>,
timestep: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
unconditional_context: Tensor<B, 2>,
unconditional_guidance_scale: f64,
) -> Tensor<B, 4> {
let [n_batch, _, _, _] = latent.dims(); let [n_batch, _, _, _] = latent.dims();
//let latent = latent.repeat(0, 2); //let latent = latent.repeat(0, 2);
let unconditional_latent = self.diffusion.forward( let unconditional_latent = self.diffusion.forward(
latent.clone(), latent.clone(),
timestep.clone(), timestep.clone(),
unconditional_context.unsqueeze().repeat(0, n_batch) unconditional_context.unsqueeze().repeat(0, n_batch),
); );
let conditional_latent = self.diffusion.forward( let conditional_latent = self.diffusion.forward(latent, timestep, context);
latent,
timestep,
context
);
/*let latent = self.diffusion.forward( /*let latent = self.diffusion.forward(
latent.repeat(0, 2), latent.repeat(0, 2),
@@ -154,7 +186,8 @@ impl<B: Backend> StableDiffusion<B> {
let unconditional_latent = latent.clone().slice([0..n_batch]); let unconditional_latent = latent.clone().slice([0..n_batch]);
let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/ let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/
unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale unconditional_latent.clone()
+ (conditional_latent - unconditional_latent) * unconditional_guidance_scale
} }
pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> { pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor<B, 2> {
@@ -164,17 +197,25 @@ impl<B: Backend> StableDiffusion<B> {
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> { pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
let device = &self.clip.devices()[0]; let device = &self.clip.devices()[0];
let text = format!("<|startoftext|>{}<|endoftext|>", text); let text = format!("<|startoftext|>{}<|endoftext|>", text);
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect(); let tokenized: Vec<_> = tokenizer
.encode(&text)
.into_iter()
.map(|v| v as i32)
.collect();
self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze()) self.clip.forward(
Tensor::from_ints(&tokenized[..])
.to_device(device)
.unsqueeze(),
)
} }
} }
use crate::helper::to_float;
use std::f64::consts::PI; use std::f64::consts::PI;
fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> { fn cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
to_float(Tensor::arange(1..n_steps + 1)) Tensor::arange(1..n_steps + 1)
.float()
.mul_scalar(PI * 0.5 / n_steps as f64) .mul_scalar(PI * 0.5 / n_steps as f64)
.cos() .cos()
} }
@@ -185,9 +226,9 @@ fn offset_cosine_schedule<B: Backend>(n_steps: usize) -> Tensor<B, 1> {
let start_angle = max_signal_rate.acos(); let start_angle = max_signal_rate.acos();
let end_angle = min_signal_rate.acos(); let end_angle = min_signal_rate.acos();
let times = Tensor::arange(1..n_steps + 1); let times = Tensor::arange(1..n_steps + 1).float();
let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle; let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle;
diffusion_angles.cos() diffusion_angles.cos()
} }

View File

@@ -7,16 +7,16 @@ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn, nn,
tensor::{ tensor::{backend::Backend, Tensor},
backend::Backend,
Tensor,
},
}; };
use super::*; use super::*;
use crate::model::groupnorm::load::load_group_norm; use crate::model::groupnorm::load::load_group_norm;
pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResBlock<B>, Box<dyn Error>> { pub fn load_res_block<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResBlock<B>, Box<dyn Error>> {
let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?; let norm_in = load_group_norm::<B>(&format!("{}/{}", path, "norm_in"), device)?;
let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?; let conv_in = load_conv2d::<B>(&format!("{}/{}", path, "conv_in"), device)?;
let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?; let lin_embed = load_linear::<B>(&format!("{}/{}", path, "lin_embed"), device)?;
@@ -39,7 +39,10 @@ pub fn load_res_block<B: Backend>(path: &str, device: &B::Device) -> Result<ResB
Ok(res_block) Ok(res_block)
} }
pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) -> Result<MultiHeadAttention<B>, Box<dyn Error>> { pub fn load_multi_head_attention<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<MultiHeadAttention<B>, Box<dyn Error>> {
let n_head = load_usize::<B>("n_head", path, device)?; let n_head = load_usize::<B>("n_head", path, device)?;
let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?; let query = load_linear::<B>(&format!("{}/{}", path, "query"), device)?;
let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?; let key = load_linear::<B>(&format!("{}/{}", path, "key"), device)?;
@@ -57,7 +60,6 @@ pub fn load_multi_head_attention<B: Backend>(path: &str, device: &B::Device) ->
Ok(multi_head_attention) Ok(multi_head_attention)
} }
pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> { pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>, Box<dyn Error>> {
let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?; let proj = load_linear::<B>(&format!("{}/{}", path, "proj"), device)?;
@@ -69,7 +71,6 @@ pub fn load_geglu<B: Backend>(path: &str, device: &B::Device) -> Result<GEGLU<B>
Ok(geglue) Ok(geglue)
} }
pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> { pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Box<dyn Error>> {
let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?; let geglu = load_geglu::<B>(&format!("{}/{}", path, "geglu"), device)?;
let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?; let lin = load_linear::<B>(&format!("{}/{}", path, "lin"), device)?;
@@ -82,8 +83,10 @@ pub fn load_mlp<B: Backend>(path: &str, device: &B::Device) -> Result<MLP<B>, Bo
Ok(mlp) Ok(mlp)
} }
pub fn load_transformer_block<B: Backend>(
pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Result<TransformerBlock<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<TransformerBlock<B>, Box<dyn Error>> {
let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?; let norm1 = load_layer_norm::<B>(&format!("{}/{}", path, "norm1"), device)?;
let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?; let attn1 = load_multi_head_attention::<B>(&format!("{}/{}", path, "attn1"), device)?;
let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?; let norm2 = load_layer_norm::<B>(&format!("{}/{}", path, "norm2"), device)?;
@@ -103,8 +106,10 @@ pub fn load_transformer_block<B: Backend>(path: &str, device: &B::Device) -> Res
Ok(transformer_block) Ok(transformer_block)
} }
pub fn load_spatial_transformer<B: Backend>(
pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<SpatialTransformer<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<SpatialTransformer<B>, Box<dyn Error>> {
let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?; let norm = load_group_norm::<B>(&format!("{}/{}", path, "norm"), device)?;
let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?; let proj_in = load_conv2d::<B>(&format!("{}/{}", path, "proj_in"), device)?;
let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer = load_transformer_block::<B>(&format!("{}/{}", path, "transformer"), device)?;
@@ -120,24 +125,31 @@ pub fn load_spatial_transformer<B: Backend>(path: &str, device: &B::Device) -> R
Ok(spatial_transformer) Ok(spatial_transformer)
} }
pub fn load_upsample<B: Backend>(
pub fn load_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<Upsample<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<Upsample<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?; let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let upsample = Upsample { let upsample = Upsample { conv: conv };
conv: conv,
};
Ok(upsample) Ok(upsample)
} }
pub fn load_downsample<B: Backend>(path: &str, device: &B::Device) -> Result<Downsample<B>, Box<dyn Error>> { pub fn load_downsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<Downsample<B>, Box<dyn Error>> {
load_conv2d(path, device) load_conv2d(path, device)
} }
pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerRes<B>, Box<dyn Error>> { pub fn load_res_transformer_res<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerRes<B>, Box<dyn Error>> {
let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function let res1 = load_res_block::<B>(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?; let res2 = load_res_block::<B>(&format!("{}/{}", path, "res2"), device)?;
let res_transformer_res = ResTransformerRes { let res_transformer_res = ResTransformerRes {
@@ -149,9 +161,13 @@ pub fn load_res_transformer_res<B: Backend>(path: &str, device: &B::Device) -> R
Ok(res_transformer_res) Ok(res_transformer_res)
} }
pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> { pub fn load_res_transformer_upsample<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<ResTransformerUpsample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?; let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
let res_transformer_upsample = ResTransformerUpsample { let res_transformer_upsample = ResTransformerUpsample {
@@ -163,8 +179,10 @@ pub fn load_res_transformer_upsample<B: Backend>(path: &str, device: &B::Device)
Ok(res_transformer_upsample) Ok(res_transformer_upsample)
} }
pub fn load_res_upsample<B: Backend>(
pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<ResUpSample<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<ResUpSample<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?; let upsample = load_upsample::<B>(&format!("{}/{}", path, "upsample"), device)?;
@@ -176,10 +194,13 @@ pub fn load_res_upsample<B: Backend>(path: &str, device: &B::Device) -> Result<R
Ok(res_upsample) Ok(res_upsample)
} }
pub fn load_res_transformer<B: Backend>(
pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Result<ResTransformer<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<ResTransformer<B>, Box<dyn Error>> {
let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?; let res = load_res_block::<B>(&format!("{}/{}", path, "res"), device)?;
let transformer = load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?; let transformer =
load_spatial_transformer::<B>(&format!("{}/{}", path, "transformer"), device)?;
let res_transformer = ResTransformer { let res_transformer = ResTransformer {
res: res, res: res,
@@ -189,8 +210,10 @@ pub fn load_res_transformer<B: Backend>(path: &str, device: &B::Device) -> Resul
Ok(res_transformer) Ok(res_transformer)
} }
pub fn load_unet_input_blocks<B: Backend>(
pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetInputBlocks<B>, Box<dyn Error>> { path: &str,
device: &B::Device,
) -> Result<UNetInputBlocks<B>, Box<dyn Error>> {
let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?; let conv = load_conv2d::<B>(&format!("{}/{}", path, "conv"), device)?;
let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?; let rt1 = load_res_transformer::<B>(&format!("{}/{}", path, "rt1"), device)?;
let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?; let rt2 = load_res_transformer::<B>(&format!("{}/{}", path, "rt2"), device)?;
@@ -222,7 +245,10 @@ pub fn load_unet_input_blocks<B: Backend>(path: &str, device: &B::Device) -> Res
Ok(unet_input_blocks) Ok(unet_input_blocks)
} }
pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> { pub fn load_unet_output_blocks<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<UNetOutputBlocks<B>, Box<dyn Error>> {
let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?; let r1 = load_res_block::<B>(&format!("{}/{}", path, "r1"), device)?;
let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?; let r2 = load_res_block::<B>(&format!("{}/{}", path, "r2"), device)?;
let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?; let ru = load_res_upsample::<B>(&format!("{}/{}", path, "ru"), device)?;
@@ -252,14 +278,16 @@ pub fn load_unet_output_blocks<B: Backend>(path: &str, device: &B::Device) -> Re
}) })
} }
pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> { pub fn load_unet<B: Backend>(path: &str, device: &B::Device) -> Result<UNet<B>, Box<dyn Error>> {
let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?; let lin1_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin1_time_embed"), device)?;
let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?; let lin2_time_embed = load_linear::<B>(&format!("{}/{}", path, "lin2_time_embed"), device)?;
let input_blocks = load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?; let input_blocks =
let middle_block = load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?; load_unet_input_blocks::<B>(&format!("{}/{}", path, "input_blocks"), device)?;
let output_blocks = load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?; let middle_block =
load_res_transformer_res::<B>(&format!("{}/{}", path, "middle_block"), device)?;
let output_blocks =
load_unet_output_blocks::<B>(&format!("{}/{}", path, "output_blocks"), device)?;
let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?; let norm_out = load_group_norm::<B>(&format!("{}/{}", path, "norm_out"), device)?;
let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct
let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?; let conv_out = load_conv2d::<B>(&format!("{}/{}", path, "conv_out"), device)?;

View File

@@ -3,32 +3,32 @@ pub mod load;
use burn::{ use burn::{
config::Config, config::Config,
module::{Module, Param}, module::{Module, Param},
nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}}, nn::{
tensor::{ self,
backend::Backend, conv::{Conv2d, Conv2dConfig},
activation::softmax, PaddingConfig2d, GELU,
module::embedding,
Tensor,
Distribution,
Int,
}, },
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
}; };
use super::silu::*;
use super::groupnorm::*; use super::groupnorm::*;
use crate::helper::to_float; use super::silu::*;
use super::attention::qkv_attention; use super::attention::qkv_attention;
fn timestep_embedding<B: Backend>(
fn timestep_embedding<B: Backend>(timesteps: Tensor<B, 1, Int>, dim: usize, max_period: usize) -> Tensor<B, 2> { timesteps: Tensor<B, 1, Int>,
dim: usize,
max_period: usize,
) -> Tensor<B, 2> {
let half = dim / 2; let half = dim / 2;
let freqs = ( to_float(Tensor::arange_device(0..half, &timesteps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp(); let freqs = (Tensor::arange_device(0..half, &timesteps.device()).float()
let args = to_float(timesteps) * freqs; * (-(max_period as f64).ln() / half as f64))
.exp();
let args = timesteps.float() * freqs;
Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze() Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze()
} }
#[derive(Config)] #[derive(Config)]
pub struct UNetConfig {} pub struct UNetConfig {}
@@ -39,7 +39,9 @@ impl UNetConfig {
let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(); let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init();
let input_blocks = UNetInputBlocks { let input_blocks = UNetInputBlocks {
conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(), conv: Conv2dConfig::new([4, 320], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(),
d1: DownsampleConfig::new(320).init(), d1: DownsampleConfig::new(320).init(),
@@ -72,7 +74,9 @@ impl UNetConfig {
let norm_out = GroupNormConfig::new(32, 320).init(); let norm_out = GroupNormConfig::new(32, 320).init();
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([320, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
UNet { UNet {
lin1_time_embed, lin1_time_embed,
@@ -102,7 +106,12 @@ pub struct UNet<B: Backend> {
} }
impl<B: Backend> UNet<B> { impl<B: Backend> UNet<B> {
pub fn forward(&self, x: Tensor<B, 4>, timesteps: Tensor<B, 1, Int>, context: Tensor<B, 3>) -> Tensor<B, 4> { pub fn forward(
&self,
x: Tensor<B, 4>,
timesteps: Tensor<B, 1, Int>,
context: Tensor<B, 3>,
) -> Tensor<B, 4> {
let t_emb = timestep_embedding(timesteps, 320, 10000); let t_emb = timestep_embedding(timesteps, 320, 10000);
let emb = self.lin1_time_embed.forward(t_emb); let emb = self.lin1_time_embed.forward(t_emb);
let emb = self.silu_time_embed.forward(emb); let emb = self.silu_time_embed.forward(emb);
@@ -133,8 +142,6 @@ impl<B: Backend> UNet<B> {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct UNetInputBlocks<B: Backend> { pub struct UNetInputBlocks<B: Backend> {
conv: Conv2d<B>, conv: Conv2d<B>,
@@ -154,18 +161,8 @@ pub struct UNetInputBlocks<B: Backend> {
impl<B: Backend> UNetInputBlocks<B> { impl<B: Backend> UNetInputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] { fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[ [
&self.conv, &self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5,
&self.rt1, &self.rt6, &self.d3, &self.r1, &self.r2,
&self.rt2,
&self.d1,
&self.rt3,
&self.rt4,
&self.d2,
&self.rt5,
&self.rt6,
&self.d3,
&self.r1,
&self.r2,
] ]
} }
} }
@@ -189,26 +186,12 @@ pub struct UNetOutputBlocks<B: Backend> {
impl<B: Backend> UNetOutputBlocks<B> { impl<B: Backend> UNetOutputBlocks<B> {
fn as_array(&self) -> [&dyn UNetBlock<B>; 12] { fn as_array(&self) -> [&dyn UNetBlock<B>; 12] {
[ [
&self.r1, &self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4,
&self.r2, &self.rtu2, &self.rt5, &self.rt6, &self.rt7,
&self.ru,
&self.rt1,
&self.rt2,
&self.rtu1,
&self.rt3,
&self.rt4,
&self.rtu2,
&self.rt5,
&self.rt6,
&self.rt7,
] ]
} }
} }
trait UNetBlock<B: Backend> { trait UNetBlock<B: Backend> {
fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>; fn forward(&self, x: Tensor<B, 4>, emb: Tensor<B, 2>, context: Tensor<B, 3>) -> Tensor<B, 4>;
} }
@@ -224,13 +207,17 @@ pub struct ResTransformerConfig {
impl ResTransformerConfig { impl ResTransformerConfig {
fn init<B: Backend>(&self) -> ResTransformer<B> { fn init<B: Backend>(&self) -> ResTransformer<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
ResTransformer { ResTransformer { res, transformer }
res,
transformer,
}
} }
} }
@@ -257,13 +244,15 @@ pub struct ResUpSampleConfig {
impl ResUpSampleConfig { impl ResUpSampleConfig {
fn init<B: Backend>(&self) -> ResUpSample<B> { fn init<B: Backend>(&self) -> ResUpSample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init(); let upsample = UpsampleConfig::new(self.n_channels_out).init();
ResUpSample { ResUpSample { res, upsample }
res,
upsample,
}
} }
} }
@@ -292,8 +281,15 @@ pub struct ResTransformerUpsampleConfig {
impl ResTransformerUpsampleConfig { impl ResTransformerUpsampleConfig {
fn init<B: Backend>(&self) -> ResTransformerUpsample<B> { fn init<B: Backend>(&self) -> ResTransformerUpsample<B> {
let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
let upsample = UpsampleConfig::new(self.n_channels_out).init(); let upsample = UpsampleConfig::new(self.n_channels_out).init();
ResTransformerUpsample { ResTransformerUpsample {
@@ -331,9 +327,21 @@ pub struct ResTransformerResConfig {
impl ResTransformerResConfig { impl ResTransformerResConfig {
fn init<B: Backend>(&self) -> ResTransformerRes<B> { fn init<B: Backend>(&self) -> ResTransformerRes<B> {
let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); let res1 = ResBlockConfig::new(
let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); self.n_channels_in,
let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); self.n_channels_embed,
self.n_channels_out,
)
.init();
let transformer =
SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head)
.init();
let res2 = ResBlockConfig::new(
self.n_channels_in,
self.n_channels_embed,
self.n_channels_out,
)
.init();
ResTransformerRes { ResTransformerRes {
res1, res1,
@@ -359,8 +367,6 @@ impl<B: Backend> UNetBlock<B> for ResTransformerRes<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct UpsampleConfig { pub struct UpsampleConfig {
n_channels: usize, n_channels: usize,
@@ -372,9 +378,7 @@ impl UpsampleConfig {
.with_padding(PaddingConfig2d::Explicit(1, 1)) .with_padding(PaddingConfig2d::Explicit(1, 1))
.init(); .init();
Upsample { Upsample { conv }
conv,
}
} }
} }
@@ -423,9 +427,6 @@ impl<B: Backend> UNetBlock<B> for Conv2d<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct SpatialTransformerConfig { pub struct SpatialTransformerConfig {
n_channels: usize, n_channels: usize,
@@ -437,7 +438,8 @@ impl SpatialTransformerConfig {
fn init<B: Backend>(&self) -> SpatialTransformer<B> { fn init<B: Backend>(&self) -> SpatialTransformer<B> {
let norm = GroupNormConfig::new(32, self.n_channels).init(); let norm = GroupNormConfig::new(32, self.n_channels).init();
let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); let transformer =
TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init();
let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init();
SpatialTransformer { SpatialTransformer {
@@ -465,9 +467,13 @@ impl<B: Backend> SpatialTransformer<B> {
let x = self.norm.forward(x); let x = self.norm.forward(x);
let x = self.proj_in.forward(x); let x = self.proj_in.forward(x);
let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); let x = x
.reshape([n_batch, n_channel, height * width])
.swap_dims(1, 2);
let x = self.transformer.forward(x, context) let x = self
.transformer
.forward(x, context)
.swap_dims(1, 2) .swap_dims(1, 2)
.reshape([n_batch, n_channel, height, width]); .reshape([n_batch, n_channel, height, width]);
@@ -475,13 +481,6 @@ impl<B: Backend> SpatialTransformer<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct TransformerBlockConfig { pub struct TransformerBlockConfig {
n_state: usize, n_state: usize,
@@ -494,7 +493,8 @@ impl TransformerBlockConfig {
let norm1 = nn::LayerNormConfig::new(self.n_state).init(); let norm1 = nn::LayerNormConfig::new(self.n_state).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(); let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
let norm2 = nn::LayerNormConfig::new(self.n_state).init(); let norm2 = nn::LayerNormConfig::new(self.n_state).init();
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); let attn2 =
MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm3 = nn::LayerNormConfig::new(self.n_state).init(); let norm3 = nn::LayerNormConfig::new(self.n_state).init();
let mlp = MLPConfig::new(self.n_state, 4).init(); let mlp = MLPConfig::new(self.n_state, 4).init();
@@ -521,13 +521,12 @@ pub struct TransformerBlock<B: Backend> {
impl<B: Backend> TransformerBlock<B> { impl<B: Backend> TransformerBlock<B> {
fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> { fn forward(&self, x: Tensor<B, 3>, context: Tensor<B, 3>) -> Tensor<B, 3> {
let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None); let x = x.clone() + self.attn1.forward(self.norm1.forward(x), None);
let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context)); let x = x.clone() + self.attn2.forward(self.norm2.forward(x), Some(context));
x.clone() + self.mlp.forward( self.norm3.forward(x) ) x.clone() + self.mlp.forward(self.norm3.forward(x))
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct MLPConfig { pub struct MLPConfig {
n_state: usize, n_state: usize,
@@ -540,10 +539,7 @@ impl MLPConfig {
let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(); let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init();
let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(); let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init();
MLP { MLP { geglu, lin }
geglu,
lin,
}
} }
} }
@@ -555,11 +551,10 @@ pub struct MLP<B: Backend> {
impl<B: Backend> MLP<B> { impl<B: Backend> MLP<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> { pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.lin.forward( self.geglu.forward(x) ) self.lin.forward(self.geglu.forward(x))
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct GEGLUConfig { pub struct GEGLUConfig {
n_state_in: usize, n_state_in: usize,
@@ -571,10 +566,7 @@ impl GEGLUConfig {
let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(); let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init();
let gelu = GELU::new(); let gelu = GELU::new();
GEGLU { GEGLU { proj, gelu }
proj,
gelu,
}
} }
} }
@@ -591,17 +583,15 @@ impl<B: Backend> GEGLU<B> {
let n_state_out = n_state / 2; let n_state_out = n_state / 2;
let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]); let x = projected
.clone()
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]); let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);
x * self.gelu.forward(gate) x * self.gelu.forward(gate)
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct MultiHeadAttentionConfig { pub struct MultiHeadAttentionConfig {
n_state: usize, n_state: usize,
@@ -611,12 +601,23 @@ pub struct MultiHeadAttentionConfig {
impl MultiHeadAttentionConfig { impl MultiHeadAttentionConfig {
fn init<B: Backend>(&self) -> MultiHeadAttention<B> { fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); assert!(
self.n_state % self.n_head == 0,
"State size {} must be a multiple of head size {}",
self.n_state,
self.n_head
);
let n_head = self.n_head; let n_head = self.n_head;
let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init(); let query = nn::LinearConfig::new(self.n_state, self.n_state)
let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); .with_bias(false)
let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); .init();
let key = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
let value = nn::LinearConfig::new(self.n_context_state, self.n_state)
.with_bias(false)
.init();
let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
MultiHeadAttention { MultiHeadAttention {
@@ -624,7 +625,7 @@ impl MultiHeadAttentionConfig {
query, query,
key, key,
value, value,
out out,
} }
} }
} }
@@ -652,21 +653,6 @@ impl<B: Backend> MultiHeadAttention<B> {
} }
} }
#[derive(Config)] #[derive(Config)]
pub struct ResBlockConfig { pub struct ResBlockConfig {
n_channels_in: usize, n_channels_in: usize,
@@ -674,22 +660,25 @@ pub struct ResBlockConfig {
n_channels_out: usize, n_channels_out: usize,
} }
impl ResBlockConfig { impl ResBlockConfig {
fn init<B: Backend>(&self) -> ResBlock<B> { fn init<B: Backend>(&self) -> ResBlock<B> {
let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(); let norm_in = GroupNormConfig::new(32, self.n_channels_in).init();
let silu_in = SILU::new(); let silu_in = SILU::new();
let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let silu_embed = SILU::new(); let silu_embed = SILU::new();
let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(); let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init();
let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(); let norm_out = GroupNormConfig::new(32, self.n_channels_out).init();
let silu_out = SILU::new(); let silu_out = SILU::new();
let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
let skip_connection = if self.n_channels_in != self.n_channels_out { let skip_connection = if self.n_channels_in != self.n_channels_out {
Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() ) Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init())
} else { } else {
None None
}; };
@@ -708,7 +697,6 @@ impl ResBlockConfig {
} }
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct ResBlock<B: Backend> { pub struct ResBlock<B: Backend> {
norm_in: GroupNorm<B>, norm_in: GroupNorm<B>,
@@ -751,5 +739,3 @@ impl<B: Backend> UNetBlock<B> for ResBlock<B> {
self.forward(x, emb) self.forward(x, emb)
} }
} }

View File

@@ -1,13 +1,14 @@
use std::collections::HashMap;
use regex::Regex; use regex::Regex;
use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::{self, BufRead}; use std::io::{self, BufRead};
fn bytes_to_unicode() -> Vec<(u8, char)> { fn bytes_to_unicode() -> Vec<(u8, char)> {
let mut bs: Vec<u8> = ('!' as u8 ..= '~' as u8).into_iter() let mut bs: Vec<u8> = ('!' as u8..='~' as u8)
.chain( ('¡' as u8..='¬' as u8).into_iter() ) .into_iter()
.chain( ('®' as u8..='ÿ' as u8).into_iter() ) .chain(('¡' as u8..='¬' as u8).into_iter())
.chain(('®' as u8..='ÿ' as u8).into_iter())
.collect(); .collect();
let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect(); let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect();
@@ -16,25 +17,21 @@ fn bytes_to_unicode() -> Vec<(u8, char)> {
for b in 0u8..=255u8 { for b in 0u8..=255u8 {
if !bs.contains(&b) { if !bs.contains(&b) {
bs.push(b); bs.push(b);
cs.push( char::from_u32(256 + n).unwrap() ); cs.push(char::from_u32(256 + n).unwrap());
n += 1; n += 1;
} }
} }
bs.into_iter() bs.into_iter()
.zip( .zip(cs.into_iter().map(|c| c.into()))
cs.into_iter() .collect()
.map(|c| c.into())
).collect()
} }
fn get_pairs(word: &[String]) -> Vec<(String, String)> { fn get_pairs(word: &[String]) -> Vec<(String, String)> {
let prev = word.into_iter().cloned(); let prev = word.into_iter().cloned();
let next = prev.clone().skip(1); let next = prev.clone().skip(1);
prev prev.zip(next).collect()
.zip(next)
.collect()
} }
fn whitespace_clean(text: &str) -> String { fn whitespace_clean(text: &str) -> String {
@@ -59,9 +56,12 @@ fn load_merges(path: &str) -> io::Result<Vec<(String, String)>> {
Ok(merges) Ok(merges)
} }
fn construct_vocab(chars: impl Iterator<Item=char> + Clone, merges: &[(String, String)]) -> Vec<String> { fn construct_vocab(
chars: impl Iterator<Item = char> + Clone,
merges: &[(String, String)],
) -> Vec<String> {
let iter = chars.map(String::from); let iter = chars.map(String::from);
let mut vocab: Vec<_> = iter.clone().chain( iter.map(|c| c + "</w>") ).collect(); let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "</w>")).collect();
for merge in merges { for merge in merges {
vocab.push(format!("{}{}", merge.0, merge.1)); vocab.push(format!("{}{}", merge.0, merge.1));
@@ -87,10 +87,10 @@ impl SimpleTokenizer {
let byte_unicode_values = bytes_to_unicode(); let byte_unicode_values = bytes_to_unicode();
let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect(); let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect();
let byte_decoder = byte_encoder.iter().map(|(k,v)| (*v,*k)).collect(); let byte_decoder = byte_encoder.iter().map(|(k, v)| (*v, *k)).collect();
let merges = load_merges("bpe_simple_vocab_16e6.txt")?; let merges = load_merges("bpe_simple_vocab_16e6.txt")?;
let merges = merges[1..49152-256-2+1].to_vec(); let merges = merges[1..49152 - 256 - 2 + 1].to_vec();
let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]); let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]);
@@ -104,7 +104,7 @@ impl SimpleTokenizer {
let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap(); let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap();
Ok( SimpleTokenizer { Ok(SimpleTokenizer {
byte_encoder: byte_encoder, byte_encoder: byte_encoder,
byte_decoder: byte_decoder, byte_decoder: byte_decoder,
encoder: encoder, encoder: encoder,
@@ -112,7 +112,7 @@ impl SimpleTokenizer {
bpe_ranks: bpe_ranks, bpe_ranks: bpe_ranks,
cache: cache, cache: cache,
pat: pat, pat: pat,
} ) })
} }
pub fn bpe(&self, token: &str) -> String { pub fn bpe(&self, token: &str) -> String {
@@ -129,7 +129,8 @@ impl SimpleTokenizer {
} }
loop { loop {
let bigram = pairs.iter() let bigram = pairs
.iter()
.filter(|pair| self.bpe_ranks.contains_key(pair)) .filter(|pair| self.bpe_ranks.contains_key(pair))
.min_by_key(|&pair| self.bpe_ranks[pair]); .min_by_key(|&pair| self.bpe_ranks[pair]);
@@ -141,7 +142,7 @@ impl SimpleTokenizer {
let mut new_word = Vec::new(); let mut new_word = Vec::new();
let mut i = 0; let mut i = 0;
while i < word.len() { while i < word.len() {
if let Some( (j, _) ) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) { if let Some((j, _)) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) {
new_word.extend(word[i..j].iter().cloned()); new_word.extend(word[i..j].iter().cloned());
i = j; i = j;
} else { } else {
@@ -178,8 +179,16 @@ impl SimpleTokenizer {
for m in self.pat.find_iter(&cleaned_text) { for m in self.pat.find_iter(&cleaned_text) {
let token = m.as_str(); let token = m.as_str();
let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect(); let token: String = token
bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token])) .as_bytes()
.into_iter()
.map(|b| self.byte_encoder[b])
.collect();
bpe_tokens.extend(
self.bpe(&token)
.split(' ')
.map(|bpe_token| self.encoder[bpe_token]),
)
} }
return bpe_tokens; return bpe_tokens;
@@ -187,9 +196,7 @@ impl SimpleTokenizer {
pub fn decode(&self, tokens: &[u32]) -> String { pub fn decode(&self, tokens: &[u32]) -> String {
let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect(); let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect();
let decoded_bytes: Vec<u8> = text.chars() let decoded_bytes: Vec<u8> = text.chars().map(|c| self.byte_decoder[&c]).collect();
.map(|c| self.byte_decoder[&c])
.collect();
String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ") String::from_utf8_lossy(&decoded_bytes[..]).replace("</w>", " ")
} }