Add ability to load dump or burn model in sample binary

This commit is contained in:
Gadersd
2023-08-05 13:27:37 -04:00
committed by Ben_Kosytorz
parent bcf6b01d3a
commit 6f9f4f1f61

View File

@@ -1,4 +1,4 @@
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*}; use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
use burn::{ use burn::{
config::Config, config::Config,
@@ -30,30 +30,39 @@ fn main() {
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() != 6 { if args.len() != 6 {
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]); eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
process::exit(1); process::exit(1);
} }
let model_name = &args[1]; let model_type = &args[1];
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| { let model_name = &args[2];
let unconditional_guidance_scale: f64 = args[3].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid unconditional guidance scale."); eprintln!("Error: Invalid unconditional guidance scale.");
process::exit(1); process::exit(1);
}); });
let n_steps: usize = args[3].parse().unwrap_or_else(|_| { let n_steps: usize = args[4].parse().unwrap_or_else(|_| {
eprintln!("Error: Invalid number of diffusion steps."); eprintln!("Error: Invalid number of diffusion steps.");
process::exit(1); process::exit(1);
}); });
let prompt = &args[4]; let prompt = &args[5];
let output_image_name = &args[5]; let output_image_name = &args[6];
println!("Loading tokenizer..."); println!("Loading tokenizer...");
let tokenizer = SimpleTokenizer::new().unwrap(); let tokenizer = SimpleTokenizer::new().unwrap();
println!("Loading model..."); println!("Loading model...");
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { let sd: StableDiffusion<Backend> = if model_type == "burn" {
eprintln!("Error loading model: {}", err); load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
process::exit(1); eprintln!("Error loading model: {}", err);
}); process::exit(1);
})
} else {
load_stable_diffusion(model_name, &device).unwrap_or_else(|err| {
eprintln!("Error loading model dump: {}", err);
process::exit(1);
})
};
let sd = sd.to_device(&device); let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer); let unconditional_context = sd.unconditional_context(&tokenizer);