Add model loading, saving, and conversion functionality
This commit is contained in:
58
src/bin/convert/main.rs
Normal file
58
src/bin/convert/main.rs
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::process;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
config::Config,
|
||||||
|
module::{Module, Param},
|
||||||
|
nn,
|
||||||
|
tensor::{
|
||||||
|
backend::Backend,
|
||||||
|
Tensor,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
|
|
||||||
|
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||||
|
|
||||||
|
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||||
|
println!("Loading dump...");
|
||||||
|
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
|
||||||
|
|
||||||
|
println!("Saving model...");
|
||||||
|
save_model_file(model, model_name)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||||
|
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||||
|
.record(
|
||||||
|
model.into_record(),
|
||||||
|
name.into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
type Backend = TchBackend<f32>;
|
||||||
|
let device = TchDevice::Cpu;
|
||||||
|
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
if args.len() != 3 {
|
||||||
|
eprintln!("Usage: {} <dump_path> <model_name>", args[0]);
|
||||||
|
process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dump_path = &args[1];
|
||||||
|
let model_name = &args[2];
|
||||||
|
|
||||||
|
if let Err(e) = convert_dump_to_model::<Backend>(dump_path, model_name, &device) {
|
||||||
|
eprintln!("Failed to convert dump to model: {:?}", e);
|
||||||
|
process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Successfully converted {} to {}", dump_path, model_name);
|
||||||
|
}
|
||||||
93
src/bin/sample/main.rs
Normal file
93
src/bin/sample/main.rs
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*};
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
config::Config,
|
||||||
|
module::{Module, Param},
|
||||||
|
nn,
|
||||||
|
tensor::{
|
||||||
|
backend::Backend,
|
||||||
|
Tensor,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
|
|
||||||
|
use std::env;
|
||||||
|
use std::io;
|
||||||
|
use std::process;
|
||||||
|
|
||||||
|
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||||
|
|
||||||
|
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||||
|
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||||
|
.load(filename.into())
|
||||||
|
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
type Backend = TchBackend<f32>;
|
||||||
|
//let device = TchDevice::Cpu;
|
||||||
|
let device = TchDevice::Cuda(0);
|
||||||
|
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
if args.len() != 6 {
|
||||||
|
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||||
|
process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_name = &args[1];
|
||||||
|
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| {
|
||||||
|
eprintln!("Error: Invalid unconditional guidance scale.");
|
||||||
|
process::exit(1);
|
||||||
|
});
|
||||||
|
let n_steps: usize = args[3].parse().unwrap_or_else(|_| {
|
||||||
|
eprintln!("Error: Invalid number of diffusion steps.");
|
||||||
|
process::exit(1);
|
||||||
|
});
|
||||||
|
let prompt = &args[4];
|
||||||
|
let output_image_name = &args[5];
|
||||||
|
|
||||||
|
|
||||||
|
println!("Loading tokenizer...");
|
||||||
|
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||||
|
println!("Loading model...");
|
||||||
|
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||||
|
eprintln!("Error loading model: {}", err);
|
||||||
|
process::exit(1);
|
||||||
|
});
|
||||||
|
let sd = sd.to_device(&device);
|
||||||
|
|
||||||
|
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||||
|
let context = sd.context(&tokenizer, prompt).unsqueeze();
|
||||||
|
|
||||||
|
println!("Sampling image...");
|
||||||
|
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||||
|
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
||||||
|
eprintln!("Error saving image: {}", err);
|
||||||
|
process::exit(1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
use image::{self, ImageResult, ColorType::Rgb8};
|
||||||
|
|
||||||
|
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||||
|
for (index, img_data) in images.iter().enumerate() {
|
||||||
|
let path = format!("{}{}.png", basepath, index);
|
||||||
|
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// save red test image
|
||||||
|
fn save_test_image() -> ImageResult<()> {
|
||||||
|
let width = 256;
|
||||||
|
let height = 256;
|
||||||
|
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
||||||
|
let row = i / width;
|
||||||
|
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||||
|
|
||||||
|
[red, 0, 0]
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||||
|
}
|
||||||
111
src/main.rs
111
src/main.rs
@@ -1,111 +0,0 @@
|
|||||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::clip::{*, load::*},
|
|
||||||
model::autoencoder::{*, load::*},
|
|
||||||
model::groupnorm::*,
|
|
||||||
model::unet::{*, load::*},
|
|
||||||
model::stablediffusion::{*, load::*}};
|
|
||||||
|
|
||||||
use burn::{
|
|
||||||
config::Config,
|
|
||||||
module::{Module, Param},
|
|
||||||
nn,
|
|
||||||
tensor::{
|
|
||||||
backend::Backend,
|
|
||||||
Tensor,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
|
||||||
|
|
||||||
fn print_tensor<B: Backend>(x: Tensor<B, 4>) {
|
|
||||||
let data = x/*.slice([0..1, 0..4, 0..10])*/.into_data();
|
|
||||||
println!("{:?}", data);
|
|
||||||
}
|
|
||||||
|
|
||||||
use stablediffusion::helper::to_float;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
type Backend = TchBackend<f32>;
|
|
||||||
//let device = TchDevice::Cpu;
|
|
||||||
let device = TchDevice::Cuda(0);
|
|
||||||
|
|
||||||
/*let norm: nn::LayerNorm<Backend> = nn::LayerNormConfig::new(3).init();
|
|
||||||
let tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape([2, 3]);
|
|
||||||
|
|
||||||
let out = norm.forward(tensor);
|
|
||||||
|
|
||||||
println!("{:?}", out.into_data());
|
|
||||||
|
|
||||||
return;*/
|
|
||||||
|
|
||||||
/*let n_channel = 6;
|
|
||||||
let norm: nn::LayerNorm<Backend> = nn::LayerNormConfig::new(10).init();
|
|
||||||
let height = 10;
|
|
||||||
let width = 10;
|
|
||||||
let n_elements = height * width * n_channel;
|
|
||||||
let t: Tensor<Backend, 4> = to_float(Tensor::arange(0..n_elements)).mul_scalar(10.0 / n_elements as f64).sin().reshape([1, n_channel, height, width]);
|
|
||||||
let out = layernorm(t, 1e-5); //norm.forward(t);
|
|
||||||
println!("{:?}", out.to_data());
|
|
||||||
return;*/
|
|
||||||
|
|
||||||
/*let clip: CLIP<Backend> = load_clip("params", &device).unwrap();
|
|
||||||
let input = Tensor::from_ints([3, 1]);
|
|
||||||
|
|
||||||
let output = clip.forward(input.unsqueeze());
|
|
||||||
print_tensor(output);*/
|
|
||||||
|
|
||||||
/*let autoencoder: Autoencoder<Backend> = load_autoencoder("params", &device).unwrap();
|
|
||||||
let input = Tensor::zeros([1, 3, 10, 10]);
|
|
||||||
let output = autoencoder.forward(input);
|
|
||||||
print_tensor(output);*/
|
|
||||||
|
|
||||||
/*let unet: UNet<Backend> = load_unet("params", &device).unwrap();
|
|
||||||
let input = Tensor::zeros([1, 4, 64, 64]);
|
|
||||||
let context = Tensor::from_floats([0.5, 1.3]).repeat(0, 768 / 2).unsqueeze();
|
|
||||||
let timesteps = Tensor::from_floats([1.0]);
|
|
||||||
|
|
||||||
let output = unet.forward(input, timesteps, context);*/
|
|
||||||
//print_tensor(output);
|
|
||||||
|
|
||||||
println!("Loading tokenizer...");
|
|
||||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
|
||||||
|
|
||||||
println!("Loading Stable Diffusion...");
|
|
||||||
let sd: StableDiffusion<Backend> = load_stable_diffusion("params", &device).unwrap();
|
|
||||||
let sd = sd.to_device(&device);
|
|
||||||
|
|
||||||
let unconditional_guidance_scale = 7.5;
|
|
||||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
|
||||||
let context = sd.context(&tokenizer, "A wine glass filled with pink flower petals.").unsqueeze();
|
|
||||||
|
|
||||||
let n_steps = 100;
|
|
||||||
|
|
||||||
println!("Sampling images...");
|
|
||||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
|
||||||
|
|
||||||
println!("Saving images...");
|
|
||||||
save_images(&images, "image_samples/", 512, 512).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
use image::{self, ImageResult, ColorType::Rgb8};
|
|
||||||
|
|
||||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
|
||||||
for (index, img_data) in images.iter().enumerate() {
|
|
||||||
let path = format!("{}{}.png", basepath, index);
|
|
||||||
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// save red test image
|
|
||||||
fn save_test_image() -> ImageResult<()> {
|
|
||||||
let width = 256;
|
|
||||||
let height = 256;
|
|
||||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
|
||||||
let row = i / width;
|
|
||||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
|
||||||
|
|
||||||
[red, 0, 0]
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
|
||||||
}
|
|
||||||
@@ -16,9 +16,7 @@ use crate::model::{load::*, autoencoder::load::load_autoencoder, unet::load::loa
|
|||||||
|
|
||||||
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
pub fn load_stable_diffusion<B: Backend>(path: &str, device: &B::Device) -> Result<StableDiffusion<B>, Box<dyn Error>> {
|
||||||
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
let n_steps = load_usize::<B>("n_steps", path, device)?;
|
||||||
let alpha_cumulative_products: Vec<_> = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into_data().value.into_iter()
|
let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
|
||||||
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap())
|
|
||||||
.collect();
|
|
||||||
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
|
||||||
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
|
||||||
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ pub mod load;
|
|||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
module::Module,
|
module::{Module, Param},
|
||||||
tensor::{
|
tensor::{
|
||||||
backend::Backend,
|
backend::Backend,
|
||||||
Tensor,
|
Tensor,
|
||||||
@@ -27,12 +27,9 @@ pub struct StableDiffusionConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl StableDiffusionConfig {
|
impl StableDiffusionConfig {
|
||||||
fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
pub fn init<B: Backend>(&self) -> StableDiffusion<B> {
|
||||||
let n_steps = 1000;
|
let n_steps = 1000;
|
||||||
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps)
|
let alpha_cumulative_products = offset_cosine_schedule_cumprod::<B>(n_steps).into();
|
||||||
.into_data().value
|
|
||||||
.into_iter()
|
|
||||||
.map(|v: <Float as BasicOps<B>>::Elem| v.to_f64().unwrap()).collect();
|
|
||||||
|
|
||||||
let autoencoder = AutoencoderConfig::new().init();
|
let autoencoder = AutoencoderConfig::new().init();
|
||||||
let diffusion = UNetConfig::new().init();
|
let diffusion = UNetConfig::new().init();
|
||||||
@@ -51,7 +48,7 @@ impl StableDiffusionConfig {
|
|||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct StableDiffusion<B: Backend> {
|
pub struct StableDiffusion<B: Backend> {
|
||||||
n_steps: usize,
|
n_steps: usize,
|
||||||
alpha_cumulative_products: Vec<f64>,
|
alpha_cumulative_products: Param<Tensor<B, 1>>,
|
||||||
autoencoder: Autoencoder<B>,
|
autoencoder: Autoencoder<B>,
|
||||||
diffusion: UNet<B>,
|
diffusion: UNet<B>,
|
||||||
clip: CLIP<B>,
|
clip: CLIP<B>,
|
||||||
@@ -90,8 +87,6 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
|
pub fn sample_latent(&self, context: Tensor<B, 3>, unconditional_context: Tensor<B, 2>, unconditional_guidance_scale: f64, n_steps: usize) -> Tensor<B, 4> {
|
||||||
assert!(self.n_steps % n_steps == 0);
|
|
||||||
|
|
||||||
let device = context.device();
|
let device = context.device();
|
||||||
|
|
||||||
let step_size = self.n_steps / n_steps;
|
let step_size = self.n_steps / n_steps;
|
||||||
@@ -107,9 +102,10 @@ impl<B: Backend> StableDiffusion<B> {
|
|||||||
let mut latent = gen_noise();
|
let mut latent = gen_noise();
|
||||||
|
|
||||||
for t in (0..self.n_steps).rev().step_by(step_size) {
|
for t in (0..self.n_steps).rev().step_by(step_size) {
|
||||||
let current_alpha = self.alpha_cumulative_products[t];
|
let current_alpha: f64 = self.alpha_cumulative_products.val().slice([t..t + 1]).into_scalar().to_f64().unwrap();
|
||||||
let prev_alpha = if t >= step_size {
|
let prev_alpha: f64 = if t >= step_size {
|
||||||
self.alpha_cumulative_products[t - step_size]
|
let i = t - step_size;
|
||||||
|
self.alpha_cumulative_products.val().slice([i..i + 1]).into_scalar().to_f64().unwrap()
|
||||||
} else {
|
} else {
|
||||||
1.0
|
1.0
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user