mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Add wgpu option
This commit is contained in:
15
Cargo.toml
15
Cargo.toml
@@ -5,9 +5,21 @@ edition = "2021"
|
|||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["torch-backend"]
|
||||||
|
torch-backend = ["burn-tch"]
|
||||||
|
wgpu-backend = ["burn-wgpu"]
|
||||||
|
|
||||||
|
[dependencies.burn-tch]
|
||||||
|
version = "0.8.0"
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[dependencies.burn-wgpu]
|
||||||
|
version = "0.8.0"
|
||||||
|
optional = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = "0.8.0"
|
burn = "0.8.0"
|
||||||
burn-tch = "0.8.0"
|
|
||||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||||
npy = "0.4.0"
|
npy = "0.4.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
@@ -15,3 +27,4 @@ rust_tokenizers = "8.1.0"
|
|||||||
regex = "1.9.1"
|
regex = "1.9.1"
|
||||||
image = "0.24.6"
|
image = "0.24.6"
|
||||||
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
|
bincode = {version = "2.0.0-alpha.0", features = ["std"]}
|
||||||
|
cfg-if = "0.1"
|
||||||
@@ -20,7 +20,7 @@ Start by downloading the SDv1-4.bin model provided on HuggingFace.
|
|||||||
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
|
wget https://huggingface.co/Gadersd/Stable-Diffusion-Burn/resolve/main/V1/SDv1-4.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, set the appropriate CUDA version.
|
Next, set the appropriate CUDA version. It may be possible to run the model using wgpu without the need for torch in the future using `cargo run --features wgpu-backend...` but currently wgpu doesn't support buffer sizes large enough for Stable Diffusion.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export TORCH_CUDA_VERSION=cu113
|
export TORCH_CUDA_VERSION=cu113
|
||||||
|
|||||||
@@ -14,7 +14,13 @@ use burn::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "torch-backend")] {
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
|
} else if #[cfg(feature = "wgpu-backend")] {
|
||||||
|
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
use burn::record::{self, Recorder, FullPrecisionSettings};
|
use burn::record::{self, Recorder, FullPrecisionSettings};
|
||||||
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
use stablediffusion::binrecorderfast::{BinFileRecorderBuffered};
|
||||||
@@ -38,8 +44,15 @@ fn save_model_file<B: Backend>(model: StableDiffusion<B>, name: &str) -> Result<
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "torch-backend")] {
|
||||||
type Backend = TchBackend<f32>;
|
type Backend = TchBackend<f32>;
|
||||||
let device = TchDevice::Cpu;
|
let device = TchDevice::Cpu;
|
||||||
|
} else if #[cfg(feature = "wgpu-backend")] {
|
||||||
|
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||||
|
let device = WgpuDevice::CPU;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let args: Vec<String> = env::args().collect();
|
let args: Vec<String> = env::args().collect();
|
||||||
if args.len() != 3 {
|
if args.len() != 3 {
|
||||||
|
|||||||
@@ -9,7 +9,14 @@ use burn::{
|
|||||||
Tensor,
|
Tensor,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "torch-backend")] {
|
||||||
use burn_tch::{TchBackend, TchDevice};
|
use burn_tch::{TchBackend, TchDevice};
|
||||||
|
} else if #[cfg(feature = "wgpu-backend")] {
|
||||||
|
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::io;
|
use std::io;
|
||||||
@@ -25,9 +32,15 @@ fn load_stable_diffusion_model_file<B: Backend>(filename: &str) -> Result<Stable
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "torch-backend")] {
|
||||||
type Backend = TchBackend<f32>;
|
type Backend = TchBackend<f32>;
|
||||||
//let device = TchDevice::Cpu;
|
|
||||||
let device = TchDevice::Cuda(0);
|
let device = TchDevice::Cuda(0);
|
||||||
|
} else if #[cfg(feature = "wgpu-backend")] {
|
||||||
|
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||||
|
let device = WgpuDevice::BestAvailable;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() != 7 {
|
if args.len() != 7 {
|
||||||
|
|||||||
Reference in New Issue
Block a user