This commit is contained in:
2026-03-02 23:06:24 +01:00
parent 27a13e3084
commit c1d1fc94ba
18 changed files with 1616 additions and 2 deletions

100
backend/src/rocminfo.rs Normal file
View File

@@ -0,0 +1,100 @@
//! ROCm GPU detection and configuration for AMD GPUs
//!
//! This module handles automatic detection of AMD GPUs and provides
//! configuration information needed for optimized inference on RX 9070 XT.
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::process::Command;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GpuConfig {
pub name: String,
pub architecture: String,
pub driver_version: String,
}
/// Detect AMD GPU using rocminfo command
/// Returns configuration information for the detected GPU
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
// Try to run rocminfo command to get GPU info
let output = Command::new("rocminfo")
.output()
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
if !output.status.success() {
return Err(format!("rocminfo failed with status: {}", output.status).into());
}
let output_str = String::from_utf8(output.stdout)
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
// Parse the output to extract relevant GPU information
let mut gpu_name = "Unknown AMD GPU".to_string();
let mut architecture = "unknown".to_string();
let mut driver_version = "unknown".to_string();
for line in output_str.lines() {
if line.contains("Name:") && line.contains("AMD") {
// Extract GPU name
if let Some(name) = extract_value(line, "Name:") {
gpu_name = name;
}
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
// Extract architecture
if let Some(arch) = extract_gfx_architecture(line) {
architecture = arch;
}
} else if line.contains("Driver Version:") {
// Extract driver version
if let Some(version) = extract_value(line, "Driver Version:") {
driver_version = version;
}
}
}
Ok(GpuConfig {
name: gpu_name,
architecture,
driver_version,
})
}
/// Helper function to extract values from key-value lines
fn extract_value(line: &str, key: &str) -> Option<String> {
let parts: Vec<&str> = line.split(key).collect();
if parts.len() >= 2 {
Some(parts[1].trim().to_string())
} else {
None
}
}
/// Helper function to extract GFX architecture from ROCm output
fn extract_gfx_architecture(line: &str) -> Option<String> {
// Look for gfx* patterns in the line
let re = Regex::new(r"gfx\d+").unwrap();
if let Some(captures) = re.captures(line) {
Some(captures.get(0).unwrap().as_str().to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_creation() {
let config = GpuConfig {
name: "Radeon RX 9070 XT".to_string(),
architecture: "gfx900".to_string(),
driver_version: "5.4.3".to_string(),
};
assert_eq!(config.name, "Radeon RX 9070 XT");
assert_eq!(config.architecture, "gfx900");
assert_eq!(config.driver_version, "5.4.3");
}
}