Add wgpu option

This commit is contained in:
Gadersd
2023-08-06 13:28:16 -04:00
committed by Ben_Kosytorz
parent 1373918c2e
commit c596eab0e3
4 changed files with 49 additions and 10 deletions

View File

@@ -14,7 +14,13 @@ use burn::{
},
};
use burn_tch::{TchBackend, TchDevice};
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
use burn_tch::{TchBackend, TchDevice};
} else if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
}
}
use burn::record::{self, Recorder, FullPrecisionSettings};
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
@@ -38,8 +44,15 @@ fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<
}
fn main() {
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
} else if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::CPU;
}
}
let args: Vec<String> = env::args().collect();
if args.len() != 3 {