Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
committed by Ben_Kosytorz
parent 167e45fc30
commit 32a3ad9b3c
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,26 +1,27 @@
use std::env;
use std::process;
use std::error::Error;
use std::process;
use stablediffusion::model::stablediffusion::{StableDiffusion, load::load_stable_diffusion};
use stablediffusion::model::stablediffusion::{load::load_stable_diffusion, StableDiffusion};
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
tensor::{backend::Backend, Tensor},
};
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device: &B::Device) -> Result<(), Box<dyn Error>> {
fn convert_dump_to_model<B: Backend>(
dump_path: &str,
model_name: &str,
device: &B::Device,
) -> Result<(), Box<dyn Error>> {
println!("Loading dump...");
let model: StableDiffusion::<B> = load_stable_diffusion(dump_path, device)?;
let model: StableDiffusion<B> = load_stable_diffusion(dump_path, device)?;
println!("Saving model...");
save_model_file(model, model_name)?;
@@ -28,12 +29,11 @@ fn convert_dump_to_model<B: Backend>(dump_path: &str, model_name: &str, device:
Ok(())
}
fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new()
.record(
model.into_record(),
name.into(),
)
fn save_model_file<B: Backend>(
model: StableDiffusion<B>,
name: &str,
) -> Result<(), record::RecorderError> {
BinFileRecorder::<FullPrecisionSettings>::new().record(model.into_record(), name.into())
}
fn main() {