Replace helper functions with native burn functions
This commit is contained in:
@@ -1,26 +1,27 @@
|
||||
use std::env;
|
||||
use std::process;
|
||||
use std::error::Error;
|
||||
use std::process;
|
||||
|
||||
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
|
||||
use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
|
||||
fn convert_dump_to_model<B: Backend>(
|
||||
dump_path: &str,
|
||||
model_name: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
println!("Loading dump...");
|
||||
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
|
||||
let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
|
||||
|
||||
println!("Saving model...");
|
||||
save_model_file(model, model_name)?;
|
||||
@@ -28,12 +29,11 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.record(
|
||||
model.into_record(),
|
||||
name.into(),
|
||||
)
|
||||
fn save_model_file<B: Backend>(
|
||||
model: StableDiffusion<B>,
|
||||
name: &str,
|
||||
) -> Result<(), record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
||||
Reference in New Issue
Block a user