From c774f80f1370190d0c91d8b98d25fa97c8d3ea06 Mon Sep 17 00:00:00 2001 From: Gadersd Date: Sat, 5 Aug 2023 10:14:03 -0400 Subject: [PATCH] Add model loading, saving, and conversion functionality --- src/bin/convert/main.rs | 58 ++++++++++++++++ src/bin/sample/main.rs | 93 +++++++++++++++++++++++++ src/main.rs | 111 ------------------------------ src/model/stablediffusion/load.rs | 4 +- src/model/stablediffusion/mod.rs | 20 +++--- 5 files changed, 160 insertions(+), 126 deletions(-) create mode 100644 src/bin/convert/main.rs create mode 100644 src/bin/sample/main.rs delete mode 100644 src/main.rs diff --git a/src/bin/convert/main.rs b/src/bin/convert/main.rs new file mode 100644 index 0000000..924c9d4 --- /dev/null +++ b/src/bin/convert/main.rs @@ -0,0 +1,58 @@ +use std::env; +use std::process; +use std::error::Error; + +use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion}; + +use burn::{ + config::Config, + module::{Module, Param}, + nn, + tensor::{ + backend::Backend, + Tensor, + }, +}; + +use burn_tch::{TchBackend, TchDevice}; + +use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; + +fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { + println!("Loading dump..."); + let model: StableDiffusion:: = load_stable_diffusion(dump_path, device)?; + + println!("Saving model..."); + save_model_file(model, model_name)?; + + Ok(()) +} + +fn save_model_file(model: StableDiffusion, name: &str) -> Result<(), record::RecorderError> { + BinFileRecorder::::new() + .record( + model.into_record(), + name.into(), + ) +} + +fn main() { + type Backend = TchBackend; + let device = TchDevice::Cpu; + + let args: Vec = env::args().collect(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + process::exit(1); + } + + let dump_path = &args[1]; + let model_name = &args[2]; + + if let Err(e) = convert_dump_to_model::(dump_path, model_name, &device) { + eprintln!("Failed to convert dump to model: {:?}", e); + process::exit(1); + } + + println!("Successfully converted {} to {}", dump_path, model_name); +} diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs new file mode 100644 index 0000000..69f9326 --- /dev/null +++ b/src/bin/sample/main.rs @@ -0,0 +1,93 @@ +use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*}; + +use burn::{ + config::Config, + module::{Module, Param}, + nn, + tensor::{ + backend::Backend, + Tensor, + }, +}; +use burn_tch::{TchBackend, TchDevice}; + +use std::env; +use std::io; +use std::process; + +use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; + +fn load_stable_diffusion_model_file(filename: &str) -> Result, record::RecorderError> { + BinFileRecorder::::new() + .load(filename.into()) + .map(|record| StableDiffusionConfig::new().init().load_record(record)) +} + +fn main() { + type Backend = TchBackend; + //let device = TchDevice::Cpu; + let device = TchDevice::Cuda(0); + + let args: Vec = std::env::args().collect(); + if args.len() != 6 { + eprintln!("Usage: {} ", args[0]); + process::exit(1); + } + + let model_name = &args[1]; + let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| { + eprintln!("Error: Invalid unconditional guidance scale."); + process::exit(1); + }); + let n_steps: usize = args[3].parse().unwrap_or_else(|_| { + eprintln!("Error: Invalid number of diffusion steps."); + process::exit(1); + }); + let prompt = &args[4]; + let output_image_name = &args[5]; + + + println!("Loading tokenizer..."); + let tokenizer = SimpleTokenizer::new().unwrap(); + println!("Loading model..."); + let sd: StableDiffusion = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { + eprintln!("Error loading model: {}", err); + process::exit(1); + }); + let sd = sd.to_device(&device); + + let unconditional_context = sd.unconditional_context(&tokenizer); + let context = sd.context(&tokenizer, prompt).unsqueeze(); + + println!("Sampling image..."); + let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); + save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| { + eprintln!("Error saving image: {}", err); + process::exit(1); + }); +} + +use image::{self, ImageResult, ColorType::Rgb8}; + +fn save_images(images: &Vec>, basepath: &str, width: u32, height: u32) -> ImageResult<()> { + for (index, img_data) in images.iter().enumerate() { + let path = format!("{}{}.png", basepath, index); + image::save_buffer(path, &img_data[..], width, height, Rgb8)?; + } + + Ok(()) +} + +// save red test image +fn save_test_image() -> ImageResult<()> { + let width = 256; + let height = 256; + let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| { + let row = i / width; + let red = (255.0 * row as f64 / height as f64) as u8; + + [red, 0, 0] + }).collect(); + + image::save_buffer("red.png", &raw[..], width, height, Rgb8) +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 98020e6..0000000 --- a/src/main.rs +++ /dev/null @@ -1,111 +0,0 @@ -use stablediffusion::{tokenizer::SimpleTokenizer, model::clip::{*, load::*}, -model::autoencoder::{*, load::*}, -model::groupnorm::*, -model::unet::{*, load::*}, -model::stablediffusion::{*, load::*}}; - -use burn::{ - config::Config, - module::{Module, Param}, - nn, - tensor::{ - backend::Backend, - Tensor, - }, -}; -use burn_tch::{TchBackend, TchDevice}; - -fn print_tensor(x: Tensor) { - let data = x/*.slice([0..1, 0..4, 0..10])*/.into_data(); - println!("{:?}", data); -} - -use stablediffusion::helper::to_float; - -fn main() { - type Backend = TchBackend; - //let device = TchDevice::Cpu; - let device = TchDevice::Cuda(0); - - /*let norm: nn::LayerNorm = nn::LayerNormConfig::new(3).init(); - let tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape([2, 3]); - - let out = norm.forward(tensor); - - println!("{:?}", out.into_data()); - - return;*/ - - /*let n_channel = 6; - let norm: nn::LayerNorm = nn::LayerNormConfig::new(10).init(); - let height = 10; - let width = 10; - let n_elements = height * width * n_channel; - let t: Tensor = to_float(Tensor::arange(0..n_elements)).mul_scalar(10.0 / n_elements as f64).sin().reshape([1, n_channel, height, width]); - let out = layernorm(t, 1e-5); //norm.forward(t); - println!("{:?}", out.to_data()); - return;*/ - - /*let clip: CLIP = load_clip("params", &device).unwrap(); - let input = Tensor::from_ints([3, 1]); - - let output = clip.forward(input.unsqueeze()); - print_tensor(output);*/ - - /*let autoencoder: Autoencoder = load_autoencoder("params", &device).unwrap(); - let input = Tensor::zeros([1, 3, 10, 10]); - let output = autoencoder.forward(input); - print_tensor(output);*/ - - /*let unet: UNet = load_unet("params", &device).unwrap(); - let input = Tensor::zeros([1, 4, 64, 64]); - let context = Tensor::from_floats([0.5, 1.3]).repeat(0, 768 / 2).unsqueeze(); - let timesteps = Tensor::from_floats([1.0]); - - let output = unet.forward(input, timesteps, context);*/ - //print_tensor(output); - - println!("Loading tokenizer..."); - let tokenizer = SimpleTokenizer::new().unwrap(); - - println!("Loading Stable Diffusion..."); - let sd: StableDiffusion = load_stable_diffusion("params", &device).unwrap(); - let sd = sd.to_device(&device); - - let unconditional_guidance_scale = 7.5; - let unconditional_context = sd.unconditional_context(&tokenizer); - let context = sd.context(&tokenizer, "A wine glass filled with pink flower petals.").unsqueeze(); - - let n_steps = 100; - - println!("Sampling images..."); - let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps); - - println!("Saving images..."); - save_images(&images, "image_samples/", 512, 512).unwrap(); -} - -use image::{self, ImageResult, ColorType::Rgb8}; - -fn save_images(images: &Vec>, basepath: &str, width: u32, height: u32) -> ImageResult<()> { - for (index, img_data) in images.iter().enumerate() { - let path = format!("{}{}.png", basepath, index); - image::save_buffer(path, &img_data[..], width, height, Rgb8)?; - } - - Ok(()) -} - -// save red test image -fn save_test_image() -> ImageResult<()> { - let width = 256; - let height = 256; - let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| { - let row = i / width; - let red = (255.0 * row as f64 / height as f64) as u8; - - [red, 0, 0] - }).collect(); - - image::save_buffer("red.png", &raw[..], width, height, Rgb8) -} \ No newline at end of file diff --git a/src/model/stablediffusion/load.rs b/src/model/stablediffusion/load.rs index 0512c0e..405ca8c 100644 --- a/src/model/stablediffusion/load.rs +++ b/src/model/stablediffusion/load.rs @@ -16,9 +16,7 @@ use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::loa pub fn load_stable_diffusion(path: &str, device: &B::Device) -> Result, Box> { let n_steps = load_usize::("n_steps", path, device)?; - let alpha_cumulative_products: Vec<_> = load_tensor::("alphas_cumprod", path, device)?.into_data().value.into_iter() - .map(|v: >::Elem| v.to_f64().unwrap()) - .collect(); + let alpha_cumulative_products = load_tensor::("alphas_cumprod", path, device)?.into(); let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?; let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?; let clip = load_clip(&format!("{}/{}", path, "clip"), device)?; diff --git a/src/model/stablediffusion/mod.rs b/src/model/stablediffusion/mod.rs index aa28d05..c3958bc 100644 --- a/src/model/stablediffusion/mod.rs +++ b/src/model/stablediffusion/mod.rs @@ -2,7 +2,7 @@ pub mod load; use burn::{ config::Config, - module::Module, + module::{Module, Param}, tensor::{ backend::Backend, Tensor, @@ -27,12 +27,9 @@ pub struct StableDiffusionConfig { } impl StableDiffusionConfig { - fn init(&self) -> StableDiffusion { + pub fn init(&self) -> StableDiffusion { let n_steps = 1000; - let alpha_cumulative_products = offset_cosine_schedule_cumprod::(n_steps) - .into_data().value - .into_iter() - .map(|v: >::Elem| v.to_f64().unwrap()).collect(); + let alpha_cumulative_products = offset_cosine_schedule_cumprod::(n_steps).into(); let autoencoder = AutoencoderConfig::new().init(); let diffusion = UNetConfig::new().init(); @@ -51,7 +48,7 @@ impl StableDiffusionConfig { #[derive(Module, Debug)] pub struct StableDiffusion { n_steps: usize, - alpha_cumulative_products: Vec, + alpha_cumulative_products: Param>, autoencoder: Autoencoder, diffusion: UNet, clip: CLIP, @@ -90,8 +87,6 @@ impl StableDiffusion { } pub fn sample_latent(&self, context: Tensor, unconditional_context: Tensor, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor { - assert!(self.n_steps % n_steps == 0); - let device = context.device(); let step_size = self.n_steps / n_steps; @@ -107,9 +102,10 @@ impl StableDiffusion { let mut latent = gen_noise(); for t in (0..self.n_steps).rev().step_by(step_size) { - let current_alpha = self.alpha_cumulative_products[t]; - let prev_alpha = if t >= step_size { - self.alpha_cumulative_products[t - step_size] + let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap(); + let prev_alpha: f64 = if t >= step_size { + let i = t - step_size; + self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap() } else { 1.0 };