- Updated burn framework dependencies from 0.21.0-pre.2 to 0.14.0 - Added optional vulkan backend support with burn-wgpu feature - Replaced React/TypeScript frontend with native Rust egui frontend - Added Vulkan GPU support documentation and setup instructions - Updated README with new architecture and backend configuration - Refactored GPU backend detection and reporting to include backend type - Added vulkan-backend feature flag for conditional compilation - Updated dependency installation instructions for Vulkan support
117 lines
3.8 KiB
Rust
117 lines
3.8 KiB
Rust
//! GPU detection and configuration for multiple backends
|
|
//!
|
|
//! This module handles automatic detection of GPUs and provides
|
|
//! configuration information needed for optimized inference.
|
|
|
|
use regex::Regex;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::process::Command;
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct GpuConfig {
|
|
pub name: String,
|
|
pub architecture: String,
|
|
pub driver_version: String,
|
|
pub backend: String,
|
|
}
|
|
|
|
/// Detect AMD GPU using rocminfo command
|
|
/// Returns configuration information for the detected GPU
|
|
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
|
|
// Try to run rocminfo command to get GPU info
|
|
let output = Command::new("rocminfo")
|
|
.output()
|
|
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
|
|
|
|
if !output.status.success() {
|
|
return Err(format!("rocminfo failed with status: {}", output.status).into());
|
|
}
|
|
|
|
let output_str = String::from_utf8(output.stdout)
|
|
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
|
|
|
|
// Parse the output to extract relevant GPU information
|
|
let mut gpu_name = "Unknown AMD GPU".to_string();
|
|
let mut architecture = "unknown".to_string();
|
|
let mut driver_version = "unknown".to_string();
|
|
|
|
for line in output_str.lines() {
|
|
if line.contains("Name:") && line.contains("AMD") {
|
|
// Extract GPU name
|
|
if let Some(name) = extract_value(line, "Name:") {
|
|
gpu_name = name;
|
|
}
|
|
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
|
|
// Extract architecture
|
|
if let Some(arch) = extract_gfx_architecture(line) {
|
|
architecture = arch;
|
|
}
|
|
} else if line.contains("Driver Version:") {
|
|
// Extract driver version
|
|
if let Some(version) = extract_value(line, "Driver Version:") {
|
|
driver_version = version;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(GpuConfig {
|
|
name: gpu_name,
|
|
architecture,
|
|
driver_version,
|
|
backend: "rocm".to_string(),
|
|
})
|
|
}
|
|
|
|
/// Attempt to detect Vulkan-compatible GPU
|
|
pub fn detect_vulkan_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
|
|
// For now, we'll create a basic Vulkan config - in a real implementation
|
|
// this would involve checking for Vulkan support using vulkaninfo or similar
|
|
Ok(GpuConfig {
|
|
name: "Vulkan GPU".to_string(),
|
|
architecture: "unknown".to_string(),
|
|
driver_version: "unknown".to_string(),
|
|
backend: "vulkan".to_string(),
|
|
})
|
|
}
|
|
|
|
/// Helper function to extract values from key-value lines
|
|
fn extract_value(line: &str, key: &str) -> Option<String> {
|
|
let parts: Vec<&str> = line.split(key).collect();
|
|
if parts.len() >= 2 {
|
|
Some(parts[1].trim().to_string())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Helper function to extract GFX architecture from ROCm output
|
|
fn extract_gfx_architecture(line: &str) -> Option<String> {
|
|
// Look for gfx* patterns in the line
|
|
let re = Regex::new(r"gfx\d+").unwrap();
|
|
if let Some(captures) = re.captures(line) {
|
|
Some(captures.get(0).unwrap().as_str().to_string())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_gpu_config_creation() {
|
|
let config = GpuConfig {
|
|
name: "Radeon RX 9070 XT".to_string(),
|
|
architecture: "gfx900".to_string(),
|
|
driver_version: "5.4.3".to_string(),
|
|
backend: "rocm".to_string(),
|
|
};
|
|
|
|
assert_eq!(config.name, "Radeon RX 9070 XT");
|
|
assert_eq!(config.architecture, "gfx900");
|
|
assert_eq!(config.driver_version, "5.4.3");
|
|
assert_eq!(config.backend, "rocm");
|
|
}
|
|
}
|