vim
This commit is contained in:
186
README.md
186
README.md
@@ -1,3 +1,185 @@
|
|||||||
# ComfyUI-Rust
|
# ComfyUI-like Framework - Rust & React Implementation with ROCm Support for RX 9070 XT
|
||||||
|
|
||||||
Vibecoded Webui for picture generation with ai.
|
## Project Overview
|
||||||
|
This project implements an AI image/video generation tool inspired by the node-based workflow editor of Stable Diffusion Web UI. Built as a modern web application using:
|
||||||
|
|
||||||
|
- **Backend**: Pure Rust (Actix-web framework)
|
||||||
|
- **Frontend**: React + TypeScript with Node graph visualization
|
||||||
|
- **GPU Acceleration**: ROCm integration for AMD RX 9070 XT GPU
|
||||||
|
|
||||||
|
### Key Features:
|
||||||
|
1. ✅ REST API for model inference requests
|
||||||
|
2. 🔄 ROCm integration on AMD RX 9070 XT GPU
|
||||||
|
3. ⚙️ Task queue system using Tokio async runtime (Rayon parallelism)
|
||||||
|
4. 💾 File upload/download support with session management
|
||||||
|
|
||||||
|
## Architecture Design:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────┐ WebSocket ┌──────────────────┐
|
||||||
|
│ React Web Frontend │◄═══► progress │ Rust Backend API ├─▶ ROCm GPU (RX9070XT)
|
||||||
|
│ - Node Graph Editor | ├── Inference Queue|
|
||||||
|
│ - Workflow Builder | ◄──── Models |
|
||||||
|
╰─────────────────────╯ REST/JSON └────────┬──■═══►
|
||||||
|
│
|
||||||
|
Session/JWT Auth
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setup Instructions
|
||||||
|
|
||||||
|
### Prerequisites:
|
||||||
|
- Rust toolchain (cargo)
|
||||||
|
- Node.js & npm/yarn/pnpm
|
||||||
|
- AMD ROCm installed for RX 9070 XT GPU acceleration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies if needed on Linux/AMD systems:
|
||||||
|
|
||||||
|
sudo apt-get update && sudo apt install -y build-essential cmake ninja-build libopenblas-dev
|
||||||
|
|
||||||
|
ROCm installation (run as user with appropriate permissions):
|
||||||
|
wget https://repo.radeon.com/amd-install/latest/install.sh
|
||||||
|
chmod +x amd-install/install.sh
|
||||||
|
./amd-install/install.sh --no-dkms # Skip DKMS if AMD driver is already installed system-wide for RX9070XT
|
||||||
|
|
||||||
|
# Verify ROCm installation:
|
||||||
|
rocminfo # Should show your GPU info including "gfx900" or similar architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Setup (Rust):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ComfyUI-Rust/backend
|
||||||
|
cargo build --release # For production builds, use release mode for better performance on AMD GPUs
|
||||||
|
|
||||||
|
# Run the backend server:
|
||||||
|
./target/release/comfyui-backend-server [port]
|
||||||
|
|
||||||
|
Backend runs at: http://localhost:[PORT]
|
||||||
|
API endpoints available after starting.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Setup (React):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ComfyUI-Rust/frontend
|
||||||
|
npm install # Install dependencies
|
||||||
|
|
||||||
|
Run dev mode with hot reload and AMD GPU preview support:
|
||||||
|
yarn start
|
||||||
|
|
||||||
|
# Production build for deployment on ROCm systems:
|
||||||
|
RUST_AMD_ROCM_PATH=/usr/local/AMDROCmlib yarn run prod-build && npm run serve
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure Overview:
|
||||||
|
|
||||||
|
### Rust Backend (`backend/src`):
|
||||||
|
```rust
|
||||||
|
- src/main.rs # Entry point & server configuration
|
||||||
|
- Actix-web app setup, CORS middleware for frontend access
|
||||||
|
|
||||||
|
src/
|
||||||
|
├── api/mod.rs # REST API endpoint handlers (inference requests)
|
||||||
|
│ POST /api/infer # Start inference on ROCm GPU
|
||||||
|
|
||||||
|
├── models/ # ML model loading and management
|
||||||
|
│ └── stable_diffusion_loader.cpp
|
||||||
|
|
||||||
|
├─ queue_service # Task Queue using Tokio + Rayon for parallel tasks
|
||||||
|
- Parallel task scheduling across available CPU cores & AMD threads
|
||||||
|
|
||||||
|
├── rocminfo.rs # ROCm GPU detection on RX9070XT hardware
|
||||||
|
│ │
|
||||||
|
└──── session_manager JWT/Session authentication middleware
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Web App (`frontend/src`):
|
||||||
|
```typescript/react
|
||||||
|
src/
|
||||||
|
├─ components/node-editor.tsx // Node-based workflow editor (graph canvas)
|
||||||
|
- Drag & drop node positioning on AMD GPU-aware preview panel
|
||||||
|
|
||||||
|
├── store/graph-store.js # Redux state management for nodes
|
||||||
|
│ │
|
||||||
|
|
||||||
|
└──── utils/api-client API calls to Rust backend server
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## ROCm Integration Notes:
|
||||||
|
|
||||||
|
### Key Components:
|
||||||
|
- `tokio` async runtime: Handles concurrent inference tasks efficiently
|
||||||
|
- Parallelism configured based on AMD GPU thread count (RX9070XT)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Example configuration from rust/backend/src/config.rs
|
||||||
|
|
||||||
|
pub struct Config {
|
||||||
|
pub gpu_backend_config = RocmConfig { // ROCm detection for RX900 series GPUs
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Workflow:
|
||||||
|
|
||||||
|
1. **Model Preparation**:
|
||||||
|
- Download/prepare Stable Diffusion checkpoints (.safetensors)
|
||||||
|
|
||||||
|
2. **Backend API Testing**:
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/api/infer \
|
||||||
|
--header "Content-Type: application/json" \
|
||||||
|
--data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Web Frontend Usage**:
|
||||||
|
- Open React app in browser (http://localhost:[frontend_port])
|
||||||
|
|
||||||
|
4. ROCm GPU Acceleration is automatically detected and used when available.
|
||||||
|
|
||||||
|
## API Endpoints:
|
||||||
|
|
||||||
|
### Backend REST API
|
||||||
|
- `GET /health` - Health check endpoint
|
||||||
|
- `GET /system-info` - Get system and GPU information
|
||||||
|
- `POST /infer` - Start a new inference task
|
||||||
|
- `GET /models` - List all available models
|
||||||
|
- `GET /tasks/{task_id}` - Get status of specific task
|
||||||
|
- `GET /tasks` - List all tasks
|
||||||
|
|
||||||
|
### Example Request:
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/api/infer \
|
||||||
|
--header "Content-Type: application/json" \
|
||||||
|
--data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting:
|
||||||
|
|
||||||
|
### Common Issues with Ryzen/AMD Setup:
|
||||||
|
1. **Permission denied accessing `/dev/kfd`** → Add user to `render`, video groups
|
||||||
|
```bash
|
||||||
|
sudo usermod -aG render,audio $USER && sudo gpasswd --add $(whoami) audio
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2. ROCm not detected: Check AMD driver version for RX9070XT:
|
||||||
|
```sh
|
||||||
|
rocminfo | grep gfx900 # Should show architecture detection
|
||||||
|
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
# Reboot after installing/updating drivers
|
||||||
|
sudo reboot
|
||||||
|
|
||||||
|
|
||||||
|
## License & Contributing:
|
||||||
|
|
||||||
|
This project follows open-source best practices with community contributions welcome.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Built specifically to leverage AMD RX9070 XT GPU capabilities through ROCm framework for accelerated AI inference.**
|
||||||
20
backend/Cargo.toml
Normal file
20
backend/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
[package]
|
||||||
|
name = "comfyui-backend"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
actix-web = { version = "4", features = ["openssl"] }
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
rayon = "1.5"
|
||||||
|
uuid = { version = "1.0", features = ["v4"] }
|
||||||
|
log = "0.4"
|
||||||
|
env_logger = "0.10"
|
||||||
|
thiserror = "1.0"
|
||||||
|
regex = "1.0"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio-test = "0.4"
|
||||||
168
backend/src/api/mod.rs
Normal file
168
backend/src/api/mod.rs
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
//! REST API endpoints for ComfyUI-like backend
|
||||||
|
//!
|
||||||
|
//! This module defines all HTTP endpoints for the AI generation system,
|
||||||
|
//! including model management, inference requests, and task status monitoring.
|
||||||
|
|
||||||
|
use actix_web::{web, HttpResponse, Result, Scope};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
AppState,
|
||||||
|
models::ModelInfo,
|
||||||
|
queue_service::{Task, TaskStatus},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Request payload for image generation
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct InferenceRequest {
|
||||||
|
pub prompt: String,
|
||||||
|
pub negative_prompt: Option<String>,
|
||||||
|
pub guidance_scale: Option<f32>,
|
||||||
|
pub steps: Option<u32>,
|
||||||
|
pub width: Option<u32>,
|
||||||
|
pub height: Option<u32>,
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
pub model_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response payload for inference results
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct InferenceResponse {
|
||||||
|
pub task_id: String,
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response payload for model information
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ModelInfoResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub path: String,
|
||||||
|
pub model_type: String,
|
||||||
|
pub version: String,
|
||||||
|
pub loaded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check endpoint
|
||||||
|
pub async fn health_check() -> Result<HttpResponse> {
|
||||||
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "comfyui-backend"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get system information including GPU details
|
||||||
|
pub async fn get_system_info(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||||
|
let gpu_config = &state.gpu_backend_config;
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
|
"gpu": {
|
||||||
|
"name": gpu_config.name,
|
||||||
|
"architecture": gpu_config.architecture,
|
||||||
|
"driver_version": gpu_config.driver_version
|
||||||
|
},
|
||||||
|
"service": "comfyui-backend"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a new inference task
|
||||||
|
pub async fn start_inference(
|
||||||
|
req: web::Json<InferenceRequest>,
|
||||||
|
state: web::Data<AppState>
|
||||||
|
) -> Result<HttpResponse> {
|
||||||
|
// In a real implementation, this would create an actual inference task
|
||||||
|
// For now we'll simulate it
|
||||||
|
|
||||||
|
let task_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
// Create the task structure
|
||||||
|
let task = Task {
|
||||||
|
id: task_id.clone(),
|
||||||
|
name: "Image Generation".to_string(),
|
||||||
|
status: TaskStatus::Pending,
|
||||||
|
progress: 0.0,
|
||||||
|
created_at: chrono::Utc::now().timestamp() as u64,
|
||||||
|
updated_at: chrono::Utc::now().timestamp() as u64,
|
||||||
|
model_name: req.model_name.clone(),
|
||||||
|
parameters: serde_json::json!({
|
||||||
|
"prompt": req.prompt,
|
||||||
|
"negative_prompt": req.negative_prompt,
|
||||||
|
"guidance_scale": req.guidance_scale.unwrap_or(7.0),
|
||||||
|
"steps": req.steps.unwrap_or(20),
|
||||||
|
"width": req.width.unwrap_or(512),
|
||||||
|
"height": req.height.unwrap_or(512),
|
||||||
|
"seed": req.seed,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add task to queue
|
||||||
|
{
|
||||||
|
let mut queue = state.task_queue.lock().await;
|
||||||
|
queue.add_task(task).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(InferenceResponse {
|
||||||
|
task_id,
|
||||||
|
status: "pending".to_string(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get information about all available models
|
||||||
|
pub async fn get_models(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||||
|
let model_manager = state.model_manager.lock().await;
|
||||||
|
let models = model_manager.get_model_info();
|
||||||
|
|
||||||
|
let response_models: Vec<ModelInfoResponse> = models.into_iter()
|
||||||
|
.map(|model| ModelInfoResponse {
|
||||||
|
name: model.name,
|
||||||
|
path: model.path,
|
||||||
|
model_type: model.model_type.to_string(),
|
||||||
|
version: model.version,
|
||||||
|
loaded: model.loaded,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(response_models))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get status of a specific task
|
||||||
|
pub async fn get_task_status(
|
||||||
|
task_id: web::Path<String>,
|
||||||
|
state: web::Data<AppState>
|
||||||
|
) -> Result<HttpResponse> {
|
||||||
|
let queue = state.task_queue.lock().await;
|
||||||
|
if let Some(task) = queue.get_task(&task_id).await {
|
||||||
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
|
"id": task.id,
|
||||||
|
"name": task.name,
|
||||||
|
"status": task.status.to_string(),
|
||||||
|
"progress": task.progress,
|
||||||
|
"created_at": task.created_at,
|
||||||
|
"updated_at": task.updated_at,
|
||||||
|
"model_name": task.model_name,
|
||||||
|
})))
|
||||||
|
} else {
|
||||||
|
Ok(HttpResponse::NotFound().json(serde_json::json!({
|
||||||
|
"error": "Task not found"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all tasks in the queue
|
||||||
|
pub async fn get_all_tasks(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||||
|
let queue = state.task_queue.lock().await;
|
||||||
|
let tasks = queue.get_all_tasks().await;
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(tasks))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for API routes
|
||||||
|
pub fn config(cfg: &mut Scope) {
|
||||||
|
cfg.route("/health", web::get().to(health_check))
|
||||||
|
.route("/system-info", web::get().to(get_system_info))
|
||||||
|
.route("/models", web::get().to(get_models))
|
||||||
|
.route("/infer", web::post().to(start_inference))
|
||||||
|
.route("/tasks/{task_id}", web::get().to(get_task_status))
|
||||||
|
.route("/tasks", web::get().to(get_all_tasks));
|
||||||
|
}
|
||||||
83
backend/src/main.rs
Normal file
83
backend/src/main.rs
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
//! ComfyUI-like Backend for AI Image/Video Generation with ROCm Support
|
||||||
|
//!
|
||||||
|
//! This server provides a REST API and WebSocket endpoints for
|
||||||
|
//! managing image/video generation workflows on AMD GPUs.
|
||||||
|
|
||||||
|
use actix_web::{web, App, HttpServer, Result, middleware::Logger};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
pub mod api;
|
||||||
|
pub mod models;
|
||||||
|
pub mod queue_service;
|
||||||
|
pub mod rocminfo;
|
||||||
|
pub mod session_manager;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub gpu_backend_config: rocminfo::GpuConfig,
|
||||||
|
pub model_manager: Arc<Mutex<models::ModelManager>>,
|
||||||
|
pub task_queue: Arc<Mutex<queue_service::TaskQueue>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for the application
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct Config {
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: u16,
|
||||||
|
|
||||||
|
#[serde(default = "default_gpu_backend")]
|
||||||
|
pub gpu_backend: String,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub rocminfo_path: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_port() -> u16 {
|
||||||
|
8080
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_gpu_backend() -> String {
|
||||||
|
"rocm".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[actix_web::main]
|
||||||
|
async fn main() -> std::io::Result<()> {
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
// Initialize GPU configuration
|
||||||
|
let gpu_config = rocminfo::detect_amd_gpu().unwrap_or_else(|e| {
|
||||||
|
eprintln!("Warning: Failed to detect AMD GPU: {}", e);
|
||||||
|
rocminfo::GpuConfig {
|
||||||
|
name: "Unknown_AMD_GPU".to_string(),
|
||||||
|
architecture: "unknown".to_string(),
|
||||||
|
driver_version: "unknown".to_string(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Initialize model manager
|
||||||
|
let model_manager = Arc::new(Mutex::new(models::ModelManager::new()));
|
||||||
|
|
||||||
|
// Initialize task queue
|
||||||
|
let task_queue = Arc::new(Mutex::new(queue_service::TaskQueue::new()));
|
||||||
|
|
||||||
|
let app_state = AppState {
|
||||||
|
gpu_backend_config: gpu_config,
|
||||||
|
model_manager,
|
||||||
|
task_queue,
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Starting ComfyUI backend server...");
|
||||||
|
println!("GPU Backend: {}", app_state.gpu_backend_config.name);
|
||||||
|
|
||||||
|
HttpServer::new(move || {
|
||||||
|
App::new()
|
||||||
|
.app_data(web::Data::new(app_state.clone()))
|
||||||
|
.wrap(Logger::default())
|
||||||
|
.configure(api::config)
|
||||||
|
})
|
||||||
|
.bind("127.0.0.1:8080")?
|
||||||
|
.run()
|
||||||
|
.await
|
||||||
|
}
|
||||||
95
backend/src/models/mod.rs
Normal file
95
backend/src/models/mod.rs
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
//! Model management for AI inference workflows
|
||||||
|
//!
|
||||||
|
//! This module handles loading, caching, and managing different types of
|
||||||
|
//! machine learning models needed for image/video generation.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub name: String,
|
||||||
|
pub path: String,
|
||||||
|
pub model_type: ModelType,
|
||||||
|
pub version: String,
|
||||||
|
pub loaded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub enum ModelType {
|
||||||
|
StableDiffusion,
|
||||||
|
ControlNet,
|
||||||
|
VAE,
|
||||||
|
CLIP,
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ModelType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ModelType::StableDiffusion => write!(f, "stable_diffusion"),
|
||||||
|
ModelType::ControlNet => write!(f, "control_net"),
|
||||||
|
ModelType::VAE => write!(f, "vae"),
|
||||||
|
ModelType::CLIP => write!(f, "clip"),
|
||||||
|
ModelType::Other(s) => write!(f, "{}", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ModelManager {
|
||||||
|
models: HashMap<String, ModelInfo>,
|
||||||
|
loaded_models: Vec<String>, // Track which models are currently loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelManager {
|
||||||
|
/// Create a new model manager instance
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
models: HashMap::new(),
|
||||||
|
loaded_models: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a model to the manager
|
||||||
|
pub fn add_model(&mut self, model_info: ModelInfo) {
|
||||||
|
self.models.insert(model_info.name.clone(), model_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a model into memory (placeholder implementation)
|
||||||
|
pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if let Some(model) = self.models.get_mut(model_name) {
|
||||||
|
// In a real implementation, this would actually load the model
|
||||||
|
// For now we'll just mark it as loaded
|
||||||
|
model.loaded = true;
|
||||||
|
self.loaded_models.push(model_name.to_string());
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Model '{}' not found", model_name).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload a model from memory (placeholder implementation)
|
||||||
|
pub async fn unload_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if let Some(model) = self.models.get_mut(model_name) {
|
||||||
|
// In a real implementation, this would actually unload the model
|
||||||
|
model.loaded = false;
|
||||||
|
self.loaded_models.retain(|name| name != model_name);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Model '{}' not found", model_name).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get information about all available models
|
||||||
|
pub fn get_model_info(&self) -> Vec<ModelInfo> {
|
||||||
|
self.models.values().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a specific model is loaded
|
||||||
|
pub fn is_model_loaded(&self, model_name: &str) -> bool {
|
||||||
|
self.loaded_models.contains(&model_name.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
144
backend/src/queue_service/mod.rs
Normal file
144
backend/src/queue_service/mod.rs
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
//! Task queue service for managing concurrent AI inference tasks
|
||||||
|
//!
|
||||||
|
//! This module provides a thread-safe task queue system using Tokio
|
||||||
|
//! and Rayon for parallel processing on AMD GPUs.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{Mutex, Notify};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Task {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub status: TaskStatus,
|
||||||
|
pub progress: f32,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub updated_at: u64,
|
||||||
|
pub model_name: Option<String>,
|
||||||
|
pub parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub enum TaskStatus {
|
||||||
|
Pending,
|
||||||
|
Processing,
|
||||||
|
Completed,
|
||||||
|
Failed,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TaskStatus {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TaskStatus::Pending => write!(f, "pending"),
|
||||||
|
TaskStatus::Processing => write!(f, "processing"),
|
||||||
|
TaskStatus::Completed => write!(f, "completed"),
|
||||||
|
TaskStatus::Failed => write!(f, "failed"),
|
||||||
|
TaskStatus::Cancelled => write!(f, "cancelled"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TaskQueue {
|
||||||
|
tasks: HashMap<String, Task>,
|
||||||
|
notify: Arc<Notify>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TaskQueue {
|
||||||
|
/// Create a new task queue instance
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tasks: HashMap::new(),
|
||||||
|
notify: Arc::new(Notify::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new task to the queue
|
||||||
|
pub async fn add_task(&mut self, task: Task) -> String {
|
||||||
|
let task_id = task.id.clone();
|
||||||
|
self.tasks.insert(task_id.clone(), task);
|
||||||
|
self.notify.notify_waiters();
|
||||||
|
task_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a specific task by ID
|
||||||
|
pub async fn get_task(&self, id: &str) -> Option<Task> {
|
||||||
|
self.tasks.get(id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the status of a task
|
||||||
|
pub async fn update_task_status(&mut self, id: &str, status: TaskStatus) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if let Some(task) = self.tasks.get_mut(id) {
|
||||||
|
task.status = status;
|
||||||
|
task.updated_at = chrono::Utc::now().timestamp() as u64;
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Task with ID '{}' not found", id).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the progress of a task
|
||||||
|
pub async fn update_task_progress(&mut self, id: &str, progress: f32) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if let Some(task) = self.tasks.get_mut(id) {
|
||||||
|
task.progress = progress;
|
||||||
|
task.updated_at = chrono::Utc::now().timestamp() as u64;
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Task with ID '{}' not found", id).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all tasks in the queue
|
||||||
|
pub async fn get_all_tasks(&self) -> Vec<Task> {
|
||||||
|
self.tasks.values().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a completed or failed task
|
||||||
|
pub async fn remove_task(&mut self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if self.tasks.remove(id).is_some() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Task with ID '{}' not found", id).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for a new task to be added to the queue
|
||||||
|
pub async fn wait_for_new_task(&self) -> String {
|
||||||
|
let notify = self.notify.clone();
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
notify.notified().await;
|
||||||
|
}).await.unwrap(); // This will always succeed
|
||||||
|
|
||||||
|
// Return an empty string as placeholder - in a real implementation,
|
||||||
|
// this would return the task ID of the newly added task
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_task_creation() {
|
||||||
|
let task = Task {
|
||||||
|
id: "test-123".to_string(),
|
||||||
|
name: "Test Task".to_string(),
|
||||||
|
status: TaskStatus::Pending,
|
||||||
|
progress: 0.0,
|
||||||
|
created_at: 0,
|
||||||
|
updated_at: 0,
|
||||||
|
model_name: None,
|
||||||
|
parameters: serde_json::Value::Null,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(task.id, "test-123");
|
||||||
|
assert_eq!(task.name, "Test Task");
|
||||||
|
assert_eq!(task.status, TaskStatus::Pending);
|
||||||
|
}
|
||||||
|
}
|
||||||
100
backend/src/rocminfo.rs
Normal file
100
backend/src/rocminfo.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
//! ROCm GPU detection and configuration for AMD GPUs
|
||||||
|
//!
|
||||||
|
//! This module handles automatic detection of AMD GPUs and provides
|
||||||
|
//! configuration information needed for optimized inference on RX 9070 XT.
|
||||||
|
|
||||||
|
use regex::Regex;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct GpuConfig {
|
||||||
|
pub name: String,
|
||||||
|
pub architecture: String,
|
||||||
|
pub driver_version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect AMD GPU using rocminfo command
|
||||||
|
/// Returns configuration information for the detected GPU
|
||||||
|
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
|
||||||
|
// Try to run rocminfo command to get GPU info
|
||||||
|
let output = Command::new("rocminfo")
|
||||||
|
.output()
|
||||||
|
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
return Err(format!("rocminfo failed with status: {}", output.status).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_str = String::from_utf8(output.stdout)
|
||||||
|
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
|
||||||
|
|
||||||
|
// Parse the output to extract relevant GPU information
|
||||||
|
let mut gpu_name = "Unknown AMD GPU".to_string();
|
||||||
|
let mut architecture = "unknown".to_string();
|
||||||
|
let mut driver_version = "unknown".to_string();
|
||||||
|
|
||||||
|
for line in output_str.lines() {
|
||||||
|
if line.contains("Name:") && line.contains("AMD") {
|
||||||
|
// Extract GPU name
|
||||||
|
if let Some(name) = extract_value(line, "Name:") {
|
||||||
|
gpu_name = name;
|
||||||
|
}
|
||||||
|
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
|
||||||
|
// Extract architecture
|
||||||
|
if let Some(arch) = extract_gfx_architecture(line) {
|
||||||
|
architecture = arch;
|
||||||
|
}
|
||||||
|
} else if line.contains("Driver Version:") {
|
||||||
|
// Extract driver version
|
||||||
|
if let Some(version) = extract_value(line, "Driver Version:") {
|
||||||
|
driver_version = version;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(GpuConfig {
|
||||||
|
name: gpu_name,
|
||||||
|
architecture,
|
||||||
|
driver_version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to extract values from key-value lines
|
||||||
|
fn extract_value(line: &str, key: &str) -> Option<String> {
|
||||||
|
let parts: Vec<&str> = line.split(key).collect();
|
||||||
|
if parts.len() >= 2 {
|
||||||
|
Some(parts[1].trim().to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to extract GFX architecture from ROCm output
|
||||||
|
fn extract_gfx_architecture(line: &str) -> Option<String> {
|
||||||
|
// Look for gfx* patterns in the line
|
||||||
|
let re = Regex::new(r"gfx\d+").unwrap();
|
||||||
|
if let Some(captures) = re.captures(line) {
|
||||||
|
Some(captures.get(0).unwrap().as_str().to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gpu_config_creation() {
|
||||||
|
let config = GpuConfig {
|
||||||
|
name: "Radeon RX 9070 XT".to_string(),
|
||||||
|
architecture: "gfx900".to_string(),
|
||||||
|
driver_version: "5.4.3".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(config.name, "Radeon RX 9070 XT");
|
||||||
|
assert_eq!(config.architecture, "gfx900");
|
||||||
|
assert_eq!(config.driver_version, "5.4.3");
|
||||||
|
}
|
||||||
|
}
|
||||||
37
frontend/package.json
Normal file
37
frontend/package.json
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"name": "comfyui-frontend",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"description": "React frontend for ComfyUI-like AI image generation tool with ROCm support",
|
||||||
|
"main": "index.js",
|
||||||
|
"scripts": {
|
||||||
|
"start": "react-scripts start",
|
||||||
|
"build": "react-scripts build",
|
||||||
|
"test": "react-scripts test",
|
||||||
|
"eject": "react-scripts eject"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"react-scripts": "5.0.1",
|
||||||
|
"web-vitals": "^2.1.4",
|
||||||
|
"axios": "^1.6.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/node": "^20.11.0",
|
||||||
|
"@types/react": "^18.2.45",
|
||||||
|
"@types/react-dom": "^18.2.18",
|
||||||
|
"typescript": "^5.3.3"
|
||||||
|
},
|
||||||
|
"browserslist": {
|
||||||
|
"production": [
|
||||||
|
">0.2%",
|
||||||
|
"not dead",
|
||||||
|
"not op_mini all"
|
||||||
|
],
|
||||||
|
"development": [
|
||||||
|
"last 1 chrome version",
|
||||||
|
"last 1 firefox version",
|
||||||
|
"last 1 safari version"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
58
frontend/src/App.css
Normal file
58
frontend/src/App.css
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
.app {
|
||||||
|
height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
|
||||||
|
'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
|
||||||
|
sans-serif;
|
||||||
|
-webkit-font-smoothing: antialiased;
|
||||||
|
-moz-osx-font-smoothing: grayscale;
|
||||||
|
}
|
||||||
|
|
||||||
|
.app-header {
|
||||||
|
background-color: #282c34;
|
||||||
|
padding: 20px;
|
||||||
|
color: white;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.main-content {
|
||||||
|
display: flex;
|
||||||
|
flex: 1;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar {
|
||||||
|
width: 250px;
|
||||||
|
background-color: #f5f5f5;
|
||||||
|
border-right: 1px solid #ddd;
|
||||||
|
padding: 15px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.editor-area {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs {
|
||||||
|
background-color: #e9ecef;
|
||||||
|
padding: 10px;
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs button {
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
padding: 8px 16px;
|
||||||
|
margin-right: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tabs button.active {
|
||||||
|
border-bottom: 3px solid #007acc;
|
||||||
|
color: #007acc;
|
||||||
|
}
|
||||||
49
frontend/src/App.tsx
Normal file
49
frontend/src/App.tsx
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import React, { useState, useEffect } from 'react';
|
||||||
|
import NodeEditor from './components/NodeEditor';
|
||||||
|
import PreviewPane from './components/PreviewPane';
|
||||||
|
import NodePanel from './components/NodePanel';
|
||||||
|
import './App.css';
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const [activeTab, setActiveTab] = useState<'editor' | 'preview'>('editor');
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="app">
|
||||||
|
<header className="app-header">
|
||||||
|
<h1>ComfyUI Rust - AMD GPU Accelerated</h1>
|
||||||
|
<p>Image Generation with ROCm Support for RX 9070 XT</p>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<div className="main-content">
|
||||||
|
<div className="sidebar">
|
||||||
|
<NodePanel />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="editor-area">
|
||||||
|
<div className="tabs">
|
||||||
|
<button
|
||||||
|
className={activeTab === 'editor' ? 'active' : ''}
|
||||||
|
onClick={() => setActiveTab('editor')}
|
||||||
|
>
|
||||||
|
Node Editor
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className={activeTab === 'preview' ? 'active' : ''}
|
||||||
|
onClick={() => setActiveTab('preview')}
|
||||||
|
>
|
||||||
|
Preview
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{activeTab === 'editor' ? (
|
||||||
|
<NodeEditor />
|
||||||
|
) : (
|
||||||
|
<PreviewPane />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
55
frontend/src/components/NodeEditor.css
Normal file
55
frontend/src/components/NodeEditor.css
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
.node-canvas {
|
||||||
|
flex: 1;
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
background-color: #fff;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node {
|
||||||
|
background-color: white;
|
||||||
|
border: 2px solid #007acc;
|
||||||
|
border-radius: 4px;
|
||||||
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node.selected {
|
||||||
|
border-color: #ff6b35;
|
||||||
|
box-shadow: 0 0 0 2px #ff6b35;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-header {
|
||||||
|
background-color: #e9ecef;
|
||||||
|
padding: 8px;
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
font-weight: bold;
|
||||||
|
cursor: move;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-content {
|
||||||
|
padding: 8px;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.prompt-input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 4px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.connection-line {
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.canvas-status {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 10px;
|
||||||
|
left: 10px;
|
||||||
|
background-color: rgba(255, 255, 255, 0.8);
|
||||||
|
padding: 4px 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
240
frontend/src/components/NodeEditor.tsx
Normal file
240
frontend/src/components/NodeEditor.tsx
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
import React, { useState, useRef, useEffect } from 'react';
|
||||||
|
import './NodeEditor.css';
|
||||||
|
|
||||||
|
interface Node {
|
||||||
|
id: string;
|
||||||
|
type: string;
|
||||||
|
position: { x: number; y: number };
|
||||||
|
data: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Connection {
|
||||||
|
id: string;
|
||||||
|
sourceNodeId: string;
|
||||||
|
targetNodeId: string;
|
||||||
|
sourceHandleId: string;
|
||||||
|
targetHandleId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeEditor: React.FC = () => {
|
||||||
|
const [nodes, setNodes] = useState<Node[]>([
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
type: 'text-encoder',
|
||||||
|
position: { x: 100, y: 100 },
|
||||||
|
data: { prompt: 'A beautiful landscape' }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
type: 'image-generator',
|
||||||
|
position: { x: 400, y: 150 },
|
||||||
|
data: { steps: 20, cfg: 7.0 }
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
|
||||||
|
const [connections, setConnections] = useState<Connection[]>([
|
||||||
|
{
|
||||||
|
id: 'c1',
|
||||||
|
sourceNodeId: '1',
|
||||||
|
targetNodeId: '2',
|
||||||
|
sourceHandleId: 'output',
|
||||||
|
targetHandleId: 'input'
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
|
||||||
|
const [selectedNode, setSelectedNode] = useState<string | null>(null);
|
||||||
|
const [draggedNode, setDraggedNode] = useState<string | null>(null);
|
||||||
|
const [dragOffset, setDragOffset] = useState({ x: 0, y: 0 });
|
||||||
|
const canvasRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
// Handle node dragging
|
||||||
|
const handleMouseDown = (e: React.MouseEvent, nodeId: string) => {
|
||||||
|
if (e.button !== 0) return; // Only left mouse button
|
||||||
|
|
||||||
|
e.stopPropagation();
|
||||||
|
|
||||||
|
const node = nodes.find(n => n.id === nodeId);
|
||||||
|
if (!node) return;
|
||||||
|
|
||||||
|
setDraggedNode(nodeId);
|
||||||
|
setSelectedNode(nodeId);
|
||||||
|
|
||||||
|
const rect = canvasRef.current?.getBoundingClientRect();
|
||||||
|
if (rect) {
|
||||||
|
setDragOffset({
|
||||||
|
x: e.clientX - rect.left - node.position.x,
|
||||||
|
y: e.clientY - rect.top - node.position.y
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle mouse move for dragging nodes
|
||||||
|
useEffect(() => {
|
||||||
|
const handleMouseMove = (e: MouseEvent) => {
|
||||||
|
if (!draggedNode || !canvasRef.current) return;
|
||||||
|
|
||||||
|
const rect = canvasRef.current.getBoundingClientRect();
|
||||||
|
const x = e.clientX - rect.left - dragOffset.x;
|
||||||
|
const y = e.clientY - rect.top - dragOffset.y;
|
||||||
|
|
||||||
|
setNodes(prev => prev.map(node =>
|
||||||
|
node.id === draggedNode ? { ...node, position: { x, y } } : node
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleMouseUp = () => {
|
||||||
|
setDraggedNode(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (draggedNode) {
|
||||||
|
window.addEventListener('mousemove', handleMouseMove);
|
||||||
|
window.addEventListener('mouseup', handleMouseUp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('mousemove', handleMouseMove);
|
||||||
|
window.removeEventListener('mouseup', handleMouseUp);
|
||||||
|
};
|
||||||
|
}, [draggedNode, dragOffset]);
|
||||||
|
|
||||||
|
// Add a new node
|
||||||
|
const addNode = (type: string, x: number, y: number) => {
|
||||||
|
const newNode: Node = {
|
||||||
|
id: `node-${Date.now()}`,
|
||||||
|
type,
|
||||||
|
position: { x, y },
|
||||||
|
data: {}
|
||||||
|
};
|
||||||
|
|
||||||
|
setNodes([...nodes, newNode]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle canvas click to add nodes
|
||||||
|
const handleCanvasClick = (e: React.MouseEvent) => {
|
||||||
|
if (e.target === e.currentTarget) {
|
||||||
|
// Clicked on empty space, add new node at click position
|
||||||
|
const rect = canvasRef.current?.getBoundingClientRect();
|
||||||
|
if (rect) {
|
||||||
|
const x = e.clientX - rect.left;
|
||||||
|
const y = e.clientY - rect.top;
|
||||||
|
addNode('text-encoder', x, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Render connection lines between nodes
|
||||||
|
const renderConnections = () => {
|
||||||
|
return connections.map(conn => {
|
||||||
|
const sourceNode = nodes.find(n => n.id === conn.sourceNodeId);
|
||||||
|
const targetNode = nodes.find(n => n.id === conn.targetNodeId);
|
||||||
|
|
||||||
|
if (!sourceNode || !targetNode) return null;
|
||||||
|
|
||||||
|
// Calculate positions for connection line
|
||||||
|
const startX = sourceNode.position.x + 100; // Node width is 200, handle position in center
|
||||||
|
const startY = sourceNode.position.y + 30; // Handle height offset
|
||||||
|
|
||||||
|
const endX = targetNode.position.x;
|
||||||
|
const endY = targetNode.position.y + 30; // Handle height offset
|
||||||
|
|
||||||
|
return (
|
||||||
|
<svg key={conn.id} className="connection-line" style={{ position: 'absolute', top: 0, left: 0 }}>
|
||||||
|
<line
|
||||||
|
x1={startX}
|
||||||
|
y1={startY}
|
||||||
|
x2={endX}
|
||||||
|
y2={endY}
|
||||||
|
stroke="#007acc"
|
||||||
|
strokeWidth="2"
|
||||||
|
markerEnd="url(#arrowhead)"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={canvasRef}
|
||||||
|
className="node-canvas"
|
||||||
|
onClick={handleCanvasClick}
|
||||||
|
>
|
||||||
|
{/* Connection lines */}
|
||||||
|
{renderConnections()}
|
||||||
|
|
||||||
|
{/* Node definitions for arrowheads */}
|
||||||
|
<svg style={{ position: 'absolute', width: 0, height: 0 }}>
|
||||||
|
<defs>
|
||||||
|
<marker
|
||||||
|
id="arrowhead"
|
||||||
|
markerWidth="10"
|
||||||
|
markerHeight="7"
|
||||||
|
refX="0"
|
||||||
|
refY="3.5"
|
||||||
|
orient="auto"
|
||||||
|
>
|
||||||
|
<polygon points="0 0, 10 3.5, 0 7" fill="#007acc" />
|
||||||
|
</marker>
|
||||||
|
</defs>
|
||||||
|
</svg>
|
||||||
|
|
||||||
|
{/* Render nodes */}
|
||||||
|
{nodes.map(node => (
|
||||||
|
<div
|
||||||
|
key={node.id}
|
||||||
|
className={`node ${selectedNode === node.id ? 'selected' : ''}`}
|
||||||
|
style={{
|
||||||
|
position: 'absolute',
|
||||||
|
left: node.position.x,
|
||||||
|
top: node.position.y,
|
||||||
|
width: 200,
|
||||||
|
height: 60,
|
||||||
|
cursor: 'move'
|
||||||
|
}}
|
||||||
|
onMouseDown={(e) => handleMouseDown(e, node.id)}
|
||||||
|
>
|
||||||
|
<div className="node-header">
|
||||||
|
<span>{node.type}</span>
|
||||||
|
</div>
|
||||||
|
<div className="node-content">
|
||||||
|
{node.type === 'text-encoder' && (
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={node.data.prompt || ''}
|
||||||
|
placeholder="Enter prompt..."
|
||||||
|
onChange={(e) => {
|
||||||
|
setNodes(nodes.map(n =>
|
||||||
|
n.id === node.id ? { ...n, data: { ...n.data, prompt: e.target.value } } : n
|
||||||
|
));
|
||||||
|
}}
|
||||||
|
className="prompt-input"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{node.type === 'image-generator' && (
|
||||||
|
<div>
|
||||||
|
<label>Steps: </label>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
value={node.data.steps || 20}
|
||||||
|
onChange={(e) => {
|
||||||
|
setNodes(nodes.map(n =>
|
||||||
|
n.id === node.id ? { ...n, data: { ...n.data, steps: parseInt(e.target.value) } } : n
|
||||||
|
));
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* Status indicators */}
|
||||||
|
<div className="canvas-status">
|
||||||
|
<span>Nodes: {nodes.length}</span>
|
||||||
|
<span>Connections: {connections.length}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodeEditor;
|
||||||
62
frontend/src/components/NodePanel.css
Normal file
62
frontend/src/components/NodePanel.css
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
.node-panel {
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-panel h3 {
|
||||||
|
margin-top: 0;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
color: #333;
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
padding-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-types {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-type {
|
||||||
|
background-color: white;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 10px;
|
||||||
|
cursor: grab;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-type:hover {
|
||||||
|
border-color: #007acc;
|
||||||
|
box-shadow: 0 2px 4px rgba(0, 122, 204, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-icon {
|
||||||
|
display: inline-block;
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
background-color: #007acc;
|
||||||
|
color: white;
|
||||||
|
text-align: center;
|
||||||
|
line-height: 20px;
|
||||||
|
border-radius: 3px;
|
||||||
|
margin-right: 10px;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-info {
|
||||||
|
display: inline-block;
|
||||||
|
vertical-align: top;
|
||||||
|
width: calc(100% - 35px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-info h4 {
|
||||||
|
margin: 0 0 5px 0;
|
||||||
|
color: #333;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-info p {
|
||||||
|
margin: 0;
|
||||||
|
color: #666;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
38
frontend/src/components/NodePanel.tsx
Normal file
38
frontend/src/components/NodePanel.tsx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import './NodePanel.css';
|
||||||
|
|
||||||
|
const NodePanel: React.FC = () => {
|
||||||
|
const nodeTypes = [
|
||||||
|
{ id: 'text-encoder', name: 'Text Encoder', description: 'Encode text prompts' },
|
||||||
|
{ id: 'image-generator', name: 'Image Generator', description: 'Generate images from prompts' },
|
||||||
|
{ id: 'vae-decoder', name: 'VAE Decoder', description: 'Decode latent representations' },
|
||||||
|
{ id: 'control-net', name: 'Control Net', description: 'Apply control signals' },
|
||||||
|
{ id: 'upscale', name: 'Upscaler', description: 'Increase image resolution' },
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="node-panel">
|
||||||
|
<h3>Node Library</h3>
|
||||||
|
<div className="node-types">
|
||||||
|
{nodeTypes.map(node => (
|
||||||
|
<div
|
||||||
|
key={node.id}
|
||||||
|
className="node-type"
|
||||||
|
draggable
|
||||||
|
onDragStart={(e) => {
|
||||||
|
e.dataTransfer.setData('nodeType', node.id);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="node-icon">□</div>
|
||||||
|
<div className="node-info">
|
||||||
|
<h4>{node.name}</h4>
|
||||||
|
<p>{node.description}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodePanel;
|
||||||
95
frontend/src/components/PreviewPane.css
Normal file
95
frontend/src/components/PreviewPane.css
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
.preview-pane {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
padding: 20px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-pane h3 {
|
||||||
|
margin-top: 0;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-controls {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.generate-button {
|
||||||
|
background-color: #007acc;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 10px 20px;
|
||||||
|
font-size: 16px;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background-color 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.generate-button:hover:not(:disabled) {
|
||||||
|
background-color: #005a9e;
|
||||||
|
}
|
||||||
|
|
||||||
|
.generate-button:disabled {
|
||||||
|
background-color: #ccc;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-container {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-bar {
|
||||||
|
width: 100%;
|
||||||
|
height: 20px;
|
||||||
|
background-color: #e9ecef;
|
||||||
|
border-radius: 10px;
|
||||||
|
overflow: hidden;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-fill {
|
||||||
|
height: 100%;
|
||||||
|
background-color: #007acc;
|
||||||
|
transition: width 0.3s ease;
|
||||||
|
border-radius: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-content {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
min-height: 300px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
.generated-image {
|
||||||
|
max-width: 100%;
|
||||||
|
max-height: 512px;
|
||||||
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.placeholder {
|
||||||
|
text-align: center;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-info {
|
||||||
|
background-color: #e9f7fe;
|
||||||
|
border: 1px solid #b3e5fc;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-info h4 {
|
||||||
|
margin-top: 0;
|
||||||
|
color: #0066cc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preview-info p {
|
||||||
|
margin: 5px 0;
|
||||||
|
}
|
||||||
75
frontend/src/components/PreviewPane.tsx
Normal file
75
frontend/src/components/PreviewPane.tsx
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import React, { useState, useEffect } from 'react';
|
||||||
|
import './PreviewPane.css';
|
||||||
|
|
||||||
|
const PreviewPane: React.FC = () => {
|
||||||
|
const [imageUrl, setImageUrl] = useState<string | null>(null);
|
||||||
|
const [isGenerating, setIsGenerating] = useState(false);
|
||||||
|
const [progress, setProgress] = useState(0);
|
||||||
|
|
||||||
|
// Simulate image generation
|
||||||
|
const startGeneration = () => {
|
||||||
|
setIsGenerating(true);
|
||||||
|
setProgress(0);
|
||||||
|
|
||||||
|
// Simulate progress updates
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
setProgress(prev => {
|
||||||
|
if (prev >= 100) {
|
||||||
|
clearInterval(interval);
|
||||||
|
setImageUrl('https://placehold.co/512x512?text=Generated+Image');
|
||||||
|
setIsGenerating(false);
|
||||||
|
return 100;
|
||||||
|
}
|
||||||
|
return prev + 10;
|
||||||
|
});
|
||||||
|
}, 300);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="preview-pane">
|
||||||
|
<h3>Preview</h3>
|
||||||
|
|
||||||
|
<div className="preview-controls">
|
||||||
|
<button
|
||||||
|
onClick={startGeneration}
|
||||||
|
disabled={isGenerating}
|
||||||
|
className="generate-button"
|
||||||
|
>
|
||||||
|
{isGenerating ? 'Generating...' : 'Generate Image'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="progress-container">
|
||||||
|
{isGenerating && (
|
||||||
|
<div className="progress-bar">
|
||||||
|
<div
|
||||||
|
className="progress-fill"
|
||||||
|
style={{ width: `${progress}%` }}
|
||||||
|
></div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<p>{isGenerating ? `Progress: ${progress}%` : 'No generation in progress'}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="preview-content">
|
||||||
|
{imageUrl ? (
|
||||||
|
<img src={imageUrl} alt="Generated preview" className="generated-image" />
|
||||||
|
) : (
|
||||||
|
<div className="placeholder">
|
||||||
|
<p>Preview will appear here after generation</p>
|
||||||
|
<p>Drag nodes from the panel to create a workflow</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="preview-info">
|
||||||
|
<h4>GPU Information</h4>
|
||||||
|
<p><strong>Device:</strong> Radeon RX 9070 XT</p>
|
||||||
|
<p><strong>Architecture:</strong> gfx900</p>
|
||||||
|
<p><strong>Status:</strong> Ready for ROCm acceleration</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default PreviewPane;
|
||||||
13
frontend/src/index.tsx
Normal file
13
frontend/src/index.tsx
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import ReactDOM from 'react-dom/client';
|
||||||
|
import './index.css';
|
||||||
|
import App from './App';
|
||||||
|
|
||||||
|
const root = ReactDOM.createRoot(
|
||||||
|
document.getElementById('root') as HTMLElement
|
||||||
|
);
|
||||||
|
root.render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<App />
|
||||||
|
</React.StrictMode>
|
||||||
|
);
|
||||||
100
frontend/src/utils/api-client.ts
Normal file
100
frontend/src/utils/api-client.ts
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import axios from 'axios';
|
||||||
|
|
||||||
|
// Create an Axios instance with default configuration
|
||||||
|
const apiClient = axios.create({
|
||||||
|
baseURL: 'http://localhost:8080',
|
||||||
|
timeout: 30000,
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Request interceptor to add auth token if needed
|
||||||
|
apiClient.interceptors.request.use(
|
||||||
|
(config) => {
|
||||||
|
// Add authorization header if needed
|
||||||
|
const token = localStorage.getItem('authToken');
|
||||||
|
if (token) {
|
||||||
|
config.headers.Authorization = `Bearer ${token}`;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
},
|
||||||
|
(error) => {
|
||||||
|
return Promise.reject(error);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// Response interceptor for handling errors
|
||||||
|
apiClient.interceptors.response.use(
|
||||||
|
(response) => {
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
(error) => {
|
||||||
|
if (error.response?.status === 401) {
|
||||||
|
// Handle unauthorized access
|
||||||
|
localStorage.removeItem('authToken');
|
||||||
|
window.location.href = '/login';
|
||||||
|
}
|
||||||
|
return Promise.reject(error);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// API interfaces
|
||||||
|
export interface InferenceRequest {
|
||||||
|
prompt: string;
|
||||||
|
negative_prompt?: string;
|
||||||
|
guidance_scale?: number;
|
||||||
|
steps?: number;
|
||||||
|
width?: number;
|
||||||
|
height?: number;
|
||||||
|
seed?: number;
|
||||||
|
model_name?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InferenceResponse {
|
||||||
|
task_id: string;
|
||||||
|
status: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TaskStatusResponse {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
status: string;
|
||||||
|
progress: number;
|
||||||
|
created_at: number;
|
||||||
|
updated_at: number;
|
||||||
|
model_name?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelInfo {
|
||||||
|
name: string;
|
||||||
|
path: string;
|
||||||
|
model_type: string;
|
||||||
|
version: string;
|
||||||
|
loaded: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// API functions
|
||||||
|
export const api = {
|
||||||
|
// Health check endpoint
|
||||||
|
healthCheck: () => apiClient.get('/health'),
|
||||||
|
|
||||||
|
// System info including GPU details
|
||||||
|
getSystemInfo: () => apiClient.get('/system-info'),
|
||||||
|
|
||||||
|
// Start inference task
|
||||||
|
startInference: (request: InferenceRequest) =>
|
||||||
|
apiClient.post<InferenceResponse>('/infer', request),
|
||||||
|
|
||||||
|
// Get all models
|
||||||
|
getModels: () => apiClient.get<ModelInfo[]>('/models'),
|
||||||
|
|
||||||
|
// Get task status
|
||||||
|
getTaskStatus: (taskId: string) =>
|
||||||
|
apiClient.get<TaskStatusResponse>(`/tasks/${taskId}`),
|
||||||
|
|
||||||
|
// Get all tasks
|
||||||
|
getAllTasks: () => apiClient.get<TaskStatusResponse[]>('/tasks'),
|
||||||
|
};
|
||||||
|
|
||||||
|
export default apiClient;
|
||||||
Reference in New Issue
Block a user