Add model loading, saving, and conversion functionality

This commit is contained in:
Gadersd
2023-08-05 10:14:03 -04:00
parent 41fce2a47e
commit 9e247777fa
5 changed files with 160 additions and 126 deletions

58
src/bin/convert/main.rs Normal file
View File

@@ -0,0 +1,58 @@
use std::env;
use std::process;
use std::error::Error;
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use burn_tch::{TchBackend, TchDevice};
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
println!("Loading dump...");
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model...");
save_model_file(model, model_name)?;
Ok(())
}
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new()
.record(
model.into_record(),
name.into(),
)
}
fn main() {
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <dump_path> <model_name>", args[0]);
process::exit(1);
}
let dump_path = &args[1];
let model_name = &args[2];
if let Err(e) = convert_dump_to_model::<Backend>(dump_path, model_name, &device) {
eprintln!("Failed to convert dump to model: {:?}", e);
process::exit(1);
}
println!("Successfully converted {} to {}", dump_path, model_name);
}

93
src/bin/sample/main.rs Normal file
View File

@@ -0,0 +1,93 @@
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*};
use burn::{
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
};
use burn_tch::{TchBackend, TchDevice};
use std::env;
use std::io;
use std::process;
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new()
.load(filename.into())
.map(|record| StableDiffusionConfig::new().init().load_record(record))
}
fn main() {
type Backend = TchBackend<f32>;
//let device = TchDevice::Cpu;
let device = TchDevice::Cuda(0);
let args: Vec<String> = std::env::args().collect();
if args.len() != 6 {
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
process::exit(1);
}
let model_name = &args[1];
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid unconditional guidance scale.");
process::exit(1);
});
let n_steps: usize = args[3].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid number of diffusion steps.");
process::exit(1);
});
let prompt = &args[4];
let output_image_name = &args[5];
println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model...");
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
eprintln!("Error loading model: {}", err);
process::exit(1);
});
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze();
println!("Sampling image...");
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
eprintln!("Error saving image: {}", err);
process::exit(1);
});
}
use image::{self, ImageResult, ColorType::Rgb8};
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
for (index, img_data) in images.iter().enumerate() {
let path = format!("{}{}.png", basepath, index);
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
}
Ok(())
}
// save red test image
fn save_test_image() -> ImageResult<()> {
let width = 256;
let height = 256;
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
let row = i / width;
let red = (255.0 * row as f64 / height as f64) as u8;
[red, 0, 0]
}).collect();
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
}