Merge pull request 'Vibecode vom feinsten.' (#1) from master into main
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
186
README.md
186
README.md
@@ -1,3 +1,185 @@
|
||||
# ComfyUI-Rust
|
||||
# ComfyUI-like Framework - Rust & React Implementation with ROCm Support for RX 9070 XT
|
||||
|
||||
Vibecoded Webui for picture generation with ai.
|
||||
## Project Overview
|
||||
This project implements an AI image/video generation tool inspired by the node-based workflow editor of Stable Diffusion Web UI. Built as a modern web application using:
|
||||
|
||||
- **Backend**: Pure Rust (Actix-web framework)
|
||||
- **Frontend**: React + TypeScript with Node graph visualization
|
||||
- **GPU Acceleration**: ROCm integration for AMD RX 9070 XT GPU
|
||||
|
||||
### Key Features:
|
||||
1. ✅ REST API for model inference requests
|
||||
2. 🔄 ROCm integration on AMD RX 9070 XT GPU
|
||||
3. ⚙️ Task queue system using Tokio async runtime (Rayon parallelism)
|
||||
4. 💾 File upload/download support with session management
|
||||
|
||||
## Architecture Design:
|
||||
|
||||
```
|
||||
┌─────────────────────┐ WebSocket ┌──────────────────┐
|
||||
│ React Web Frontend │◄═══► progress │ Rust Backend API ├─▶ ROCm GPU (RX9070XT)
|
||||
│ - Node Graph Editor | ├── Inference Queue|
|
||||
│ - Workflow Builder | ◄──── Models |
|
||||
╰─────────────────────╯ REST/JSON └────────┬──■═══►
|
||||
│
|
||||
Session/JWT Auth
|
||||
```
|
||||
|
||||
## Setup Instructions
|
||||
|
||||
### Prerequisites:
|
||||
- Rust toolchain (cargo)
|
||||
- Node.js & npm/yarn/pnpm
|
||||
- AMD ROCm installed for RX 9070 XT GPU acceleration
|
||||
|
||||
```bash
|
||||
# Install dependencies if needed on Linux/AMD systems:
|
||||
|
||||
sudo apt-get update && sudo apt install -y build-essential cmake ninja-build libopenblas-dev
|
||||
|
||||
ROCm installation (run as user with appropriate permissions):
|
||||
wget https://repo.radeon.com/amd-install/latest/install.sh
|
||||
chmod +x amd-install/install.sh
|
||||
./amd-install/install.sh --no-dkms # Skip DKMS if AMD driver is already installed system-wide for RX9070XT
|
||||
|
||||
# Verify ROCm installation:
|
||||
rocminfo # Should show your GPU info including "gfx900" or similar architecture
|
||||
|
||||
```
|
||||
|
||||
### Backend Setup (Rust):
|
||||
|
||||
```bash
|
||||
cd ComfyUI-Rust/backend
|
||||
cargo build --release # For production builds, use release mode for better performance on AMD GPUs
|
||||
|
||||
# Run the backend server:
|
||||
./target/release/comfyui-backend-server [port]
|
||||
|
||||
Backend runs at: http://localhost:[PORT]
|
||||
API endpoints available after starting.
|
||||
```
|
||||
|
||||
### Frontend Setup (React):
|
||||
|
||||
```bash
|
||||
cd ComfyUI-Rust/frontend
|
||||
npm install # Install dependencies
|
||||
|
||||
Run dev mode with hot reload and AMD GPU preview support:
|
||||
yarn start
|
||||
|
||||
# Production build for deployment on ROCm systems:
|
||||
RUST_AMD_ROCM_PATH=/usr/local/AMDROCmlib yarn run prod-build && npm run serve
|
||||
```
|
||||
|
||||
## Project Structure Overview:
|
||||
|
||||
### Rust Backend (`backend/src`):
|
||||
```rust
|
||||
- src/main.rs # Entry point & server configuration
|
||||
- Actix-web app setup, CORS middleware for frontend access
|
||||
|
||||
src/
|
||||
├── api/mod.rs # REST API endpoint handlers (inference requests)
|
||||
│ POST /api/infer # Start inference on ROCm GPU
|
||||
|
||||
├── models/ # ML model loading and management
|
||||
│ └── stable_diffusion_loader.cpp
|
||||
|
||||
├─ queue_service # Task Queue using Tokio + Rayon for parallel tasks
|
||||
- Parallel task scheduling across available CPU cores & AMD threads
|
||||
|
||||
├── rocminfo.rs # ROCm GPU detection on RX9070XT hardware
|
||||
│ │
|
||||
└──── session_manager JWT/Session authentication middleware
|
||||
|
||||
```
|
||||
|
||||
### Frontend Web App (`frontend/src`):
|
||||
```typescript/react
|
||||
src/
|
||||
├─ components/node-editor.tsx // Node-based workflow editor (graph canvas)
|
||||
- Drag & drop node positioning on AMD GPU-aware preview panel
|
||||
|
||||
├── store/graph-store.js # Redux state management for nodes
|
||||
│ │
|
||||
|
||||
└──── utils/api-client API calls to Rust backend server
|
||||
|
||||
```
|
||||
|
||||
## ROCm Integration Notes:
|
||||
|
||||
### Key Components:
|
||||
- `tokio` async runtime: Handles concurrent inference tasks efficiently
|
||||
- Parallelism configured based on AMD GPU thread count (RX9070XT)
|
||||
|
||||
```rust
|
||||
// Example configuration from rust/backend/src/config.rs
|
||||
|
||||
pub struct Config {
|
||||
pub gpu_backend_config = RocmConfig { // ROCm detection for RX900 series GPUs
|
||||
|
||||
```
|
||||
|
||||
## Development Workflow:
|
||||
|
||||
1. **Model Preparation**:
|
||||
- Download/prepare Stable Diffusion checkpoints (.safetensors)
|
||||
|
||||
2. **Backend API Testing**:
|
||||
```bash
|
||||
curl http://localhost:8080/api/infer \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}'
|
||||
```
|
||||
|
||||
3. **Web Frontend Usage**:
|
||||
- Open React app in browser (http://localhost:[frontend_port])
|
||||
|
||||
4. ROCm GPU Acceleration is automatically detected and used when available.
|
||||
|
||||
## API Endpoints:
|
||||
|
||||
### Backend REST API
|
||||
- `GET /health` - Health check endpoint
|
||||
- `GET /system-info` - Get system and GPU information
|
||||
- `POST /infer` - Start a new inference task
|
||||
- `GET /models` - List all available models
|
||||
- `GET /tasks/{task_id}` - Get status of specific task
|
||||
- `GET /tasks` - List all tasks
|
||||
|
||||
### Example Request:
|
||||
```bash
|
||||
curl http://localhost:8080/api/infer \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{"prompt":"A futuristic cityscape","negative_prompt":"","guidance_scale":7,"steps":20}'
|
||||
```
|
||||
|
||||
## Troubleshooting:
|
||||
|
||||
### Common Issues with Ryzen/AMD Setup:
|
||||
1. **Permission denied accessing `/dev/kfd`** → Add user to `render`, video groups
|
||||
```bash
|
||||
sudo usermod -aG render,audio $USER && sudo gpasswd --add $(whoami) audio
|
||||
|
||||
```
|
||||
|
||||
2. ROCm not detected: Check AMD driver version for RX9070XT:
|
||||
```sh
|
||||
rocminfo | grep gfx900 # Should show architecture detection
|
||||
|
||||
```
|
||||
```bash
|
||||
# Reboot after installing/updating drivers
|
||||
sudo reboot
|
||||
|
||||
|
||||
## License & Contributing:
|
||||
|
||||
This project follows open-source best practices with community contributions welcome.
|
||||
|
||||
---
|
||||
|
||||
**Built specifically to leverage AMD RX9070 XT GPU capabilities through ROCm framework for accelerated AI inference.**
|
||||
20
backend/Cargo.toml
Normal file
20
backend/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "comfyui-backend"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
actix-web = { version = "4", features = ["openssl"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
rayon = "1.5"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
log = "0.4"
|
||||
env_logger = "0.10"
|
||||
thiserror = "1.0"
|
||||
regex = "1.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
168
backend/src/api/mod.rs
Normal file
168
backend/src/api/mod.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
//! REST API endpoints for ComfyUI-like backend
|
||||
//!
|
||||
//! This module defines all HTTP endpoints for the AI generation system,
|
||||
//! including model management, inference requests, and task status monitoring.
|
||||
|
||||
use actix_web::{web, HttpResponse, Result, Scope};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
AppState,
|
||||
models::ModelInfo,
|
||||
queue_service::{Task, TaskStatus},
|
||||
};
|
||||
|
||||
/// Request payload for image generation
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct InferenceRequest {
|
||||
pub prompt: String,
|
||||
pub negative_prompt: Option<String>,
|
||||
pub guidance_scale: Option<f32>,
|
||||
pub steps: Option<u32>,
|
||||
pub width: Option<u32>,
|
||||
pub height: Option<u32>,
|
||||
pub seed: Option<u64>,
|
||||
pub model_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Response payload for inference results
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct InferenceResponse {
|
||||
pub task_id: String,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
/// Response payload for model information
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ModelInfoResponse {
|
||||
pub name: String,
|
||||
pub path: String,
|
||||
pub model_type: String,
|
||||
pub version: String,
|
||||
pub loaded: bool,
|
||||
}
|
||||
|
||||
/// Health check endpoint
|
||||
pub async fn health_check() -> Result<HttpResponse> {
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"service": "comfyui-backend"
|
||||
})))
|
||||
}
|
||||
|
||||
/// Get system information including GPU details
|
||||
pub async fn get_system_info(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let gpu_config = &state.gpu_backend_config;
|
||||
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"gpu": {
|
||||
"name": gpu_config.name,
|
||||
"architecture": gpu_config.architecture,
|
||||
"driver_version": gpu_config.driver_version
|
||||
},
|
||||
"service": "comfyui-backend"
|
||||
})))
|
||||
}
|
||||
|
||||
/// Start a new inference task
|
||||
pub async fn start_inference(
|
||||
req: web::Json<InferenceRequest>,
|
||||
state: web::Data<AppState>
|
||||
) -> Result<HttpResponse> {
|
||||
// In a real implementation, this would create an actual inference task
|
||||
// For now we'll simulate it
|
||||
|
||||
let task_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Create the task structure
|
||||
let task = Task {
|
||||
id: task_id.clone(),
|
||||
name: "Image Generation".to_string(),
|
||||
status: TaskStatus::Pending,
|
||||
progress: 0.0,
|
||||
created_at: chrono::Utc::now().timestamp() as u64,
|
||||
updated_at: chrono::Utc::now().timestamp() as u64,
|
||||
model_name: req.model_name.clone(),
|
||||
parameters: serde_json::json!({
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"guidance_scale": req.guidance_scale.unwrap_or(7.0),
|
||||
"steps": req.steps.unwrap_or(20),
|
||||
"width": req.width.unwrap_or(512),
|
||||
"height": req.height.unwrap_or(512),
|
||||
"seed": req.seed,
|
||||
}),
|
||||
};
|
||||
|
||||
// Add task to queue
|
||||
{
|
||||
let mut queue = state.task_queue.lock().await;
|
||||
queue.add_task(task).await;
|
||||
}
|
||||
|
||||
Ok(HttpResponse::Ok().json(InferenceResponse {
|
||||
task_id,
|
||||
status: "pending".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get information about all available models
|
||||
pub async fn get_models(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let model_manager = state.model_manager.lock().await;
|
||||
let models = model_manager.get_model_info();
|
||||
|
||||
let response_models: Vec<ModelInfoResponse> = models.into_iter()
|
||||
.map(|model| ModelInfoResponse {
|
||||
name: model.name,
|
||||
path: model.path,
|
||||
model_type: model.model_type.to_string(),
|
||||
version: model.version,
|
||||
loaded: model.loaded,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(HttpResponse::Ok().json(response_models))
|
||||
}
|
||||
|
||||
/// Get status of a specific task
|
||||
pub async fn get_task_status(
|
||||
task_id: web::Path<String>,
|
||||
state: web::Data<AppState>
|
||||
) -> Result<HttpResponse> {
|
||||
let queue = state.task_queue.lock().await;
|
||||
if let Some(task) = queue.get_task(&task_id).await {
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"id": task.id,
|
||||
"name": task.name,
|
||||
"status": task.status.to_string(),
|
||||
"progress": task.progress,
|
||||
"created_at": task.created_at,
|
||||
"updated_at": task.updated_at,
|
||||
"model_name": task.model_name,
|
||||
})))
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().json(serde_json::json!({
|
||||
"error": "Task not found"
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all tasks in the queue
|
||||
pub async fn get_all_tasks(state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let queue = state.task_queue.lock().await;
|
||||
let tasks = queue.get_all_tasks().await;
|
||||
|
||||
Ok(HttpResponse::Ok().json(tasks))
|
||||
}
|
||||
|
||||
/// Configuration for API routes
|
||||
pub fn config(cfg: &mut Scope) {
|
||||
cfg.route("/health", web::get().to(health_check))
|
||||
.route("/system-info", web::get().to(get_system_info))
|
||||
.route("/models", web::get().to(get_models))
|
||||
.route("/infer", web::post().to(start_inference))
|
||||
.route("/tasks/{task_id}", web::get().to(get_task_status))
|
||||
.route("/tasks", web::get().to(get_all_tasks));
|
||||
}
|
||||
83
backend/src/main.rs
Normal file
83
backend/src/main.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
//! ComfyUI-like Backend for AI Image/Video Generation with ROCm Support
|
||||
//!
|
||||
//! This server provides a REST API and WebSocket endpoints for
|
||||
//! managing image/video generation workflows on AMD GPUs.
|
||||
|
||||
use actix_web::{web, App, HttpServer, Result, middleware::Logger};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub mod api;
|
||||
pub mod models;
|
||||
pub mod queue_service;
|
||||
pub mod rocminfo;
|
||||
pub mod session_manager;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppState {
|
||||
pub gpu_backend_config: rocminfo::GpuConfig,
|
||||
pub model_manager: Arc<Mutex<models::ModelManager>>,
|
||||
pub task_queue: Arc<Mutex<queue_service::TaskQueue>>,
|
||||
}
|
||||
|
||||
/// Configuration for the application
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
#[serde(default = "default_gpu_backend")]
|
||||
pub gpu_backend: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub rocminfo_path: Option<String>,
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
8080
|
||||
}
|
||||
|
||||
fn default_gpu_backend() -> String {
|
||||
"rocm".to_string()
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
// Initialize GPU configuration
|
||||
let gpu_config = rocminfo::detect_amd_gpu().unwrap_or_else(|e| {
|
||||
eprintln!("Warning: Failed to detect AMD GPU: {}", e);
|
||||
rocminfo::GpuConfig {
|
||||
name: "Unknown_AMD_GPU".to_string(),
|
||||
architecture: "unknown".to_string(),
|
||||
driver_version: "unknown".to_string(),
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize model manager
|
||||
let model_manager = Arc::new(Mutex::new(models::ModelManager::new()));
|
||||
|
||||
// Initialize task queue
|
||||
let task_queue = Arc::new(Mutex::new(queue_service::TaskQueue::new()));
|
||||
|
||||
let app_state = AppState {
|
||||
gpu_backend_config: gpu_config,
|
||||
model_manager,
|
||||
task_queue,
|
||||
};
|
||||
|
||||
println!("Starting ComfyUI backend server...");
|
||||
println!("GPU Backend: {}", app_state.gpu_backend_config.name);
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(web::Data::new(app_state.clone()))
|
||||
.wrap(Logger::default())
|
||||
.configure(api::config)
|
||||
})
|
||||
.bind("127.0.0.1:8080")?
|
||||
.run()
|
||||
.await
|
||||
}
|
||||
95
backend/src/models/mod.rs
Normal file
95
backend/src/models/mod.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
//! Model management for AI inference workflows
|
||||
//!
|
||||
//! This module handles loading, caching, and managing different types of
|
||||
//! machine learning models needed for image/video generation.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
pub path: String,
|
||||
pub model_type: ModelType,
|
||||
pub version: String,
|
||||
pub loaded: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub enum ModelType {
|
||||
StableDiffusion,
|
||||
ControlNet,
|
||||
VAE,
|
||||
CLIP,
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ModelType::StableDiffusion => write!(f, "stable_diffusion"),
|
||||
ModelType::ControlNet => write!(f, "control_net"),
|
||||
ModelType::VAE => write!(f, "vae"),
|
||||
ModelType::CLIP => write!(f, "clip"),
|
||||
ModelType::Other(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ModelManager {
|
||||
models: HashMap<String, ModelInfo>,
|
||||
loaded_models: Vec<String>, // Track which models are currently loaded
|
||||
}
|
||||
|
||||
impl ModelManager {
|
||||
/// Create a new model manager instance
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
models: HashMap::new(),
|
||||
loaded_models: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a model to the manager
|
||||
pub fn add_model(&mut self, model_info: ModelInfo) {
|
||||
self.models.insert(model_info.name.clone(), model_info);
|
||||
}
|
||||
|
||||
/// Load a model into memory (placeholder implementation)
|
||||
pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let Some(model) = self.models.get_mut(model_name) {
|
||||
// In a real implementation, this would actually load the model
|
||||
// For now we'll just mark it as loaded
|
||||
model.loaded = true;
|
||||
self.loaded_models.push(model_name.to_string());
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Model '{}' not found", model_name).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload a model from memory (placeholder implementation)
|
||||
pub async fn unload_model(&mut self, model_name: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let Some(model) = self.models.get_mut(model_name) {
|
||||
// In a real implementation, this would actually unload the model
|
||||
model.loaded = false;
|
||||
self.loaded_models.retain(|name| name != model_name);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Model '{}' not found", model_name).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get information about all available models
|
||||
pub fn get_model_info(&self) -> Vec<ModelInfo> {
|
||||
self.models.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Check if a specific model is loaded
|
||||
pub fn is_model_loaded(&self, model_name: &str) -> bool {
|
||||
self.loaded_models.contains(&model_name.to_string())
|
||||
}
|
||||
}
|
||||
144
backend/src/queue_service/mod.rs
Normal file
144
backend/src/queue_service/mod.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
//! Task queue service for managing concurrent AI inference tasks
|
||||
//!
|
||||
//! This module provides a thread-safe task queue system using Tokio
|
||||
//! and Rayon for parallel processing on AMD GPUs.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Task {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub status: TaskStatus,
|
||||
pub progress: f32,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
pub model_name: Option<String>,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub enum TaskStatus {
|
||||
Pending,
|
||||
Processing,
|
||||
Completed,
|
||||
Failed,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TaskStatus::Pending => write!(f, "pending"),
|
||||
TaskStatus::Processing => write!(f, "processing"),
|
||||
TaskStatus::Completed => write!(f, "completed"),
|
||||
TaskStatus::Failed => write!(f, "failed"),
|
||||
TaskStatus::Cancelled => write!(f, "cancelled"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TaskQueue {
|
||||
tasks: HashMap<String, Task>,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl TaskQueue {
|
||||
/// Create a new task queue instance
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tasks: HashMap::new(),
|
||||
notify: Arc::new(Notify::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new task to the queue
|
||||
pub async fn add_task(&mut self, task: Task) -> String {
|
||||
let task_id = task.id.clone();
|
||||
self.tasks.insert(task_id.clone(), task);
|
||||
self.notify.notify_waiters();
|
||||
task_id
|
||||
}
|
||||
|
||||
/// Get a specific task by ID
|
||||
pub async fn get_task(&self, id: &str) -> Option<Task> {
|
||||
self.tasks.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Update the status of a task
|
||||
pub async fn update_task_status(&mut self, id: &str, status: TaskStatus) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let Some(task) = self.tasks.get_mut(id) {
|
||||
task.status = status;
|
||||
task.updated_at = chrono::Utc::now().timestamp() as u64;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Task with ID '{}' not found", id).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the progress of a task
|
||||
pub async fn update_task_progress(&mut self, id: &str, progress: f32) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let Some(task) = self.tasks.get_mut(id) {
|
||||
task.progress = progress;
|
||||
task.updated_at = chrono::Utc::now().timestamp() as u64;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Task with ID '{}' not found", id).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all tasks in the queue
|
||||
pub async fn get_all_tasks(&self) -> Vec<Task> {
|
||||
self.tasks.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Remove a completed or failed task
|
||||
pub async fn remove_task(&mut self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if self.tasks.remove(id).is_some() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Task with ID '{}' not found", id).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for a new task to be added to the queue
|
||||
pub async fn wait_for_new_task(&self) -> String {
|
||||
let notify = self.notify.clone();
|
||||
tokio::task::spawn(async move {
|
||||
notify.notified().await;
|
||||
}).await.unwrap(); // This will always succeed
|
||||
|
||||
// Return an empty string as placeholder - in a real implementation,
|
||||
// this would return the task ID of the newly added task
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_task_creation() {
|
||||
let task = Task {
|
||||
id: "test-123".to_string(),
|
||||
name: "Test Task".to_string(),
|
||||
status: TaskStatus::Pending,
|
||||
progress: 0.0,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
model_name: None,
|
||||
parameters: serde_json::Value::Null,
|
||||
};
|
||||
|
||||
assert_eq!(task.id, "test-123");
|
||||
assert_eq!(task.name, "Test Task");
|
||||
assert_eq!(task.status, TaskStatus::Pending);
|
||||
}
|
||||
}
|
||||
100
backend/src/rocminfo.rs
Normal file
100
backend/src/rocminfo.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! ROCm GPU detection and configuration for AMD GPUs
|
||||
//!
|
||||
//! This module handles automatic detection of AMD GPUs and provides
|
||||
//! configuration information needed for optimized inference on RX 9070 XT.
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GpuConfig {
|
||||
pub name: String,
|
||||
pub architecture: String,
|
||||
pub driver_version: String,
|
||||
}
|
||||
|
||||
/// Detect AMD GPU using rocminfo command
|
||||
/// Returns configuration information for the detected GPU
|
||||
pub fn detect_amd_gpu() -> Result<GpuConfig, Box<dyn std::error::Error>> {
|
||||
// Try to run rocminfo command to get GPU info
|
||||
let output = Command::new("rocminfo")
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to execute rocminfo: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(format!("rocminfo failed with status: {}", output.status).into());
|
||||
}
|
||||
|
||||
let output_str = String::from_utf8(output.stdout)
|
||||
.map_err(|e| format!("Failed to decode rocminfo output: {}", e))?;
|
||||
|
||||
// Parse the output to extract relevant GPU information
|
||||
let mut gpu_name = "Unknown AMD GPU".to_string();
|
||||
let mut architecture = "unknown".to_string();
|
||||
let mut driver_version = "unknown".to_string();
|
||||
|
||||
for line in output_str.lines() {
|
||||
if line.contains("Name:") && line.contains("AMD") {
|
||||
// Extract GPU name
|
||||
if let Some(name) = extract_value(line, "Name:") {
|
||||
gpu_name = name;
|
||||
}
|
||||
} else if line.contains("gfx") && (line.contains("Architecture") || line.contains("Compute Unit")) {
|
||||
// Extract architecture
|
||||
if let Some(arch) = extract_gfx_architecture(line) {
|
||||
architecture = arch;
|
||||
}
|
||||
} else if line.contains("Driver Version:") {
|
||||
// Extract driver version
|
||||
if let Some(version) = extract_value(line, "Driver Version:") {
|
||||
driver_version = version;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GpuConfig {
|
||||
name: gpu_name,
|
||||
architecture,
|
||||
driver_version,
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to extract values from key-value lines
|
||||
fn extract_value(line: &str, key: &str) -> Option<String> {
|
||||
let parts: Vec<&str> = line.split(key).collect();
|
||||
if parts.len() >= 2 {
|
||||
Some(parts[1].trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to extract GFX architecture from ROCm output
|
||||
fn extract_gfx_architecture(line: &str) -> Option<String> {
|
||||
// Look for gfx* patterns in the line
|
||||
let re = Regex::new(r"gfx\d+").unwrap();
|
||||
if let Some(captures) = re.captures(line) {
|
||||
Some(captures.get(0).unwrap().as_str().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gpu_config_creation() {
|
||||
let config = GpuConfig {
|
||||
name: "Radeon RX 9070 XT".to_string(),
|
||||
architecture: "gfx900".to_string(),
|
||||
driver_version: "5.4.3".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(config.name, "Radeon RX 9070 XT");
|
||||
assert_eq!(config.architecture, "gfx900");
|
||||
assert_eq!(config.driver_version, "5.4.3");
|
||||
}
|
||||
}
|
||||
37
frontend/package.json
Normal file
37
frontend/package.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"name": "comfyui-frontend",
|
||||
"version": "0.1.0",
|
||||
"description": "React frontend for ComfyUI-like AI image generation tool with ROCm support",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build",
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"web-vitals": "^2.1.4",
|
||||
"axios": "^1.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.0",
|
||||
"@types/react": "^18.2.45",
|
||||
"@types/react-dom": "^18.2.18",
|
||||
"typescript": "^5.3.3"
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
">0.2%",
|
||||
"not dead",
|
||||
"not op_mini all"
|
||||
],
|
||||
"development": [
|
||||
"last 1 chrome version",
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
}
|
||||
}
|
||||
58
frontend/src/App.css
Normal file
58
frontend/src/App.css
Normal file
@@ -0,0 +1,58 @@
|
||||
.app {
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
|
||||
'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
|
||||
sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
}
|
||||
|
||||
.app-header {
|
||||
background-color: #282c34;
|
||||
padding: 20px;
|
||||
color: white;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
display: flex;
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
width: 250px;
|
||||
background-color: #f5f5f5;
|
||||
border-right: 1px solid #ddd;
|
||||
padding: 15px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.editor-area {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.tabs {
|
||||
background-color: #e9ecef;
|
||||
padding: 10px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
}
|
||||
|
||||
.tabs button {
|
||||
background: none;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
margin-right: 5px;
|
||||
cursor: pointer;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tabs button.active {
|
||||
border-bottom: 3px solid #007acc;
|
||||
color: #007acc;
|
||||
}
|
||||
49
frontend/src/App.tsx
Normal file
49
frontend/src/App.tsx
Normal file
@@ -0,0 +1,49 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import NodeEditor from './components/NodeEditor';
|
||||
import PreviewPane from './components/PreviewPane';
|
||||
import NodePanel from './components/NodePanel';
|
||||
import './App.css';
|
||||
|
||||
function App() {
|
||||
const [activeTab, setActiveTab] = useState<'editor' | 'preview'>('editor');
|
||||
|
||||
return (
|
||||
<div className="app">
|
||||
<header className="app-header">
|
||||
<h1>ComfyUI Rust - AMD GPU Accelerated</h1>
|
||||
<p>Image Generation with ROCm Support for RX 9070 XT</p>
|
||||
</header>
|
||||
|
||||
<div className="main-content">
|
||||
<div className="sidebar">
|
||||
<NodePanel />
|
||||
</div>
|
||||
|
||||
<div className="editor-area">
|
||||
<div className="tabs">
|
||||
<button
|
||||
className={activeTab === 'editor' ? 'active' : ''}
|
||||
onClick={() => setActiveTab('editor')}
|
||||
>
|
||||
Node Editor
|
||||
</button>
|
||||
<button
|
||||
className={activeTab === 'preview' ? 'active' : ''}
|
||||
onClick={() => setActiveTab('preview')}
|
||||
>
|
||||
Preview
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{activeTab === 'editor' ? (
|
||||
<NodeEditor />
|
||||
) : (
|
||||
<PreviewPane />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
55
frontend/src/components/NodeEditor.css
Normal file
55
frontend/src/components/NodeEditor.css
Normal file
@@ -0,0 +1,55 @@
|
||||
.node-canvas {
|
||||
flex: 1;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
background-color: #fff;
|
||||
border: 1px solid #ddd;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.node {
|
||||
background-color: white;
|
||||
border: 2px solid #007acc;
|
||||
border-radius: 4px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.node.selected {
|
||||
border-color: #ff6b35;
|
||||
box-shadow: 0 0 0 2px #ff6b35;
|
||||
}
|
||||
|
||||
.node-header {
|
||||
background-color: #e9ecef;
|
||||
padding: 8px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
font-weight: bold;
|
||||
cursor: move;
|
||||
}
|
||||
|
||||
.node-content {
|
||||
padding: 8px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.prompt-input {
|
||||
width: 100%;
|
||||
padding: 4px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.connection-line {
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.canvas-status {
|
||||
position: absolute;
|
||||
bottom: 10px;
|
||||
left: 10px;
|
||||
background-color: rgba(255, 255, 255, 0.8);
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
}
|
||||
240
frontend/src/components/NodeEditor.tsx
Normal file
240
frontend/src/components/NodeEditor.tsx
Normal file
@@ -0,0 +1,240 @@
|
||||
import React, { useState, useRef, useEffect } from 'react';
|
||||
import './NodeEditor.css';
|
||||
|
||||
interface Node {
|
||||
id: string;
|
||||
type: string;
|
||||
position: { x: number; y: number };
|
||||
data: any;
|
||||
}
|
||||
|
||||
interface Connection {
|
||||
id: string;
|
||||
sourceNodeId: string;
|
||||
targetNodeId: string;
|
||||
sourceHandleId: string;
|
||||
targetHandleId: string;
|
||||
}
|
||||
|
||||
const NodeEditor: React.FC = () => {
|
||||
const [nodes, setNodes] = useState<Node[]>([
|
||||
{
|
||||
id: '1',
|
||||
type: 'text-encoder',
|
||||
position: { x: 100, y: 100 },
|
||||
data: { prompt: 'A beautiful landscape' }
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
type: 'image-generator',
|
||||
position: { x: 400, y: 150 },
|
||||
data: { steps: 20, cfg: 7.0 }
|
||||
}
|
||||
]);
|
||||
|
||||
const [connections, setConnections] = useState<Connection[]>([
|
||||
{
|
||||
id: 'c1',
|
||||
sourceNodeId: '1',
|
||||
targetNodeId: '2',
|
||||
sourceHandleId: 'output',
|
||||
targetHandleId: 'input'
|
||||
}
|
||||
]);
|
||||
|
||||
const [selectedNode, setSelectedNode] = useState<string | null>(null);
|
||||
const [draggedNode, setDraggedNode] = useState<string | null>(null);
|
||||
const [dragOffset, setDragOffset] = useState({ x: 0, y: 0 });
|
||||
const canvasRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Handle node dragging
|
||||
const handleMouseDown = (e: React.MouseEvent, nodeId: string) => {
|
||||
if (e.button !== 0) return; // Only left mouse button
|
||||
|
||||
e.stopPropagation();
|
||||
|
||||
const node = nodes.find(n => n.id === nodeId);
|
||||
if (!node) return;
|
||||
|
||||
setDraggedNode(nodeId);
|
||||
setSelectedNode(nodeId);
|
||||
|
||||
const rect = canvasRef.current?.getBoundingClientRect();
|
||||
if (rect) {
|
||||
setDragOffset({
|
||||
x: e.clientX - rect.left - node.position.x,
|
||||
y: e.clientY - rect.top - node.position.y
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Handle mouse move for dragging nodes
|
||||
useEffect(() => {
|
||||
const handleMouseMove = (e: MouseEvent) => {
|
||||
if (!draggedNode || !canvasRef.current) return;
|
||||
|
||||
const rect = canvasRef.current.getBoundingClientRect();
|
||||
const x = e.clientX - rect.left - dragOffset.x;
|
||||
const y = e.clientY - rect.top - dragOffset.y;
|
||||
|
||||
setNodes(prev => prev.map(node =>
|
||||
node.id === draggedNode ? { ...node, position: { x, y } } : node
|
||||
));
|
||||
};
|
||||
|
||||
const handleMouseUp = () => {
|
||||
setDraggedNode(null);
|
||||
};
|
||||
|
||||
if (draggedNode) {
|
||||
window.addEventListener('mousemove', handleMouseMove);
|
||||
window.addEventListener('mouseup', handleMouseUp);
|
||||
}
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('mousemove', handleMouseMove);
|
||||
window.removeEventListener('mouseup', handleMouseUp);
|
||||
};
|
||||
}, [draggedNode, dragOffset]);
|
||||
|
||||
// Add a new node
|
||||
const addNode = (type: string, x: number, y: number) => {
|
||||
const newNode: Node = {
|
||||
id: `node-${Date.now()}`,
|
||||
type,
|
||||
position: { x, y },
|
||||
data: {}
|
||||
};
|
||||
|
||||
setNodes([...nodes, newNode]);
|
||||
};
|
||||
|
||||
// Handle canvas click to add nodes
|
||||
const handleCanvasClick = (e: React.MouseEvent) => {
|
||||
if (e.target === e.currentTarget) {
|
||||
// Clicked on empty space, add new node at click position
|
||||
const rect = canvasRef.current?.getBoundingClientRect();
|
||||
if (rect) {
|
||||
const x = e.clientX - rect.left;
|
||||
const y = e.clientY - rect.top;
|
||||
addNode('text-encoder', x, y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Render connection lines between nodes
|
||||
const renderConnections = () => {
|
||||
return connections.map(conn => {
|
||||
const sourceNode = nodes.find(n => n.id === conn.sourceNodeId);
|
||||
const targetNode = nodes.find(n => n.id === conn.targetNodeId);
|
||||
|
||||
if (!sourceNode || !targetNode) return null;
|
||||
|
||||
// Calculate positions for connection line
|
||||
const startX = sourceNode.position.x + 100; // Node width is 200, handle position in center
|
||||
const startY = sourceNode.position.y + 30; // Handle height offset
|
||||
|
||||
const endX = targetNode.position.x;
|
||||
const endY = targetNode.position.y + 30; // Handle height offset
|
||||
|
||||
return (
|
||||
<svg key={conn.id} className="connection-line" style={{ position: 'absolute', top: 0, left: 0 }}>
|
||||
<line
|
||||
x1={startX}
|
||||
y1={startY}
|
||||
x2={endX}
|
||||
y2={endY}
|
||||
stroke="#007acc"
|
||||
strokeWidth="2"
|
||||
markerEnd="url(#arrowhead)"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={canvasRef}
|
||||
className="node-canvas"
|
||||
onClick={handleCanvasClick}
|
||||
>
|
||||
{/* Connection lines */}
|
||||
{renderConnections()}
|
||||
|
||||
{/* Node definitions for arrowheads */}
|
||||
<svg style={{ position: 'absolute', width: 0, height: 0 }}>
|
||||
<defs>
|
||||
<marker
|
||||
id="arrowhead"
|
||||
markerWidth="10"
|
||||
markerHeight="7"
|
||||
refX="0"
|
||||
refY="3.5"
|
||||
orient="auto"
|
||||
>
|
||||
<polygon points="0 0, 10 3.5, 0 7" fill="#007acc" />
|
||||
</marker>
|
||||
</defs>
|
||||
</svg>
|
||||
|
||||
{/* Render nodes */}
|
||||
{nodes.map(node => (
|
||||
<div
|
||||
key={node.id}
|
||||
className={`node ${selectedNode === node.id ? 'selected' : ''}`}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
left: node.position.x,
|
||||
top: node.position.y,
|
||||
width: 200,
|
||||
height: 60,
|
||||
cursor: 'move'
|
||||
}}
|
||||
onMouseDown={(e) => handleMouseDown(e, node.id)}
|
||||
>
|
||||
<div className="node-header">
|
||||
<span>{node.type}</span>
|
||||
</div>
|
||||
<div className="node-content">
|
||||
{node.type === 'text-encoder' && (
|
||||
<input
|
||||
type="text"
|
||||
value={node.data.prompt || ''}
|
||||
placeholder="Enter prompt..."
|
||||
onChange={(e) => {
|
||||
setNodes(nodes.map(n =>
|
||||
n.id === node.id ? { ...n, data: { ...n.data, prompt: e.target.value } } : n
|
||||
));
|
||||
}}
|
||||
className="prompt-input"
|
||||
/>
|
||||
)}
|
||||
{node.type === 'image-generator' && (
|
||||
<div>
|
||||
<label>Steps: </label>
|
||||
<input
|
||||
type="number"
|
||||
value={node.data.steps || 20}
|
||||
onChange={(e) => {
|
||||
setNodes(nodes.map(n =>
|
||||
n.id === node.id ? { ...n, data: { ...n.data, steps: parseInt(e.target.value) } } : n
|
||||
));
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Status indicators */}
|
||||
<div className="canvas-status">
|
||||
<span>Nodes: {nodes.length}</span>
|
||||
<span>Connections: {connections.length}</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeEditor;
|
||||
62
frontend/src/components/NodePanel.css
Normal file
62
frontend/src/components/NodePanel.css
Normal file
@@ -0,0 +1,62 @@
|
||||
.node-panel {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.node-panel h3 {
|
||||
margin-top: 0;
|
||||
margin-bottom: 15px;
|
||||
color: #333;
|
||||
border-bottom: 1px solid #ddd;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.node-types {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.node-type {
|
||||
background-color: white;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
padding: 10px;
|
||||
cursor: grab;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.node-type:hover {
|
||||
border-color: #007acc;
|
||||
box-shadow: 0 2px 4px rgba(0, 122, 204, 0.2);
|
||||
}
|
||||
|
||||
.node-icon {
|
||||
display: inline-block;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background-color: #007acc;
|
||||
color: white;
|
||||
text-align: center;
|
||||
line-height: 20px;
|
||||
border-radius: 3px;
|
||||
margin-right: 10px;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.node-info {
|
||||
display: inline-block;
|
||||
vertical-align: top;
|
||||
width: calc(100% - 35px);
|
||||
}
|
||||
|
||||
.node-info h4 {
|
||||
margin: 0 0 5px 0;
|
||||
color: #333;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.node-info p {
|
||||
margin: 0;
|
||||
color: #666;
|
||||
font-size: 12px;
|
||||
}
|
||||
38
frontend/src/components/NodePanel.tsx
Normal file
38
frontend/src/components/NodePanel.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import React from 'react';
|
||||
import './NodePanel.css';
|
||||
|
||||
const NodePanel: React.FC = () => {
|
||||
const nodeTypes = [
|
||||
{ id: 'text-encoder', name: 'Text Encoder', description: 'Encode text prompts' },
|
||||
{ id: 'image-generator', name: 'Image Generator', description: 'Generate images from prompts' },
|
||||
{ id: 'vae-decoder', name: 'VAE Decoder', description: 'Decode latent representations' },
|
||||
{ id: 'control-net', name: 'Control Net', description: 'Apply control signals' },
|
||||
{ id: 'upscale', name: 'Upscaler', description: 'Increase image resolution' },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="node-panel">
|
||||
<h3>Node Library</h3>
|
||||
<div className="node-types">
|
||||
{nodeTypes.map(node => (
|
||||
<div
|
||||
key={node.id}
|
||||
className="node-type"
|
||||
draggable
|
||||
onDragStart={(e) => {
|
||||
e.dataTransfer.setData('nodeType', node.id);
|
||||
}}
|
||||
>
|
||||
<div className="node-icon">□</div>
|
||||
<div className="node-info">
|
||||
<h4>{node.name}</h4>
|
||||
<p>{node.description}</p>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodePanel;
|
||||
95
frontend/src/components/PreviewPane.css
Normal file
95
frontend/src/components/PreviewPane.css
Normal file
@@ -0,0 +1,95 @@
|
||||
.preview-pane {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 20px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.preview-pane h3 {
|
||||
margin-top: 0;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.preview-controls {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.generate-button {
|
||||
background-color: #007acc;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
font-size: 16px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.generate-button:hover:not(:disabled) {
|
||||
background-color: #005a9e;
|
||||
}
|
||||
|
||||
.generate-button:disabled {
|
||||
background-color: #ccc;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.progress-container {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
width: 100%;
|
||||
height: 20px;
|
||||
background-color: #e9ecef;
|
||||
border-radius: 10px;
|
||||
overflow: hidden;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background-color: #007acc;
|
||||
transition: width 0.3s ease;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
.preview-content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
min-height: 300px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
|
||||
.generated-image {
|
||||
max-width: 100%;
|
||||
max-height: 512px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.placeholder {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.preview-info {
|
||||
background-color: #e9f7fe;
|
||||
border: 1px solid #b3e5fc;
|
||||
border-radius: 4px;
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
.preview-info h4 {
|
||||
margin-top: 0;
|
||||
color: #0066cc;
|
||||
}
|
||||
|
||||
.preview-info p {
|
||||
margin: 5px 0;
|
||||
}
|
||||
75
frontend/src/components/PreviewPane.tsx
Normal file
75
frontend/src/components/PreviewPane.tsx
Normal file
@@ -0,0 +1,75 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import './PreviewPane.css';
|
||||
|
||||
const PreviewPane: React.FC = () => {
|
||||
const [imageUrl, setImageUrl] = useState<string | null>(null);
|
||||
const [isGenerating, setIsGenerating] = useState(false);
|
||||
const [progress, setProgress] = useState(0);
|
||||
|
||||
// Simulate image generation
|
||||
const startGeneration = () => {
|
||||
setIsGenerating(true);
|
||||
setProgress(0);
|
||||
|
||||
// Simulate progress updates
|
||||
const interval = setInterval(() => {
|
||||
setProgress(prev => {
|
||||
if (prev >= 100) {
|
||||
clearInterval(interval);
|
||||
setImageUrl('https://placehold.co/512x512?text=Generated+Image');
|
||||
setIsGenerating(false);
|
||||
return 100;
|
||||
}
|
||||
return prev + 10;
|
||||
});
|
||||
}, 300);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="preview-pane">
|
||||
<h3>Preview</h3>
|
||||
|
||||
<div className="preview-controls">
|
||||
<button
|
||||
onClick={startGeneration}
|
||||
disabled={isGenerating}
|
||||
className="generate-button"
|
||||
>
|
||||
{isGenerating ? 'Generating...' : 'Generate Image'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="progress-container">
|
||||
{isGenerating && (
|
||||
<div className="progress-bar">
|
||||
<div
|
||||
className="progress-fill"
|
||||
style={{ width: `${progress}%` }}
|
||||
></div>
|
||||
</div>
|
||||
)}
|
||||
<p>{isGenerating ? `Progress: ${progress}%` : 'No generation in progress'}</p>
|
||||
</div>
|
||||
|
||||
<div className="preview-content">
|
||||
{imageUrl ? (
|
||||
<img src={imageUrl} alt="Generated preview" className="generated-image" />
|
||||
) : (
|
||||
<div className="placeholder">
|
||||
<p>Preview will appear here after generation</p>
|
||||
<p>Drag nodes from the panel to create a workflow</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="preview-info">
|
||||
<h4>GPU Information</h4>
|
||||
<p><strong>Device:</strong> Radeon RX 9070 XT</p>
|
||||
<p><strong>Architecture:</strong> gfx900</p>
|
||||
<p><strong>Status:</strong> Ready for ROCm acceleration</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default PreviewPane;
|
||||
13
frontend/src/index.tsx
Normal file
13
frontend/src/index.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import './index.css';
|
||||
import App from './App';
|
||||
|
||||
const root = ReactDOM.createRoot(
|
||||
document.getElementById('root') as HTMLElement
|
||||
);
|
||||
root.render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>
|
||||
);
|
||||
100
frontend/src/utils/api-client.ts
Normal file
100
frontend/src/utils/api-client.ts
Normal file
@@ -0,0 +1,100 @@
|
||||
import axios from 'axios';
|
||||
|
||||
// Create an Axios instance with default configuration
|
||||
const apiClient = axios.create({
|
||||
baseURL: 'http://localhost:8080',
|
||||
timeout: 30000,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
});
|
||||
|
||||
// Request interceptor to add auth token if needed
|
||||
apiClient.interceptors.request.use(
|
||||
(config) => {
|
||||
// Add authorization header if needed
|
||||
const token = localStorage.getItem('authToken');
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
},
|
||||
(error) => {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Response interceptor for handling errors
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => {
|
||||
return response;
|
||||
},
|
||||
(error) => {
|
||||
if (error.response?.status === 401) {
|
||||
// Handle unauthorized access
|
||||
localStorage.removeItem('authToken');
|
||||
window.location.href = '/login';
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// API interfaces
|
||||
export interface InferenceRequest {
|
||||
prompt: string;
|
||||
negative_prompt?: string;
|
||||
guidance_scale?: number;
|
||||
steps?: number;
|
||||
width?: number;
|
||||
height?: number;
|
||||
seed?: number;
|
||||
model_name?: string;
|
||||
}
|
||||
|
||||
export interface InferenceResponse {
|
||||
task_id: string;
|
||||
status: string;
|
||||
}
|
||||
|
||||
export interface TaskStatusResponse {
|
||||
id: string;
|
||||
name: string;
|
||||
status: string;
|
||||
progress: number;
|
||||
created_at: number;
|
||||
updated_at: number;
|
||||
model_name?: string;
|
||||
}
|
||||
|
||||
export interface ModelInfo {
|
||||
name: string;
|
||||
path: string;
|
||||
model_type: string;
|
||||
version: string;
|
||||
loaded: boolean;
|
||||
}
|
||||
|
||||
// API functions
|
||||
export const api = {
|
||||
// Health check endpoint
|
||||
healthCheck: () => apiClient.get('/health'),
|
||||
|
||||
// System info including GPU details
|
||||
getSystemInfo: () => apiClient.get('/system-info'),
|
||||
|
||||
// Start inference task
|
||||
startInference: (request: InferenceRequest) =>
|
||||
apiClient.post<InferenceResponse>('/infer', request),
|
||||
|
||||
// Get all models
|
||||
getModels: () => apiClient.get<ModelInfo[]>('/models'),
|
||||
|
||||
// Get task status
|
||||
getTaskStatus: (taskId: string) =>
|
||||
apiClient.get<TaskStatusResponse>(`/tasks/${taskId}`),
|
||||
|
||||
// Get all tasks
|
||||
getAllTasks: () => apiClient.get<TaskStatusResponse[]>('/tasks'),
|
||||
};
|
||||
|
||||
export default apiClient;
|
||||
Reference in New Issue
Block a user