diff --git a/src/bin/sample/main.rs b/src/bin/sample/main.rs index 69f9326..14b4169 100644 --- a/src/bin/sample/main.rs +++ b/src/bin/sample/main.rs @@ -1,4 +1,4 @@ -use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::*}; +use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}}; use burn::{ config::Config, @@ -30,30 +30,39 @@ fn main() { let args: Vec = std::env::args().collect(); if args.len() != 6 { - eprintln!("Usage: {} ", args[0]); + eprintln!("Usage: {} ", args[0]); process::exit(1); } - let model_name = &args[1]; - let unconditional_guidance_scale: f64 = args[2].parse().unwrap_or_else(|_| { + let model_type = &args[1]; + let model_name = &args[2]; + let unconditional_guidance_scale: f64 = args[3].parse().unwrap_or_else(|_| { eprintln!("Error: Invalid unconditional guidance scale."); process::exit(1); }); - let n_steps: usize = args[3].parse().unwrap_or_else(|_| { + let n_steps: usize = args[4].parse().unwrap_or_else(|_| { eprintln!("Error: Invalid number of diffusion steps."); process::exit(1); }); - let prompt = &args[4]; - let output_image_name = &args[5]; - + let prompt = &args[5]; + let output_image_name = &args[6]; println!("Loading tokenizer..."); let tokenizer = SimpleTokenizer::new().unwrap(); println!("Loading model..."); - let sd: StableDiffusion = load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { - eprintln!("Error loading model: {}", err); - process::exit(1); - }); + let sd: StableDiffusion = if model_type == "burn" { + load_stable_diffusion_model_file(model_name).unwrap_or_else(|err| { + eprintln!("Error loading model: {}", err); + process::exit(1); + }) + } else { + load_stable_diffusion(model_name, &device).unwrap_or_else(|err| { + eprintln!("Error loading model dump: {}", err); + process::exit(1); + }) + }; + + let sd = sd.to_device(&device); let unconditional_context = sd.unconditional_context(&tokenizer);