Add model loading, saving, and conversion functionality
This commit is contained in:
58
src/bin/convert/main.rs
Normal file
58
src/bin/convert/main.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::env;
|
||||
use std::process;
|
||||
use std::error::Error;
|
||||
|
||||
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
|
||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||
println!("Loading dump...");
|
||||
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
|
||||
|
||||
println!("Saving model...");
|
||||
save_model_file(model, model_name)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.record(
|
||||
model.into_record(),
|
||||
name.into(),
|
||||
)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cpu;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
if args.len() != 3 {
|
||||
eprintln!("Usage: {} <dump_path> <model_name>", args[0]);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
let dump_path = &args[1];
|
||||
let model_name = &args[2];
|
||||
|
||||
if let Err(e) = convert_dump_to_model::<Backend>(dump_path, model_name, &device) {
|
||||
eprintln!("Failed to convert dump to model: {:?}", e);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
println!("Successfully converted {} to {}", dump_path, model_name);
|
||||
}
|
||||
93
src/bin/sample/main.rs
Normal file
93
src/bin/sample/main.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
|
||||
use std::env;
|
||||
use std::io;
|
||||
use std::process;
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
|
||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = TchBackend<f32>;
|
||||
//let device = TchDevice::Cpu;
|
||||
let device = TchDevice::Cuda(0);
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() != 6 {
|
||||
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
let model_name = &args[1];
|
||||
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| {
|
||||
eprintln!("Error: Invalid unconditional guidance scale.");
|
||||
process::exit(1);
|
||||
});
|
||||
let n_steps: usize = args[3].parse().unwrap_or_else(|_| {
|
||||
eprintln!("Error: Invalid number of diffusion steps.");
|
||||
process::exit(1);
|
||||
});
|
||||
let prompt = &args[4];
|
||||
let output_image_name = &args[5];
|
||||
|
||||
|
||||
println!("Loading tokenizer...");
|
||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||
println!("Loading model...");
|
||||
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||
eprintln!("Error loading model: {}", err);
|
||||
process::exit(1);
|
||||
});
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze();
|
||||
|
||||
println!("Sampling image...");
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
||||
eprintln!("Error saving image: {}", err);
|
||||
process::exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
use image::{self, ImageResult, ColorType::Rgb8};
|
||||
|
||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||
for (index, img_data) in images.iter().enumerate() {
|
||||
let path = format!("{}{}.png", basepath, index);
|
||||
image::save_buffer(path, &img_data[..], width, height, Rgb8)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// save red test image
|
||||
fn save_test_image() -> ImageResult<()> {
|
||||
let width = 256;
|
||||
let height = 256;
|
||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
|
||||
[red, 0, 0]
|
||||
}).collect();
|
||||
|
||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||
}
|
||||
Reference in New Issue
Block a user