//! GPU detection and configuration for multiple backends //! //! This module handles automatic detection of GPUs and provides //! configuration information needed for optimized inference. use regex::Regex; use serde::{Deserialize, Serialize}; use std::process::Command; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GpuConfig { pub name: String, pub architecture: String, pub driver_version: String, pub backend: String, } /// Detect AMD GPU using rocminfo command /// Returns configuration information for the detected GPU pub fn detect_amd_gpu() -> Result> { // Try to run rocminfo command to get GPU info let output = Command::new("rocminfo") .output() .map_err(|e| format!("Failed to execute rocminfo: {}", e))?; if !output.status.success() { return Err(format!("rocminfo failed with status: {}", output.status).into()); } let output_str = String::from_utf8(output.stdout) .map_err(|e| format!("Failed to decode rocminfo output: {}", e))?; // Parse the output to extract relevant GPU information let mut gpu_name = "Unknown AMD GPU".to_string(); let mut architecture = "unknown".to_string(); let mut driver_version = "unknown".to_string(); for line in output_str.lines() { if line.contains("Name:") && line.contains("AMD") { // Extract GPU name if let Some(name) = extract_value(line, "Name:") { gpu_name = name; } } else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) { // Extract architecture if let Some(arch) = extract_gfx_architecture(line) { architecture = arch; } } else if line.contains("Driver Version:") { // Extract driver version if let Some(version) = extract_value(line, "Driver Version:") { driver_version = version; } } } Ok(GpuConfig { name: gpu_name, architecture, driver_version, backend: "rocm".to_string(), }) } /// Attempt to detect Vulkan-compatible GPU pub fn detect_vulkan_gpu() -> Result> { // For now, we'll create a basic Vulkan config - in a real implementation // this would involve checking for Vulkan support using vulkaninfo or similar Ok(GpuConfig { name: "Vulkan GPU".to_string(), architecture: "unknown".to_string(), driver_version: "unknown".to_string(), backend: "vulkan".to_string(), }) } /// Helper function to extract values from key-value lines fn extract_value(line: &str, key: &str) -> Option { let parts: Vec<&str> = line.split(key).collect(); if parts.len() >= 2 { Some(parts[1].trim().to_string()) } else { None } } /// Helper function to extract GFX architecture from ROCm output fn extract_gfx_architecture(line: &str) -> Option { // Look for gfx* patterns in the line let re = Regex::new(r"gfx\d+").unwrap(); if let Some(captures) = re.captures(line) { Some(captures.get(0).unwrap().as_str().to_string()) } else { None } } #[cfg(test)] mod tests { use super::*; #[test] fn test_gpu_config_creation() { let config = GpuConfig { name: "Radeon RX 9070 XT".to_string(), architecture: "gfx900".to_string(), driver_version: "5.4.3".to_string(), backend: "rocm".to_string(), }; assert_eq!(config.name, "Radeon RX 9070 XT"); assert_eq!(config.architecture, "gfx900"); assert_eq!(config.driver_version, "5.4.3"); assert_eq!(config.backend, "rocm"); } }