diff --git a/img0.png b/img0.png index 981a8c2..5f71d62 100644 Binary files a/img0.png and b/img0.png differ diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs index 4885dab..f6b9697 100644 --- a/src/bin/convert/main.rs +++ b/src/bin/convert/main.rs @@ -1,26 +1,27 @@ use std::env; -use std::process; use std::error::Error; +use std::process; -use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion}; +use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion}; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; use burn_ndarray::{NdArrayBackend, NdArrayDevice}; -use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; +use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; -fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { +fn convert_dump_to_model( + dump_path: &str, + model_name: &str, + device: &B::Device, +) -> Result<(), Box> { println!("Loading dump..."); - let model: StableDiffusion:: = load_stable_diffusion(dump_path, device)?; + let model: StableDiffusion = load_stable_diffusion(dump_path, device)?; println!("Saving model..."); save_model_file(model, model_name)?; @@ -28,12 +29,11 @@ fn convert_dump_to_model(dump_path: &str, model_name: &str, device: Ok(()) } -fn save_model_file(model: StableDiffusion, name: &str) -> Result<(), record::RecorderError> { - BinFileRecorder::::new() - .record( - model.into_record(), - name.into(), - ) +fn save_model_file( + model: StableDiffusion, + name: &str, +) -> Result<(), record::RecorderError> { + BinFileRecorder::::new().record(model.into_record(), name.into()) } fn main() { diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index c92cb3c..d9db921 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -1,13 +1,13 @@ -use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}}; +use stablediffusion::{ + model::stablediffusion::{load::load_stable_diffusion, *}, + tokenizer::SimpleTokenizer, +}; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; cfg_if::cfg_if! { @@ -22,12 +22,14 @@ use std::env; use std::io; use std::process; -use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; +use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder}; -fn load_stable_diffusion_model_file(filename: &str) -> Result, record::RecorderError> { +fn load_stable_diffusion_model_file( + filename: &str, +) -> Result, record::RecorderError> { BinFileRecorder::::new() - .load(filename.into()) - .map(|record| StableDiffusionConfig::new().init().load_record(record)) + .load(filename.into()) + .map(|record| StableDiffusionConfig::new().init().load_record(record)) } fn main() { @@ -78,17 +80,22 @@ fn main() { let sd = sd.to_device(&device); let unconditional_context = sd.unconditional_context(&tokenizer); - let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples + let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples println!("Sampling image..."); - let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); + let images = sd.sample_image( + context, + unconditional_context, + unconditional_guidance_scale, + n_steps, + ); save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| { eprintln!("Error saving image: {}", err); process::exit(1); }); } -use image::{self, ImageResult, ColorType::Rgb8}; +use image::{self, ColorType::Rgb8, ImageResult}; fn save_images(images: &Vec>, basepath: &str, width: u32, height: u32) -> ImageResult<()> { for (index, img_data) in images.iter().enumerate() { @@ -103,12 +110,15 @@ fn save_images(images: &Vec>, basepath: &str, width: u32, height: u32) - fn save_test_image() -> ImageResult<()> { let width = 256; let height = 256; - let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| { - let row = i / width; - let red = (255.0 * row as f64 / height as f64) as u8; + let raw: Vec<_> = (0..width * height) + .into_iter() + .flat_map(|i| { + let row = i / width; + let red = (255.0 * row as f64 / height as f64) as u8; - [red, 0, 0] - }).collect(); + [red, 0, 0] + }) + .collect(); image::save_buffer("red.png", &raw[..], width, height, Rgb8) -} \ No newline at end of file +} diff --git a/src/helper.rs b/src/helper.rs deleted file mode 100644 index 0332a9d..0000000 --- a/src/helper.rs +++ /dev/null @@ -1,87 +0,0 @@ -use burn::{ - tensor::{ - backend::Backend, - activation::relu, - Tensor, - Int, - Bool, - Float, - TensorKind, - BasicOps, - Numeric, - Element, - }, -}; - -use num_traits::ToPrimitive; - - -pub fn tensor_max_scalar(x: Tensor, max: f64) -> Tensor { - relu(x.sub_scalar(max)).add_scalar(max) -} - -pub fn tensor_min_scalar(x: Tensor, min: f64) -> Tensor { - -tensor_max_scalar(-x, -min) -} - -pub fn tensor_max(x: Tensor, max: Tensor) -> Tensor { - relu(x - max.clone()) + max -} - -pub fn tensor_min(x: Tensor, min: Tensor) -> Tensor { - -tensor_max(-x, -min) -} - -pub fn tensor_log10(x: Tensor) -> Tensor { - let ln10 = (10.0f64).ln(); - x.log() / ln10 -} - -pub fn tensor_max_element(x: Tensor) -> f64 { - let flat: Tensor = x.flatten(0, D - 1); - let max_index = flat.clone().argmax(0); - - flat.select(0, max_index).into_scalar().to_f64().unwrap() -} - -pub fn all_zeros(x: Tensor) -> bool { - x.powf(2.0).sum().into_scalar().to_f64().unwrap() == 0.0 -} - -pub fn max_dim(x: Tensor, dim: usize) -> Tensor { - let indices = x.clone().argmax(dim).flatten(0, 1); - x.select(dim, indices) -} - -pub fn _10pow(x: Tensor) -> Tensor { - let log10 = (10.0f64).ln(); - (x * log10).exp() -} - -pub fn to_float(x: Tensor) -> Tensor { - let device = x.device(); - Tensor::from_data( - x - .into_data() - .convert() - ).to_device(&device) -} - -pub fn to_float_bool(x: Tensor) -> Tensor { - let device = x.device(); - Tensor::from_data( - x - .into_data() - .convert() - ).to_device(&device) -} - -pub fn reverse + BasicOps + Numeric>(x: Tensor, dim: usize) -> Tensor where >::Elem: Element { - let len = x.dims()[dim]; - let indices = -Tensor::arange_device(0..len, &x.device()) + (len - 1) as i64; - x.select(dim, indices) -} - -pub fn div_roundup(x: usize, y: usize) -> usize { - (x + y - 1) / y -} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 993bb59..0e2590d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,2 @@ pub mod model; pub mod tokenizer; -pub mod helper; \ No newline at end of file diff --git a/src/model/attention.rs b/src/model/attention.rs index e516c2e..158afee 100644 --- a/src/model/attention.rs +++ b/src/model/attention.rs @@ -1,23 +1,32 @@ -use burn::{ - tensor::{ - backend::Backend, - activation::softmax, - Tensor, - }, -}; +use burn::tensor::{activation::softmax, backend::Backend, Tensor}; use std::f32::NEG_INFINITY; -pub fn qkv_attention(q: Tensor, k: Tensor, v: Tensor, mask: Option>, n_head: usize) -> Tensor { +pub fn qkv_attention( + q: Tensor, + k: Tensor, + v: Tensor, + mask: Option>, + n_head: usize, +) -> Tensor { let [n_batch, n_qctx, n_state] = q.dims(); let [_, n_ctx, _] = k.dims(); let scale = (n_state as f64 / n_head as f64).powf(-0.25); let n_hstate = n_state / n_head; - let q = q.reshape([n_batch, n_qctx, n_head, n_hstate]).swap_dims(1, 2) * scale; - let k = k.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2).transpose() * scale; - let v = v.reshape([n_batch, n_ctx, n_head, n_hstate]).swap_dims(1, 2); + let q = q + .reshape([n_batch, n_qctx, n_head, n_hstate]) + .swap_dims(1, 2) + * scale; + let k = k + .reshape([n_batch, n_ctx, n_head, n_hstate]) + .swap_dims(1, 2) + .transpose() + * scale; + let v = v + .reshape([n_batch, n_ctx, n_head, n_hstate]) + .swap_dims(1, 2); let qk = q.matmul(k); @@ -44,4 +53,4 @@ pub fn attn_decoder_mask(seq_length: usize, device: &B::Device) -> T } return mask.to_device(device); -} \ No newline at end of file +} diff --git a/src/model/autoencoder/load.rs b/src/model/autoencoder/load.rs index 484fe0d..d06aa1e 100644 --- a/src/model/autoencoder/load.rs +++ b/src/model/autoencoder/load.rs @@ -4,29 +4,38 @@ use crate::model::load::*; use std::error::Error; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; use super::*; use crate::model::groupnorm::load::load_group_norm; -fn load_conv_self_attention_block(path: &str, device: &B::Device) -> Result, Box> { +fn load_conv_self_attention_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let norm = load_group_norm(&format!("{}/{}", path, "norm"), device)?; let q = load_conv2d(&format!("{}/{}", path, "q"), device)?; let k = load_conv2d(&format!("{}/{}", path, "k"), device)?; let v = load_conv2d(&format!("{}/{}", path, "v"), device)?; let proj_out = load_conv2d(&format!("{}/{}", path, "proj_out"), device)?; - Ok(ConvSelfAttentionBlock { norm, q, k, v, proj_out }) + Ok(ConvSelfAttentionBlock { + norm, + q, + k, + v, + proj_out, + }) } -fn load_resnet_block(path: &str, device: &B::Device) -> Result, Box> { +fn load_resnet_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let norm1 = load_group_norm(&format!("{}/{}", path, "norm1"), device)?; let silu1 = SILU {}; let conv1 = load_conv2d(&format!("{}/{}", path, "conv1"), device)?; @@ -35,7 +44,15 @@ fn load_resnet_block(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { @@ -43,14 +60,21 @@ fn load_mid(path: &str, device: &B::Device) -> Result, Box(path: &str, device: &B::Device) -> Result, Box> { +fn load_padded_conv2d( + path: &str, + device: &B::Device, +) -> Result, Box> { let conv = load_conv2d(&format!("{}/{}", path, "conv"), device)?; let channels = load_tensor::("channels", path, device)?; - let channels = tensor_to_array_2(channels); + let channels = tensor_to_array_2(channels); let kernel_size = load_usize::("kernel_size", path, device)?; let stride = load_usize::("stride", path, device)?; @@ -61,31 +85,48 @@ fn load_padded_conv2d(path: &str, device: &B::Device) -> Result = PaddedConv2dConfig::new(channels, kernel_size, padding).with_stride(stride).init(); - let padding_actual = PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); + let mut padded_conv: PaddedConv2d = PaddedConv2dConfig::new(channels, kernel_size, padding) + .with_stride(stride) + .init(); + let padding_actual = + PaddingConfig2d::Explicit(padded_conv.padding_actual[0], padded_conv.padding_actual[1]); record.padding = >::into_record(padding_actual); padded_conv.conv = padded_conv.conv.load_record(record); - Ok(padded_conv) } -fn load_decoder_block(path: &str, device: &B::Device) -> Result, Box> { +fn load_decoder_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let res3 = load_resnet_block(&format!("{}/{}", path, "res3"), device)?; let upsampler = load_conv2d(&format!("{}/{}", path, "upsampler"), device).ok(); - Ok(DecoderBlock { res1, res2, res3, upsampler }) + Ok(DecoderBlock { + res1, + res2, + res3, + upsampler, + }) } -fn load_encoder_block(path: &str, device: &B::Device) -> Result, Box> { +fn load_encoder_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let res1 = load_resnet_block(&format!("{}/{}", path, "res1"), device)?; let res2 = load_resnet_block(&format!("{}/{}", path, "res2"), device)?; let downsampler = load_padded_conv2d(&format!("{}/{}", path, "downsampler"), device).ok(); - Ok(EncoderBlock { res1, res2, downsampler }) + Ok(EncoderBlock { + res1, + res2, + downsampler, + }) } fn load_decoder(path: &str, device: &B::Device) -> Result, Box> { @@ -95,15 +136,21 @@ fn load_decoder(path: &str, device: &B::Device) -> Result let n_block = load_usize::("n_block", path, device)?; let mut blocks = (0..n_block) .into_iter() - .map(|i| { - load_decoder_block::(&format!("{}/blocks/{}", path, i), device) - }).collect::, _>>()?; + .map(|i| load_decoder_block::(&format!("{}/blocks/{}", path, i), device)) + .collect::, _>>()?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let silu = SILU {}; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; - Ok(Decoder { conv_in, mid, blocks, norm_out, silu, conv_out }) + Ok(Decoder { + conv_in, + mid, + blocks, + norm_out, + silu, + conv_out, + }) } fn load_encoder(path: &str, device: &B::Device) -> Result, Box> { @@ -113,22 +160,36 @@ fn load_encoder(path: &str, device: &B::Device) -> Result let n_block = load_usize::("n_block", path, device)?; let mut blocks = (0..n_block) .into_iter() - .map(|i| { - load_encoder_block::(&format!("{}/blocks/{}", path, i), device) - }).collect::, _>>()?; + .map(|i| load_encoder_block::(&format!("{}/blocks/{}", path, i), device)) + .collect::, _>>()?; let norm_out = load_group_norm(&format!("{}/{}", path, "norm_out"), device)?; let silu = SILU {}; let conv_out = load_conv2d(&format!("{}/{}", path, "conv_out"), device)?; - Ok(Encoder { conv_in, mid, blocks, norm_out, silu, conv_out }) + Ok(Encoder { + conv_in, + mid, + blocks, + norm_out, + silu, + conv_out, + }) } -pub fn load_autoencoder(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_autoencoder( + path: &str, + device: &B::Device, +) -> Result, Box> { let encoder = load_encoder(&format!("{}/{}", path, "encoder"), device)?; let decoder = load_decoder(&format!("{}/{}", path, "decoder"), device)?; let quant_conv = load_conv2d(&format!("{}/{}", path, "quant_conv"), device)?; let post_quant_conv = load_conv2d(&format!("{}/{}", path, "post_quant_conv"), device)?; - Ok(Autoencoder { encoder, decoder, quant_conv, post_quant_conv }) -} \ No newline at end of file + Ok(Autoencoder { + encoder, + decoder, + quant_conv, + post_quant_conv, + }) +} diff --git a/src/model/autoencoder/mod.rs b/src/model/autoencoder/mod.rs index d3dbea7..ae9e2b5 100644 --- a/src/model/autoencoder/mod.rs +++ b/src/model/autoencoder/mod.rs @@ -1,59 +1,59 @@ pub mod load; use burn::{ - config::Config, + config::Config, module::{Module, Param}, - nn::{self, PaddingConfig2d, conv::{Conv2d, Conv2dConfig, Conv2dRecord}}, + nn::{ + self, + conv::{Conv2d, Conv2dConfig, Conv2dRecord}, + PaddingConfig2d, + }, tensor::{ + activation::{sigmoid, softmax}, backend::Backend, - activation::{softmax, sigmoid}, - module::embedding, - Tensor, - Distribution, - Int, + module::embedding, + Distribution, Int, Tensor, }, }; -use crate::helper::div_roundup; - -use super::silu::*; -use super::groupnorm::*; use super::attention::qkv_attention; +use super::groupnorm::*; +use super::silu::*; use std::iter; - #[derive(Config)] pub struct AutoencoderConfig {} impl AutoencoderConfig { pub fn init(&self) -> Autoencoder { - let encoder = EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); - let decoder = DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); + let encoder = + EncoderConfig::new(vec![(128, 128), (128, 256), (256, 512), (512, 512)], 32, 8).init(); + let decoder = + DecoderConfig::new(vec![(512, 512), (512, 512), (512, 256), (256, 128)], 32).init(); let quant_conv = Conv2dConfig::new([8, 8], [1, 1]).init(); let post_quant_conv = Conv2dConfig::new([4, 4], [1, 1]).init(); Autoencoder { - encoder, - decoder, - quant_conv, - post_quant_conv, + encoder, + decoder, + quant_conv, + post_quant_conv, } } } - #[derive(Module, Debug)] pub struct Autoencoder { - encoder: Encoder, - decoder: Decoder, - quant_conv: Conv2d, - post_quant_conv: Conv2d, + encoder: Encoder, + decoder: Decoder, + quant_conv: Conv2d, + post_quant_conv: Conv2d, } impl Autoencoder { pub fn forward(&self, x: Tensor) -> Tensor { - self.decode_latent( self.encode_image(x) ) + self.decode_latent(self.encode_image(x)) } pub fn encode_image(&self, x: Tensor) -> Tensor { @@ -72,48 +72,60 @@ impl Autoencoder { #[derive(Config)] pub struct EncoderConfig { - channels: Vec<(usize, usize)>, - n_group: usize, - n_channels_out: usize, + channels: Vec<(usize, usize)>, + n_group: usize, + n_channels_out: usize, } impl EncoderConfig { fn init(&self) -> Encoder { - let n_expanded_channels_initial = self.channels.first().map(|f| f.1).expect("Channels must not be empty."); + let n_expanded_channels_initial = self + .channels + .first() + .map(|f| f.1) + .expect("Channels must not be empty."); let n_expanded_channels_final = self.channels.first().unwrap().0; - let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_in = Conv2dConfig::new([3, n_expanded_channels_initial], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); - let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { - let downsample = i != self.channels.len() - 1; - EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() - }).collect(); + let blocks = self + .channels + .iter() + .enumerate() + .map(|(i, &(n_channel_in, n_channel_out))| { + let downsample = i != self.channels.len() - 1; + EncoderBlockConfig::new(n_channel_in, n_channel_out, downsample).init() + }) + .collect(); let mid = MidConfig::new(n_expanded_channels_final).init(); let norm_out = GroupNormConfig::new(self.n_group, n_expanded_channels_final).init(); let silu = SILU::new(); - let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_out = Conv2dConfig::new([n_expanded_channels_final, self.n_channels_out], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); Encoder { - conv_in, - mid, - blocks, - norm_out, - silu, - conv_out, + conv_in, + mid, + blocks, + norm_out, + silu, + conv_out, } } } - #[derive(Module, Debug)] pub struct Encoder { - conv_in: Conv2d, - mid: Mid, - blocks: Vec>, - norm_out: GroupNorm, - silu: SILU, - conv_out: Conv2d, + conv_in: Conv2d, + mid: Mid, + blocks: Vec>, + norm_out: GroupNorm, + silu: SILU, + conv_out: Conv2d, } impl Encoder { @@ -126,55 +138,66 @@ impl Encoder { } let x = self.mid.forward(x); - self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) + self.conv_out + .forward(self.silu.forward(self.norm_out.forward(x))) } } - - #[derive(Config)] pub struct DecoderConfig { - channels: Vec<(usize, usize)>, - n_group: usize, + channels: Vec<(usize, usize)>, + n_group: usize, } impl DecoderConfig { fn init(&self) -> Decoder { - let n_expanded_channels = self.channels.first().map(|f| f.0).expect("Channels must not be empty."); + let n_expanded_channels = self + .channels + .first() + .map(|f| f.0) + .expect("Channels must not be empty."); let n_condensed_channels = self.channels.last().unwrap().1; - let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_in = Conv2dConfig::new([4, n_expanded_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); let mid = MidConfig::new(n_expanded_channels).init(); - let blocks = self.channels.iter().enumerate().map(|(i, &(n_channel_in, n_channel_out))| { - let upsample = i != self.channels.len() - 1; - DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() - }).collect(); + let blocks = self + .channels + .iter() + .enumerate() + .map(|(i, &(n_channel_in, n_channel_out))| { + let upsample = i != self.channels.len() - 1; + DecoderBlockConfig::new(n_channel_in, n_channel_out, upsample).init() + }) + .collect(); let norm_out = GroupNormConfig::new(self.n_group, n_condensed_channels).init(); let silu = SILU::new(); - let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_out = Conv2dConfig::new([n_condensed_channels, 3], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); Decoder { - conv_in, - mid, - blocks, - norm_out, - silu, - conv_out, + conv_in, + mid, + blocks, + norm_out, + silu, + conv_out, } } } - #[derive(Module, Debug)] pub struct Decoder { - conv_in: Conv2d, - mid: Mid, - blocks: Vec>, - norm_out: GroupNorm, - silu: SILU, - conv_out: Conv2d, + conv_in: Conv2d, + mid: Mid, + blocks: Vec>, + norm_out: GroupNorm, + silu: SILU, + conv_out: Conv2d, } impl Decoder { @@ -187,15 +210,16 @@ impl Decoder { x = block.forward(x); } - self.conv_out.forward( self.silu.forward( self.norm_out.forward(x) ) ) + self.conv_out + .forward(self.silu.forward(self.norm_out.forward(x))) } } #[derive(Config)] pub struct EncoderBlockConfig { - n_channels_in: usize, - n_channels_out: usize, - downsample: bool, + n_channels_in: usize, + n_channels_out: usize, + downsample: bool, } impl EncoderBlockConfig { @@ -204,24 +228,28 @@ impl EncoderBlockConfig { let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let downsampler = if self.downsample { let padding = Padding::new(0, 1, 0, 1); - Some( PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding).with_stride(2).init() ) + Some( + PaddedConv2dConfig::new([self.n_channels_out, self.n_channels_out], 3, padding) + .with_stride(2) + .init(), + ) } else { None }; EncoderBlock { - res1, - res2, - downsampler, + res1, + res2, + downsampler, } } } #[derive(Module, Debug)] pub struct EncoderBlock { - res1: ResnetBlock, - res2: ResnetBlock, - downsampler: Option>, + res1: ResnetBlock, + res2: ResnetBlock, + downsampler: Option>, } impl EncoderBlock { @@ -238,9 +266,9 @@ impl EncoderBlock { #[derive(Config)] pub struct DecoderBlockConfig { - n_channels_in: usize, - n_channels_out: usize, - upsample: bool, + n_channels_in: usize, + n_channels_out: usize, + upsample: bool, } impl DecoderBlockConfig { @@ -249,26 +277,30 @@ impl DecoderBlockConfig { let res2 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let res3 = ResnetBlockConfig::new(self.n_channels_out, self.n_channels_out).init(); let upsampler = if self.upsample { - Some( Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init() ) + Some( + Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(), + ) } else { None }; DecoderBlock { - res1, - res2, - res3, - upsampler, + res1, + res2, + res3, + upsampler, } } } #[derive(Module, Debug)] pub struct DecoderBlock { - res1: ResnetBlock, - res2: ResnetBlock, - res3: ResnetBlock, - upsampler: Option>, + res1: ResnetBlock, + res2: ResnetBlock, + res3: ResnetBlock, + upsampler: Option>, } impl DecoderBlock { @@ -280,10 +312,10 @@ impl DecoderBlock { if let Some(d) = self.upsampler.as_ref() { let [n_batch, n_channel, height, width] = x.dims(); let x = x - .reshape([n_batch, n_channel, height, 1, width, 1]) - .repeat(3, 2) - .repeat(5, 2) - .reshape([n_batch, n_channel, 2 * height, 2 * width]); + .reshape([n_batch, n_channel, height, 1, width, 1]) + .repeat(3, 2) + .repeat(5, 2) + .reshape([n_batch, n_channel, 2 * height, 2 * width]); d.forward(x) } else { x @@ -291,14 +323,13 @@ impl DecoderBlock { } } - #[derive(Config)] pub struct PaddedConv2dConfig { - channels: [usize; 2], - kernel_size: usize, + channels: [usize; 2], + kernel_size: usize, #[config(default = 1)] - stride: usize, - padding: Padding, + stride: usize, + padding: Padding, } impl PaddedConv2dConfig { @@ -328,57 +359,68 @@ impl PaddedConv2dConfig { let padding = self.padding; PaddedConv2d { - conv, - kernel_size, - stride, - padding, - padding_actual, + conv, + kernel_size, + stride, + padding, + padding_actual, } } } +fn div_roundup(x: usize, y: usize) -> usize { + (x + y - 1) / y +} + #[derive(Module, Debug)] pub struct PaddedConv2d { - conv: Conv2d, - kernel_size: usize, - stride: usize, - padding: Padding, - padding_actual: [usize; 2], + conv: Conv2d, + kernel_size: usize, + stride: usize, + padding: Padding, + padding_actual: [usize; 2], } impl PaddedConv2d { fn forward(&self, x: Tensor) -> Tensor { - println!("{} {} {:?} {:?}", self.kernel_size, self.stride, self.padding, self.padding_actual); + println!( + "{} {} {:?} {:?}", + self.kernel_size, self.stride, self.padding, self.padding_actual + ); let [n_batch, n_channel, height, width] = x.dims(); - let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height - self.kernel_size) / self.stride + 1; - let desired_width = (self.padding.pad_left + self.padding.pad_right + width - self.kernel_size) / self.stride + 1; + let desired_height = (self.padding.pad_top + self.padding.pad_bottom + height + - self.kernel_size) + / self.stride + + 1; + let desired_width = (self.padding.pad_left + self.padding.pad_right + width + - self.kernel_size) + / self.stride + + 1; let skip_vert = (self.padding_actual[0] - self.padding.pad_top) / self.stride; let skip_hor = (self.padding_actual[1] - self.padding.pad_left) / self.stride; - self.conv - .forward(x) - .slice([ - 0..n_batch, - 0..n_channel, - skip_vert..(skip_vert + desired_height), - skip_hor..(skip_hor + desired_width) - ]) + self.conv.forward(x).slice([ + 0..n_batch, + 0..n_channel, + skip_vert..(skip_vert + desired_height), + skip_hor..(skip_hor + desired_width), + ]) } } #[derive(Config, Module, Copy, Debug)] pub struct Padding { - pad_left: usize, - pad_right: usize, - pad_top: usize, + pad_left: usize, + pad_right: usize, + pad_top: usize, pad_bottom: usize, } #[derive(Config)] pub struct MidConfig { - n_channel: usize, + n_channel: usize, } impl MidConfig { @@ -388,18 +430,18 @@ impl MidConfig { let block_2 = ResnetBlockConfig::new(self.n_channel, self.n_channel).init(); Mid { - block_1, - attn, - block_2, + block_1, + attn, + block_2, } } } #[derive(Module, Debug)] pub struct Mid { - block_1: ResnetBlock, - attn: ConvSelfAttentionBlock, - block_2: ResnetBlock, + block_1: ResnetBlock, + attn: ConvSelfAttentionBlock, + block_2: ResnetBlock, } impl Mid { @@ -411,21 +453,24 @@ impl Mid { } } - #[derive(Config)] pub struct ResnetBlockConfig { - in_channels: usize, - out_channels: usize, + in_channels: usize, + out_channels: usize, } impl ResnetBlockConfig { fn init(&self) -> ResnetBlock { let norm1 = GroupNormConfig::new(32, self.in_channels).init(); - let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv1 = Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); let norm2 = GroupNormConfig::new(32, self.out_channels).init(); - let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv2 = Conv2dConfig::new([self.out_channels, self.out_channels], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); let nin_shortcut = if self.in_channels != self.out_channels { - Some( Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init() ) + Some(Conv2dConfig::new([self.in_channels, self.out_channels], [1, 1]).init()) } else { None }; @@ -434,34 +479,37 @@ impl ResnetBlockConfig { let silu2 = SILU::new(); ResnetBlock { - norm1, - silu1, - conv1, - norm2, - silu2, - conv2, - nin_shortcut, + norm1, + silu1, + conv1, + norm2, + silu2, + conv2, + nin_shortcut, } } } #[derive(Module, Debug)] pub struct ResnetBlock { - norm1: GroupNorm, - silu1: SILU, - conv1: Conv2d, - norm2: GroupNorm, - silu2: SILU, - conv2: Conv2d, - nin_shortcut: Option>, + norm1: GroupNorm, + silu1: SILU, + conv1: Conv2d, + norm2: GroupNorm, + silu2: SILU, + conv2: Conv2d, + nin_shortcut: Option>, } impl ResnetBlock { fn forward(&self, x: Tensor) -> Tensor { - let h = self.conv1.forward( self.silu1.forward(self.norm1.forward(x.clone())) ); - let h = self.conv2.forward( self.silu2.forward(self.norm2.forward(h)) ); + let h = self + .conv1 + .forward(self.silu1.forward(self.norm1.forward(x.clone()))); + let h = self + .conv2 + .forward(self.silu2.forward(self.norm2.forward(h))); - if let Some(ns) = self.nin_shortcut.as_ref() { ns.forward(x) + h } else { @@ -472,7 +520,7 @@ impl ResnetBlock { #[derive(Config)] pub struct ConvSelfAttentionBlockConfig { - n_channel: usize, + n_channel: usize, } impl ConvSelfAttentionBlockConfig { @@ -484,22 +532,22 @@ impl ConvSelfAttentionBlockConfig { let proj_out = Conv2dConfig::new([self.n_channel, self.n_channel], [1, 1]).init(); ConvSelfAttentionBlock { - norm, - q, - k, - v, - proj_out, + norm, + q, + k, + v, + proj_out, } } } #[derive(Module, Debug)] pub struct ConvSelfAttentionBlock { - norm: GroupNorm, - q: Conv2d, - k: Conv2d, - v: Conv2d, - proj_out: Conv2d, + norm: GroupNorm, + q: Conv2d, + k: Conv2d, + v: Conv2d, + proj_out: Conv2d, } impl ConvSelfAttentionBlock { @@ -508,9 +556,21 @@ impl ConvSelfAttentionBlock { let h = self.norm.forward(x.clone()); - let q = self.q.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); - let k = self.k.forward(h.clone()).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); - let v = self.v.forward(h).reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); + let q = self + .q + .forward(h.clone()) + .reshape([n_batch, n_channel, height * width]) + .swap_dims(1, 2); + let k = self + .k + .forward(h.clone()) + .reshape([n_batch, n_channel, height * width]) + .swap_dims(1, 2); + let v = self + .v + .forward(h) + .reshape([n_batch, n_channel, height * width]) + .swap_dims(1, 2); let wv = qkv_attention(q, k, v, None, 1) .swap_dims(1, 2) diff --git a/src/model/clip/load.rs b/src/model/clip/load.rs index 5395b52..d8f20d7 100644 --- a/src/model/clip/load.rs +++ b/src/model/clip/load.rs @@ -1,14 +1,11 @@ -use std::error::Error; use burn::tensor::ElementConversion; +use std::error::Error; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; use super::*; @@ -28,7 +25,10 @@ pub fn load_mlp(path: &str, device: &B::Device) -> Result, Bo Ok(mlp) } -pub fn load_multi_head_self_attention(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_multi_head_self_attention( + path: &str, + device: &B::Device, +) -> Result, Box> { let n_head = load_usize::("n_head", path, device)?; let query = load_linear(&format!("{}/{}", path, "query"), device)?; let key = load_linear(&format!("{}/{}", path, "key"), device)?; @@ -46,7 +46,10 @@ pub fn load_multi_head_self_attention(path: &str, device: &B::Device Ok(mhsa) } -pub fn load_residual_decoder_attention_block(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_residual_decoder_attention_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let mlp = load_mlp(&format!("{}/{}", path, "mlp"), device)?; let attn = load_multi_head_self_attention(&format!("{}/{}", path, "attn"), device)?; let attn_ln = load_layer_norm(&format!("{}/{}", path, "attn_ln"), device)?; @@ -64,15 +67,17 @@ pub fn load_residual_decoder_attention_block(path: &str, device: &B: pub fn load_clip(path: &str, device: &B::Device) -> Result, Box> { let token_embedding = load_embedding(&format!("{}/{}", path, "token_embedding"), device)?; - let position_embedding = load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); + let position_embedding = + load_tensor("weight", &format!("{}/position_embedding", path), device)?.into(); let n_layer = load_usize::("n_layer", path, device)?; let mut blocks = (0..n_layer) .into_iter() .map(|i| { load_residual_decoder_attention_block::(&format!("{}/blocks/{}", path, i), device) - }).collect::, _>>()?; - + }) + .collect::, _>>()?; + let layer_norm = load_layer_norm(&format!("{}/{}", path, "layer_norm"), device)?; let clip = CLIP { @@ -81,6 +86,6 @@ pub fn load_clip(path: &str, device: &B::Device) -> Result, blocks: blocks, layer_norm: layer_norm, }; - + Ok(clip) } diff --git a/src/model/clip/mod.rs b/src/model/clip/mod.rs index 9c6e8f9..21c29fe 100644 --- a/src/model/clip/mod.rs +++ b/src/model/clip/mod.rs @@ -1,35 +1,33 @@ pub mod load; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, tensor::{ + activation::{sigmoid, softmax}, backend::Backend, - activation::{softmax, sigmoid}, - module::embedding, - Tensor, - Distribution, - Int, + module::embedding, + Distribution, Int, Tensor, }, }; -use crate::model::attention::{qkv_attention, attn_decoder_mask}; - +use crate::model::attention::{attn_decoder_mask, qkv_attention}; #[derive(Config)] pub struct CLIPConfig { - n_vocab: usize, - n_state: usize, - n_head: usize, - n_ctx: usize, - n_layer: usize, + n_vocab: usize, + n_state: usize, + n_head: usize, + n_ctx: usize, + n_layer: usize, } impl CLIPConfig { pub fn init(&self) -> CLIP { let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); - let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); + let position_embedding = + Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into(); let blocks = (0..self.n_layer) .into_iter() .map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init()) @@ -37,33 +35,35 @@ impl CLIPConfig { let layer_norm = nn::LayerNormConfig::new(self.n_state).init(); CLIP { - token_embedding, - position_embedding, - blocks, - layer_norm, + token_embedding, + position_embedding, + blocks, + layer_norm, } } } - - #[derive(Module, Debug)] pub struct CLIP { - token_embedding: nn::Embedding, - position_embedding: Param>, - blocks: Vec>, - layer_norm: nn::LayerNorm, + token_embedding: nn::Embedding, + position_embedding: Param>, + blocks: Vec>, + layer_norm: nn::LayerNorm, } impl CLIP { pub fn forward(&self, x: Tensor) -> Tensor { let [n_batch, seq_len] = x.dims(); - + let mask = attn_decoder_mask(seq_len, &x.device()); - let embedded = self.token_embedding.forward(x) - + self.position_embedding.val().slice([0..seq_len]).unsqueeze(); - + let embedded = self.token_embedding.forward(x) + + self + .position_embedding + .val() + .slice([0..seq_len]) + .unsqueeze(); + let mut x = embedded; for block in &self.blocks { x = block.forward(x, mask.clone()); @@ -73,37 +73,35 @@ impl CLIP { } } - - #[derive(Config)] pub struct ResidualDecoderAttentionBlockConfig { - n_state: usize, - n_head: usize, + n_state: usize, + n_head: usize, } impl ResidualDecoderAttentionBlockConfig { pub fn init(&self) -> ResidualDecoderAttentionBlock { let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(); let attn_ln = nn::LayerNormConfig::new(self.n_state).init(); - + let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init(); let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(); ResidualDecoderAttentionBlock { - attn, - attn_ln, - mlp, - mlp_ln, + attn, + attn_ln, + mlp, + mlp_ln, } } } #[derive(Module, Debug)] pub struct ResidualDecoderAttentionBlock { - attn: MultiHeadSelfAttention, - attn_ln: nn::LayerNorm, - mlp: MLP, - mlp_ln: nn::LayerNorm, + attn: MultiHeadSelfAttention, + attn_ln: nn::LayerNorm, + mlp: MLP, + mlp_ln: nn::LayerNorm, } impl ResidualDecoderAttentionBlock { @@ -117,12 +115,17 @@ impl ResidualDecoderAttentionBlock { #[derive(Config)] pub struct MultiHeadSelfAttentionConfig { n_state: usize, - n_head: usize, + n_head: usize, } impl MultiHeadSelfAttentionConfig { fn init(&self) -> MultiHeadSelfAttention { - assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); + assert!( + self.n_state % self.n_head == 0, + "State size {} must be a multiple of head size {}", + self.n_state, + self.n_head + ); let n_head = self.n_head; let query = nn::LinearConfig::new(self.n_state, self.n_state).init(); @@ -130,23 +133,23 @@ impl MultiHeadSelfAttentionConfig { let value = nn::LinearConfig::new(self.n_state, self.n_state).init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); - MultiHeadSelfAttention { - n_head, - query, - key, - value, - out + MultiHeadSelfAttention { + n_head, + query, + key, + value, + out, } } } #[derive(Module, Debug)] pub struct MultiHeadSelfAttention { - n_head: usize, - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - out: nn::Linear, + n_head: usize, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + out: nn::Linear, } impl MultiHeadSelfAttention { @@ -161,17 +164,10 @@ impl MultiHeadSelfAttention { } } - - - - - - - #[derive(Config, Debug)] pub struct MLPConfig { - input_size: usize, - hidden_size: usize, + input_size: usize, + hidden_size: usize, } impl MLPConfig { @@ -180,19 +176,15 @@ impl MLPConfig { let gelu = QuickGELU::new(); let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init(); - MLP { - fc1, - gelu, - fc2, - } + MLP { fc1, gelu, fc2 } } } #[derive(Module, Debug)] pub struct MLP { - fc1: nn::Linear, - gelu: QuickGELU, - fc2: nn::Linear, + fc1: nn::Linear, + gelu: QuickGELU, + fc2: nn::Linear, } impl MLP { @@ -217,4 +209,3 @@ impl QuickGELU { x.clone() * sigmoid(x * 1.702) } } - diff --git a/src/model/groupnorm/load.rs b/src/model/groupnorm/load.rs index 2c96206..a57d9e8 100644 --- a/src/model/groupnorm/load.rs +++ b/src/model/groupnorm/load.rs @@ -4,30 +4,34 @@ use crate::model::load::*; use std::error::Error; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; -pub fn load_group_norm(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_group_norm( + path: &str, + device: &B::Device, +) -> Result, Box> { let n_group = load_usize::("n_group", path, device)?.into(); let n_channel = load_usize::("n_channel", path, device)?.into(); let eps = load_f32::("eps", path, device)?.into(); - let gamma = load_tensor::("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into(); - let beta = load_tensor::("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into(); + let gamma = load_tensor::("weight", path, device) + .ok() + .unwrap_or_else(|| Tensor::ones_device([n_channel], device)) + .into(); + let beta = load_tensor::("bias", path, device) + .ok() + .unwrap_or_else(|| Tensor::zeros_device([n_channel], device)) + .into(); - Ok( - GroupNorm { - n_group, - n_channel, - gamma, - beta, - eps, - } - ) -} \ No newline at end of file + Ok(GroupNorm { + n_group, + n_channel, + gamma, + beta, + eps, + }) +} diff --git a/src/model/groupnorm/mod.rs b/src/model/groupnorm/mod.rs index 7dc125d..8d2947d 100644 --- a/src/model/groupnorm/mod.rs +++ b/src/model/groupnorm/mod.rs @@ -1,25 +1,27 @@ pub mod load; use burn::{ - config::Config, + config::Config, module::{Module, Param}, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; #[derive(Config)] pub struct GroupNormConfig { - n_group: usize, - n_channel: usize, + n_group: usize, + n_channel: usize, #[config(default = 1e-5)] - eps: f64, + eps: f64, } impl GroupNormConfig { pub fn init(&self) -> GroupNorm { - assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group); + assert!( + self.n_channel % self.n_group == 0, + "The number of channels {} must be divisible by the number of groups {}", + self.n_channel, + self.n_group + ); let n_per_group = self.n_channel / self.n_group; @@ -29,22 +31,22 @@ impl GroupNormConfig { let eps = self.eps; GroupNorm { - n_group: self.n_group, - n_channel: self.n_channel, - gamma, - beta, - eps, + n_group: self.n_group, + n_channel: self.n_channel, + gamma, + beta, + eps, } } } #[derive(Module, Debug)] pub struct GroupNorm { - n_group: usize, - n_channel: usize, - gamma: Param>, - beta: Param>, - eps: f64, + n_group: usize, + n_channel: usize, + gamma: Param>, + beta: Param>, + eps: f64, } impl GroupNorm { @@ -56,10 +58,17 @@ impl GroupNorm { let mut affine_shape = [1; D]; affine_shape[1] = self.n_channel; - layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps ) - .reshape(shape) - .mul(self.gamma.val().reshape(affine_shape)) - .add(self.beta.val().reshape(affine_shape)) + layernorm( + x.reshape([ + n_batch, + self.n_group, + num_elements / (n_batch * self.n_group), + ]), + self.eps, + ) + .reshape(shape) + .mul(self.gamma.val().reshape(affine_shape)) + .add(self.beta.val().reshape(affine_shape)) } } @@ -68,5 +77,6 @@ pub fn layernorm(x: Tensor, eps: f64) -> Tenso //x.sub(mean).div(var.sqrt().add_scalar(eps)) let u = x.clone() - x.mean_dim(D - 1); - u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() ) -} \ No newline at end of file + u.clone() + .div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt()) +} diff --git a/src/model/load.rs b/src/model/load.rs index 7009939..41ad96e 100644 --- a/src/model/load.rs +++ b/src/model/load.rs @@ -1,36 +1,38 @@ -use std::error::Error; -use std::io::Read; use npy::{self, NpyData}; use num_traits::cast::ToPrimitive; +use std::error::Error; +use std::io::Read; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn::{self, conv}, - tensor::{ - backend::Backend, - Tensor, - Data, - }, + tensor::{backend::Backend, Data, Tensor}, }; use burn::tensor::ElementConversion; -pub fn numpy_to_tensor(numpy_data: NpyData, device: &B::Device) -> Tensor { +pub fn numpy_to_tensor( + numpy_data: NpyData, + device: &B::Device, +) -> Tensor { let mut v = numpy_data.to_vec(); let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); let data: Vec = v[D..].into_iter().map(|e| e.elem()).collect(); - + Tensor::from_data_device(Data::new(data, shape.into()), device) } -pub fn load_tensor(name: &str, path: &str, device: &B::Device) -> Result, Box> { +pub fn load_tensor( + name: &str, + path: &str, + device: &B::Device, +) -> Result, Box> { let tensor_path = format!("{}/{}.npy", path, name); let mut buf = vec![]; - std::fs::File::open(&tensor_path)? - .read_to_end(&mut buf)?; + std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?; let tensor_numpy: NpyData = NpyData::from_bytes(&buf)?; @@ -41,15 +43,26 @@ pub fn load_tensor(name: &str, path: &str, device: & Ok(tensor) } -pub fn load_f32(name: &str, path: &str, device: &B::Device) -> Result> { +pub fn load_f32( + name: &str, + path: &str, + device: &B::Device, +) -> Result> { load_tensor::(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) } -pub fn load_usize(name: &str, path: &str, device: &B::Device) -> Result> { +pub fn load_usize( + name: &str, + path: &str, + device: &B::Device, +) -> Result> { load_tensor::(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) } -pub fn load_linear(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_linear( + path: &str, + device: &B::Device, +) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; let bias = load_tensor::("bias", path, device).ok(); @@ -62,7 +75,10 @@ pub fn load_linear(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_embedding( + path: &str, + device: &B::Device, +) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; let [n_vocab, n_state] = weight.dims(); @@ -74,7 +90,10 @@ pub fn load_embedding(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_layer_norm( + path: &str, + device: &B::Device, +) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; let bias = load_tensor::("bias", path, device)?; let eps = load_f32::("eps", path, device)? as f64; @@ -84,7 +103,7 @@ pub fn load_layer_norm(path: &str, device: &B::Device) -> Result>::into_record(eps), + epsilon: >::into_record(eps), }; let layer_norm: nn::LayerNorm = nn::LayerNormConfig::new(n_state).init_with(record); @@ -92,20 +111,22 @@ pub fn load_layer_norm(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; let eps = load_f32::("eps", path, device)?.into(); - let rmsnorm = RMSNorm { - weight: weight.into(), + let rmsnorm = RMSNorm { + weight: weight.into(), eps: eps }; - + Ok(rmsnorm) }*/ -pub fn load_conv2d(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_conv2d( + path: &str, + device: &B::Device, +) -> Result, Box> { let weight = load_tensor::("weight", path, device)?; let bias = load_tensor::("bias", path, device).ok(); let has_bias = bias.is_some(); @@ -127,24 +148,24 @@ pub fn load_conv2d(path: &str, device: &B::Device) -> Result>::into_record(stride), - kernel_size: <[usize; 2] as Module>::into_record(kernel_size), - dilation: <[usize; 2] as Module>::into_record(dilation), + stride: <[usize; 2] as Module>::into_record(stride), + kernel_size: <[usize; 2] as Module>::into_record(kernel_size), + dilation: <[usize; 2] as Module>::into_record(dilation), groups: >::into_record(n_group), - padding: >::into_record(padding.clone()), + padding: >::into_record(padding.clone()), }; - let conv2d: conv::Conv2d = conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size) - .with_stride(stride) - .with_dilation(dilation) - .with_groups(n_group) - .with_padding(padding) - .with_bias(has_bias) - .init_with(record); + let conv2d: conv::Conv2d = + conv::Conv2dConfig::new([n_channels_in, n_channels_out], kernel_size) + .with_stride(stride) + .with_dilation(dilation) + .with_groups(n_group) + .with_padding(padding) + .with_bias(has_bias) + .init_with(record); Ok(conv2d) } @@ -164,4 +185,4 @@ pub fn tensor_to_array(x: Tensor) -> [usize; N } arr -} \ No newline at end of file +} diff --git a/src/model/mod.rs b/src/model/mod.rs index 832222e..b89f764 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,11 +1,11 @@ pub mod stablediffusion; pub mod autoencoder; -pub mod unet; pub mod clip; +pub mod unet; -pub mod silu; -pub mod groupnorm; pub mod attention; +pub mod groupnorm; +pub mod silu; -pub mod load; \ No newline at end of file +pub mod load; diff --git a/src/model/silu.rs b/src/model/silu.rs index 47da5a8..5766349 100644 --- a/src/model/silu.rs +++ b/src/model/silu.rs @@ -1,13 +1,8 @@ use burn::{ module::Module, - tensor::{ - backend::Backend, - activation::sigmoid, - Tensor, - }, + tensor::{activation::sigmoid, backend::Backend, Tensor}, }; - #[derive(Module, Clone, Debug)] pub struct SILU {} @@ -19,4 +14,4 @@ impl SILU { pub fn forward(&self, x: Tensor) -> Tensor { x.clone() * sigmoid(x) } -} \ No newline at end of file +} diff --git a/src/model/stablediffusion/load.rs b/src/model/stablediffusion/load.rs index 405ca8c..de31fab 100644 --- a/src/model/stablediffusion/load.rs +++ b/src/model/stablediffusion/load.rs @@ -1,20 +1,22 @@ -use std::error::Error; use burn::tensor::ElementConversion; +use std::error::Error; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; use super::*; -use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::load_unet, clip::load::load_clip}; +use crate::model::{ + autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet, +}; -pub fn load_stable_diffusion(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_stable_diffusion( + path: &str, + device: &B::Device, +) -> Result, Box> { let n_steps = load_usize::("n_steps", path, device)?; let alpha_cumulative_products = load_tensor::("alphas_cumprod", path, device)?.into(); let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; @@ -22,11 +24,10 @@ pub fn load_stable_diffusion(path: &str, device: &B::Device) -> Resu let clip = load_clip(&format!("{}/{}", path, "clip"), device)?; Ok(StableDiffusion { - n_steps, - alpha_cumulative_products, - autoencoder, - diffusion, - clip, + n_steps, + alpha_cumulative_products, + autoencoder, + diffusion, + clip, }) } - diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index fb3dc0a..3c9ffb3 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -1,30 +1,20 @@ pub mod load; use burn::{ - config::Config, + config::Config, module::{Module, Param}, - tensor::{ - backend::Backend, - Tensor, - Int, - Float, - BasicOps, - Data, - Distribution, - }, + tensor::{backend::Backend, BasicOps, Data, Distribution, Float, Int, Tensor}, }; use num_traits::ToPrimitive; use super::autoencoder::{Autoencoder, AutoencoderConfig}; +use super::clip::{CLIPConfig, CLIP}; use super::unet::{UNet, UNetConfig}; -use super::clip::{CLIP, CLIPConfig}; use crate::tokenizer::SimpleTokenizer; #[derive(Config)] -pub struct StableDiffusionConfig { - -} +pub struct StableDiffusionConfig {} impl StableDiffusionConfig { pub fn init(&self) -> StableDiffusion { @@ -36,29 +26,40 @@ impl StableDiffusionConfig { let clip = CLIPConfig::new(49408, 768, 12, 77, 12).init(); StableDiffusion { - n_steps, - alpha_cumulative_products, - autoencoder, - diffusion, - clip, + n_steps, + alpha_cumulative_products, + autoencoder, + diffusion, + clip, } } } #[derive(Module, Debug)] pub struct StableDiffusion { - n_steps: usize, - alpha_cumulative_products: Param>, - autoencoder: Autoencoder, - diffusion: UNet, - clip: CLIP, + n_steps: usize, + alpha_cumulative_products: Param>, + autoencoder: Autoencoder, + diffusion: UNet, + clip: CLIP, } impl StableDiffusion { - pub fn sample_image(&self, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64, n_steps: usize) -> Vec> { + pub fn sample_image( + &self, + context: Tensor, + unconditional_context: Tensor, + unconditional_guidance_scale: f64, + n_steps: usize, + ) -> Vec> { let [n_batch, _, _] = context.dims(); - let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps); + let latent = self.sample_latent( + context, + unconditional_context, + unconditional_guidance_scale, + n_steps, + ); self.latent_to_image(latent) } @@ -71,7 +72,7 @@ impl StableDiffusion { let width = 512; let num_elements_per_image = n_channel * height * width; - // correct size and scale and reorder to + // correct size and scale and reorder to let image = (image + 1.0) / 2.0; let image = image .reshape([n_batch, n_channel, height, width]) @@ -79,19 +80,29 @@ impl StableDiffusion { .swap_dims(2, 3) .mul_scalar(255.0); - let flattened: Vec<_> = image. - into_data(). - value; + let flattened: Vec<_> = image.into_data().value; - (0..n_batch).into_iter().map(|b| { - let start = b * num_elements_per_image; - let end = start + num_elements_per_image; + (0..n_batch) + .into_iter() + .map(|b| { + let start = b * num_elements_per_image; + let end = start + num_elements_per_image; - flattened[start..end].into_iter().map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()).collect() - }).collect() + flattened[start..end] + .into_iter() + .map(|v| v.to_f64().unwrap().min(255.0).max(0.0).to_u8().unwrap()) + .collect() + }) + .collect() } - pub fn sample_latent(&self, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor { + pub fn sample_latent( + &self, + context: Tensor, + unconditional_context: Tensor, + unconditional_guidance_scale: f64, + n_steps: usize, + ) -> Tensor { let device = context.device(); let step_size = self.n_steps / n_steps; @@ -99,7 +110,8 @@ impl StableDiffusion { let [n_batches, _, _] = context.dims(); let gen_noise = || { - Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)).to_device(&device) + Tensor::random([n_batches, 4, 64, 64], Distribution::Normal(0.0, 1.0)) + .to_device(&device) }; let sigma = 0.0; // Use deterministic diffusion @@ -107,10 +119,21 @@ impl StableDiffusion { let mut latent = gen_noise(); for t in (0..self.n_steps).rev().step_by(step_size) { - let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap(); + let current_alpha: f64 = self + .alpha_cumulative_products + .val() + .slice([t..t + 1]) + .into_scalar() + .to_f64() + .unwrap(); let prev_alpha: f64 = if t >= step_size { let i = t - step_size; - self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap() + self.alpha_cumulative_products + .val() + .slice([i..i + 1]) + .into_scalar() + .to_f64() + .unwrap() } else { 1.0 }; @@ -118,7 +141,13 @@ impl StableDiffusion { let sqrt_noise = (1.0 - current_alpha).sqrt(); let timestep = Tensor::from_ints([t as i32]).to_device(&device); - let pred_noise = self.forward_diffuser(latent.clone(), timestep, context.clone(), unconditional_context.clone(), unconditional_guidance_scale); + let pred_noise = self.forward_diffuser( + latent.clone(), + timestep, + context.clone(), + unconditional_context.clone(), + unconditional_guidance_scale, + ); let predx0 = (latent - pred_noise.clone() * sqrt_noise) / current_alpha.sqrt(); let dir_latent = pred_noise * (1.0 - prev_alpha - sigma * sigma).sqrt(); @@ -129,32 +158,36 @@ impl StableDiffusion { latent } - fn forward_diffuser(&self, latent: Tensor, timestep: Tensor, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64) -> Tensor { + fn forward_diffuser( + &self, + latent: Tensor, + timestep: Tensor, + context: Tensor, + unconditional_context: Tensor, + unconditional_guidance_scale: f64, + ) -> Tensor { let [n_batch, _, _, _] = latent.dims(); //let latent = latent.repeat(0, 2); let unconditional_latent = self.diffusion.forward( - latent.clone(), - timestep.clone(), - unconditional_context.unsqueeze().repeat(0, n_batch) + latent.clone(), + timestep.clone(), + unconditional_context.unsqueeze().repeat(0, n_batch), ); - let conditional_latent = self.diffusion.forward( - latent, - timestep, - context - ); + let conditional_latent = self.diffusion.forward(latent, timestep, context); /*let latent = self.diffusion.forward( - latent.repeat(0, 2), - timestep.repeat(0, 2), + latent.repeat(0, 2), + timestep.repeat(0, 2), Tensor::cat(vec![unconditional_context.unsqueeze::<3>(), context], 0) ); let unconditional_latent = latent.clone().slice([0..n_batch]); let conditional_latent = latent.slice([n_batch..2 * n_batch]);*/ - unconditional_latent.clone() + (conditional_latent - unconditional_latent) * unconditional_guidance_scale + unconditional_latent.clone() + + (conditional_latent - unconditional_latent) * unconditional_guidance_scale } pub fn unconditional_context(&self, tokenizer: &SimpleTokenizer) -> Tensor { @@ -164,17 +197,25 @@ impl StableDiffusion { pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor { let device = &self.clip.devices()[0]; let text = format!("<|startoftext|>{}<|endoftext|>", text); - let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect(); + let tokenized: Vec<_> = tokenizer + .encode(&text) + .into_iter() + .map(|v| v as i32) + .collect(); - self.clip.forward(Tensor::from_ints(&tokenized[..]).to_device(device).unsqueeze()) + self.clip.forward( + Tensor::from_ints(&tokenized[..]) + .to_device(device) + .unsqueeze(), + ) } } -use crate::helper::to_float; use std::f64::consts::PI; fn cosine_schedule(n_steps: usize) -> Tensor { - to_float(Tensor::arange(1..n_steps + 1)) + Tensor::arange(1..n_steps + 1) + .float() .mul_scalar(PI * 0.5 / n_steps as f64) .cos() } @@ -185,12 +226,12 @@ fn offset_cosine_schedule(n_steps: usize) -> Tensor { let start_angle = max_signal_rate.acos(); let end_angle = min_signal_rate.acos(); - let times = Tensor::arange(1..n_steps + 1); + let times = Tensor::arange(1..n_steps + 1).float(); - let diffusion_angles = to_float(times) * ( (end_angle - start_angle) / n_steps as f64) + start_angle; + let diffusion_angles = times * ((end_angle - start_angle) / n_steps as f64) + start_angle; diffusion_angles.cos() } fn offset_cosine_schedule_cumprod(n_steps: usize) -> Tensor { offset_cosine_schedule::(n_steps).powf(2.0) -} \ No newline at end of file +} diff --git a/src/model/unet/load.rs b/src/model/unet/load.rs index 7fe0ab0..b2821dd 100644 --- a/src/model/unet/load.rs +++ b/src/model/unet/load.rs @@ -4,19 +4,19 @@ use crate::model::load::*; use std::error::Error; use burn::{ - config::Config, + config::Config, module::{Module, Param}, nn, - tensor::{ - backend::Backend, - Tensor, - }, + tensor::{backend::Backend, Tensor}, }; use super::*; use crate::model::groupnorm::load::load_group_norm; -pub fn load_res_block(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_res_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let norm_in = load_group_norm::(&format!("{}/{}", path, "norm_in"), device)?; let conv_in = load_conv2d::(&format!("{}/{}", path, "conv_in"), device)?; let lin_embed = load_linear::(&format!("{}/{}", path, "lin_embed"), device)?; @@ -26,12 +26,12 @@ pub fn load_res_block(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_multi_head_attention( + path: &str, + device: &B::Device, +) -> Result, Box> { let n_head = load_usize::("n_head", path, device)?; let query = load_linear::(&format!("{}/{}", path, "query"), device)?; let key = load_linear::(&format!("{}/{}", path, "key"), device)?; @@ -53,11 +56,10 @@ pub fn load_multi_head_attention(path: &str, device: &B::Device) -> value: value, out: out, }; - + Ok(multi_head_attention) } - pub fn load_geglu(path: &str, device: &B::Device) -> Result, Box> { let proj = load_linear::(&format!("{}/{}", path, "proj"), device)?; @@ -65,11 +67,10 @@ pub fn load_geglu(path: &str, device: &B::Device) -> Result proj: proj, gelu: GELU::new(), // Assuming GELU::new() initializes a new GELU struct }; - + Ok(geglue) } - pub fn load_mlp(path: &str, device: &B::Device) -> Result, Box> { let geglu = load_geglu::(&format!("{}/{}", path, "geglu"), device)?; let lin = load_linear::(&format!("{}/{}", path, "lin"), device)?; @@ -78,12 +79,14 @@ pub fn load_mlp(path: &str, device: &B::Device) -> Result, Bo geglu: geglu, lin: lin, }; - + Ok(mlp) } - -pub fn load_transformer_block(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_transformer_block( + path: &str, + device: &B::Device, +) -> Result, Box> { let norm1 = load_layer_norm::(&format!("{}/{}", path, "norm1"), device)?; let attn1 = load_multi_head_attention::(&format!("{}/{}", path, "attn1"), device)?; let norm2 = load_layer_norm::(&format!("{}/{}", path, "norm2"), device)?; @@ -99,12 +102,14 @@ pub fn load_transformer_block(path: &str, device: &B::Device) -> Res norm3: norm3, mlp: mlp, }; - + Ok(transformer_block) } - -pub fn load_spatial_transformer(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_spatial_transformer( + path: &str, + device: &B::Device, +) -> Result, Box> { let norm = load_group_norm::(&format!("{}/{}", path, "norm"), device)?; let proj_in = load_conv2d::(&format!("{}/{}", path, "proj_in"), device)?; let transformer = load_transformer_block::(&format!("{}/{}", path, "transformer"), device)?; @@ -116,28 +121,35 @@ pub fn load_spatial_transformer(path: &str, device: &B::Device) -> R transformer: transformer, proj_out: proj_out, }; - + Ok(spatial_transformer) } - -pub fn load_upsample(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_upsample( + path: &str, + device: &B::Device, +) -> Result, Box> { let conv = load_conv2d::(&format!("{}/{}", path, "conv"), device)?; - let upsample = Upsample { - conv: conv, - }; - + let upsample = Upsample { conv: conv }; + Ok(upsample) } -pub fn load_downsample(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_downsample( + path: &str, + device: &B::Device, +) -> Result, Box> { load_conv2d(path, device) } -pub fn load_res_transformer_res(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_res_transformer_res( + path: &str, + device: &B::Device, +) -> Result, Box> { let res1 = load_res_block::(&format!("{}/{}", path, "res1"), device)?; // Assuming load_res_block function - let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; + let transformer = + load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let res2 = load_res_block::(&format!("{}/{}", path, "res2"), device)?; let res_transformer_res = ResTransformerRes { @@ -145,13 +157,17 @@ pub fn load_res_transformer_res(path: &str, device: &B::Device) -> R transformer: transformer, res2: res2, }; - + Ok(res_transformer_res) } -pub fn load_res_transformer_upsample(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_res_transformer_upsample( + path: &str, + device: &B::Device, +) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; - let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; + let transformer = + load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let upsample = load_upsample::(&format!("{}/{}", path, "upsample"), device)?; let res_transformer_upsample = ResTransformerUpsample { @@ -159,12 +175,14 @@ pub fn load_res_transformer_upsample(path: &str, device: &B::Device) transformer: transformer, upsample: upsample, }; - + Ok(res_transformer_upsample) } - -pub fn load_res_upsample(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_res_upsample( + path: &str, + device: &B::Device, +) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; let upsample = load_upsample::(&format!("{}/{}", path, "upsample"), device)?; @@ -172,25 +190,30 @@ pub fn load_res_upsample(path: &str, device: &B::Device) -> Result(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_res_transformer( + path: &str, + device: &B::Device, +) -> Result, Box> { let res = load_res_block::(&format!("{}/{}", path, "res"), device)?; - let transformer = load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; + let transformer = + load_spatial_transformer::(&format!("{}/{}", path, "transformer"), device)?; let res_transformer = ResTransformer { res: res, transformer: transformer, }; - + Ok(res_transformer) } - -pub fn load_unet_input_blocks(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_unet_input_blocks( + path: &str, + device: &B::Device, +) -> Result, Box> { let conv = load_conv2d::(&format!("{}/{}", path, "conv"), device)?; let rt1 = load_res_transformer::(&format!("{}/{}", path, "rt1"), device)?; let rt2 = load_res_transformer::(&format!("{}/{}", path, "rt2"), device)?; @@ -218,11 +241,14 @@ pub fn load_unet_input_blocks(path: &str, device: &B::Device) -> Res r1: r1, r2: r2, }; - + Ok(unet_input_blocks) } -pub fn load_unet_output_blocks(path: &str, device: &B::Device) -> Result, Box> { +pub fn load_unet_output_blocks( + path: &str, + device: &B::Device, +) -> Result, Box> { let r1 = load_res_block::(&format!("{}/{}", path, "r1"), device)?; let r2 = load_res_block::(&format!("{}/{}", path, "r2"), device)?; let ru = load_res_upsample::(&format!("{}/{}", path, "ru"), device)?; @@ -252,14 +278,16 @@ pub fn load_unet_output_blocks(path: &str, device: &B::Device) -> Re }) } - pub fn load_unet(path: &str, device: &B::Device) -> Result, Box> { let lin1_time_embed = load_linear::(&format!("{}/{}", path, "lin1_time_embed"), device)?; let silu_time_embed = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let lin2_time_embed = load_linear::(&format!("{}/{}", path, "lin2_time_embed"), device)?; - let input_blocks = load_unet_input_blocks::(&format!("{}/{}", path, "input_blocks"), device)?; - let middle_block = load_res_transformer_res::(&format!("{}/{}", path, "middle_block"), device)?; - let output_blocks = load_unet_output_blocks::(&format!("{}/{}", path, "output_blocks"), device)?; + let input_blocks = + load_unet_input_blocks::(&format!("{}/{}", path, "input_blocks"), device)?; + let middle_block = + load_res_transformer_res::(&format!("{}/{}", path, "middle_block"), device)?; + let output_blocks = + load_unet_output_blocks::(&format!("{}/{}", path, "output_blocks"), device)?; let norm_out = load_group_norm::(&format!("{}/{}", path, "norm_out"), device)?; let silu_out = SILU::new(); // Assuming SILU::new() initializes a new SILU struct let conv_out = load_conv2d::(&format!("{}/{}", path, "conv_out"), device)?; diff --git a/src/model/unet/mod.rs b/src/model/unet/mod.rs index d22c170..5cec18a 100644 --- a/src/model/unet/mod.rs +++ b/src/model/unet/mod.rs @@ -1,34 +1,34 @@ pub mod load; use burn::{ - config::Config, + config::Config, module::{Module, Param}, - nn::{self, PaddingConfig2d, GELU, conv::{Conv2d, Conv2dConfig}}, - tensor::{ - backend::Backend, - activation::softmax, - module::embedding, - Tensor, - Distribution, - Int, + nn::{ + self, + conv::{Conv2d, Conv2dConfig}, + PaddingConfig2d, GELU, }, + tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor}, }; -use super::silu::*; use super::groupnorm::*; -use crate::helper::to_float; +use super::silu::*; use super::attention::qkv_attention; - -fn timestep_embedding(timesteps: Tensor, dim: usize, max_period: usize) -> Tensor { +fn timestep_embedding( + timesteps: Tensor, + dim: usize, + max_period: usize, +) -> Tensor { let half = dim / 2; - let freqs = ( to_float(Tensor::arange_device(0..half, ×teps.device())) * (-(max_period as f64).ln() / half as f64 ) ).exp(); - let args = to_float(timesteps) * freqs; + let freqs = (Tensor::arange_device(0..half, ×teps.device()).float() + * (-(max_period as f64).ln() / half as f64)) + .exp(); + let args = timesteps.float() * freqs; Tensor::cat(vec![args.clone().cos(), args.sin()], 0).unsqueeze() } - #[derive(Config)] pub struct UNetConfig {} @@ -39,7 +39,9 @@ impl UNetConfig { let lin2_time_embed = nn::LinearConfig::new(1280, 1280).init(); let input_blocks = UNetInputBlocks { - conv: Conv2dConfig::new([4, 320], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(), + conv: Conv2dConfig::new([4, 320], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(), rt1: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), rt2: ResTransformerConfig::new(320, 1280, 320, 768, 8).init(), d1: DownsampleConfig::new(320).init(), @@ -52,7 +54,7 @@ impl UNetConfig { r1: ResBlockConfig::new(1280, 1280, 1280).init(), r2: ResBlockConfig::new(1280, 1280, 1280).init(), }; - + let middle_block = ResTransformerResConfig::new(1280, 1280, 1280, 768, 8).init(); let output_blocks = UNetOutputBlocks { @@ -72,37 +74,44 @@ impl UNetConfig { let norm_out = GroupNormConfig::new(32, 320).init(); let silu_out = SILU::new(); - let conv_out = Conv2dConfig::new([320, 4], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_out = Conv2dConfig::new([320, 4], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); UNet { - lin1_time_embed, - silu_time_embed, - lin2_time_embed, - input_blocks, - middle_block, - output_blocks, - norm_out, - silu_out, - conv_out, + lin1_time_embed, + silu_time_embed, + lin2_time_embed, + input_blocks, + middle_block, + output_blocks, + norm_out, + silu_out, + conv_out, } } } #[derive(Module, Debug)] pub struct UNet { - lin1_time_embed: nn::Linear, - silu_time_embed: SILU, - lin2_time_embed: nn::Linear, - input_blocks: UNetInputBlocks, - middle_block: ResTransformerRes, - output_blocks: UNetOutputBlocks, - norm_out: GroupNorm, - silu_out: SILU, - conv_out: Conv2d, + lin1_time_embed: nn::Linear, + silu_time_embed: SILU, + lin2_time_embed: nn::Linear, + input_blocks: UNetInputBlocks, + middle_block: ResTransformerRes, + output_blocks: UNetOutputBlocks, + norm_out: GroupNorm, + silu_out: SILU, + conv_out: Conv2d, } impl UNet { - pub fn forward(&self, x: Tensor, timesteps: Tensor, context: Tensor) -> Tensor { + pub fn forward( + &self, + x: Tensor, + timesteps: Tensor, + context: Tensor, + ) -> Tensor { let t_emb = timestep_embedding(timesteps, 320, 10000); let emb = self.lin1_time_embed.forward(t_emb); let emb = self.silu_time_embed.forward(emb); @@ -133,39 +142,27 @@ impl UNet { } } - - #[derive(Module, Debug)] pub struct UNetInputBlocks { - conv: Conv2d, - rt1: ResTransformer, - rt2: ResTransformer, - d1: Downsample, - rt3: ResTransformer, - rt4: ResTransformer, - d2: Downsample, - rt5: ResTransformer, - rt6: ResTransformer, - d3: Downsample, - r1: ResBlock, - r2: ResBlock, + conv: Conv2d, + rt1: ResTransformer, + rt2: ResTransformer, + d1: Downsample, + rt3: ResTransformer, + rt4: ResTransformer, + d2: Downsample, + rt5: ResTransformer, + rt6: ResTransformer, + d3: Downsample, + r1: ResBlock, + r2: ResBlock, } impl UNetInputBlocks { fn as_array(&self) -> [&dyn UNetBlock; 12] { [ - &self.conv, - &self.rt1, - &self.rt2, - &self.d1, - &self.rt3, - &self.rt4, - &self.d2, - &self.rt5, - &self.rt6, - &self.d3, - &self.r1, - &self.r2, + &self.conv, &self.rt1, &self.rt2, &self.d1, &self.rt3, &self.rt4, &self.d2, &self.rt5, + &self.rt6, &self.d3, &self.r1, &self.r2, ] } } @@ -177,67 +174,57 @@ pub struct UNetOutputBlocks { ru: ResUpSample, rt1: ResTransformer, rt2: ResTransformer, - rtu1: ResTransformerUpsample, + rtu1: ResTransformerUpsample, rt3: ResTransformer, rt4: ResTransformer, - rtu2: ResTransformerUpsample, - rt5: ResTransformer, - rt6: ResTransformer, - rt7: ResTransformer, + rtu2: ResTransformerUpsample, + rt5: ResTransformer, + rt6: ResTransformer, + rt7: ResTransformer, } impl UNetOutputBlocks { fn as_array(&self) -> [&dyn UNetBlock; 12] { [ - &self.r1, - &self.r2, - &self.ru, - &self.rt1, - &self.rt2, - &self.rtu1, - &self.rt3, - &self.rt4, - &self.rtu2, - &self.rt5, - &self.rt6, - &self.rt7, + &self.r1, &self.r2, &self.ru, &self.rt1, &self.rt2, &self.rtu1, &self.rt3, &self.rt4, + &self.rtu2, &self.rt5, &self.rt6, &self.rt7, ] } } - - - - trait UNetBlock { fn forward(&self, x: Tensor, emb: Tensor, context: Tensor) -> Tensor; } #[derive(Config)] pub struct ResTransformerConfig { - n_channels_in: usize, - n_channels_embed: usize, - n_channels_out: usize, - n_context_state: usize, - n_head: usize, + n_channels_in: usize, + n_channels_embed: usize, + n_channels_out: usize, + n_context_state: usize, + n_head: usize, } impl ResTransformerConfig { fn init(&self) -> ResTransformer { - let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); - let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); + let res = ResBlockConfig::new( + self.n_channels_in, + self.n_channels_embed, + self.n_channels_out, + ) + .init(); + let transformer = + SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) + .init(); - ResTransformer { - res, - transformer, - } + ResTransformer { res, transformer } } } #[derive(Module, Debug)] pub struct ResTransformer { - res: ResBlock, - transformer: SpatialTransformer, + res: ResBlock, + transformer: SpatialTransformer, } impl UNetBlock for ResTransformer { @@ -250,27 +237,29 @@ impl UNetBlock for ResTransformer { #[derive(Config)] pub struct ResUpSampleConfig { - n_channels_in: usize, - n_channels_embed: usize, - n_channels_out: usize, + n_channels_in: usize, + n_channels_embed: usize, + n_channels_out: usize, } impl ResUpSampleConfig { fn init(&self) -> ResUpSample { - let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); + let res = ResBlockConfig::new( + self.n_channels_in, + self.n_channels_embed, + self.n_channels_out, + ) + .init(); let upsample = UpsampleConfig::new(self.n_channels_out).init(); - ResUpSample { - res, - upsample, - } + ResUpSample { res, upsample } } } #[derive(Module, Debug)] pub struct ResUpSample { - res: ResBlock, - upsample: Upsample, + res: ResBlock, + upsample: Upsample, } impl UNetBlock for ResUpSample { @@ -283,32 +272,39 @@ impl UNetBlock for ResUpSample { #[derive(Config)] pub struct ResTransformerUpsampleConfig { - n_channels_in: usize, - n_channels_embed: usize, - n_channels_out: usize, - n_context_state: usize, - n_head: usize, + n_channels_in: usize, + n_channels_embed: usize, + n_channels_out: usize, + n_context_state: usize, + n_head: usize, } impl ResTransformerUpsampleConfig { fn init(&self) -> ResTransformerUpsample { - let res = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); - let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); + let res = ResBlockConfig::new( + self.n_channels_in, + self.n_channels_embed, + self.n_channels_out, + ) + .init(); + let transformer = + SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) + .init(); let upsample = UpsampleConfig::new(self.n_channels_out).init(); ResTransformerUpsample { - res, - transformer, - upsample, + res, + transformer, + upsample, } } } #[derive(Module, Debug)] pub struct ResTransformerUpsample { - res: ResBlock, - transformer: SpatialTransformer, - upsample: Upsample, + res: ResBlock, + transformer: SpatialTransformer, + upsample: Upsample, } impl UNetBlock for ResTransformerUpsample { @@ -322,32 +318,44 @@ impl UNetBlock for ResTransformerUpsample { #[derive(Config)] pub struct ResTransformerResConfig { - n_channels_in: usize, - n_channels_embed: usize, - n_channels_out: usize, - n_context_state: usize, - n_head: usize, + n_channels_in: usize, + n_channels_embed: usize, + n_channels_out: usize, + n_context_state: usize, + n_head: usize, } impl ResTransformerResConfig { fn init(&self) -> ResTransformerRes { - let res1 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); - let transformer = SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head).init(); - let res2 = ResBlockConfig::new(self.n_channels_in, self.n_channels_embed, self.n_channels_out).init(); + let res1 = ResBlockConfig::new( + self.n_channels_in, + self.n_channels_embed, + self.n_channels_out, + ) + .init(); + let transformer = + SpatialTransformerConfig::new(self.n_channels_out, self.n_context_state, self.n_head) + .init(); + let res2 = ResBlockConfig::new( + self.n_channels_in, + self.n_channels_embed, + self.n_channels_out, + ) + .init(); ResTransformerRes { - res1, - transformer, - res2, + res1, + transformer, + res2, } } } #[derive(Module, Debug)] pub struct ResTransformerRes { - res1: ResBlock, - transformer: SpatialTransformer, - res2: ResBlock, + res1: ResBlock, + transformer: SpatialTransformer, + res2: ResBlock, } impl UNetBlock for ResTransformerRes { @@ -359,11 +367,9 @@ impl UNetBlock for ResTransformerRes { } } - - #[derive(Config)] pub struct UpsampleConfig { - n_channels: usize, + n_channels: usize, } impl UpsampleConfig { @@ -372,25 +378,23 @@ impl UpsampleConfig { .with_padding(PaddingConfig2d::Explicit(1, 1)) .init(); - Upsample { - conv, - } + Upsample { conv } } } #[derive(Module, Debug)] pub struct Upsample { - conv: Conv2d, + conv: Conv2d, } impl Upsample { fn forward(&self, x: Tensor) -> Tensor { let [n_batch, n_channel, height, width] = x.dims(); let x = x - .reshape([n_batch, n_channel, height, 1, width, 1]) - .repeat(3, 2) - .repeat(5, 2) - .reshape([n_batch, n_channel, 2 * height, 2 * width]); + .reshape([n_batch, n_channel, height, 1, width, 1]) + .repeat(3, 2) + .repeat(5, 2) + .reshape([n_batch, n_channel, 2 * height, 2 * width]); self.conv.forward(x) } } @@ -403,7 +407,7 @@ impl UNetBlock for Upsample { #[derive(Config)] pub struct DownsampleConfig { - n_channels: usize, + n_channels: usize, } impl DownsampleConfig { @@ -423,38 +427,36 @@ impl UNetBlock for Conv2d { } } - - - #[derive(Config)] pub struct SpatialTransformerConfig { - n_channels: usize, - n_context_state: usize, - n_head: usize, + n_channels: usize, + n_context_state: usize, + n_head: usize, } impl SpatialTransformerConfig { fn init(&self) -> SpatialTransformer { let norm = GroupNormConfig::new(32, self.n_channels).init(); let proj_in = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); - let transformer = TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); + let transformer = + TransformerBlockConfig::new(self.n_channels, self.n_context_state, self.n_head).init(); let proj_out = Conv2dConfig::new([self.n_channels, self.n_channels], [1, 1]).init(); SpatialTransformer { - norm, - proj_in, - transformer, - proj_out, + norm, + proj_in, + transformer, + proj_out, } } } #[derive(Module, Debug)] pub struct SpatialTransformer { - norm: GroupNorm, - proj_in: Conv2d, - transformer: TransformerBlock, - proj_out: Conv2d, + norm: GroupNorm, + proj_in: Conv2d, + transformer: TransformerBlock, + proj_out: Conv2d, } impl SpatialTransformer { @@ -465,9 +467,13 @@ impl SpatialTransformer { let x = self.norm.forward(x); let x = self.proj_in.forward(x); - let x = x.reshape([n_batch, n_channel, height * width]).swap_dims(1, 2); + let x = x + .reshape([n_batch, n_channel, height * width]) + .swap_dims(1, 2); - let x = self.transformer.forward(x, context) + let x = self + .transformer + .forward(x, context) .swap_dims(1, 2) .reshape([n_batch, n_channel, height, width]); @@ -475,18 +481,11 @@ impl SpatialTransformer { } } - - - - - - - #[derive(Config)] pub struct TransformerBlockConfig { - n_state: usize, - n_context_state: usize, - n_head: usize, + n_state: usize, + n_context_state: usize, + n_head: usize, } impl TransformerBlockConfig { @@ -494,44 +493,44 @@ impl TransformerBlockConfig { let norm1 = nn::LayerNormConfig::new(self.n_state).init(); let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init(); let norm2 = nn::LayerNormConfig::new(self.n_state).init(); - let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); + let attn2 = + MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init(); let norm3 = nn::LayerNormConfig::new(self.n_state).init(); let mlp = MLPConfig::new(self.n_state, 4).init(); TransformerBlock { - norm1, - attn1, - norm2, - attn2, - norm3, - mlp, + norm1, + attn1, + norm2, + attn2, + norm3, + mlp, } } } #[derive(Module, Debug)] pub struct TransformerBlock { - norm1: nn::LayerNorm, - attn1: MultiHeadAttention, - norm2: nn::LayerNorm, - attn2: MultiHeadAttention, - norm3: nn::LayerNorm, - mlp: MLP, + norm1: nn::LayerNorm, + attn1: MultiHeadAttention, + norm2: nn::LayerNorm, + attn2: MultiHeadAttention, + norm3: nn::LayerNorm, + mlp: MLP, } impl TransformerBlock { fn forward(&self, x: Tensor, context: Tensor) -> Tensor { - let x = x.clone() + self.attn1.forward( self.norm1.forward(x), None); - let x = x.clone() + self.attn2.forward( self.norm2.forward(x), Some(context)); - x.clone() + self.mlp.forward( self.norm3.forward(x) ) + let x = x.clone() + self.attn1.forward(self.norm1.forward(x), None); + let x = x.clone() + self.attn2.forward(self.norm2.forward(x), Some(context)); + x.clone() + self.mlp.forward(self.norm3.forward(x)) } } - #[derive(Config)] pub struct MLPConfig { - n_state: usize, - mult: usize, + n_state: usize, + mult: usize, } impl MLPConfig { @@ -540,30 +539,26 @@ impl MLPConfig { let geglu = GEGLUConfig::new(self.n_state, n_state_hidden).init(); let lin = nn::LinearConfig::new(n_state_hidden, self.n_state).init(); - MLP { - geglu, - lin, - } + MLP { geglu, lin } } } #[derive(Module, Debug)] pub struct MLP { - geglu: GEGLU, - lin: nn::Linear, + geglu: GEGLU, + lin: nn::Linear, } impl MLP { pub fn forward(&self, x: Tensor) -> Tensor { - self.lin.forward( self.geglu.forward(x) ) + self.lin.forward(self.geglu.forward(x)) } } - #[derive(Config)] pub struct GEGLUConfig { - n_state_in: usize, - n_state_out: usize, + n_state_in: usize, + n_state_out: usize, } impl GEGLUConfig { @@ -571,17 +566,14 @@ impl GEGLUConfig { let proj = nn::LinearConfig::new(self.n_state_in, 2 * self.n_state_out).init(); let gelu = GELU::new(); - GEGLU { - proj, - gelu, - } + GEGLU { proj, gelu } } } #[derive(Module, Debug)] pub struct GEGLU { - proj: nn::Linear, - gelu: GELU, + proj: nn::Linear, + gelu: GELU, } impl GEGLU { @@ -591,51 +583,60 @@ impl GEGLU { let n_state_out = n_state / 2; - let x = projected.clone().slice([0..n_batch, 0..n_ctx, 0..n_state_out]); + let x = projected + .clone() + .slice([0..n_batch, 0..n_ctx, 0..n_state_out]); let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]); x * self.gelu.forward(gate) } } - - - - #[derive(Config)] pub struct MultiHeadAttentionConfig { - n_state: usize, - n_context_state: usize, - n_head: usize, + n_state: usize, + n_context_state: usize, + n_head: usize, } impl MultiHeadAttentionConfig { fn init(&self) -> MultiHeadAttention { - assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head); + assert!( + self.n_state % self.n_head == 0, + "State size {} must be a multiple of head size {}", + self.n_state, + self.n_head + ); let n_head = self.n_head; - let query = nn::LinearConfig::new(self.n_state, self.n_state).with_bias(false).init(); - let key = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); - let value = nn::LinearConfig::new(self.n_context_state, self.n_state).with_bias(false).init(); + let query = nn::LinearConfig::new(self.n_state, self.n_state) + .with_bias(false) + .init(); + let key = nn::LinearConfig::new(self.n_context_state, self.n_state) + .with_bias(false) + .init(); + let value = nn::LinearConfig::new(self.n_context_state, self.n_state) + .with_bias(false) + .init(); let out = nn::LinearConfig::new(self.n_state, self.n_state).init(); - MultiHeadAttention { - n_head, - query, - key, - value, - out + MultiHeadAttention { + n_head, + query, + key, + value, + out, } } } #[derive(Module, Debug)] pub struct MultiHeadAttention { - n_head: usize, - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - out: nn::Linear, + n_head: usize, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + out: nn::Linear, } impl MultiHeadAttention { @@ -652,74 +653,61 @@ impl MultiHeadAttention { } } - - - - - - - - - - - - - - - #[derive(Config)] pub struct ResBlockConfig { - n_channels_in: usize, - n_channels_embed: usize, - n_channels_out: usize, + n_channels_in: usize, + n_channels_embed: usize, + n_channels_out: usize, } - impl ResBlockConfig { fn init(&self) -> ResBlock { let norm_in = GroupNormConfig::new(32, self.n_channels_in).init(); let silu_in = SILU::new(); - let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_in = Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); let silu_embed = SILU::new(); let lin_embed = nn::LinearConfig::new(self.n_channels_embed, self.n_channels_out).init(); let norm_out = GroupNormConfig::new(32, self.n_channels_out).init(); let silu_out = SILU::new(); - let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]).with_padding(PaddingConfig2d::Explicit(1, 1)).init(); + let conv_out = Conv2dConfig::new([self.n_channels_out, self.n_channels_out], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 1)) + .init(); let skip_connection = if self.n_channels_in != self.n_channels_out { - Some( Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init() ) + Some(Conv2dConfig::new([self.n_channels_in, self.n_channels_out], [1, 1]).init()) } else { None }; ResBlock { - norm_in, - silu_in, - conv_in, - silu_embed, - lin_embed, - norm_out, - silu_out, - conv_out, - skip_connection, + norm_in, + silu_in, + conv_in, + silu_embed, + lin_embed, + norm_out, + silu_out, + conv_out, + skip_connection, } } } - #[derive(Module, Debug)] pub struct ResBlock { - norm_in: GroupNorm, - silu_in: SILU, - conv_in: Conv2d, - silu_embed: SILU, - lin_embed: nn::Linear, - norm_out: GroupNorm, - silu_out: SILU, - conv_out: Conv2d, - skip_connection: Option>, + norm_in: GroupNorm, + silu_in: SILU, + conv_in: Conv2d, + silu_embed: SILU, + lin_embed: nn::Linear, + norm_out: GroupNorm, + silu_out: SILU, + conv_out: Conv2d, + skip_connection: Option>, } impl ResBlock { @@ -730,7 +718,7 @@ impl ResBlock { let embed_out = self.silu_embed.forward(embed); let embed_out = self.lin_embed.forward(embed_out); - + let [n_batch_embed, n_state_embed] = embed_out.dims(); let h = h + embed_out.reshape([n_batch_embed, n_state_embed, 1, 1]); @@ -751,5 +739,3 @@ impl UNetBlock for ResBlock { self.forward(x, emb) } } - - diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 0e5f223..0ac76bb 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -1,13 +1,14 @@ -use std::collections::HashMap; use regex::Regex; +use std::collections::HashMap; use std::fs::File; use std::io::{self, BufRead}; fn bytes_to_unicode() -> Vec<(u8, char)> { - let mut bs: Vec = ('!' as u8 ..= '~' as u8).into_iter() - .chain( ('¡' as u8..='¬' as u8).into_iter() ) - .chain( ('®' as u8..='ÿ' as u8).into_iter() ) + let mut bs: Vec = ('!' as u8..='~' as u8) + .into_iter() + .chain(('¡' as u8..='¬' as u8).into_iter()) + .chain(('®' as u8..='ÿ' as u8).into_iter()) .collect(); let mut cs: Vec<_> = bs.iter().cloned().map(char::from).collect(); @@ -16,25 +17,21 @@ fn bytes_to_unicode() -> Vec<(u8, char)> { for b in 0u8..=255u8 { if !bs.contains(&b) { bs.push(b); - cs.push( char::from_u32(256 + n).unwrap() ); + cs.push(char::from_u32(256 + n).unwrap()); n += 1; } } bs.into_iter() - .zip( - cs.into_iter() - .map(|c| c.into()) - ).collect() + .zip(cs.into_iter().map(|c| c.into())) + .collect() } fn get_pairs(word: &[String]) -> Vec<(String, String)> { let prev = word.into_iter().cloned(); let next = prev.clone().skip(1); - prev - .zip(next) - .collect() + prev.zip(next).collect() } fn whitespace_clean(text: &str) -> String { @@ -44,24 +41,27 @@ fn whitespace_clean(text: &str) -> String { fn load_merges(path: &str) -> io::Result> { let file = File::open(&path)?; let reader = io::BufReader::new(file); - + let mut merges = Vec::new(); - + for line in reader.lines() { let line = line?; let mut words = line.split_whitespace(); - + if let (Some(word1), Some(word2)) = (words.next(), words.next()) { merges.push((word1.into(), word2.into())); } } - + Ok(merges) } -fn construct_vocab(chars: impl Iterator + Clone, merges: &[(String, String)]) -> Vec { +fn construct_vocab( + chars: impl Iterator + Clone, + merges: &[(String, String)], +) -> Vec { let iter = chars.map(String::from); - let mut vocab: Vec<_> = iter.clone().chain( iter.map(|c| c + "") ).collect(); + let mut vocab: Vec<_> = iter.clone().chain(iter.map(|c| c + "")).collect(); for merge in merges { vocab.push(format!("{}{}", merge.0, merge.1)); @@ -79,7 +79,7 @@ pub struct SimpleTokenizer { decoder: HashMap, bpe_ranks: HashMap<(String, String), u32>, cache: HashMap, - pat: Regex, + pat: Regex, } impl SimpleTokenizer { @@ -87,10 +87,10 @@ impl SimpleTokenizer { let byte_unicode_values = bytes_to_unicode(); let byte_encoder: HashMap<_, _> = byte_unicode_values.iter().cloned().collect(); - let byte_decoder = byte_encoder.iter().map(|(k,v)| (*v,*k)).collect(); + let byte_decoder = byte_encoder.iter().map(|(k, v)| (*v, *k)).collect(); let merges = load_merges("bpe_simple_vocab_16e6.txt")?; - let merges = merges[1..49152-256-2+1].to_vec(); + let merges = merges[1..49152 - 256 - 2 + 1].to_vec(); let vocab = construct_vocab(byte_unicode_values.into_iter().map(|(_, u)| u), &merges[..]); @@ -98,38 +98,39 @@ impl SimpleTokenizer { let decoder: HashMap = encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); let bpe_ranks = merges.iter().cloned().zip((0..).into_iter()).collect(); let cache = HashMap::from([ - ("<|startoftext|>".to_string(), "<|startoftext|>".to_string()), - ("<|endoftext|>".to_string(), "<|endoftext|>".to_string()), + ("<|startoftext|>".to_string(), "<|startoftext|>".to_string()), + ("<|endoftext|>".to_string(), "<|endoftext|>".to_string()), ]); let pat = Regex::new(r"(?i)<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|\p{L}+|\p{N}|[^\s\p{L}\p{N}]+").unwrap(); - Ok( SimpleTokenizer { + Ok(SimpleTokenizer { byte_encoder: byte_encoder, byte_decoder: byte_decoder, encoder: encoder, decoder: decoder, bpe_ranks: bpe_ranks, cache: cache, - pat: pat, - } ) + pat: pat, + }) } pub fn bpe(&self, token: &str) -> String { if let Some(word) = self.cache.get(token) { return word.clone(); } - + let mut word: Vec = token.chars().map(|c| c.to_string()).collect(); word.last_mut().map(|w| *w += ""); let mut pairs = get_pairs(&word); - + if pairs.is_empty() { return format!("{}{}", token, ""); } - + loop { - let bigram = pairs.iter() + let bigram = pairs + .iter() .filter(|pair| self.bpe_ranks.contains_key(pair)) .min_by_key(|&pair| self.bpe_ranks[pair]); @@ -141,14 +142,14 @@ impl SimpleTokenizer { let mut new_word = Vec::new(); let mut i = 0; while i < word.len() { - if let Some( (j, _) ) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) { + if let Some((j, _)) = word.iter().enumerate().skip(i).find(|(_, w)| w == &first) { new_word.extend(word[i..j].iter().cloned()); i = j; } else { new_word.extend(word[i..].iter().cloned()); break; } - + if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second { new_word.push(format!("{}{}", first, second)); i += 2; @@ -157,7 +158,7 @@ impl SimpleTokenizer { i += 1; } } - + word = new_word; if word.len() == 1 { break; @@ -170,7 +171,7 @@ impl SimpleTokenizer { //self.cache.insert(token.into(), word); return word; } - + pub fn encode(&self, text: &str) -> Vec { let cleaned_text = whitespace_clean(text.trim()).to_lowercase(); @@ -178,8 +179,16 @@ impl SimpleTokenizer { for m in self.pat.find_iter(&cleaned_text) { let token = m.as_str(); - let token: String = token.as_bytes().into_iter().map(|b| self.byte_encoder[b]).collect(); - bpe_tokens.extend(self.bpe(&token).split(' ').map(|bpe_token| self.encoder[bpe_token])) + let token: String = token + .as_bytes() + .into_iter() + .map(|b| self.byte_encoder[b]) + .collect(); + bpe_tokens.extend( + self.bpe(&token) + .split(' ') + .map(|bpe_token| self.encoder[bpe_token]), + ) } return bpe_tokens; @@ -187,9 +196,7 @@ impl SimpleTokenizer { pub fn decode(&self, tokens: &[u32]) -> String { let text: String = tokens.iter().map(|t| self.decoder[t].as_str()).collect(); - let decoded_bytes: Vec = text.chars() - .map(|c| self.byte_decoder[&c]) - .collect(); + let decoded_bytes: Vec = text.chars().map(|c| self.byte_decoder[&c]).collect(); String::from_utf8_lossy(&decoded_bytes[..]).replace("", " ") } @@ -212,4 +219,4 @@ mod tests { let decoded = tokenizer.decode(&encoded[..]); assert_eq!(target_decode, decoded); } -} \ No newline at end of file +}