Use wgpu by default and ndarray for convert
This commit is contained in:
@@ -14,13 +14,7 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
}
|
||||
}
|
||||
use burn_ndarray::{NdArrayBackend, NdArrayDevice};
|
||||
|
||||
use burn::record::{self, Recorder, BinFileRecorder, FullPrecisionSettings};
|
||||
|
||||
@@ -43,15 +37,8 @@ fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<
|
||||
}
|
||||
|
||||
fn main() {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cpu;
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::CPU;
|
||||
}
|
||||
}
|
||||
type Backend = NdArrayBackend<f32>;
|
||||
let device = NdArrayDevice::Cpu;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
if args.len() != 3 {
|
||||
|
||||
@@ -74,11 +74,11 @@ fn main() {
|
||||
process::exit(1);
|
||||
})
|
||||
};
|
||||
|
||||
|
||||
let sd = sd.to_device(&device);
|
||||
|
||||
let unconditional_context = sd.unconditional_context(&tokenizer);
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze().repeat(0, 2); // generate 2 samples
|
||||
let context = sd.context(&tokenizer, prompt).unsqueeze::<3>();//.repeat(0, 2); // generate 2 samples
|
||||
|
||||
println!("Sampling image...");
|
||||
let images = sd.sample_image(context, unconditional_context, unconditional_guidance_scale, n_steps);
|
||||
|
||||
Reference in New Issue
Block a user