Use wgpu by default and ndarray for convert

This commit is contained in:
Gadersd
2023-08-08 15:32:21 -04:00
committed by Ben_Kosytorz
parent 0101e8f930
commit d4afd71fda
5 changed files with 20 additions and 26 deletions

View File

@@ -14,13 +14,7 @@ use burn::{
},
};
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
use burn_tch::{TchBackend, TchDevice};
} else if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
}
}
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
@@ -43,15 +37,8 @@ fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<
}
fn main() {
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
} else if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::CPU;
}
}
type Backend = NdArrayBackend<f32>;
let device = NdArrayDevice::Cpu;
let args: Vec<String> = env::args().collect();
if args.len() != 3 {

View File

@@ -74,11 +74,11 @@ fn main() {
process::exit(1);
})
};
let sd = sd.to_device(&device);
let unconditional_context = sd.unconditional_context(&tokenizer);
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples
println!("Sampling image...");
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);

View File

@@ -59,6 +59,11 @@ impl<B: Backend> StableDiffusion<B> {
let [n_batch, _, _] = context.dims();
let latent = self.sample_latent(context, unconditional_context, unconditional_guidance_scale, n_steps);
self.latent_to_image(latent)
}
pub fn latent_to_image(&self, latent: Tensor<B, 4>) -> Vec<Vec<u8>> {
let [n_batch, _, _, _] = latent.dims();
let image = self.autoencoder.decode_latent(latent * (1.0 / 0.18215));
let n_channel = 3;
@@ -157,7 +162,7 @@ impl<B: Backend> StableDiffusion<B> {
}
pub fn context(&self, tokenizer: &SimpleTokenizer, text: &str) -> Tensor<B, 3> {
let device = &self.devices()[0];
let device = &self.clip.devices()[0];
let text = format!("<|startoftext|>{}<|endoftext|>", text);
let tokenized: Vec<_> = tokenizer.encode(&text).into_iter().map(|v| v as i32).collect();