mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Replace helper functions with native burn functions
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
use stablediffusion::{tokenizer::SimpleTokenizer, model::stablediffusion::{*, load::load_stable_diffusion}};
|
||||
use stablediffusion::{
|
||||
model::stablediffusion::{load::load_stable_diffusion, *},
|
||||
tokenizer::SimpleTokenizer,
|
||||
};
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
@@ -22,12 +22,14 @@ use std::env;
|
||||
use std::io;
|
||||
use std::process;
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
use burn::record::{self, BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||
|
||||
fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
fn load_stable_diffusion_model_file<B: Backend>(
|
||||
filename: &str,
|
||||
) -> Result<StableDiffusion<B>, record::RecorderError> {
|
||||
BinFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
.load(filename.into())
|
||||
.map(|record| StableDiffusionConfig::new().init().load_record(record))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
@@ -78,17 +80,22 @@ fn main() {
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples
|
||||
|
||||
println!("Sampling image...");
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
let images = sd.sample_image(
|
||||
context,
|
||||
unconditional_context,
|
||||
unconditional_guidance_scale,
|
||||
n_steps,
|
||||
);
|
||||
save_images(&images, output_image_name, 512, 512).unwrap_or_else(|err| {
|
||||
eprintln!("Error saving image: {}", err);
|
||||
process::exit(1);
|
||||
});
|
||||
}
|
||||
|
||||
use image::{self, ImageResult, ColorType::Rgb8};
|
||||
use image::{self, ColorType::Rgb8, ImageResult};
|
||||
|
||||
fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -> ImageResult<()> {
|
||||
for (index, img_data) in images.iter().enumerate() {
|
||||
@@ -103,12 +110,15 @@ fn save_images(images: &Vec<Vec<u8>>, basepath: &str, width: u32, height: u32) -
|
||||
fn save_test_image() -> ImageResult<()> {
|
||||
let width = 256;
|
||||
let height = 256;
|
||||
let raw: Vec<_> = (0..width * height).into_iter().flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
let raw: Vec<_> = (0..width * height)
|
||||
.into_iter()
|
||||
.flat_map(|i| {
|
||||
let row = i / width;
|
||||
let red = (255.0 * row as f64 / height as f64) as u8;
|
||||
|
||||
[red, 0, 0]
|
||||
}).collect();
|
||||
[red, 0, 0]
|
||||
})
|
||||
.collect();
|
||||
|
||||
image::save_buffer("red.png", &raw[..], width, height, Rgb8)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user