feat: migrate to pure Rust backend with WGPU/Vulkan support
- Remove libtorch dependency and related features - Switch to ndarray backend as default for CPU execution - Update README to reflect new WGPU/Vulkan backend usage - Simplify device selection to use only CPU backend - Enable WGPU backend via feature flag for GPU acceleration
This commit is contained in:
@@ -14,7 +14,7 @@ cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "wgpu-backend")] {
|
||||
use burn_wgpu::{Wgpu, WgpuDevice};
|
||||
} else {
|
||||
use burn_tch::{LibTorch, LibTorchDevice};
|
||||
use burn_ndarray::NdArrayDevice;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,24 +61,8 @@ fn main() {
|
||||
type Backend = Wgpu;
|
||||
let device = WgpuDevice::BestAvailable;
|
||||
} else {
|
||||
type Backend = LibTorch<f32>;
|
||||
|
||||
let device = if let Some(dev_str) = device_arg {
|
||||
match dev_str.to_lowercase().as_str() {
|
||||
"cpu" => LibTorchDevice::Cpu,
|
||||
"mps" => LibTorchDevice::Mps,
|
||||
s if s.starts_with("cuda") => {
|
||||
let idx = s[4..].parse().unwrap_or(0);
|
||||
LibTorchDevice::Cuda(idx)
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Unknown device: {}", dev_str);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LibTorchDevice::Cuda(0)
|
||||
};
|
||||
type Backend = burn::backend::ndarray::NdArray<f32>;
|
||||
let device = NdArrayDevice::Cpu;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user