vim
This commit is contained in:
100
backend/src/rocminfo.rs
Normal file
100
backend/src/rocminfo.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! ROCm GPU detection and configuration for AMD GPUs
|
||||
//!
|
||||
//! This module handles automatic detection of AMD GPUs and provides
|
||||
//! configuration information needed for optimized inference on RX 9070 XT.
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GpuConfig {
|
||||
pub name: String,
|
||||
pub architecture: String,
|
||||
pub driver_version: String,
|
||||
}
|
||||
|
||||
/// Detect AMD GPU using rocminfo command
|
||||
/// Returns configuration information for the detected GPU
|
||||
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
|
||||
// Try to run rocminfo command to get GPU info
|
||||
let output = Command::new("rocminfo")
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(format!("rocminfo failed with status: {}", output.status).into());
|
||||
}
|
||||
|
||||
let output_str = String::from_utf8(output.stdout)
|
||||
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
|
||||
|
||||
// Parse the output to extract relevant GPU information
|
||||
let mut gpu_name = "Unknown AMD GPU".to_string();
|
||||
let mut architecture = "unknown".to_string();
|
||||
let mut driver_version = "unknown".to_string();
|
||||
|
||||
for line in output_str.lines() {
|
||||
if line.contains("Name:") && line.contains("AMD") {
|
||||
// Extract GPU name
|
||||
if let Some(name) = extract_value(line, "Name:") {
|
||||
gpu_name = name;
|
||||
}
|
||||
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
|
||||
// Extract architecture
|
||||
if let Some(arch) = extract_gfx_architecture(line) {
|
||||
architecture = arch;
|
||||
}
|
||||
} else if line.contains("Driver Version:") {
|
||||
// Extract driver version
|
||||
if let Some(version) = extract_value(line, "Driver Version:") {
|
||||
driver_version = version;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GpuConfig {
|
||||
name: gpu_name,
|
||||
architecture,
|
||||
driver_version,
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to extract values from key-value lines
|
||||
fn extract_value(line: &str, key: &str) -> Option<String> {
|
||||
let parts: Vec<&str> = line.split(key).collect();
|
||||
if parts.len() >= 2 {
|
||||
Some(parts[1].trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to extract GFX architecture from ROCm output
|
||||
fn extract_gfx_architecture(line: &str) -> Option<String> {
|
||||
// Look for gfx* patterns in the line
|
||||
let re = Regex::new(r"gfx\d+").unwrap();
|
||||
if let Some(captures) = re.captures(line) {
|
||||
Some(captures.get(0).unwrap().as_str().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gpu_config_creation() {
|
||||
let config = GpuConfig {
|
||||
name: "Radeon RX 9070 XT".to_string(),
|
||||
architecture: "gfx900".to_string(),
|
||||
driver_version: "5.4.3".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(config.name, "Radeon RX 9070 XT");
|
||||
assert_eq!(config.architecture, "gfx900");
|
||||
assert_eq!(config.driver_version, "5.4.3");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user