This commit is contained in:
2026-03-02 23:06:24 +01:00
parent 27a13e3084
commit c1d1fc94ba
18 changed files with 1616 additions and 2 deletions

20
backend/Cargo.toml Normal file
View File

@@ -0,0 +1,20 @@
[package]
name = "comfyui-backend"
version = "0.1.0"
edition = "2021"
[dependencies]
actix-web = { version = "4", features = ["openssl"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["full"] }
rayon = "1.5"
uuid = { version = "1.0", features = ["v4"] }
log = "0.4"
env_logger = "0.10"
thiserror = "1.0"
regex = "1.0"
chrono = { version = "0.4", features = ["serde"] }
[dev-dependencies]
tokio-test = "0.4"

168
backend/src/api/mod.rs Normal file
View File

@@ -0,0 +1,168 @@
//! REST API endpoints for ComfyUI-like backend
//!
//! This module defines all HTTP endpoints for the AI generation system,
//! including model management, inference requests, and task status monitoring.
use actix_web::{web, HttpResponse, Result, Scope};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::{
AppState,
models::ModelInfo,
queue_service::{Task, TaskStatus},
};
/// Request payload for image generation
#[derive(Debug, Deserialize)]
pub struct InferenceRequest {
pub prompt: String,
pub negative_prompt: Option<String>,
pub guidance_scale: Option<f32>,
pub steps: Option<u32>,
pub width: Option<u32>,
pub height: Option<u32>,
pub seed: Option<u64>,
pub model_name: Option<String>,
}
/// Response payload for inference results
#[derive(Debug, Serialize)]
pub struct InferenceResponse {
pub task_id: String,
pub status: String,
}
/// Response payload for model information
#[derive(Debug, Serialize)]
pub struct ModelInfoResponse {
pub name: String,
pub path: String,
pub model_type: String,
pub version: String,
pub loaded: bool,
}
/// Health check endpoint
pub async fn health_check() -> Result<HttpResponse> {
Ok(HttpResponse::Ok().json(serde_json::json!({
"status": "healthy",
"service": "comfyui-backend"
})))
}
/// Get system information including GPU details
pub async fn get_system_info(state: web::Data<AppState>) -> Result<HttpResponse> {
let gpu_config = &state.gpu_backend_config;
Ok(HttpResponse::Ok().json(serde_json::json!({
"gpu": {
"name": gpu_config.name,
"architecture": gpu_config.architecture,
"driver_version": gpu_config.driver_version
},
"service": "comfyui-backend"
})))
}
/// Start a new inference task
pub async fn start_inference(
req: web::Json<InferenceRequest>,
state: web::Data<AppState>
) -> Result<HttpResponse> {
// In a real implementation, this would create an actual inference task
// For now we'll simulate it
let task_id = uuid::Uuid::new_v4().to_string();
// Create the task structure
let task = Task {
id: task_id.clone(),
name: "Image Generation".to_string(),
status: TaskStatus::Pending,
progress: 0.0,
created_at: chrono::Utc::now().timestamp() as u64,
updated_at: chrono::Utc::now().timestamp() as u64,
model_name: req.model_name.clone(),
parameters: serde_json::json!({
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"guidance_scale": req.guidance_scale.unwrap_or(7.0),
"steps": req.steps.unwrap_or(20),
"width": req.width.unwrap_or(512),
"height": req.height.unwrap_or(512),
"seed": req.seed,
}),
};
// Add task to queue
{
let mut queue = state.task_queue.lock().await;
queue.add_task(task).await;
}
Ok(HttpResponse::Ok().json(InferenceResponse {
task_id,
status: "pending".to_string(),
}))
}
/// Get information about all available models
pub async fn get_models(state: web::Data<AppState>) -> Result<HttpResponse> {
let model_manager = state.model_manager.lock().await;
let models = model_manager.get_model_info();
let response_models: Vec<ModelInfoResponse> = models.into_iter()
.map(|model| ModelInfoResponse {
name: model.name,
path: model.path,
model_type: model.model_type.to_string(),
version: model.version,
loaded: model.loaded,
})
.collect();
Ok(HttpResponse::Ok().json(response_models))
}
/// Get status of a specific task
pub async fn get_task_status(
task_id: web::Path<String>,
state: web::Data<AppState>
) -> Result<HttpResponse> {
let queue = state.task_queue.lock().await;
if let Some(task) = queue.get_task(&task_id).await {
Ok(HttpResponse::Ok().json(serde_json::json!({
"id": task.id,
"name": task.name,
"status": task.status.to_string(),
"progress": task.progress,
"created_at": task.created_at,
"updated_at": task.updated_at,
"model_name": task.model_name,
})))
} else {
Ok(HttpResponse::NotFound().json(serde_json::json!({
"error": "Task not found"
})))
}
}
/// Get all tasks in the queue
pub async fn get_all_tasks(state: web::Data<AppState>) -> Result<HttpResponse> {
let queue = state.task_queue.lock().await;
let tasks = queue.get_all_tasks().await;
Ok(HttpResponse::Ok().json(tasks))
}
/// Configuration for API routes
pub fn config(cfg: &mut Scope) {
cfg.route("/health", web::get().to(health_check))
.route("/system-info", web::get().to(get_system_info))
.route("/models", web::get().to(get_models))
.route("/infer", web::post().to(start_inference))
.route("/tasks/{task_id}", web::get().to(get_task_status))
.route("/tasks", web::get().to(get_all_tasks));
}

83
backend/src/main.rs Normal file
View File

@@ -0,0 +1,83 @@
//! ComfyUI-like Backend for AI Image/Video Generation with ROCm Support
//!
//! This server provides a REST API and WebSocket endpoints for
//! managing image/video generation workflows on AMD GPUs.
use actix_web::{web, App, HttpServer, Result, middleware::Logger};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
pub mod api;
pub mod models;
pub mod queue_service;
pub mod rocminfo;
pub mod session_manager;
#[derive(Debug, Clone)]
pub struct AppState {
pub gpu_backend_config: rocminfo::GpuConfig,
pub model_manager: Arc<Mutex<models::ModelManager>>,
pub task_queue: Arc<Mutex<queue_service::TaskQueue>>,
}
/// Configuration for the application
#[derive(Debug, Deserialize, Serialize)]
pub struct Config {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_gpu_backend")]
pub gpu_backend: String,
#[serde(default)]
pub rocminfo_path: Option<String>,
}
fn default_port() -> u16 {
8080
}
fn default_gpu_backend() -> String {
"rocm".to_string()
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
env_logger::init();
// Initialize GPU configuration
let gpu_config = rocminfo::detect_amd_gpu().unwrap_or_else(|e| {
eprintln!("Warning: Failed to detect AMD GPU: {}", e);
rocminfo::GpuConfig {
name: "Unknown_AMD_GPU".to_string(),
architecture: "unknown".to_string(),
driver_version: "unknown".to_string(),
}
});
// Initialize model manager
let model_manager = Arc::new(Mutex::new(models::ModelManager::new()));
// Initialize task queue
let task_queue = Arc::new(Mutex::new(queue_service::TaskQueue::new()));
let app_state = AppState {
gpu_backend_config: gpu_config,
model_manager,
task_queue,
};
println!("Starting ComfyUI backend server...");
println!("GPU Backend: {}", app_state.gpu_backend_config.name);
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(app_state.clone()))
.wrap(Logger::default())
.configure(api::config)
})
.bind("127.0.0.1:8080")?
.run()
.await
}

95
backend/src/models/mod.rs Normal file
View File

@@ -0,0 +1,95 @@
//! Model management for AI inference workflows
//!
//! This module handles loading, caching, and managing different types of
//! machine learning models needed for image/video generation.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelInfo {
pub name: String,
pub path: String,
pub model_type: ModelType,
pub version: String,
pub loaded: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum ModelType {
StableDiffusion,
ControlNet,
VAE,
CLIP,
Other(String),
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::StableDiffusion => write!(f, "stable_diffusion"),
ModelType::ControlNet => write!(f, "control_net"),
ModelType::VAE => write!(f, "vae"),
ModelType::CLIP => write!(f, "clip"),
ModelType::Other(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug)]
pub struct ModelManager {
models: HashMap<String, ModelInfo>,
loaded_models: Vec<String>, // Track which models are currently loaded
}
impl ModelManager {
/// Create a new model manager instance
pub fn new() -> Self {
Self {
models: HashMap::new(),
loaded_models: Vec::new(),
}
}
/// Add a model to the manager
pub fn add_model(&mut self, model_info: ModelInfo) {
self.models.insert(model_info.name.clone(), model_info);
}
/// Load a model into memory (placeholder implementation)
pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
if let Some(model) = self.models.get_mut(model_name) {
// In a real implementation, this would actually load the model
// For now we'll just mark it as loaded
model.loaded = true;
self.loaded_models.push(model_name.to_string());
Ok(())
} else {
Err(format!("Model '{}' not found", model_name).into())
}
}
/// Unload a model from memory (placeholder implementation)
pub async fn unload_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
if let Some(model) = self.models.get_mut(model_name) {
// In a real implementation, this would actually unload the model
model.loaded = false;
self.loaded_models.retain(|name| name != model_name);
Ok(())
} else {
Err(format!("Model '{}' not found", model_name).into())
}
}
/// Get information about all available models
pub fn get_model_info(&self) -> Vec<ModelInfo> {
self.models.values().cloned().collect()
}
/// Check if a specific model is loaded
pub fn is_model_loaded(&self, model_name: &str) -> bool {
self.loaded_models.contains(&model_name.to_string())
}
}

View File

@@ -0,0 +1,144 @@
//! Task queue service for managing concurrent AI inference tasks
//!
//! This module provides a thread-safe task queue system using Tokio
//! and Rayon for parallel processing on AMD GPUs.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
use uuid::Uuid;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Task {
pub id: String,
pub name: String,
pub status: TaskStatus,
pub progress: f32,
pub created_at: u64,
pub updated_at: u64,
pub model_name: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum TaskStatus {
Pending,
Processing,
Completed,
Failed,
Cancelled,
}
impl std::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskStatus::Pending => write!(f, "pending"),
TaskStatus::Processing => write!(f, "processing"),
TaskStatus::Completed => write!(f, "completed"),
TaskStatus::Failed => write!(f, "failed"),
TaskStatus::Cancelled => write!(f, "cancelled"),
}
}
}
#[derive(Debug)]
pub struct TaskQueue {
tasks: HashMap<String, Task>,
notify: Arc<Notify>,
}
impl TaskQueue {
/// Create a new task queue instance
pub fn new() -> Self {
Self {
tasks: HashMap::new(),
notify: Arc::new(Notify::new()),
}
}
/// Add a new task to the queue
pub async fn add_task(&mut self, task: Task) -> String {
let task_id = task.id.clone();
self.tasks.insert(task_id.clone(), task);
self.notify.notify_waiters();
task_id
}
/// Get a specific task by ID
pub async fn get_task(&self, id: &str) -> Option<Task> {
self.tasks.get(id).cloned()
}
/// Update the status of a task
pub async fn update_task_status(&mut self, id: &str, status: TaskStatus) -> Result<(), Box<dyn std::error::Error>> {
if let Some(task) = self.tasks.get_mut(id) {
task.status = status;
task.updated_at = chrono::Utc::now().timestamp() as u64;
Ok(())
} else {
Err(format!("Task with ID '{}' not found", id).into())
}
}
/// Update the progress of a task
pub async fn update_task_progress(&mut self, id: &str, progress: f32) -> Result<(), Box<dyn std::error::Error>> {
if let Some(task) = self.tasks.get_mut(id) {
task.progress = progress;
task.updated_at = chrono::Utc::now().timestamp() as u64;
Ok(())
} else {
Err(format!("Task with ID '{}' not found", id).into())
}
}
/// Get all tasks in the queue
pub async fn get_all_tasks(&self) -> Vec<Task> {
self.tasks.values().cloned().collect()
}
/// Remove a completed or failed task
pub async fn remove_task(&mut self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
if self.tasks.remove(id).is_some() {
Ok(())
} else {
Err(format!("Task with ID '{}' not found", id).into())
}
}
/// Wait for a new task to be added to the queue
pub async fn wait_for_new_task(&self) -> String {
let notify = self.notify.clone();
tokio::task::spawn(async move {
notify.notified().await;
}).await.unwrap(); // This will always succeed
// Return an empty string as placeholder - in a real implementation,
// this would return the task ID of the newly added task
String::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_task_creation() {
let task = Task {
id: "test-123".to_string(),
name: "Test Task".to_string(),
status: TaskStatus::Pending,
progress: 0.0,
created_at: 0,
updated_at: 0,
model_name: None,
parameters: serde_json::Value::Null,
};
assert_eq!(task.id, "test-123");
assert_eq!(task.name, "Test Task");
assert_eq!(task.status, TaskStatus::Pending);
}
}

100
backend/src/rocminfo.rs Normal file
View File

@@ -0,0 +1,100 @@
//! ROCm GPU detection and configuration for AMD GPUs
//!
//! This module handles automatic detection of AMD GPUs and provides
//! configuration information needed for optimized inference on RX 9070 XT.
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::process::Command;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GpuConfig {
pub name: String,
pub architecture: String,
pub driver_version: String,
}
/// Detect AMD GPU using rocminfo command
/// Returns configuration information for the detected GPU
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
// Try to run rocminfo command to get GPU info
let output = Command::new("rocminfo")
.output()
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
if !output.status.success() {
return Err(format!("rocminfo failed with status: {}", output.status).into());
}
let output_str = String::from_utf8(output.stdout)
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
// Parse the output to extract relevant GPU information
let mut gpu_name = "Unknown AMD GPU".to_string();
let mut architecture = "unknown".to_string();
let mut driver_version = "unknown".to_string();
for line in output_str.lines() {
if line.contains("Name:") && line.contains("AMD") {
// Extract GPU name
if let Some(name) = extract_value(line, "Name:") {
gpu_name = name;
}
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
// Extract architecture
if let Some(arch) = extract_gfx_architecture(line) {
architecture = arch;
}
} else if line.contains("Driver Version:") {
// Extract driver version
if let Some(version) = extract_value(line, "Driver Version:") {
driver_version = version;
}
}
}
Ok(GpuConfig {
name: gpu_name,
architecture,
driver_version,
})
}
/// Helper function to extract values from key-value lines
fn extract_value(line: &str, key: &str) -> Option<String> {
let parts: Vec<&str> = line.split(key).collect();
if parts.len() >= 2 {
Some(parts[1].trim().to_string())
} else {
None
}
}
/// Helper function to extract GFX architecture from ROCm output
fn extract_gfx_architecture(line: &str) -> Option<String> {
// Look for gfx* patterns in the line
let re = Regex::new(r"gfx\d+").unwrap();
if let Some(captures) = re.captures(line) {
Some(captures.get(0).unwrap().as_str().to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_creation() {
let config = GpuConfig {
name: "Radeon RX 9070 XT".to_string(),
architecture: "gfx900".to_string(),
driver_version: "5.4.3".to_string(),
};
assert_eq!(config.name, "Radeon RX 9070 XT");
assert_eq!(config.architecture, "gfx900");
assert_eq!(config.driver_version, "5.4.3");
}
}