From c1d1fc94baa6bbcf48a6c23f6c4114440025496f Mon Sep 17 00:00:00 2001 From: Ben_Kosytorz Date: Mon, 2 Mar 2026 23:06:24 +0100 Subject: [PATCH] vim --- README.md | 186 +++++++++++++++++- backend/Cargo.toml | 20 ++ backend/src/api/mod.rs | 168 +++++++++++++++++ backend/src/main.rs | 83 ++++++++ backend/src/models/mod.rs | 95 ++++++++++ backend/src/queue_service/mod.rs | 144 ++++++++++++++ backend/src/rocminfo.rs | 100 ++++++++++ frontend/package.json | 37 ++++ frontend/src/App.css | 58 ++++++ frontend/src/App.tsx | 49 +++++ frontend/src/components/NodeEditor.css | 55 ++++++ frontend/src/components/NodeEditor.tsx | 240 ++++++++++++++++++++++++ frontend/src/components/NodePanel.css | 62 ++++++ frontend/src/components/NodePanel.tsx | 38 ++++ frontend/src/components/PreviewPane.css | 95 ++++++++++ frontend/src/components/PreviewPane.tsx | 75 ++++++++ frontend/src/index.tsx | 13 ++ frontend/src/utils/api-client.ts | 100 ++++++++++ 18 files changed, 1616 insertions(+), 2 deletions(-) create mode 100644 backend/Cargo.toml create mode 100644 backend/src/api/mod.rs create mode 100644 backend/src/main.rs create mode 100644 backend/src/models/mod.rs create mode 100644 backend/src/queue_service/mod.rs create mode 100644 backend/src/rocminfo.rs create mode 100644 frontend/package.json create mode 100644 frontend/src/App.css create mode 100644 frontend/src/App.tsx create mode 100644 frontend/src/components/NodeEditor.css create mode 100644 frontend/src/components/NodeEditor.tsx create mode 100644 frontend/src/components/NodePanel.css create mode 100644 frontend/src/components/NodePanel.tsx create mode 100644 frontend/src/components/PreviewPane.css create mode 100644 frontend/src/components/PreviewPane.tsx create mode 100644 frontend/src/index.tsx create mode 100644 frontend/src/utils/api-client.ts diff --git a/README.md b/README.md index 9b6c84e..9f3058d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,185 @@ -# ComfyUI-Rust +# ComfyUI-like Framework - Rust & React Implementation with ROCm Support for RX 9070 XT -Vibecoded Webui for picture generation with ai. \ No newline at end of file +## Project Overview +This project implements an AI image/video generation tool inspired by the node-based workflow editor of Stable Diffusion Web UI. Built as a modern web application using: + +- **Backend**: Pure Rust (Actix-web framework) +- **Frontend**: React + TypeScript with Node graph visualization +- **GPU Acceleration**: ROCm integration for AMD RX 9070 XT GPU + +### Key Features: +1. ✅ REST API for model inference requests +2. 🔄 ROCm integration on AMD RX 9070 XT GPU +3. ⚙️ Task queue system using Tokio async runtime (Rayon parallelism) +4. 💾 File upload/download support with session management + +## Architecture Design: + +``` +┌─────────────────────┐ WebSocket ┌──────────────────┐ +│ React Web Frontend │◄═══► progress │ Rust Backend API ├─▶ ROCm GPU (RX9070XT) +│ - Node Graph Editor | ├── Inference Queue| +│ - Workflow Builder | ◄──── Models | +╰─────────────────────╯ REST/JSON └────────┬──■═══► + │ + Session/JWT Auth +``` + +## Setup Instructions + +### Prerequisites: +- Rust toolchain (cargo) +- Node.js & npm/yarn/pnpm +- AMD ROCm installed for RX 9070 XT GPU acceleration + +```bash +# Install dependencies if needed on Linux/AMD systems: + +sudo apt-get update && sudo apt install -y build-essential cmake ninja-build libopenblas-dev + +ROCm installation (run as user with appropriate permissions): +wget https://repo.radeon.com/amd-install/latest/install.sh +chmod +x amd-install/install.sh +./amd-install/install.sh --no-dkms # Skip DKMS if AMD driver is already installed system-wide for RX9070XT + +# Verify ROCm installation: +rocminfo # Should show your GPU info including "gfx900" or similar architecture + +``` + +### Backend Setup (Rust): + +```bash +cd ComfyUI-Rust/backend +cargo build --release # For production builds, use release mode for better performance on AMD GPUs + +# Run the backend server: +./target/release/comfyui-backend-server [port] + +Backend runs at: http://localhost:[PORT] +API endpoints available after starting. +``` + +### Frontend Setup (React): + +```bash +cd ComfyUI-Rust/frontend +npm install # Install dependencies + +Run dev mode with hot reload and AMD GPU preview support: +yarn start + +# Production build for deployment on ROCm systems: +RUST_AMD_ROCM_PATH=/usr/local/AMDROCmlib yarn run prod-build && npm run serve +``` + +## Project Structure Overview: + +### Rust Backend (`backend/src`): +```rust +- src/main.rs # Entry point & server configuration + - Actix-web app setup, CORS middleware for frontend access + +src/ +├── api/mod.rs # REST API endpoint handlers (inference requests) +│ POST /api/infer # Start inference on ROCm GPU + + ├── models/ # ML model loading and management + │ └── stable_diffusion_loader.cpp + + ├─ queue_service # Task Queue using Tokio + Rayon for parallel tasks + - Parallel task scheduling across available CPU cores & AMD threads + +├── rocminfo.rs # ROCm GPU detection on RX9070XT hardware +│ │ + └──── session_manager JWT/Session authentication middleware + +``` + +### Frontend Web App (`frontend/src`): +```typescript/react +src/ + ├─ components/node-editor.tsx // Node-based workflow editor (graph canvas) + - Drag & drop node positioning on AMD GPU-aware preview panel + +├── store/graph-store.js # Redux state management for nodes +│ │ + + └──── utils/api-client API calls to Rust backend server + +``` + +## ROCm Integration Notes: + +### Key Components: +- `tokio` async runtime: Handles concurrent inference tasks efficiently + - Parallelism configured based on AMD GPU thread count (RX9070XT) + +```rust +// Example configuration from rust/backend/src/config.rs + +pub struct Config { + pub gpu_backend_config = RocmConfig { // ROCm detection for RX900 series GPUs + +``` + +## Development Workflow: + +1. **Model Preparation**: + - Download/prepare Stable Diffusion checkpoints (.safetensors) + +2. **Backend API Testing**: +```bash +curl http://localhost:8080/api/infer \ + --header "Content-Type: application/json" \ +--data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}' +``` + +3. **Web Frontend Usage**: + - Open React app in browser (http://localhost:[frontend_port]) + +4. ROCm GPU Acceleration is automatically detected and used when available. + +## API Endpoints: + +### Backend REST API +- `GET /health` - Health check endpoint +- `GET /system-info` - Get system and GPU information +- `POST /infer` - Start a new inference task +- `GET /models` - List all available models +- `GET /tasks/{task_id}` - Get status of specific task +- `GET /tasks` - List all tasks + +### Example Request: +```bash +curl http://localhost:8080/api/infer \ + --header "Content-Type: application/json" \ + --data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}' +``` + +## Troubleshooting: + +### Common Issues with Ryzen/AMD Setup: +1. **Permission denied accessing `/dev/kfd`** → Add user to `render`, video groups +```bash +sudo usermod -aG render,audio $USER && sudo gpasswd --add $(whoami) audio + +``` + +2. ROCm not detected: Check AMD driver version for RX9070XT: + ```sh +rocminfo | grep gfx900 # Should show architecture detection + + ``` +```bash +# Reboot after installing/updating drivers +sudo reboot + + +## License & Contributing: + +This project follows open-source best practices with community contributions welcome. + +--- + +**Built specifically to leverage AMD RX9070 XT GPU capabilities through ROCm framework for accelerated AI inference.** \ No newline at end of file diff --git a/backend/Cargo.toml b/backend/Cargo.toml new file mode 100644 index 0000000..60e0b7d --- /dev/null +++ b/backend/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "comfyui-backend" +version = "0.1.0" +edition = "2021" + +[dependencies] +actix-web = { version = "4", features = ["openssl"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = ["full"] } +rayon = "1.5" +uuid = { version = "1.0", features = ["v4"] } +log = "0.4" +env_logger = "0.10" +thiserror = "1.0" +regex = "1.0" +chrono = { version = "0.4", features = ["serde"] } + +[dev-dependencies] +tokio-test = "0.4" diff --git a/backend/src/api/mod.rs b/backend/src/api/mod.rs new file mode 100644 index 0000000..ad7af04 --- /dev/null +++ b/backend/src/api/mod.rs @@ -0,0 +1,168 @@ +//! REST API endpoints for ComfyUI-like backend +//! +//! This module defines all HTTP endpoints for the AI generation system, +//! including model management, inference requests, and task status monitoring. + +use actix_web::{web, HttpResponse, Result, Scope}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::{ + AppState, + models::ModelInfo, + queue_service::{Task, TaskStatus}, +}; + +/// Request payload for image generation +#[derive(Debug, Deserialize)] +pub struct InferenceRequest { + pub prompt: String, + pub negative_prompt: Option, + pub guidance_scale: Option, + pub steps: Option, + pub width: Option, + pub height: Option, + pub seed: Option, + pub model_name: Option, +} + +/// Response payload for inference results +#[derive(Debug, Serialize)] +pub struct InferenceResponse { + pub task_id: String, + pub status: String, +} + +/// Response payload for model information +#[derive(Debug, Serialize)] +pub struct ModelInfoResponse { + pub name: String, + pub path: String, + pub model_type: String, + pub version: String, + pub loaded: bool, +} + +/// Health check endpoint +pub async fn health_check() -> Result { + Ok(HttpResponse::Ok().json(serde_json::json!({ + "status": "healthy", + "service": "comfyui-backend" + }))) +} + +/// Get system information including GPU details +pub async fn get_system_info(state: web::Data) -> Result { + let gpu_config = &state.gpu_backend_config; + + Ok(HttpResponse::Ok().json(serde_json::json!({ + "gpu": { + "name": gpu_config.name, + "architecture": gpu_config.architecture, + "driver_version": gpu_config.driver_version + }, + "service": "comfyui-backend" + }))) +} + +/// Start a new inference task +pub async fn start_inference( + req: web::Json, + state: web::Data +) -> Result { + // In a real implementation, this would create an actual inference task + // For now we'll simulate it + + let task_id = uuid::Uuid::new_v4().to_string(); + + // Create the task structure + let task = Task { + id: task_id.clone(), + name: "Image Generation".to_string(), + status: TaskStatus::Pending, + progress: 0.0, + created_at: chrono::Utc::now().timestamp() as u64, + updated_at: chrono::Utc::now().timestamp() as u64, + model_name: req.model_name.clone(), + parameters: serde_json::json!({ + "prompt": req.prompt, + "negative_prompt": req.negative_prompt, + "guidance_scale": req.guidance_scale.unwrap_or(7.0), + "steps": req.steps.unwrap_or(20), + "width": req.width.unwrap_or(512), + "height": req.height.unwrap_or(512), + "seed": req.seed, + }), + }; + + // Add task to queue + { + let mut queue = state.task_queue.lock().await; + queue.add_task(task).await; + } + + Ok(HttpResponse::Ok().json(InferenceResponse { + task_id, + status: "pending".to_string(), + })) +} + +/// Get information about all available models +pub async fn get_models(state: web::Data) -> Result { + let model_manager = state.model_manager.lock().await; + let models = model_manager.get_model_info(); + + let response_models: Vec = models.into_iter() + .map(|model| ModelInfoResponse { + name: model.name, + path: model.path, + model_type: model.model_type.to_string(), + version: model.version, + loaded: model.loaded, + }) + .collect(); + + Ok(HttpResponse::Ok().json(response_models)) +} + +/// Get status of a specific task +pub async fn get_task_status( + task_id: web::Path, + state: web::Data +) -> Result { + let queue = state.task_queue.lock().await; + if let Some(task) = queue.get_task(&task_id).await { + Ok(HttpResponse::Ok().json(serde_json::json!({ + "id": task.id, + "name": task.name, + "status": task.status.to_string(), + "progress": task.progress, + "created_at": task.created_at, + "updated_at": task.updated_at, + "model_name": task.model_name, + }))) + } else { + Ok(HttpResponse::NotFound().json(serde_json::json!({ + "error": "Task not found" + }))) + } +} + +/// Get all tasks in the queue +pub async fn get_all_tasks(state: web::Data) -> Result { + let queue = state.task_queue.lock().await; + let tasks = queue.get_all_tasks().await; + + Ok(HttpResponse::Ok().json(tasks)) +} + +/// Configuration for API routes +pub fn config(cfg: &mut Scope) { + cfg.route("/health", web::get().to(health_check)) + .route("/system-info", web::get().to(get_system_info)) + .route("/models", web::get().to(get_models)) + .route("/infer", web::post().to(start_inference)) + .route("/tasks/{task_id}", web::get().to(get_task_status)) + .route("/tasks", web::get().to(get_all_tasks)); +} \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs new file mode 100644 index 0000000..6d918e2 --- /dev/null +++ b/backend/src/main.rs @@ -0,0 +1,83 @@ +//! ComfyUI-like Backend for AI Image/Video Generation with ROCm Support +//! +//! This server provides a REST API and WebSocket endpoints for +//! managing image/video generation workflows on AMD GPUs. + +use actix_web::{web, App, HttpServer, Result, middleware::Logger}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub mod api; +pub mod models; +pub mod queue_service; +pub mod rocminfo; +pub mod session_manager; + +#[derive(Debug, Clone)] +pub struct AppState { + pub gpu_backend_config: rocminfo::GpuConfig, + pub model_manager: Arc>, + pub task_queue: Arc>, +} + +/// Configuration for the application +#[derive(Debug, Deserialize, Serialize)] +pub struct Config { + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_gpu_backend")] + pub gpu_backend: String, + + #[serde(default)] + pub rocminfo_path: Option, +} + +fn default_port() -> u16 { + 8080 +} + +fn default_gpu_backend() -> String { + "rocm".to_string() +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + env_logger::init(); + + // Initialize GPU configuration + let gpu_config = rocminfo::detect_amd_gpu().unwrap_or_else(|e| { + eprintln!("Warning: Failed to detect AMD GPU: {}", e); + rocminfo::GpuConfig { + name: "Unknown_AMD_GPU".to_string(), + architecture: "unknown".to_string(), + driver_version: "unknown".to_string(), + } + }); + + // Initialize model manager + let model_manager = Arc::new(Mutex::new(models::ModelManager::new())); + + // Initialize task queue + let task_queue = Arc::new(Mutex::new(queue_service::TaskQueue::new())); + + let app_state = AppState { + gpu_backend_config: gpu_config, + model_manager, + task_queue, + }; + + println!("Starting ComfyUI backend server..."); + println!("GPU Backend: {}", app_state.gpu_backend_config.name); + + HttpServer::new(move || { + App::new() + .app_data(web::Data::new(app_state.clone())) + .wrap(Logger::default()) + .configure(api::config) + }) + .bind("127.0.0.1:8080")? + .run() + .await +} \ No newline at end of file diff --git a/backend/src/models/mod.rs b/backend/src/models/mod.rs new file mode 100644 index 0000000..27a0265 --- /dev/null +++ b/backend/src/models/mod.rs @@ -0,0 +1,95 @@ +//! Model management for AI inference workflows +//! +//! This module handles loading, caching, and managing different types of +//! machine learning models needed for image/video generation. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelInfo { + pub name: String, + pub path: String, + pub model_type: ModelType, + pub version: String, + pub loaded: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum ModelType { + StableDiffusion, + ControlNet, + VAE, + CLIP, + Other(String), +} + +impl std::fmt::Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ModelType::StableDiffusion => write!(f, "stable_diffusion"), + ModelType::ControlNet => write!(f, "control_net"), + ModelType::VAE => write!(f, "vae"), + ModelType::CLIP => write!(f, "clip"), + ModelType::Other(s) => write!(f, "{}", s), + } + } +} + +#[derive(Debug)] +pub struct ModelManager { + models: HashMap, + loaded_models: Vec, // Track which models are currently loaded +} + +impl ModelManager { + /// Create a new model manager instance + pub fn new() -> Self { + Self { + models: HashMap::new(), + loaded_models: Vec::new(), + } + } + + /// Add a model to the manager + pub fn add_model(&mut self, model_info: ModelInfo) { + self.models.insert(model_info.name.clone(), model_info); + } + + /// Load a model into memory (placeholder implementation) + pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box> { + if let Some(model) = self.models.get_mut(model_name) { + // In a real implementation, this would actually load the model + // For now we'll just mark it as loaded + model.loaded = true; + self.loaded_models.push(model_name.to_string()); + Ok(()) + } else { + Err(format!("Model '{}' not found", model_name).into()) + } + } + + /// Unload a model from memory (placeholder implementation) + pub async fn unload_model(&mut self, model_name: &str) -> Result<(), Box> { + if let Some(model) = self.models.get_mut(model_name) { + // In a real implementation, this would actually unload the model + model.loaded = false; + self.loaded_models.retain(|name| name != model_name); + Ok(()) + } else { + Err(format!("Model '{}' not found", model_name).into()) + } + } + + /// Get information about all available models + pub fn get_model_info(&self) -> Vec { + self.models.values().cloned().collect() + } + + /// Check if a specific model is loaded + pub fn is_model_loaded(&self, model_name: &str) -> bool { + self.loaded_models.contains(&model_name.to_string()) + } +} \ No newline at end of file diff --git a/backend/src/queue_service/mod.rs b/backend/src/queue_service/mod.rs new file mode 100644 index 0000000..1b1c2ec --- /dev/null +++ b/backend/src/queue_service/mod.rs @@ -0,0 +1,144 @@ +//! Task queue service for managing concurrent AI inference tasks +//! +//! This module provides a thread-safe task queue system using Tokio +//! and Rayon for parallel processing on AMD GPUs. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{Mutex, Notify}; +use uuid::Uuid; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Task { + pub id: String, + pub name: String, + pub status: TaskStatus, + pub progress: f32, + pub created_at: u64, + pub updated_at: u64, + pub model_name: Option, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum TaskStatus { + Pending, + Processing, + Completed, + Failed, + Cancelled, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "pending"), + TaskStatus::Processing => write!(f, "processing"), + TaskStatus::Completed => write!(f, "completed"), + TaskStatus::Failed => write!(f, "failed"), + TaskStatus::Cancelled => write!(f, "cancelled"), + } + } +} + +#[derive(Debug)] +pub struct TaskQueue { + tasks: HashMap, + notify: Arc, +} + +impl TaskQueue { + /// Create a new task queue instance + pub fn new() -> Self { + Self { + tasks: HashMap::new(), + notify: Arc::new(Notify::new()), + } + } + + /// Add a new task to the queue + pub async fn add_task(&mut self, task: Task) -> String { + let task_id = task.id.clone(); + self.tasks.insert(task_id.clone(), task); + self.notify.notify_waiters(); + task_id + } + + /// Get a specific task by ID + pub async fn get_task(&self, id: &str) -> Option { + self.tasks.get(id).cloned() + } + + /// Update the status of a task + pub async fn update_task_status(&mut self, id: &str, status: TaskStatus) -> Result<(), Box> { + if let Some(task) = self.tasks.get_mut(id) { + task.status = status; + task.updated_at = chrono::Utc::now().timestamp() as u64; + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Update the progress of a task + pub async fn update_task_progress(&mut self, id: &str, progress: f32) -> Result<(), Box> { + if let Some(task) = self.tasks.get_mut(id) { + task.progress = progress; + task.updated_at = chrono::Utc::now().timestamp() as u64; + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Get all tasks in the queue + pub async fn get_all_tasks(&self) -> Vec { + self.tasks.values().cloned().collect() + } + + /// Remove a completed or failed task + pub async fn remove_task(&mut self, id: &str) -> Result<(), Box> { + if self.tasks.remove(id).is_some() { + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Wait for a new task to be added to the queue + pub async fn wait_for_new_task(&self) -> String { + let notify = self.notify.clone(); + tokio::task::spawn(async move { + notify.notified().await; + }).await.unwrap(); // This will always succeed + + // Return an empty string as placeholder - in a real implementation, + // this would return the task ID of the newly added task + String::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_task_creation() { + let task = Task { + id: "test-123".to_string(), + name: "Test Task".to_string(), + status: TaskStatus::Pending, + progress: 0.0, + created_at: 0, + updated_at: 0, + model_name: None, + parameters: serde_json::Value::Null, + }; + + assert_eq!(task.id, "test-123"); + assert_eq!(task.name, "Test Task"); + assert_eq!(task.status, TaskStatus::Pending); + } +} \ No newline at end of file diff --git a/backend/src/rocminfo.rs b/backend/src/rocminfo.rs new file mode 100644 index 0000000..f78a9a9 --- /dev/null +++ b/backend/src/rocminfo.rs @@ -0,0 +1,100 @@ +//! ROCm GPU detection and configuration for AMD GPUs +//! +//! This module handles automatic detection of AMD GPUs and provides +//! configuration information needed for optimized inference on RX 9070 XT. + +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::process::Command; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct GpuConfig { + pub name: String, + pub architecture: String, + pub driver_version: String, +} + +/// Detect AMD GPU using rocminfo command +/// Returns configuration information for the detected GPU +pub fn detect_amd_gpu() -> Result> { + // Try to run rocminfo command to get GPU info + let output = Command::new("rocminfo") + .output() + .map_err(|e| format!("Failed to execute rocminfo: {}", e))?; + + if !output.status.success() { + return Err(format!("rocminfo failed with status: {}", output.status).into()); + } + + let output_str = String::from_utf8(output.stdout) + .map_err(|e| format!("Failed to decode rocminfo output: {}", e))?; + + // Parse the output to extract relevant GPU information + let mut gpu_name = "Unknown AMD GPU".to_string(); + let mut architecture = "unknown".to_string(); + let mut driver_version = "unknown".to_string(); + + for line in output_str.lines() { + if line.contains("Name:") && line.contains("AMD") { + // Extract GPU name + if let Some(name) = extract_value(line, "Name:") { + gpu_name = name; + } + } else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) { + // Extract architecture + if let Some(arch) = extract_gfx_architecture(line) { + architecture = arch; + } + } else if line.contains("Driver Version:") { + // Extract driver version + if let Some(version) = extract_value(line, "Driver Version:") { + driver_version = version; + } + } + } + + Ok(GpuConfig { + name: gpu_name, + architecture, + driver_version, + }) +} + +/// Helper function to extract values from key-value lines +fn extract_value(line: &str, key: &str) -> Option { + let parts: Vec<&str> = line.split(key).collect(); + if parts.len() >= 2 { + Some(parts[1].trim().to_string()) + } else { + None + } +} + +/// Helper function to extract GFX architecture from ROCm output +fn extract_gfx_architecture(line: &str) -> Option { + // Look for gfx* patterns in the line + let re = Regex::new(r"gfx\d+").unwrap(); + if let Some(captures) = re.captures(line) { + Some(captures.get(0).unwrap().as_str().to_string()) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gpu_config_creation() { + let config = GpuConfig { + name: "Radeon RX 9070 XT".to_string(), + architecture: "gfx900".to_string(), + driver_version: "5.4.3".to_string(), + }; + + assert_eq!(config.name, "Radeon RX 9070 XT"); + assert_eq!(config.architecture, "gfx900"); + assert_eq!(config.driver_version, "5.4.3"); + } +} diff --git a/frontend/package.json b/frontend/package.json new file mode 100644 index 0000000..0fe4e74 --- /dev/null +++ b/frontend/package.json @@ -0,0 +1,37 @@ +{ + "name": "comfyui-frontend", + "version": "0.1.0", + "description": "React frontend for ComfyUI-like AI image generation tool with ROCm support", + "main": "index.js", + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "dependencies": { + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-scripts": "5.0.1", + "web-vitals": "^2.1.4", + "axios": "^1.6.0" + }, + "devDependencies": { + "@types/node": "^20.11.0", + "@types/react": "^18.2.45", + "@types/react-dom": "^18.2.18", + "typescript": "^5.3.3" + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + } +} \ No newline at end of file diff --git a/frontend/src/App.css b/frontend/src/App.css new file mode 100644 index 0000000..cad86ed --- /dev/null +++ b/frontend/src/App.css @@ -0,0 +1,58 @@ +.app { + height: 100vh; + display: flex; + flex-direction: column; + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', + 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', + sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +.app-header { + background-color: #282c34; + padding: 20px; + color: white; + text-align: center; +} + +.main-content { + display: flex; + flex: 1; + overflow: hidden; +} + +.sidebar { + width: 250px; + background-color: #f5f5f5; + border-right: 1px solid #ddd; + padding: 15px; + overflow-y: auto; +} + +.editor-area { + flex: 1; + display: flex; + flex-direction: column; + overflow: hidden; +} + +.tabs { + background-color: #e9ecef; + padding: 10px; + border-bottom: 1px solid #ddd; +} + +.tabs button { + background: none; + border: none; + padding: 8px 16px; + margin-right: 5px; + cursor: pointer; + font-weight: bold; +} + +.tabs button.active { + border-bottom: 3px solid #007acc; + color: #007acc; +} \ No newline at end of file diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx new file mode 100644 index 0000000..d91f13b --- /dev/null +++ b/frontend/src/App.tsx @@ -0,0 +1,49 @@ +import React, { useState, useEffect } from 'react'; +import NodeEditor from './components/NodeEditor'; +import PreviewPane from './components/PreviewPane'; +import NodePanel from './components/NodePanel'; +import './App.css'; + +function App() { + const [activeTab, setActiveTab] = useState<'editor' | 'preview'>('editor'); + + return ( +
+
+

ComfyUI Rust - AMD GPU Accelerated

+

Image Generation with ROCm Support for RX 9070 XT

+
+ +
+
+ +
+ +
+
+ + +
+ + {activeTab === 'editor' ? ( + + ) : ( + + )} +
+
+
+ ); +} + +export default App; \ No newline at end of file diff --git a/frontend/src/components/NodeEditor.css b/frontend/src/components/NodeEditor.css new file mode 100644 index 0000000..f3238d9 --- /dev/null +++ b/frontend/src/components/NodeEditor.css @@ -0,0 +1,55 @@ +.node-canvas { + flex: 1; + position: relative; + overflow: hidden; + background-color: #fff; + border: 1px solid #ddd; + cursor: default; +} + +.node { + background-color: white; + border: 2px solid #007acc; + border-radius: 4px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + display: flex; + flex-direction: column; +} + +.node.selected { + border-color: #ff6b35; + box-shadow: 0 0 0 2px #ff6b35; +} + +.node-header { + background-color: #e9ecef; + padding: 8px; + border-bottom: 1px solid #ddd; + font-weight: bold; + cursor: move; +} + +.node-content { + padding: 8px; + flex: 1; +} + +.prompt-input { + width: 100%; + padding: 4px; + box-sizing: border-box; +} + +.connection-line { + pointer-events: none; +} + +.canvas-status { + position: absolute; + bottom: 10px; + left: 10px; + background-color: rgba(255, 255, 255, 0.8); + padding: 4px 8px; + border-radius: 4px; + font-size: 12px; +} \ No newline at end of file diff --git a/frontend/src/components/NodeEditor.tsx b/frontend/src/components/NodeEditor.tsx new file mode 100644 index 0000000..14d7c5a --- /dev/null +++ b/frontend/src/components/NodeEditor.tsx @@ -0,0 +1,240 @@ +import React, { useState, useRef, useEffect } from 'react'; +import './NodeEditor.css'; + +interface Node { + id: string; + type: string; + position: { x: number; y: number }; + data: any; +} + +interface Connection { + id: string; + sourceNodeId: string; + targetNodeId: string; + sourceHandleId: string; + targetHandleId: string; +} + +const NodeEditor: React.FC = () => { + const [nodes, setNodes] = useState([ + { + id: '1', + type: 'text-encoder', + position: { x: 100, y: 100 }, + data: { prompt: 'A beautiful landscape' } + }, + { + id: '2', + type: 'image-generator', + position: { x: 400, y: 150 }, + data: { steps: 20, cfg: 7.0 } + } + ]); + + const [connections, setConnections] = useState([ + { + id: 'c1', + sourceNodeId: '1', + targetNodeId: '2', + sourceHandleId: 'output', + targetHandleId: 'input' + } + ]); + + const [selectedNode, setSelectedNode] = useState(null); + const [draggedNode, setDraggedNode] = useState(null); + const [dragOffset, setDragOffset] = useState({ x: 0, y: 0 }); + const canvasRef = useRef(null); + + // Handle node dragging + const handleMouseDown = (e: React.MouseEvent, nodeId: string) => { + if (e.button !== 0) return; // Only left mouse button + + e.stopPropagation(); + + const node = nodes.find(n => n.id === nodeId); + if (!node) return; + + setDraggedNode(nodeId); + setSelectedNode(nodeId); + + const rect = canvasRef.current?.getBoundingClientRect(); + if (rect) { + setDragOffset({ + x: e.clientX - rect.left - node.position.x, + y: e.clientY - rect.top - node.position.y + }); + } + }; + + // Handle mouse move for dragging nodes + useEffect(() => { + const handleMouseMove = (e: MouseEvent) => { + if (!draggedNode || !canvasRef.current) return; + + const rect = canvasRef.current.getBoundingClientRect(); + const x = e.clientX - rect.left - dragOffset.x; + const y = e.clientY - rect.top - dragOffset.y; + + setNodes(prev => prev.map(node => + node.id === draggedNode ? { ...node, position: { x, y } } : node + )); + }; + + const handleMouseUp = () => { + setDraggedNode(null); + }; + + if (draggedNode) { + window.addEventListener('mousemove', handleMouseMove); + window.addEventListener('mouseup', handleMouseUp); + } + + return () => { + window.removeEventListener('mousemove', handleMouseMove); + window.removeEventListener('mouseup', handleMouseUp); + }; + }, [draggedNode, dragOffset]); + + // Add a new node + const addNode = (type: string, x: number, y: number) => { + const newNode: Node = { + id: `node-${Date.now()}`, + type, + position: { x, y }, + data: {} + }; + + setNodes([...nodes, newNode]); + }; + + // Handle canvas click to add nodes + const handleCanvasClick = (e: React.MouseEvent) => { + if (e.target === e.currentTarget) { + // Clicked on empty space, add new node at click position + const rect = canvasRef.current?.getBoundingClientRect(); + if (rect) { + const x = e.clientX - rect.left; + const y = e.clientY - rect.top; + addNode('text-encoder', x, y); + } + } + }; + + // Render connection lines between nodes + const renderConnections = () => { + return connections.map(conn => { + const sourceNode = nodes.find(n => n.id === conn.sourceNodeId); + const targetNode = nodes.find(n => n.id === conn.targetNodeId); + + if (!sourceNode || !targetNode) return null; + + // Calculate positions for connection line + const startX = sourceNode.position.x + 100; // Node width is 200, handle position in center + const startY = sourceNode.position.y + 30; // Handle height offset + + const endX = targetNode.position.x; + const endY = targetNode.position.y + 30; // Handle height offset + + return ( + + + + ); + }); + }; + + return ( +
+ {/* Connection lines */} + {renderConnections()} + + {/* Node definitions for arrowheads */} + + + + + + + + + {/* Render nodes */} + {nodes.map(node => ( +
handleMouseDown(e, node.id)} + > +
+ {node.type} +
+
+ {node.type === 'text-encoder' && ( + { + setNodes(nodes.map(n => + n.id === node.id ? { ...n, data: { ...n.data, prompt: e.target.value } } : n + )); + }} + className="prompt-input" + /> + )} + {node.type === 'image-generator' && ( +
+ + { + setNodes(nodes.map(n => + n.id === node.id ? { ...n, data: { ...n.data, steps: parseInt(e.target.value) } } : n + )); + }} + /> +
+ )} +
+
+ ))} + + {/* Status indicators */} +
+ Nodes: {nodes.length} + Connections: {connections.length} +
+
+ ); +}; + +export default NodeEditor; \ No newline at end of file diff --git a/frontend/src/components/NodePanel.css b/frontend/src/components/NodePanel.css new file mode 100644 index 0000000..20083d4 --- /dev/null +++ b/frontend/src/components/NodePanel.css @@ -0,0 +1,62 @@ +.node-panel { + padding: 10px; +} + +.node-panel h3 { + margin-top: 0; + margin-bottom: 15px; + color: #333; + border-bottom: 1px solid #ddd; + padding-bottom: 8px; +} + +.node-types { + display: flex; + flex-direction: column; + gap: 10px; +} + +.node-type { + background-color: white; + border: 1px solid #ddd; + border-radius: 4px; + padding: 10px; + cursor: grab; + transition: all 0.2s ease; +} + +.node-type:hover { + border-color: #007acc; + box-shadow: 0 2px 4px rgba(0, 122, 204, 0.2); +} + +.node-icon { + display: inline-block; + width: 20px; + height: 20px; + background-color: #007acc; + color: white; + text-align: center; + line-height: 20px; + border-radius: 3px; + margin-right: 10px; + vertical-align: middle; +} + +.node-info { + display: inline-block; + vertical-align: top; + width: calc(100% - 35px); +} + +.node-info h4 { + margin: 0 0 5px 0; + color: #333; + font-size: 14px; +} + +.node-info p { + margin: 0; + color: #666; + font-size: 12px; +} \ No newline at end of file diff --git a/frontend/src/components/NodePanel.tsx b/frontend/src/components/NodePanel.tsx new file mode 100644 index 0000000..895e6bf --- /dev/null +++ b/frontend/src/components/NodePanel.tsx @@ -0,0 +1,38 @@ +import React from 'react'; +import './NodePanel.css'; + +const NodePanel: React.FC = () => { + const nodeTypes = [ + { id: 'text-encoder', name: 'Text Encoder', description: 'Encode text prompts' }, + { id: 'image-generator', name: 'Image Generator', description: 'Generate images from prompts' }, + { id: 'vae-decoder', name: 'VAE Decoder', description: 'Decode latent representations' }, + { id: 'control-net', name: 'Control Net', description: 'Apply control signals' }, + { id: 'upscale', name: 'Upscaler', description: 'Increase image resolution' }, + ]; + + return ( +
+

Node Library

+
+ {nodeTypes.map(node => ( +
{ + e.dataTransfer.setData('nodeType', node.id); + }} + > +
+
+

{node.name}

+

{node.description}

+
+
+ ))} +
+
+ ); +}; + +export default NodePanel; \ No newline at end of file diff --git a/frontend/src/components/PreviewPane.css b/frontend/src/components/PreviewPane.css new file mode 100644 index 0000000..f73f177 --- /dev/null +++ b/frontend/src/components/PreviewPane.css @@ -0,0 +1,95 @@ +.preview-pane { + flex: 1; + display: flex; + flex-direction: column; + padding: 20px; + overflow-y: auto; +} + +.preview-pane h3 { + margin-top: 0; + color: #333; +} + +.preview-controls { + margin-bottom: 20px; +} + +.generate-button { + background-color: #007acc; + color: white; + border: none; + padding: 10px 20px; + font-size: 16px; + border-radius: 4px; + cursor: pointer; + transition: background-color 0.2s ease; +} + +.generate-button:hover:not(:disabled) { + background-color: #005a9e; +} + +.generate-button:disabled { + background-color: #ccc; + cursor: not-allowed; +} + +.progress-container { + margin-bottom: 20px; +} + +.progress-bar { + width: 100%; + height: 20px; + background-color: #e9ecef; + border-radius: 10px; + overflow: hidden; + margin-bottom: 5px; +} + +.progress-fill { + height: 100%; + background-color: #007acc; + transition: width 0.3s ease; + border-radius: 10px; +} + +.preview-content { + flex: 1; + display: flex; + justify-content: center; + align-items: center; + margin-bottom: 20px; + min-height: 300px; + border: 1px solid #ddd; + border-radius: 4px; + background-color: #f8f9fa; +} + +.generated-image { + max-width: 100%; + max-height: 512px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); +} + +.placeholder { + text-align: center; + color: #666; +} + +.preview-info { + background-color: #e9f7fe; + border: 1px solid #b3e5fc; + border-radius: 4px; + padding: 15px; +} + +.preview-info h4 { + margin-top: 0; + color: #0066cc; +} + +.preview-info p { + margin: 5px 0; +} \ No newline at end of file diff --git a/frontend/src/components/PreviewPane.tsx b/frontend/src/components/PreviewPane.tsx new file mode 100644 index 0000000..7575b3d --- /dev/null +++ b/frontend/src/components/PreviewPane.tsx @@ -0,0 +1,75 @@ +import React, { useState, useEffect } from 'react'; +import './PreviewPane.css'; + +const PreviewPane: React.FC = () => { + const [imageUrl, setImageUrl] = useState(null); + const [isGenerating, setIsGenerating] = useState(false); + const [progress, setProgress] = useState(0); + + // Simulate image generation + const startGeneration = () => { + setIsGenerating(true); + setProgress(0); + + // Simulate progress updates + const interval = setInterval(() => { + setProgress(prev => { + if (prev >= 100) { + clearInterval(interval); + setImageUrl('https://placehold.co/512x512?text=Generated+Image'); + setIsGenerating(false); + return 100; + } + return prev + 10; + }); + }, 300); + }; + + return ( +
+

Preview

+ +
+ +
+ +
+ {isGenerating && ( +
+
+
+ )} +

{isGenerating ? `Progress: ${progress}%` : 'No generation in progress'}

+
+ +
+ {imageUrl ? ( + Generated preview + ) : ( +
+

Preview will appear here after generation

+

Drag nodes from the panel to create a workflow

+
+ )} +
+ +
+

GPU Information

+

Device: Radeon RX 9070 XT

+

Architecture: gfx900

+

Status: Ready for ROCm acceleration

+
+
+ ); +}; + +export default PreviewPane; \ No newline at end of file diff --git a/frontend/src/index.tsx b/frontend/src/index.tsx new file mode 100644 index 0000000..23b3422 --- /dev/null +++ b/frontend/src/index.tsx @@ -0,0 +1,13 @@ +import React from 'react'; +import ReactDOM from 'react-dom/client'; +import './index.css'; +import App from './App'; + +const root = ReactDOM.createRoot( + document.getElementById('root') as HTMLElement +); +root.render( + + + +); \ No newline at end of file diff --git a/frontend/src/utils/api-client.ts b/frontend/src/utils/api-client.ts new file mode 100644 index 0000000..e37ec93 --- /dev/null +++ b/frontend/src/utils/api-client.ts @@ -0,0 +1,100 @@ +import axios from 'axios'; + +// Create an Axios instance with default configuration +const apiClient = axios.create({ + baseURL: 'http://localhost:8080', + timeout: 30000, + headers: { + 'Content-Type': 'application/json', + } +}); + +// Request interceptor to add auth token if needed +apiClient.interceptors.request.use( + (config) => { + // Add authorization header if needed + const token = localStorage.getItem('authToken'); + if (token) { + config.headers.Authorization = `Bearer ${token}`; + } + return config; + }, + (error) => { + return Promise.reject(error); + } +); + +// Response interceptor for handling errors +apiClient.interceptors.response.use( + (response) => { + return response; + }, + (error) => { + if (error.response?.status === 401) { + // Handle unauthorized access + localStorage.removeItem('authToken'); + window.location.href = '/login'; + } + return Promise.reject(error); + } +); + +// API interfaces +export interface InferenceRequest { + prompt: string; + negative_prompt?: string; + guidance_scale?: number; + steps?: number; + width?: number; + height?: number; + seed?: number; + model_name?: string; +} + +export interface InferenceResponse { + task_id: string; + status: string; +} + +export interface TaskStatusResponse { + id: string; + name: string; + status: string; + progress: number; + created_at: number; + updated_at: number; + model_name?: string; +} + +export interface ModelInfo { + name: string; + path: string; + model_type: string; + version: string; + loaded: boolean; +} + +// API functions +export const api = { + // Health check endpoint + healthCheck: () => apiClient.get('/health'), + + // System info including GPU details + getSystemInfo: () => apiClient.get('/system-info'), + + // Start inference task + startInference: (request: InferenceRequest) => + apiClient.post('/infer', request), + + // Get all models + getModels: () => apiClient.get('/models'), + + // Get task status + getTaskStatus: (taskId: string) => + apiClient.get(`/tasks/${taskId}`), + + // Get all tasks + getAllTasks: () => apiClient.get('/tasks'), +}; + +export default apiClient; \ No newline at end of file