Add wgpu option

This commit is contained in:
Gadersd
2023-08-06 13:28:16 -04:00
parent 8a76c234e7
commit daba33ebf9
4 changed files with 49 additions and 10 deletions

View File

@@ -9,7 +9,14 @@ use burn::{
Tensor,
},
};
use burn_tch::{TchBackend, TchDevice};
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
use burn_tch::{TchBackend, TchDevice};
} else if #[cfg(feature = "wgpu-backend")] {
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
}
}
use std::env;
use std::io;
@@ -25,9 +32,15 @@ fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<Stable
}
fn main() {
type Backend = TchBackend<f32>;
//let device = TchDevice::Cpu;
let device = TchDevice::Cuda(0);
cfg_if::cfg_if! {
if #[cfg(feature = "torch-backend")] {
type Backend = TchBackend<f32>;
let device = TchDevice::Cuda(0);
} else if #[cfg(feature = "wgpu-backend")] {
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
}
}
let args: Vec<String> = std::env::args().collect();
if args.len() != 7 {