Add ability to load dump or burn model in sample binary
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*};
|
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
|
||||||
|
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config,
|
config::Config,
|
||||||
@@ -30,30 +30,39 @@ fn main() {
|
|||||||
|
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() != 6 {
|
if args.len() != 6 {
|
||||||
eprintln!("Usage: {} <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
eprintln!("Usage: {} <model_type(burn or dump)> <model_name> <unconditional_guidance_scale> <n_diffusion_steps> <prompt> <output_image_name>", args[0]);
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model_name = &args[1];
|
let model_type = &args[1];
|
||||||
let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| {
|
let model_name = &args[2];
|
||||||
|
let unconditional_guidance_scale: f64 = args[3].parse().unwrap_or_else(|_| {
|
||||||
eprintln!("Error: Invalid unconditional guidance scale.");
|
eprintln!("Error: Invalid unconditional guidance scale.");
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
});
|
});
|
||||||
let n_steps: usize = args[3].parse().unwrap_or_else(|_| {
|
let n_steps: usize = args[4].parse().unwrap_or_else(|_| {
|
||||||
eprintln!("Error: Invalid number of diffusion steps.");
|
eprintln!("Error: Invalid number of diffusion steps.");
|
||||||
process::exit(1);
|
process::exit(1);
|
||||||
});
|
});
|
||||||
let prompt = &args[4];
|
let prompt = &args[5];
|
||||||
let output_image_name = &args[5];
|
let output_image_name = &args[6];
|
||||||
|
|
||||||
|
|
||||||
println!("Loading tokenizer...");
|
println!("Loading tokenizer...");
|
||||||
let tokenizer = SimpleTokenizer::new().unwrap();
|
let tokenizer = SimpleTokenizer::new().unwrap();
|
||||||
println!("Loading model...");
|
println!("Loading model...");
|
||||||
let sd: StableDiffusion<Backend> = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
let sd: StableDiffusion<Backend> = if model_type == "burn" {
|
||||||
eprintln!("Error loading model: {}", err);
|
load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| {
|
||||||
process::exit(1);
|
eprintln!("Error loading model: {}", err);
|
||||||
});
|
process::exit(1);
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
load_stable_diffusion(model_name, &device).unwrap_or_else(|err| {
|
||||||
|
eprintln!("Error loading model dump: {}", err);
|
||||||
|
process::exit(1);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
let sd = sd.to_device(&device);
|
let sd = sd.to_device(&device);
|
||||||
|
|
||||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||||
|
|||||||
Reference in New Issue
Block a user