Files
ComfyUI-Rust/backend/src/rocminfo.rs
Ben_Kosytorz bd198bb8e9 feat: update to burn 0.14.0 with vulkan support and refactor frontend to egui
- Updated burn framework dependencies from 0.21.0-pre.2 to 0.14.0
- Added optional vulkan backend support with burn-wgpu feature
- Replaced React/TypeScript frontend with native Rust egui frontend
- Added Vulkan GPU support documentation and setup instructions
- Updated README with new architecture and backend configuration
- Refactored GPU backend detection and reporting to include backend type
- Added vulkan-backend feature flag for conditional compilation
- Updated dependency installation instructions for Vulkan support
2026-03-03 22:04:45 +01:00

117 lines
3.8 KiB
Rust

//! GPU detection and configuration for multiple backends
//!
//! This module handles automatic detection of GPUs and provides
//! configuration information needed for optimized inference.
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::process::Command;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GpuConfig {
pub name: String,
pub architecture: String,
pub driver_version: String,
pub backend: String,
}
/// Detect AMD GPU using rocminfo command
/// Returns configuration information for the detected GPU
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
// Try to run rocminfo command to get GPU info
let output = Command::new("rocminfo")
.output()
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
if !output.status.success() {
return Err(format!("rocminfo failed with status: {}", output.status).into());
}
let output_str = String::from_utf8(output.stdout)
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
// Parse the output to extract relevant GPU information
let mut gpu_name = "Unknown AMD GPU".to_string();
let mut architecture = "unknown".to_string();
let mut driver_version = "unknown".to_string();
for line in output_str.lines() {
if line.contains("Name:") && line.contains("AMD") {
// Extract GPU name
if let Some(name) = extract_value(line, "Name:") {
gpu_name = name;
}
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
// Extract architecture
if let Some(arch) = extract_gfx_architecture(line) {
architecture = arch;
}
} else if line.contains("Driver Version:") {
// Extract driver version
if let Some(version) = extract_value(line, "Driver Version:") {
driver_version = version;
}
}
}
Ok(GpuConfig {
name: gpu_name,
architecture,
driver_version,
backend: "rocm".to_string(),
})
}
/// Attempt to detect Vulkan-compatible GPU
pub fn detect_vulkan_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
// For now, we'll create a basic Vulkan config - in a real implementation
// this would involve checking for Vulkan support using vulkaninfo or similar
Ok(GpuConfig {
name: "Vulkan GPU".to_string(),
architecture: "unknown".to_string(),
driver_version: "unknown".to_string(),
backend: "vulkan".to_string(),
})
}
/// Helper function to extract values from key-value lines
fn extract_value(line: &str, key: &str) -> Option<String> {
let parts: Vec<&str> = line.split(key).collect();
if parts.len() >= 2 {
Some(parts[1].trim().to_string())
} else {
None
}
}
/// Helper function to extract GFX architecture from ROCm output
fn extract_gfx_architecture(line: &str) -> Option<String> {
// Look for gfx* patterns in the line
let re = Regex::new(r"gfx\d+").unwrap();
if let Some(captures) = re.captures(line) {
Some(captures.get(0).unwrap().as_str().to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_creation() {
let config = GpuConfig {
name: "Radeon RX 9070 XT".to_string(),
architecture: "gfx900".to_string(),
driver_version: "5.4.3".to_string(),
backend: "rocm".to_string(),
};
assert_eq!(config.name, "Radeon RX 9070 XT");
assert_eq!(config.architecture, "gfx900");
assert_eq!(config.driver_version, "5.4.3");
assert_eq!(config.backend, "rocm");
}
}