mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Add wgpu option
This commit is contained in:
@@ -14,7 +14,13 @@ use burn::{
|
||||
},
|
||||
};
|
||||
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
}
|
||||
}
|
||||
|
||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
||||
@@ -38,8 +44,15 @@ fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cpu;
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cpu;
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::CPU;
|
||||
}
|
||||
}
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
if args.len() != 3 {
|
||||
|
||||
@@ -9,7 +9,14 @@ use burn::{
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||
}
|
||||
}
|
||||
|
||||
use std::env;
|
||||
use std::io;
|
||||
@@ -25,9 +32,15 @@ fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<Stable
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Backend = TchBackend<f32>;
|
||||
//let device = TchDevice::Cpu;
|
||||
let device = TchDevice::Cuda(0);
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "torch-backend")] {
|
||||
type Backend = TchBackend<f32>;
|
||||
let device = TchDevice::Cuda(0);
|
||||
} else if #[cfg(feature = "wgpu-backend")] {
|
||||
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
}
|
||||
}
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() != 7 {
|
||||
|
||||
Reference in New Issue
Block a user