use std::env; use std::process; use std::error::Error; use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion}; use burn::{ config::Config, module::{Module, Param}, nn, tensor::{ backend::Backend, Tensor, }, }; cfg_if::cfg_if! { if #[cfg(feature = "torch-backend")] { use burn_tch::{TchBackend, TchDevice}; } else if #[cfg(feature = "wgpu-backend")] { use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi}; } } use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings}; fn convert_dump_to_model(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box> { println!("Loading dump..."); let model: StableDiffusion:: = load_stable_diffusion(dump_path, device)?; println!("Saving model..."); save_model_file(model, model_name)?; Ok(()) } fn save_model_file(model: StableDiffusion, name: &str) -> Result<(), record::RecorderError> { BinFileRecorder::::new() .record( model.into_record(), name.into(), ) } fn main() { cfg_if::cfg_if! { if #[cfg(feature = "torch-backend")] { type Backend = TchBackend; let device = TchDevice::Cpu; } else if #[cfg(feature = "wgpu-backend")] { type Backend = WgpuBackend; let device = WgpuDevice::CPU; } } let args: Vec = env::args().collect(); if args.len() != 3 { eprintln!("Usage: {} ", args[0]); process::exit(1); } let dump_path = &args[1]; let model_name = &args[2]; if let Err(e) = convert_dump_to_model::(dump_path, model_name, &device) { eprintln!("Failed to convert dump to model: {:?}", e); process::exit(1); } println!("Successfully converted {} to {}", dump_path, model_name); }