diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b1d7002 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "comfyui-rust" +version = "0.1.0" +edition = "2021" + +[dependencies] +actix-web = { version = "4", features = ["openssl"] } +actix-multipart = "0.6" +futures-util = "0.3" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = ["full"] } +rayon = "1.5" +uuid = { version = "1.0", features = ["v4"] } +log = "0.4" +env_logger = "0.10" +thiserror = "1.0" +regex = "1.0" +chrono = { version = "0.4", features = ["serde"] } +eframe = "0.24" +egui = "0.24" +egui_extras = "0.24" +reqwest = { version = "0.11", features = ["json"] } +image = "0.24" + +# For model loading capabilities +burn = { version = "0.21.0-pre.2", default-features = false } +burn-tch = { version = "0.21.0-pre.2" } # for torch backend + +[dev-dependencies] +tokio-test = "0.4" + +[[bin]] +name = "comfyui-rust-server" +path = "src/server.rs" + +[[bin]] +name = "comfyui-rust-frontend" +path = "src/frontend.rs" \ No newline at end of file diff --git a/backend/src/api/mod.rs b/backend/src/api/mod.rs index 21a7e87..f10f092 100644 --- a/backend/src/api/mod.rs +++ b/backend/src/api/mod.rs @@ -3,8 +3,11 @@ //! This module defines all HTTP endpoints for the AI generation system, //! including model management, inference requests, and task status monitoring. -use actix_web::{web, HttpResponse, Result}; +use actix_web::{web, HttpResponse, Result, HttpRequest}; +use actix_multipart::Multipart; +use futures_util::stream::TryStreamExt as _; use serde::{Deserialize, Serialize}; +use std::path::Path; use crate::{ AppState, @@ -155,11 +158,89 @@ pub async fn get_all_tasks(state: web::Data) -> Result { Ok(HttpResponse::Ok().json(tasks)) } +/// Load a specific model into memory +pub async fn load_model( + model_name: web::Path, + state: web::Data +) -> Result { + let mut manager = state.model_manager.lock().await; + match manager.load_model(&model_name).await { + Ok(_) => { + Ok(HttpResponse::Ok().json(serde_json::json!({ + "status": "loaded", + "model": model_name + }))) + } + Err(e) => { + Ok(HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to load model: {}", e) + }))) + } + } +} + +/// Unload a specific model from memory +pub async fn unload_model( + model_name: web::Path, + state: web::Data +) -> Result { + let mut manager = state.model_manager.lock().await; + match manager.unload_model(&model_name).await { + Ok(_) => { + Ok(HttpResponse::Ok().json(serde_json::json!({ + "status": "unloaded", + "model": model_name + }))) + } + Err(e) => { + Ok(HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to unload model: {}", e) + }))) + } + } +} + +/// Upload a new model file +pub async fn upload_model( + mut payload: Multipart, + state: web::Data +) -> Result { + // Process the multipart form data + while let Some(field) = payload.try_next().await.map_err(|e| { + actix_web::error::ErrorInternalServerError(format!("Multipart error: {}", e)) + })? { + let content_disposition = field.content_disposition(); + let filename = content_disposition + .get_filename() + .map(|f| f.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + // In a real implementation, we would save the file to disk and register it with the model manager + // For now just show that we received it + + println!("Received uploaded model file: {}", filename); + + // Here you would: + // 1. Save the file to disk at some models directory + // 2. Create a ModelInfo entry + // 3. Add it to the model manager + } + + Ok(HttpResponse::Ok().json(serde_json::json!({ + "status": "uploaded", + "filename": filename, + "message": "Model file received and processed" + }))) +} + /// Configuration for API routes pub fn config(cfg: &mut actix_web::web::ServiceConfig) { cfg.route("/health", web::get().to(health_check)) .route("/system-info", web::get().to(get_system_info)) .route("/models", web::get().to(get_models)) + .route("/models/{model_name}/load", web::post().to(load_model)) + .route("/models/{model_name}/unload", web::post().to(unload_model)) + .route("/models/upload", web::post().to(upload_model)) .route("/infer", web::post().to(start_inference)) .route("/tasks/{task_id}", web::get().to(get_task_status)) .route("/tasks", web::get().to(get_all_tasks)); diff --git a/backend/src/models/mod.rs b/backend/src/models/mod.rs index 1939a29..d57c4dd 100644 --- a/backend/src/models/mod.rs +++ b/backend/src/models/mod.rs @@ -1,5 +1,6 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::path::Path; /// Model management for AI inference workflows /// @@ -58,8 +59,18 @@ impl ModelManager { /// Load a model into memory (placeholder implementation) pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box> { if let Some(model) = self.models.get_mut(model_name) { - // In a real implementation, this would actually load the model - // For now we'll just mark it as loaded + // Validate that the model file exists before attempting to load + if !Path::new(&model.path).exists() { + return Err(format!("Model file not found at path: {}", model.path).into()); + } + + // In a real implementation, this would actually load the model using something like: + // - diffusers-rs or similar Rust crates for safetensors + // - burn crate for native inference (if available) + // For now we just mark it as loaded since we can't easily implement actual loading here + + println!("Loading model '{}' from {}", model.name, model.path); + model.loaded = true; self.loaded_models.push(model_name.to_string()); Ok(()) @@ -89,4 +100,4 @@ impl ModelManager { pub fn is_model_loaded(&self, model_name: &str) -> bool { self.loaded_models.contains(&model_name.to_string()) } -} \ No newline at end of file +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..a838505 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,120 @@ +use actix_web::{web, HttpResponse, Result}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::models::ModelInfo; +use crate::queue_service::TaskQueue; + +#[derive(Debug, Deserialize)] +pub struct GenerateRequest { + pub prompt: String, + pub negative_prompt: Option, + pub width: Option, + pub height: Option, + pub steps: Option, + pub cfg_scale: Option, +} + +#[derive(Debug, Serialize)] +pub struct GenerateResponse { + pub image_url: String, + pub task_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct ModelLoadRequest { + pub model_name: String, + pub model_path: String, +} + +#[derive(Debug, Serialize)] +pub struct ModelLoadResponse { + pub success: bool, + pub message: String, +} + +/// Health check endpoint +pub async fn health() -> Result { + Ok(HttpResponse::Ok().json("ComfyUI-Rust backend is running")) +} + +/// Generate image from prompt +pub async fn generate( + data: web::Data, + req: web::Json +) -> Result { + // In a real implementation, this would actually perform inference + // For now we'll just simulate the task queue behavior + + let mut task_queue = data.task_queue.lock().await; + let task_id = task_queue.add_task(&req.prompt).await; + + Ok(HttpResponse::Ok() + .json(GenerateResponse { + image_url: format!("/api/image/{}", task_id), + task_id, + }) + ) +} + +/// Load a model into memory +pub async fn load_model( + data: web::Data, + req: web::Json +) -> Result { + let mut model_manager = data.model_manager.lock().await; + + // Create and add the model info to manager + let model_info = ModelInfo { + name: req.model_name.clone(), + path: req.model_path.clone(), + model_type: crate::models::ModelType::StableDiffusion, + version: "1.0".to_string(), + loaded: false, // Will be set when actually loaded + }; + + model_manager.add_model(model_info); + + // Attempt to load the model (this is still a placeholder) + match model_manager.load_model(&req.model_name).await { + Ok(_) => { + Ok(HttpResponse::Ok().json(ModelLoadResponse { + success: true, + message: format!("Model '{}' loaded successfully", req.model_name), + })) + } + Err(e) => { + Ok(HttpResponse::InternalServerError().json(ModelLoadResponse { + success: false, + message: format!("Failed to load model '{}': {}", req.model_name, e), + })) + } + } +} + +/// Get information about available models +pub async fn get_models(data: web::Data) -> Result { + let model_manager = data.model_manager.lock().await; + let models = model_manager.get_model_info(); + + Ok(HttpResponse::Ok().json(models)) +} + +/// Configuration function to register all API routes +pub fn config(cfg: &mut web::ServiceConfig) { + cfg.service( + web::scope("/api") + .route("/health", web::get().to(health)) + .route("/generate", web::post().to(generate)) + .route("/model/load", web::post().to(load_model)) + .route("/models", web::get().to(get_models)) + ); +} + +/// Application state that holds shared data +#[derive(Debug, Clone)] +pub struct AppState { + pub model_manager: Arc>, + pub task_queue: Arc>, +} diff --git a/src/api_client.rs b/src/api_client.rs new file mode 100644 index 0000000..9b36e94 --- /dev/null +++ b/src/api_client.rs @@ -0,0 +1,177 @@ +//! API client for communicating with the ComfyUI backend +//! +//! This module provides functionality to interact with the local +//! backend server through HTTP requests. + +use reqwest; +use serde::{Deserialize, Serialize}; +use std::error::Error; + +#[derive(Debug, Clone)] +pub struct ApiClient { + base_url: String, + client: reqwest::Client, +} + +impl ApiClient { + pub fn new(base_url: String) -> Self { + Self { + base_url, + client: reqwest::Client::new(), + } + } + + /// Get system information including GPU details + pub async fn get_system_info(&self) -> Result> { + let response = self.client + .get(&format!("{}/system-info", self.base_url)) + .send() + .await?; + + let info: SystemInfo = response.json().await?; + Ok(info) + } + + /// Get all available models + pub async fn get_models(&self) -> Result, Box> { + let response = self.client + .get(&format!("{}/models", self.base_url)) + .send() + .await?; + + let models: Vec = response.json().await?; + Ok(models) + } + + /// Start a new inference task + pub async fn start_inference(&self, request: &InferenceRequest) -> Result> { + let response = self.client + .post(&format!("{}/infer", self.base_url)) + .json(request) + .send() + .await?; + + let result: InferenceResponse = response.json().await?; + Ok(result) + } + + /// Get the status of a specific task + pub async fn get_task_status(&self, task_id: &str) -> Result> { + let response = self.client + .get(&format!("{}/tasks/{}", self.base_url, task_id)) + .send() + .await?; + + let status: TaskStatusResponse = response.json().await?; + Ok(status) + } + + /// Get all tasks in the queue + pub async fn get_all_tasks(&self) -> Result, Box> { + let response = self.client + .get(&format!("{}/tasks", self.base_url)) + .send() + .await?; + + let tasks: Vec = response.json().await?; + Ok(tasks) + } + + /// Load a specific model into memory + pub async fn load_model(&self, model_name: &str) -> Result> { + let response = self.client + .post(&format!("{}/models/{}/load", self.base_url, model_name)) + .send() + .await?; + + let result: LoadModelResponse = response.json().await?; + Ok(result) + } + + /// Unload a specific model from memory + pub async fn unload_model(&self, model_name: &str) -> Result> { + let response = self.client + .post(&format!("{}/models/{}/unload", self.base_url, model_name)) + .send() + .await?; + + let result: LoadModelResponse = response.json().await?; + Ok(result) + } +} + +/// System information structure +#[derive(Debug, Serialize, Deserialize)] +pub struct SystemInfo { + pub gpu: GpuInfo, + pub service: String, +} + +/// GPU information +#[derive(Debug, Serialize, Deserialize)] +pub struct GpuInfo { + pub name: String, + pub architecture: String, + pub driver_version: String, +} + +/// Model information structure +#[derive(Debug, Serialize, Deserialize)] +pub struct ModelInfo { + pub name: String, + pub path: String, + pub model_type: String, + pub version: String, + pub loaded: bool, +} + +/// Request payload for image generation +#[derive(Debug, Serialize, Deserialize)] +pub struct InferenceRequest { + pub prompt: String, + pub negative_prompt: Option, + pub guidance_scale: Option, + pub steps: Option, + pub width: Option, + pub height: Option, + pub seed: Option, + pub model_name: Option, +} + +/// Response for load/unload model operations +#[derive(Debug, Serialize, Deserialize)] +pub struct LoadModelResponse { + pub status: String, + pub model: String, +} + +/// Response payload for inference results +#[derive(Debug, Serialize, Deserialize)] +pub struct InferenceResponse { + pub task_id: String, + pub status: String, +} + +/// Task status response +#[derive(Debug, Serialize, Deserialize)] +pub struct TaskStatusResponse { + pub id: String, + pub name: String, + pub status: String, + pub progress: f32, + pub created_at: u64, + pub updated_at: u64, + pub model_name: Option, +} + +/// Basic task information +#[derive(Debug, Serialize, Deserialize)] +pub struct TaskInfo { + pub id: String, + pub name: String, + pub status: String, + pub progress: f32, + pub created_at: u64, + pub updated_at: u64, + pub model_name: Option, +} \ No newline at end of file diff --git a/src/frontend.rs b/src/frontend.rs new file mode 100644 index 0000000..e9093c7 --- /dev/null +++ b/src/frontend.rs @@ -0,0 +1,116 @@ +//! Frontend for ComfyUI-Rust using egui + +use eframe::egui; + +// Simple UI that connects to our local backend +pub struct ComfyUiFrontend { + prompt: String, + negative_prompt: String, + width: u32, + height: u32, + steps: u32, + cfg_scale: f32, + is_generating: bool, + model_name: String, +} + +impl ComfyUiFrontend { + pub fn new(_cc: &eframe::CreationContext) -> Self { + Self { + prompt: "A beautiful sunset over the ocean".to_string(), + negative_prompt: "".to_string(), + width: 512, + height: 512, + steps: 30, + cfg_scale: 7.0, + is_generating: false, + model_name: "stable-diffusion-v1-4".to_string(), + } + } + + fn generate_image(&mut self) { + if self.is_generating { + return; + } + + self.is_generating = true; + let prompt = self.prompt.clone(); + + // Simulate async operation by spawning background task that updates UI after delay + std::thread::spawn(move || { + // Simulate network delay and processing time (5 seconds) + std::thread::sleep(std::time::Duration::from_secs(3)); + + println!("Image generation completed for: {}", prompt); + }); + } +} + +impl eframe::App for ComfyUiFrontend { + fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { + egui::CentralPanel::default().show(ctx, |ui| { + ui.heading("ComfyUI-Rust Frontend"); + + // Model selection + ui.horizontal(|ui| { + ui.label("Model:"); + ui.text_edit_singleline(&mut self.model_name); + }); + + // Prompt input + ui.separator(); + ui.label("Prompt:"); + ui.text_edit_multiline(&mut self.prompt); + + ui.label("Negative Prompt:"); + ui.text_edit_multiline(&mut self.negative_prompt); + + // Generation parameters + ui.separator(); + ui.horizontal(|ui| { + ui.label("Width:"); + ui.add(egui::Slider::new(&mut self.width, 256..=1024).step_by(64.0)); + + ui.label("Height:"); + ui.add(egui::Slider::new(&mut self.height, 256..=1024).step_by(64.0)); + }); + + ui.horizontal(|ui| { + ui.label("Steps:"); + ui.add(egui::Slider::new(&mut self.steps, 10..=100).step_by(1.0)); + + ui.label("CFG Scale:"); + ui.add(egui::Slider::new(&mut self.cfg_scale, 1.0..=20.0).step_by(0.5)); + }); + + // Generate button + if ui.button("Generate Image").clicked() { + self.generate_image(); + } + + // Status and image preview + ui.separator(); + if self.is_generating { + ui.label("Generating image..."); + } else { + ui.label("Ready to generate"); + } + }); + } +} + +fn main() -> eframe::Result<()> { + let options = eframe::NativeOptions { + viewport: egui::ViewportBuilder::default() + .with_title("ComfyUI-Rust Frontend") + .with_inner_size([800.0, 600.0]) + .with_min_inner_size([400.0, 300.0]), + ..Default::default() + }; + + eframe::run_native( + "ComfyUI-Rust", + options, + Box::new(|cc| Box::new(ComfyUiFrontend::new(cc))), + ) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..54d1355 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,15 @@ +//! Unified ComfyUI Rust Library +//! +//! This library provides the core functionality for both backend API +//! and frontend UI components in a single integrated system. + +pub mod api; +pub mod models; +pub mod queue_service; +pub mod rocminfo; + +/// Re-export key types for easier access +pub use api::*; +pub use models::*; +pub use queue_service::*; +pub use rocminfo::*; \ No newline at end of file diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..aceaa7a --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,103 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; + +/// Model management for AI inference workflows +/// +/// This module handles loading, caching, and managing different types of +/// machine learning models needed for image/video generation. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelInfo { + pub name: String, + pub path: String, + pub model_type: ModelType, + pub version: String, + pub loaded: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum ModelType { + StableDiffusion, + ControlNet, + VAE, + CLIP, + Other(String), +} + +impl std::fmt::Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ModelType::StableDiffusion => write!(f, "stable_diffusion"), + ModelType::ControlNet => write!(f, "control_net"), + ModelType::VAE => write!(f, "vae"), + ModelType::CLIP => write!(f, "clip"), + ModelType::Other(s) => write!(f, "{}", s), + } + } +} + +#[derive(Debug)] +pub struct ModelManager { + models: HashMap, + loaded_models: Vec, // Track which models are currently loaded +} + +impl ModelManager { + /// Create a new model manager instance + pub fn new() -> Self { + Self { + models: HashMap::new(), + loaded_models: Vec::new(), + } + } + + /// Add a model to the manager + pub fn add_model(&mut self, model_info: ModelInfo) { + self.models.insert(model_info.name.clone(), model_info); + } + + /// Load a model into memory (placeholder implementation) + pub async fn load_model(&mut self, model_name: &str) -> Result<(), Box> { + if let Some(model) = self.models.get_mut(model_name) { + // Validate that the model file exists before attempting to load + if !Path::new(&model.path).exists() { + return Err(format!("Model file not found at path: {}", model.path).into()); + } + + // In a real implementation, this would actually load the model using something like: + // - diffusers-rs or similar Rust crates for safetensors + // - burn crate for native inference (if available) + // For now we just mark it as loaded since we can't easily implement actual loading here + + println!("Loading model '{}' from {}", model.name, model.path); + + model.loaded = true; + self.loaded_models.push(model_name.to_string()); + Ok(()) + } else { + Err(format!("Model '{}' not found", model_name).into()) + } + } + + /// Unload a model from memory (placeholder implementation) + pub async fn unload_model(&mut self, model_name: &str) -> Result<(), Box> { + if let Some(model) = self.models.get_mut(model_name) { + // In a real implementation, this would actually unload the model + model.loaded = false; + self.loaded_models.retain(|name| name != model_name); + Ok(()) + } else { + Err(format!("Model '{}' not found", model_name).into()) + } + } + + /// Get information about all available models + pub fn get_model_info(&self) -> Vec { + self.models.values().cloned().collect() + } + + /// Check if a specific model is loaded + pub fn is_model_loaded(&self, model_name: &str) -> bool { + self.loaded_models.contains(&model_name.to_string()) + } +} \ No newline at end of file diff --git a/src/node_editor.rs b/src/node_editor.rs new file mode 100644 index 0000000..f6d0e69 --- /dev/null +++ b/src/node_editor.rs @@ -0,0 +1,102 @@ +//! Node editor component for ComfyUI frontend +//! +//! This module provides the visual node-based workflow editor +//! that allows users to create and connect AI generation nodes. + +use eframe::egui; +use crate::api_client::{ApiClient, ModelInfo}; + +pub struct NodeEditor { + nodes: Vec, + selected_node: Option, + models: Vec, + api_client: ApiClient, +} + +#[derive(Debug, Clone)] +pub struct Node { + pub id: String, + pub name: String, + pub node_type: NodeType, + pub x: f32, + pub y: f32, + pub inputs: Vec, + pub outputs: Vec, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone)] +pub enum NodeType { + ImageGenerator, + ImageLoader, + TextInput, + ImageSave, + ControlNet, + VAE, + CLIP, + Other(String), +} + +#[derive(Debug, Clone)] +pub struct Input { + pub id: String, + pub name: String, + pub node_id: String, + pub value: Option, +} + +#[derive(Debug, Clone)] +pub struct Output { + pub id: String, + pub name: String, + pub node_id: String, +} + +impl NodeEditor { + pub fn new(api_client: ApiClient) -> Self { + Self { + nodes: vec![], + selected_node: None, + models: vec![], + api_client, + } + } + + pub fn ui(&mut self, ui: &mut egui::Ui) { + ui.heading("Node Editor"); + + // Create a scroll area for the node editor + egui::ScrollArea::vertical().show(ui, |ui| { + // Placeholder for actual node rendering logic + ui.label("This is where the node-based workflow would be rendered."); + ui.label("Nodes can be dragged and connected here."); + + // Simple example of a node + if ui.button("Add Node").clicked() { + let new_node = Node { + id: format!("node_{}", self.nodes.len()), + name: "New Node".to_string(), + node_type: NodeType::Other("Generic".to_string()), + x: 100.0 + (self.nodes.len() as f32 * 50.0), + y: 100.0, + inputs: vec![], + outputs: vec![], + parameters: serde_json::json!({}), + }; + self.nodes.push(new_node); + } + + // Display existing nodes + for node in &self.nodes { + ui.label(format!("Node: {} at ({}, {})", node.name, node.x, node.y)); + } + }); + } + + /// Load models from the backend + pub async fn load_models(&mut self) -> Result<(), Box> { + let models = self.api_client.get_models().await?; + self.models = models; + Ok(()) + } +} \ No newline at end of file diff --git a/src/node_panel.rs b/src/node_panel.rs new file mode 100644 index 0000000..7aec794 --- /dev/null +++ b/src/node_panel.rs @@ -0,0 +1,88 @@ +//! Node panel component for ComfyUI frontend +//! +//! This module provides the sidebar panel that lists available node types +//! and allows users to add them to the workflow. + +use eframe::egui; +use crate::api_client::{ApiClient, ModelInfo}; + +pub struct NodePanel { + selected_node_type: Option, + models: Vec, + api_client: ApiClient, + loading_models: bool, +} + +impl NodePanel { + pub fn new(api_client: ApiClient) -> Self { + Self { + selected_node_type: None, + models: vec![], + api_client, + loading_models: false, + } + } + + pub fn ui(&mut self, ui: &mut egui::Ui) { + ui.set_min_width(200.0); + ui.heading("Node Panel"); + + ui.separator(); + + // Available node types + ui.label("Available Nodes:"); + + let node_types = vec![ + "Image Loader", + "Text Input", + "Image Generator", + "Image Save", + "ControlNet", + "VAE", + "CLIP" + ]; + + for node_type in node_types { + if ui.button(node_type).clicked() { + self.selected_node_type = Some(node_type.to_string()); + } + } + + ui.separator(); + + // Models section + ui.heading("Models"); + + ui.horizontal(|ui| { + if ui.button("Refresh Models").clicked() { + self.loading_models = true; + // In a real implementation, this would load models from backend + } + + if self.loading_models { + ui.label("Loading..."); + } + }); + + for model in &self.models { + ui.label(format!("{} ({})", model.name, model.model_type)); + } + + ui.separator(); + + // Selected node info + if let Some(selected) = &self.selected_node_type { + ui.label(format!("Selected: {}", selected)); + } else { + ui.label("No node selected"); + } + } + + /// Load models from the backend + pub async fn load_models(&mut self) -> Result<(), Box> { + let models = self.api_client.get_models().await?; + self.models = models; + self.loading_models = false; + Ok(()) + } +} diff --git a/src/preview_pane.rs b/src/preview_pane.rs new file mode 100644 index 0000000..e91e9ee --- /dev/null +++ b/src/preview_pane.rs @@ -0,0 +1,81 @@ +//! Preview pane component for ComfyUI frontend +//! +//! This module provides the right-side panel that displays generated +//! images and preview information. + +use eframe::egui; +use crate::api_client::{ApiClient, TaskInfo}; + +pub struct PreviewPane { + image_data: Option>, + image_name: String, + tasks: Vec, + api_client: ApiClient, +} + +impl PreviewPane { + pub fn new(api_client: ApiClient) -> Self { + Self { + image_data: None, + image_name: "No image".to_string(), + tasks: vec![], + api_client, + } + } + + pub fn ui(&mut self, ui: &mut egui::Ui) { + ui.set_min_width(300.0); + ui.heading("Preview Pane"); + + ui.separator(); + + // Display current preview info + ui.label(format!("Image: {}", self.image_name)); + + if let Some(_data) = &self.image_data { + // Try to display the image (simplified) + ui.label("Image would be displayed here"); + + // Placeholder for actual image display logic + ui.horizontal(|ui| { + if ui.button("Load Sample Image").clicked() { + self.image_name = "sample_output.png".to_string(); + // In a real implementation, this would load actual image data + self.image_data = Some(vec![0; 1024]); + } + }); + } else { + ui.label("No preview available"); + + if ui.button("Generate Preview").clicked() { + self.image_name = "generated_output.png".to_string(); + // In a real implementation, this would fetch actual image data from backend + self.image_data = Some(vec![0; 1024]); + } + } + + ui.separator(); + + // Task list section + ui.heading("Recent Tasks"); + if ui.button("Refresh Tasks").clicked() { + // In a real implementation, this would load tasks from backend + } + + for task in &self.tasks { + ui.label(format!("{}: {} ({:.1}%)", task.id, task.status, task.progress)); + } + + ui.separator(); + + // Status info + ui.label("Status: Ready"); + } + + /// Load recent tasks from the backend + pub async fn load_tasks(&mut self) -> Result<(), Box> { + let tasks = self.api_client.get_all_tasks().await?; + self.tasks = tasks; + Ok(()) + } +} \ No newline at end of file diff --git a/src/queue_service/mod.rs b/src/queue_service/mod.rs new file mode 100644 index 0000000..6818899 --- /dev/null +++ b/src/queue_service/mod.rs @@ -0,0 +1,153 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Notify; + +/// Task queue service for managing concurrent AI inference tasks +/// +/// This module provides a thread-safe task queue system using Tokio +/// and Rayon for parallel processing on AMD GPUs. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Task { + pub id: String, + pub name: String, + pub status: TaskStatus, + pub progress: f32, + pub created_at: u64, + pub updated_at: u64, + pub model_name: Option, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum TaskStatus { + Pending, + Processing, + Completed, + Failed, + Cancelled, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "pending"), + TaskStatus::Processing => write!(f, "processing"), + TaskStatus::Completed => write!(f, "completed"), + TaskStatus::Failed => write!(f, "failed"), + TaskStatus::Cancelled => write!(f, "cancelled"), + } + } +} + +#[derive(Debug)] +pub struct TaskQueue { + tasks: HashMap, + notify: Arc, +} + +impl TaskQueue { + /// Create a new task queue instance + pub fn new() -> Self { + Self { + tasks: HashMap::new(), + notify: Arc::new(Notify::new()), + } + } + + /// Add a new task to the queue + pub async fn add_task(&mut self, prompt: &str) -> String { + // Simple ID generation for demo purposes + let id = format!("task-{}", chrono::Utc::now().timestamp()); + let task = Task { + id: id.clone(), + name: prompt.to_string(), + status: TaskStatus::Pending, + progress: 0.0, + created_at: chrono::Utc::now().timestamp() as u64, + updated_at: chrono::Utc::now().timestamp() as u64, + model_name: None, + parameters: serde_json::Value::Null, + }; + + self.tasks.insert(id.clone(), task); + self.notify.notify_waiters(); + id + } + + /// Get a specific task by ID + pub async fn get_task(&self, id: &str) -> Option { + self.tasks.get(id).cloned() + } + + /// Update the status of a task + pub async fn update_task_status(&mut self, id: &str, status: TaskStatus) -> Result<(), Box> { + if let Some(task) = self.tasks.get_mut(id) { + task.status = status; + task.updated_at = chrono::Utc::now().timestamp() as u64; + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Update the progress of a task + pub async fn update_task_progress(&mut self, id: &str, progress: f32) -> Result<(), Box> { + if let Some(task) = self.tasks.get_mut(id) { + task.progress = progress; + task.updated_at = chrono::Utc::now().timestamp() as u64; + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Get all tasks in the queue + pub async fn get_all_tasks(&self) -> Vec { + self.tasks.values().cloned().collect() + } + + /// Remove a completed or failed task + pub async fn remove_task(&mut self, id: &str) -> Result<(), Box> { + if self.tasks.remove(id).is_some() { + Ok(()) + } else { + Err(format!("Task with ID '{}' not found", id).into()) + } + } + + /// Wait for a new task to be added to the queue + pub async fn wait_for_new_task(&self) -> String { + let notify = self.notify.clone(); + tokio::task::spawn(async move { + notify.notified().await; + }).await.unwrap(); // This will always succeed + + // Return an empty string as placeholder - in a real implementation, + // this would return the task ID of the newly added task + String::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_task_creation() { + let task = Task { + id: "test-123".to_string(), + name: "Test Task".to_string(), + status: TaskStatus::Pending, + progress: 0.0, + created_at: 0, + updated_at: 0, + model_name: None, + parameters: serde_json::Value::Null, + }; + + assert_eq!(task.id, "test-123"); + assert_eq!(task.name, "Test Task"); + assert_eq!(task.status, TaskStatus::Pending); + } +} \ No newline at end of file diff --git a/src/rocminfo.rs b/src/rocminfo.rs new file mode 100644 index 0000000..bd07e92 --- /dev/null +++ b/src/rocminfo.rs @@ -0,0 +1,45 @@ +//! ROCm GPU information gathering +//! +//! This module provides functionality to detect and retrieve +//! information about AMD GPUs installed on the system. + +/// GPU configuration structure +#[derive(Debug, Clone)] +pub struct GpuConfig { + /// Name of the GPU + pub name: String, + + /// Architecture of the GPU + pub architecture: String, + + /// Driver version + pub driver_version: String, +} + +/// Detect AMD GPU information using rocminfo +/// +/// This is a simplified implementation that returns mock data. +/// In a real application, this would call rocminfo or similar tools. +pub fn detect_amd_gpu() -> Result> { + // Mock implementation - in reality this would execute rocminfo command + // and parse the output to extract GPU information + + Ok(GpuConfig { + name: "AMD Radeon RX 7000 Series".to_string(), + architecture: "RDNA2".to_string(), + driver_version: "23.10.1".to_string(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_amd_gpu() { + let config = detect_amd_gpu().unwrap(); + assert!(!config.name.is_empty()); + assert!(!config.architecture.is_empty()); + assert!(!config.driver_version.is_empty()); + } +} \ No newline at end of file diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..ba2050c --- /dev/null +++ b/src/server.rs @@ -0,0 +1,46 @@ +//! ComfyUI-like Backend for AI Image/Video Generation with ROCm Support +//! +//! This server provides a REST API and WebSocket endpoints for +//! managing image/video generation workflows on AMD GPUs. + +use actix_web::{web, App, HttpServer, middleware::Logger}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub mod api; +pub mod models; +pub mod queue_service; + +#[derive(Debug, Clone)] +pub struct AppState { + pub model_manager: Arc>, + pub task_queue: Arc>, +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + env_logger::init(); + + // Initialize model manager + let model_manager = Arc::new(Mutex::new(models::ModelManager::new())); + + // Initialize task queue + let task_queue = Arc::new(Mutex::new(queue_service::TaskQueue::new())); + + let app_state = AppState { + model_manager, + task_queue, + }; + + println!("Starting ComfyUI backend server..."); + + HttpServer::new(move || { + App::new() + .app_data(web::Data::new(app_state.clone())) + .wrap(Logger::default()) + .configure(api::config) + }) + .bind("127.0.0.1:8080")? + .run() + .await +}