From 0a513b6c20b3cffe9e3cea8fc1572b438ae9c269 Mon Sep 17 00:00:00 2001 From: IGN-Styly Date: Thu, 12 Mar 2026 02:01:34 +0000 Subject: [PATCH 1/3] Add proxy-based cluster routing for backend engine nodes - introduce Cluster gRPC service and node registration/heartbeat messages - add new `proxy` binary with routing, backend discovery, and stale-node reaping - refactor `server` into backend service module and add cluster client auto-registration - add proxy cluster integration tests and RFC 1004 documentation --- Cargo.lock | 43 ++- engine/Cargo.toml | 3 + engine/proto/engine.proto | 42 +++ engine/src/bin/proxy.rs | 36 ++ engine/src/bin/server.rs | 663 ++-------------------------------- engine/src/cluster_client.rs | 179 +++++++++ engine/src/lib.rs | 24 +- engine/src/proto.rs | 4 + engine/src/proxy_config.rs | 101 ++++++ engine/src/routing.rs | 214 +++++++++++ engine/src/service/backend.rs | 567 +++++++++++++++++++++++++++++ engine/src/service/mod.rs | 2 + engine/src/service/proxy.rs | 500 +++++++++++++++++++++++++ engine/tests/common/mod.rs | 291 +++++++++++++++ engine/tests/proxy_cluster.rs | 507 ++++++++++++++++++++++++++ enginelib/src/config.rs | 15 + rfc/rfc1004.md | 157 ++++++++ 17 files changed, 2703 insertions(+), 645 deletions(-) create mode 100644 engine/src/bin/proxy.rs create mode 100644 engine/src/cluster_client.rs create mode 100644 engine/src/proto.rs create mode 100644 engine/src/proxy_config.rs create mode 100644 engine/src/routing.rs create mode 100644 engine/src/service/backend.rs create mode 100644 engine/src/service/mod.rs create mode 100644 engine/src/service/proxy.rs create mode 100644 engine/tests/common/mod.rs create mode 100644 engine/tests/proxy_cluster.rs create mode 100644 rfc/rfc1004.md diff --git a/Cargo.lock b/Cargo.lock index 0add590..1b13ca9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,8 +543,8 @@ version = "0.1.0" source = "git+https://github.com/GrandEngineering/druid.git#801365f097b6b437314d0deaa72fd85be6af0a8f" dependencies = [ "chrono", - "rand", - "rand_chacha", + "rand 0.9.2", + "rand_chacha 0.9.0", "uuid", ] @@ -591,14 +591,17 @@ dependencies = [ "druid", "enginelib", "prost", + "rand 0.8.5", "serde", "tokio", + "tokio-stream", "toml", "tonic", "tonic-build", "tonic-prost", "tonic-prost-build", "tonic-reflection", + "tracing", ] [[package]] @@ -2307,14 +2310,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -2324,7 +2348,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", ] [[package]] diff --git a/engine/Cargo.toml b/engine/Cargo.toml index 9e82df7..eca766c 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -17,13 +17,16 @@ druid = { git = "https://github.com/GrandEngineering/druid.git" } enginelib = { path = "../enginelib" } # libloading = "0.8.6" prost = "0.14" +rand = "0.8.5" serde = { workspace = true } # serde = "1.0.219" tokio = { version = "1.50.0", features = ["rt-multi-thread", "macros"] } +tokio-stream = { version = "0.1.17", features = ["net"] } toml = { workspace = true } # toml = "0.8.19" tonic = "0.14" tonic-prost = "0.14.5" +tracing = "0.1.44" tonic-reflection = "0.14" [build-dependencies] diff --git a/engine/proto/engine.proto b/engine/proto/engine.proto index 188c371..e532eeb 100644 --- a/engine/proto/engine.proto +++ b/engine/proto/engine.proto @@ -10,6 +10,13 @@ service Engine { rpc GetTasks(TaskPageRequest) returns (TaskPage); rpc CheckAuth(empty) returns (empty); } + +// Cluster membership and discovery service, used between backend nodes and a proxy. +service Cluster { + rpc RegisterNode(NodeRegisterRequest) returns (NodeRegisterResponse); + rpc Heartbeat(NodeHeartbeat) returns (NodeHeartbeatAck); + rpc ListNodes(empty) returns (NodeList); +} message TaskSelector { TaskState state = 1; string namespace = 2; @@ -58,3 +65,38 @@ message Task { string task_id = 2; // namespace:task bytes payload = 3; } + +message NodeRegisterRequest { + string node_id = 1; + string advertise_addr = 2; // e.g. "http://10.0.0.12:50051" + repeated string tags = 3; // e.g. ["gpu", "control"] + repeated string tasks = 4; // ["ns:task", ...] from local task registry +} + +message NodeRegisterResponse { + string session_id = 1; + uint64 heartbeat_interval_seconds = 2; // default 5 + uint64 ttl_seconds = 3; // default 15 +} + +message NodeHeartbeat { + string node_id = 1; + string session_id = 2; +} + +message NodeHeartbeatAck { + uint64 server_time_unix = 1; +} + +message NodeInfo { + string node_id = 1; + string advertise_addr = 2; + repeated string tags = 3; + repeated string tasks = 4; + uint64 last_seen_unix = 5; + bool healthy = 6; +} + +message NodeList { + repeated NodeInfo nodes = 1; +} diff --git a/engine/src/bin/proxy.rs b/engine/src/bin/proxy.rs new file mode 100644 index 0000000..996dcd3 --- /dev/null +++ b/engine/src/bin/proxy.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use engine::{ + proto, + proxy_config::ProxyConfigToml, + routing::ProxyState, + service::proxy::{ProxyService, spawn_reaper}, +}; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::Server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = ProxyConfigToml::load()?; + let listener = TcpListener::bind(config.listen.as_str()).await?; + let state = Arc::new(ProxyState::new(config)?); + let _reaper = spawn_reaper(state.clone()); + let cluster_service = ProxyService::new(state.clone()); + let engine_service = ProxyService::new(state); + + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1() + .map_err(|e| Box::new(e) as Box)?; + + Server::builder() + .add_service(reflection_service) + .add_service(proto::cluster_server::ClusterServer::new(cluster_service)) + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .map_err(|e| Box::new(e) as Box)?; + + Ok(()) +} diff --git a/engine/src/bin/server.rs b/engine/src/bin/server.rs index 5d6679a..0c11598 100644 --- a/engine/src/bin/server.rs +++ b/engine/src/bin/server.rs @@ -1,638 +1,16 @@ -use engine::{get_auth, get_uid}; -use enginelib::api::postcard; -use enginelib::{ - Identifier, RawIdentier, Registry, - api::EngineAPI, - chrono::Utc, - event::{debug, info, warn}, - events::{self, Events, ID}, - plugin::LibraryManager, - task::{SolvedTasks, StoredExecutingTask, StoredTask, Task, TaskQueue}, -}; -use proto::{ - TaskState, - engine_server::{Engine, EngineServer}, +use engine::{ + cluster_client::{registration_from_api, spawn_registration}, + proto, + service::backend::BackendEngineService, }; +use enginelib::{api::EngineAPI, events::Events}; use std::{ - collections::HashMap, - env::consts::OS, - io::Read, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, - sync::{Arc, RwLock as RS_RwLock}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, }; -use tokio::sync::RwLock; -use tonic::{Request, Response, Status, metadata::MetadataValue, transport::Server}; - -mod proto { - tonic::include_proto!("engine"); - pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = - tonic::include_file_descriptor_set!("engine_descriptor"); -} -#[allow(non_snake_case)] -struct EngineService { - pub EngineAPI: Arc>, -} -#[tonic::async_trait] -impl Engine for EngineService { - async fn check_auth( - &self, - request: tonic::Request, - ) -> Result, Status> { - let challenge = get_auth(&request); - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); - if !output { - warn!("Auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid Auth")); - }; - return Ok(tonic::Response::new(proto::Empty {})); - } - async fn delete_task( - &self, - request: tonic::Request, - ) -> Result, Status> { - let mut api = self.EngineAPI.write().await; - let data = request.get_ref(); - let challenge = get_auth(&request); - let db = api.db.clone(); - let id = ID(&data.namespace, &data.task); - - let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); - if !output { - warn!("Auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid Auth")); - }; - // Generic helper for removing a task by id from a collection, using an id extractor closure - fn delete_task_from_collection( - collection: &mut HashMap<(String, String), Vec>, - id: &(String, String), - task_id: &str, - state_name: &str, - namespace: &str, - task: &str, - id_extractor: F, - ) -> Result<(), Status> - where - F: Fn(&T) -> &str, - { - match collection.get_mut(id) { - Some(query) => { - let orig_len = query.len(); - query.retain(|f| id_extractor(f) != task_id); - if query.len() == orig_len { - info!( - "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", - task_id, state_name, namespace, task - ); - return Err(Status::not_found(format!( - "Task with id {} not found in {} state", - task_id, state_name - ))); - } - Ok(()) - } - None => { - info!( - "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", - state_name, namespace, task - ); - Err(Status::not_found(format!( - "No tasks found in {} state for given namespace and task", - state_name - ))) - } - } - } - - // Use the helper for each state - let result = match data.state() { - TaskState::Processing => delete_task_from_collection( - &mut api.executing_tasks.tasks, - &id, - &data.id, - "Processing", - &data.namespace, - &data.task, - |f| &f.id, - ), - TaskState::Solved => delete_task_from_collection( - &mut api.solved_tasks.tasks, - &id, - &data.id, - "Solved", - &data.namespace, - &data.task, - |f| &f.id, - ), - TaskState::Queued => delete_task_from_collection( - &mut api.task_queue.tasks, - &id, - &data.id, - "Queued", - &data.namespace, - &data.task, - |f| &f.id, - ), - }; - - if let Err(e) = result { - return Err(e); - } - - // Sync running memory into DB - EngineAPI::sync_db(&mut api); - info!( - "DeleteTask: Successfully deleted task with id {} in state {:?} for namespace: {}, task: {}", - data.id, - data.state(), - data.namespace, - data.task - ); - Ok(tonic::Response::new(proto::Empty {})) - } - /// Retrieves a paginated list of tasks filtered by namespace, task name, and state. - /// - /// Authenticates the request and, if authorized, returns tasks in the specified state - /// (`Processing`, `Queued`, or `Solved`) for the given namespace and task name. The results - /// are sorted by task ID and paginated according to the requested page and page size. - /// - /// Returns a `TaskPage` containing the filtered tasks and pagination metadata, or a - /// permission denied error if authentication fails. - /// - /// # Examples - /// - /// ``` - /// // Example usage within a tonic gRPC client context: - /// let request = proto::TaskPageRequest { - /// namespace: "example_ns".to_string(), - /// task: "example_task".to_string(), - /// state: proto::TaskState::Queued as i32, - /// page: 0, - /// page_size: 10, - /// }; - /// let response = engine_client.get_tasks(request).await?; - /// assert!(response.get_ref().tasks.len() <= 10); - /// ``` - async fn get_tasks( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - - let db = api.db.clone(); - if !Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db) { - info!("GetTask denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - }; - let data = request.get_ref(); - - let q: Vec = match data.clone().state() { - TaskState::Processing => { - match api - .executing_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut task_refs: Vec<_> = tasks.iter().collect(); - task_refs.sort_by_key(|f| &f.id); - task_refs - .iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Processing state", - data.namespace, data.task - ); - Vec::new() - } - } - } - TaskState::Queued => { - match api - .task_queue - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut d = tasks.clone(); - d.sort_by_key(|f| f.id.clone()); - d.iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Queued state", - data.namespace, data.task - ); - Vec::new() - } - } - } - TaskState::Solved => { - match api - .solved_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut d = tasks.clone(); - d.sort_by_key(|f| f.id.clone()); - d.iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Solved state", - data.namespace, data.task - ); - Vec::new() - } - } - } - }; - let index = data.page * data.page_size as u64; - let end = index + (api.cfg.config_toml.pagination_limit.min(data.page_size) as u64); - let final_vec: Vec<_> = q - .iter() - .skip(index as usize) - .take(data.page_size as usize) - .cloned() - .collect(); - return Ok(tonic::Response::new(proto::TaskPage { - namespace: data.namespace.clone(), - task: data.task.clone(), - page: data.page, - page_size: data.page_size, - state: data.state, - tasks: final_vec, - })); - } - /// Handles custom gRPC messages with admin-level authentication. - /// - /// Processes a CGRPC request by verifying admin credentials and dispatching the event payload to the appropriate handler. Returns the processed event payload in the response. If authentication fails, returns a permission denied error. - /// - /// # Returns - /// A `Cgrpcmsg` response containing the processed event payload, or a permission denied gRPC status on failed authentication. - /// - /// # Examples - /// - /// ``` - /// // Example usage within a gRPC client context: - /// let request = proto::Cgrpcmsg { - /// handler_mod_id: "mod".to_string(), - /// handler_id: "handler".to_string(), - /// event_payload: vec![1, 2, 3], - /// // ... other fields ... - /// }; - /// let response = engine_service.cgrpc(tonic::Request::new(request)).await?; - /// assert_eq!(response.get_ref().handler_mod_id, "mod"); - /// ``` - async fn cgrpc( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status> { - info!( - "CGRPC request received for handler: {}:{}", - request.get_ref().handler_mod_id, - request.get_ref().handler_id - ); - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let db = api.db.clone(); - debug!("Checking admin authentication for CGRPC request"); - let output = Events::CheckAdminAuth( - &mut api, - challenge, - ( - request.get_ref().handler_mod_id.clone(), - request.get_ref().handler_id.clone(), - ), - db, - ); - if !output { - warn!("CGRPC auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid CGRPC Auth")); - }; - let out = Arc::new(std::sync::RwLock::new(Vec::new())); - debug!("Dispatching CGRPC event to handler"); - Events::CgrpcEvent( - &mut api, - ID("engine_core", "grpc"), - request.get_ref().event_payload.clone(), - out.clone(), - ); - let mut res = request.get_ref().clone(); - res.event_payload = match out.read() { - Ok(g) => g.clone(), - Err(_) => { - warn!("CGRPC response lock poisoned, returning empty payload"); - Vec::new() - } - }; - info!("CGRPC request processed successfully"); - return Ok(tonic::Response::new(res)); - } - async fn aquire_task_reg( - &self, - request: tonic::Request, - ) -> Result, tonic::Status> { - let uid = get_uid(&request); - let challenge = get_auth(&request); - info!("Task registry request received from user: {}", uid); - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - - debug!("Validating authentication for task registry request"); - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!( - "Task registry request denied - invalid authentication for user: {}", - uid - ); - return Err(Status::permission_denied("Invalid authentication")); - }; - let mut tasks: Vec = Vec::new(); - for (k, v) in &api.task_registry.tasks { - let js: Vec = vec![k.0.clone(), k.1.clone()]; - let jstr = js.join(":"); - tasks.push(jstr); - } - info!("Returning task registry with {} tasks", tasks.len()); - let response = proto::TaskRegistry { tasks }; - Ok(tonic::Response::new(response)) - } - - async fn aquire_task( - &self, - request: tonic::Request, - ) -> Result, tonic::Status> { - let challenge = get_auth(&request); - let input = request.get_ref(); - let task_id = input.task_id.clone(); - let uid = get_uid(&request); - info!( - "Task acquisition request received from user: {} for task: {}", - uid, task_id - ); - - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - debug!("Validating authentication for task acquisition"); - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!( - "Task acquisition denied - invalid authentication for user: {}", - uid - ); - return Err(Status::permission_denied("Invalid authentication")); - }; - - // Todo: check for wrong input to not cause a Panic out of bounds. - let alen = &task_id.split(":").collect::>().len(); - if *alen != 2 { - info!("Invalid task ID format: {}", task_id); - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:name", - )); - } - let namespace = &task_id.split(":").collect::>()[0]; - let task_name = &task_id.split(":").collect::>()[1]; - debug!("Looking up task definition for {}:{}", namespace, task_name); - let tsx = api - .task_registry - .get(&(namespace.to_string(), task_name.to_string())); - if tsx.is_none() { - warn!( - "Task acquisition failed - task does not exist: {}:{}", - namespace, task_name - ); - return Err(Status::invalid_argument("Task Does not Exist")); - } - let key = ID(namespace, task_name); - let mut map = match api.task_queue.tasks.get(&key) { - Some(v) if !v.is_empty() => v.clone(), - _ => { - info!("No queued tasks for {}:{}", namespace, task_name); - return Err(Status::not_found("No queued tasks available")); - } - }; - let ttask = map.remove(0); - let task_payload = ttask.bytes.clone(); - // Get Task and remove it from queue - api.task_queue.tasks.insert(key.clone(), map); - match postcard::to_allocvec(&api.task_queue.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => { - return Err(Status::internal(format!("Serialization error: {}", e))); - } - } - // Move it to exec queue - let mut exec_tsks = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - exec_tsks.push(enginelib::task::StoredExecutingTask { - bytes: task_payload.clone(), - user_id: uid.clone(), - given_at: Utc::now(), - id: ttask.id.clone(), - }); - api.executing_tasks.tasks.insert(key.clone(), exec_tsks); - match postcard::to_allocvec(&api.executing_tasks.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("executing_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => { - return Err(Status::internal(format!("Serialization error: {}", e))); - } - } - let response = proto::Task { - id: ttask.id, - task_id: input.task_id.clone(), - task_payload, - payload: Vec::new(), - }; - Ok(tonic::Response::new(response)) - } - async fn publish_task( - &self, - request: tonic::Request, - ) -> Result, tonic::Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let uid = get_uid(&request); - let db = api.db.clone(); - - let task_id = request.get_ref().task_id.clone(); - let alen = &task_id.split(":").collect::>().len(); - if *alen != 2 { - return Err(Status::invalid_argument("Invalid Params")); - } - let namespace = &task_id.split(":").collect::>()[0]; - let task_name = &task_id.split(":").collect::>()[1]; - - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!("Aquire Task denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - }; - if !api - .task_registry - .tasks - .contains_key(&ID(namespace, task_name)) - { - warn!( - "Task acquisition failed - task does not exist: {}:{}", - namespace, task_name - ); - return Err(Status::invalid_argument("Task Does not Exist")); - } - let key = ID(namespace, task_name); - let mem_tsk = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - let tsk_opt = mem_tsk - .iter() - .find(|f| f.id == task_id.clone() && f.user_id == uid.clone()); - if let Some(tsk) = tsk_opt { - let reg_tsk = match api.task_registry.get(&key) { - Some(r) => r.clone(), - None => { - warn!("Task registry missing for {}:{}", namespace, task_name); - return Err(Status::invalid_argument("Task Does not Exist")); - } - }; - if !reg_tsk.verify(request.get_ref().task_payload.clone()) { - info!("Failed to parse task"); - return Err(Status::invalid_argument("Failed to parse given task bytes")); - } - // Exec Tasks -> DB - let mut nmem_tsk = mem_tsk.clone(); - nmem_tsk.retain(|f| f.id != task_id.clone() && f.user_id != uid.clone()); - api.executing_tasks - .tasks - .insert(key.clone(), nmem_tsk.clone()); - let t_mem_execs = api.executing_tasks.clone(); - match postcard::to_allocvec(&t_mem_execs) { - Ok(store) => { - if let Err(e) = api.db.insert("executing_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - // tsk-> solved Tsks - let mut mem_solv = api - .solved_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - mem_solv.push(enginelib::task::StoredTask { - bytes: tsk.bytes.clone(), - id: tsk.id.clone(), - }); - api.solved_tasks.tasks.insert(key.clone(), mem_solv); - // Solved tsks -> DB - match postcard::to_allocvec(&api.solved_tasks.tasks) { - Ok(e_solv) => { - if let Err(e) = api.db.insert("solved_tasks", e_solv) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - info!("Task published successfully: {} by user: {}", task_id, uid); - return Ok(tonic::Response::new(proto::Empty {})); - } else { - return Err(tonic::Status::not_found("Invalid taskid or userid")); - } - } - async fn create_task( - &self, - request: tonic::Request, - ) -> Result, tonic::Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let uid = get_uid(&request); - let db = api.db.clone(); - if !Events::CheckAuth(&mut api, uid, challenge, db) { - //TODO: change to AdminSpecific Auth - info!("Create Task denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - }; - let task = request.get_ref(); - let task_id = task.task_id.clone(); - let parts: Vec<&str> = task_id.splitn(2, ':').collect(); - if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - } - let id: Identifier = (parts[0].to_string(), parts[1].to_string()); - let tsk_reg = api.task_registry.get(&id); - if let Some(tsk_reg) = tsk_reg { - if !tsk_reg.clone().verify(task.task_payload.clone()) { - warn!("Failed to parse given task bytes"); - return Err(Status::invalid_argument("Failed to parse given task bytes")); - } - let tbp_tsk = StoredTask { - bytes: task.task_payload.clone(), - id: druid::Druid::default().to_hex(), - }; - let mut mem_tsks = api.task_queue.clone(); - let mut mem_tsk = mem_tsks.tasks.get(&id).cloned().unwrap_or_default(); - mem_tsk.push(tbp_tsk.clone()); - mem_tsks.tasks.insert(id.clone(), mem_tsk); - api.task_queue = mem_tsks; - match postcard::to_allocvec(&api.task_queue.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - return Ok(tonic::Response::new(proto::Task { - id: tbp_tsk.id.clone(), - task_id: task_id.clone(), - payload: Vec::new(), - task_payload: tbp_tsk.bytes.clone(), - })); - } - Err(tonic::Status::aborted("Error")) - } -} +use tokio::{net::TcpListener, sync::RwLock}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::Server; #[tokio::main] async fn main() -> Result<(), Box> { @@ -640,6 +18,7 @@ async fn main() -> Result<(), Box> { EngineAPI::init(&mut api); Events::init_auth(&mut api); Events::StartEvent(&mut api); + let addr = api .cfg .config_toml @@ -649,21 +28,27 @@ async fn main() -> Result<(), Box> { Ipv4Addr::new(127, 0, 0, 1), 50051, ))); - let apii = Arc::new(RwLock::new(api)); - EngineAPI::init_chron(apii.clone()); - let engine = EngineService { EngineAPI: apii }; + let listener = TcpListener::bind(addr).await?; + + let api = Arc::new(RwLock::new(api)); + let registration = { + let api_guard = api.read().await; + registration_from_api(&api_guard)? + }; + let _registration_task = registration.map(spawn_registration); + + EngineAPI::init_chron(api.clone()); + let engine = BackendEngineService::new(api); - // Build reflection service, mapping its concrete error into Box let reflection_service = tonic_reflection::server::Builder::configure() .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) .build_v1() .map_err(|e| Box::new(e) as Box)?; - // Start server and map transport errors into Box so `?` works with our return type. Server::builder() .add_service(reflection_service) - .add_service(EngineServer::new(engine)) - .serve(addr) + .add_service(proto::engine_server::EngineServer::new(engine)) + .serve_with_incoming(TcpListenerStream::new(listener)) .await .map_err(|e| Box::new(e) as Box)?; diff --git a/engine/src/cluster_client.rs b/engine/src/cluster_client.rs new file mode 100644 index 0000000..146bf20 --- /dev/null +++ b/engine/src/cluster_client.rs @@ -0,0 +1,179 @@ +use std::time::Duration; + +use enginelib::api::EngineAPI; +use tokio::{task::JoinHandle, time::sleep}; +use tonic::{Request, transport::Endpoint}; +use tracing::{error, info, warn}; + +use crate::proto::{self, cluster_client::ClusterClient}; + +#[derive(Debug, Clone)] +pub struct NodeRegistration { + pub node_id: String, + pub advertise_addr: String, + pub cluster_proxy_addr: String, + pub cluster_token: String, + pub node_tags: Vec, + pub tasks: Vec, +} + +pub fn registration_from_api(api: &EngineAPI) -> Result, String> { + let cfg = &api.cfg.config_toml; + let Some(cluster_proxy_addr) = cfg.cluster_proxy_addr.clone() else { + return Ok(None); + }; + + let node_id = cfg + .node_id + .clone() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| "clustered backend requires config.server.node_id".to_string())?; + if node_id.contains('@') { + return Err("clustered backend node_id must not contain '@'".into()); + } + + let cluster_token = cfg + .cluster_token + .clone() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| "clustered backend requires config.server.cluster_token".to_string())?; + + let advertise_addr = + normalize_advertise_addr(cfg.advertise_addr.as_deref(), cfg.host.as_str())?; + + validate_endpoint(&cluster_proxy_addr, "cluster_proxy_addr")?; + validate_endpoint(&advertise_addr, "advertise_addr")?; + + let mut tasks: Vec = api + .task_registry + .tasks + .keys() + .map(|(namespace, task)| format!("{namespace}:{task}")) + .collect(); + tasks.sort(); + tasks.dedup(); + + let mut node_tags = cfg.node_tags.clone().unwrap_or_default(); + node_tags.sort(); + node_tags.dedup(); + + Ok(Some(NodeRegistration { + node_id, + advertise_addr, + cluster_proxy_addr, + cluster_token, + node_tags, + tasks, + })) +} + +pub fn normalize_advertise_addr( + advertise_addr: Option<&str>, + host: &str, +) -> Result { + let value = advertise_addr.unwrap_or(host).trim(); + if value.is_empty() { + return Err("advertise address must not be empty".into()); + } + + if value.starts_with("http://") || value.starts_with("https://") { + Ok(value.to_string()) + } else { + Ok(format!("http://{value}")) + } +} + +pub fn spawn_registration(registration: NodeRegistration) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + match ClusterClient::connect(registration.cluster_proxy_addr.clone()).await { + Ok(mut client) => { + let register_response = match register_node(&mut client, ®istration).await { + Ok(response) => response, + Err(err) => { + warn!( + "cluster registration failed for {}: {}", + registration.node_id, err + ); + sleep(Duration::from_secs(1)).await; + continue; + } + }; + + info!( + "registered backend node {} with proxy {}", + registration.node_id, registration.cluster_proxy_addr + ); + + let session_id = register_response.session_id; + let interval_seconds = register_response.heartbeat_interval_seconds.max(1); + + loop { + sleep(Duration::from_secs(interval_seconds)).await; + match heartbeat(&mut client, ®istration, &session_id).await { + Ok(_) => {} + Err(err) => { + warn!( + "cluster heartbeat failed for {}: {}", + registration.node_id, err + ); + break; + } + } + } + } + Err(err) => { + error!( + "failed to connect backend node {} to cluster proxy {}: {}", + registration.node_id, registration.cluster_proxy_addr, err + ); + sleep(Duration::from_secs(1)).await; + } + } + } + }) +} + +async fn register_node( + client: &mut ClusterClient, + registration: &NodeRegistration, +) -> Result { + let mut request = Request::new(proto::NodeRegisterRequest { + node_id: registration.node_id.clone(), + advertise_addr: registration.advertise_addr.clone(), + tags: registration.node_tags.clone(), + tasks: registration.tasks.clone(), + }); + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", registration.cluster_token) + .parse() + .map_err(|_| tonic::Status::internal("failed to encode cluster authorization"))?, + ); + Ok(client.register_node(request).await?.into_inner()) +} + +async fn heartbeat( + client: &mut ClusterClient, + registration: &NodeRegistration, + session_id: &str, +) -> Result<(), tonic::Status> { + let mut request = Request::new(proto::NodeHeartbeat { + node_id: registration.node_id.clone(), + session_id: session_id.to_string(), + }); + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", registration.cluster_token) + .parse() + .map_err(|_| tonic::Status::internal("failed to encode cluster authorization"))?, + ); + client.heartbeat(request).await?; + Ok(()) +} + +fn validate_endpoint(value: &str, field_name: &str) -> Result<(), String> { + Endpoint::from_shared(value.to_string()) + .map(|_| ()) + .map_err(|err| format!("invalid {field_name} '{value}': {err}")) +} diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 5b529f0..9d09966 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -1,4 +1,13 @@ -use tonic::Request; +use tonic::{ + Request, + metadata::{KeyAndValueRef, MetadataMap}, +}; + +pub mod cluster_client; +pub mod proto; +pub mod proxy_config; +pub mod routing; +pub mod service; pub fn get_uid(req: &Request) -> String { req.metadata() @@ -8,6 +17,19 @@ pub fn get_uid(req: &Request) -> String { .unwrap_or_default() } +pub fn copy_metadata(source: &MetadataMap, target: &mut MetadataMap) { + for entry in source.iter() { + match entry { + KeyAndValueRef::Ascii(key, value) => { + target.insert(key, value.clone()); + } + KeyAndValueRef::Binary(key, value) => { + target.insert_bin(key, value.clone()); + } + } + } +} + pub fn get_auth(req: &Request) -> String { req.metadata() .get("authorization") diff --git a/engine/src/proto.rs b/engine/src/proto.rs new file mode 100644 index 0000000..86acc1e --- /dev/null +++ b/engine/src/proto.rs @@ -0,0 +1,4 @@ +// Shared gRPC/protobuf bindings for the engine crate (bins + tests). +tonic::include_proto!("engine"); + +pub const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("engine_descriptor"); diff --git a/engine/src/proxy_config.rs b/engine/src/proxy_config.rs new file mode 100644 index 0000000..0415760 --- /dev/null +++ b/engine/src/proxy_config.rs @@ -0,0 +1,101 @@ +use std::fs; + +use serde::Deserialize; + +fn default_listen() -> String { + "0.0.0.0:50052".into() +} + +fn default_node_ttl_seconds() -> u64 { + 15 +} + +fn default_heartbeat_interval_seconds() -> u64 { + 5 +} + +fn default_max_acquire_hops() -> u32 { + 3 +} + +fn default_admin_fanout_limit() -> u32 { + 5000 +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct RouteRuleToml { + pub r#match: String, + #[serde(default)] + pub require_tags: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProxyConfigToml { + #[serde(default = "default_listen")] + pub listen: String, + pub cluster_token: String, + #[serde(default = "default_node_ttl_seconds")] + pub node_ttl_seconds: u64, + #[serde(default = "default_heartbeat_interval_seconds")] + pub heartbeat_interval_seconds: u64, + #[serde(default = "default_max_acquire_hops")] + pub max_acquire_hops: u32, + #[serde(default = "default_admin_fanout_limit")] + pub admin_fanout_limit: u32, + #[serde(default)] + pub rules: Vec, +} + +impl Default for ProxyConfigToml { + fn default() -> Self { + Self { + listen: default_listen(), + cluster_token: String::new(), + node_ttl_seconds: default_node_ttl_seconds(), + heartbeat_interval_seconds: default_heartbeat_interval_seconds(), + max_acquire_hops: default_max_acquire_hops(), + admin_fanout_limit: default_admin_fanout_limit(), + rules: Vec::new(), + } + } +} + +impl ProxyConfigToml { + pub fn load() -> Result { + let content = match fs::read_to_string("proxy.toml") { + Ok(content) => content, + Err(err) if err.kind() == ErrorKind::NotFound => String::new(), + Err(err) => return Err(format!("Failed to read proxy.toml: {err}")), + }; + + let config = if content.trim().is_empty() { + Self::default() + } else { + toml::from_str(&content).map_err(|err| format!("Failed to parse proxy.toml: {err}"))? + }; + + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> Result<(), String> { + if self.cluster_token.trim().is_empty() { + return Err("proxy cluster_token must not be empty".into()); + } + if self.node_ttl_seconds == 0 { + return Err("proxy node_ttl_seconds must be greater than zero".into()); + } + if self.heartbeat_interval_seconds == 0 { + return Err("proxy heartbeat_interval_seconds must be greater than zero".into()); + } + if self.max_acquire_hops == 0 { + return Err("proxy max_acquire_hops must be greater than zero".into()); + } + if self.admin_fanout_limit == 0 { + return Err("proxy admin_fanout_limit must be greater than zero".into()); + } + Ok(()) + } +} + +use std::io::ErrorKind; diff --git a/engine/src/routing.rs b/engine/src/routing.rs new file mode 100644 index 0000000..4100071 --- /dev/null +++ b/engine/src/routing.rs @@ -0,0 +1,214 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::{SystemTime, UNIX_EPOCH}, +}; + +use tokio::sync::RwLock; +use tonic::transport::{Channel, Endpoint}; + +use crate::proxy_config::{ProxyConfigToml, RouteRuleToml}; + +#[derive(Debug, Clone)] +pub struct RouteRequirement { + pub require_tags: Vec, +} + +#[derive(Debug, Clone)] +pub struct RouteRule { + pub namespace_pattern: String, + pub task_pattern: String, + pub requirement: RouteRequirement, +} + +impl RouteRule { + pub fn from_toml(rule: RouteRuleToml) -> Result { + let Some((namespace_pattern, task_pattern)) = rule.r#match.split_once(':') else { + return Err(format!( + "invalid proxy rule match '{}', expected namespace:task", + rule.r#match + )); + }; + + Ok(Self { + namespace_pattern: namespace_pattern.to_string(), + task_pattern: task_pattern.to_string(), + requirement: RouteRequirement { + require_tags: rule.require_tags, + }, + }) + } + + pub fn matches_task(&self, task_key: &str) -> bool { + let Some((namespace, task)) = task_key.split_once(':') else { + return false; + }; + + matches_pattern(&self.namespace_pattern, namespace) + && matches_pattern(&self.task_pattern, task) + } +} + +fn matches_pattern(pattern: &str, value: &str) -> bool { + pattern == "*" || pattern == value +} + +#[derive(Debug)] +pub struct NodeState { + pub node_id: String, + pub advertise_addr: String, + pub tags: Vec, + pub tasks: Vec, + pub session_id: String, + pub channel: Channel, + pub last_seen_unix: AtomicU64, + pub in_flight_create: AtomicU64, + tag_set: BTreeSet, + task_set: BTreeSet, +} + +impl NodeState { + pub fn new( + node_id: String, + advertise_addr: String, + tags: Vec, + tasks: Vec, + session_id: String, + now_unix: u64, + ) -> Result { + let channel = Endpoint::from_shared(advertise_addr.clone()) + .map_err(|err| format!("invalid advertise_addr '{advertise_addr}': {err}"))? + .connect_lazy(); + + Ok(Self { + node_id, + advertise_addr, + tag_set: tags.iter().cloned().collect(), + task_set: tasks.iter().cloned().collect(), + tags, + tasks, + session_id, + channel, + last_seen_unix: AtomicU64::new(now_unix), + in_flight_create: AtomicU64::new(0), + }) + } + + pub fn has_task(&self, task_key: &str) -> bool { + self.task_set.contains(task_key) + } + + pub fn has_tags(&self, tags: &[String]) -> bool { + tags.iter().all(|tag| self.tag_set.contains(tag)) + } + + pub fn last_seen_unix(&self) -> u64 { + self.last_seen_unix.load(Ordering::Relaxed) + } +} + +#[derive(Debug)] +pub struct ProxyState { + pub config: ProxyConfigToml, + pub rules: Vec, + pub nodes: RwLock>>, +} + +impl ProxyState { + pub fn new(config: ProxyConfigToml) -> Result { + let mut rules = if config.rules.is_empty() { + vec![RouteRule::from_toml(RouteRuleToml { + r#match: "*:*".into(), + require_tags: Vec::new(), + })?] + } else { + let mut rules = Vec::with_capacity(config.rules.len()); + for rule in config.rules.iter().cloned() { + rules.push(RouteRule::from_toml(rule)?); + } + rules + }; + + if rules.is_empty() { + rules.push(RouteRule::from_toml(RouteRuleToml { + r#match: "*:*".into(), + require_tags: Vec::new(), + })?); + } + + Ok(Self { + config, + rules, + nodes: RwLock::new(HashMap::new()), + }) + } + + pub async fn upsert_node(&self, node: NodeState) -> Arc { + let node = Arc::new(node); + self.nodes + .write() + .await + .insert(node.node_id.clone(), node.clone()); + node + } + + pub async fn healthy_nodes(&self) -> Vec> { + let mut nodes: Vec<_> = self.nodes.read().await.values().cloned().collect(); + nodes.sort_by(|left, right| left.node_id.cmp(&right.node_id)); + nodes + } + + pub async fn candidate_nodes_for_task(&self, task_key: &str) -> Vec> { + let required_tags = self + .rules + .iter() + .find(|rule| rule.matches_task(task_key)) + .map(|rule| rule.requirement.require_tags.clone()) + .unwrap_or_default(); + + self.healthy_nodes() + .await + .into_iter() + .filter(|node| node.has_task(task_key) && node.has_tags(&required_tags)) + .collect() + } + + pub async fn preferred_nodes_by_tag(&self, tag: &str) -> Vec> { + let nodes = self.healthy_nodes().await; + let mut preferred = Vec::new(); + let mut fallback = Vec::new(); + + for node in nodes { + if node.has_tags(&[tag.to_string()]) { + preferred.push(node); + } else { + fallback.push(node); + } + } + + preferred.extend(fallback); + preferred + } + + pub async fn node_by_id(&self, node_id: &str) -> Option> { + self.nodes.read().await.get(node_id).cloned() + } + + pub async fn reap_stale_nodes(&self, now_unix: u64) { + let ttl_seconds = self.config.node_ttl_seconds; + self.nodes + .write() + .await + .retain(|_, node| now_unix.saturating_sub(node.last_seen_unix()) <= ttl_seconds); + } +} + +pub fn now_unix() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} diff --git a/engine/src/service/backend.rs b/engine/src/service/backend.rs new file mode 100644 index 0000000..ba139dc --- /dev/null +++ b/engine/src/service/backend.rs @@ -0,0 +1,567 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock as StdRwLock}, +}; + +use enginelib::api::postcard; +use enginelib::{ + Identifier, RawIdentier, Registry, + api::EngineAPI, + chrono::Utc, + events::{Events, ID}, + task::{StoredExecutingTask, StoredTask}, +}; +use tokio::sync::RwLock; +use tonic::{Response, Status}; +use tracing::{debug, info, warn}; + +use crate::{ + get_auth, get_uid, + proto::{self, TaskState, engine_server::Engine}, +}; + +#[allow(non_snake_case)] +pub struct BackendEngineService { + pub EngineAPI: Arc>, +} + +impl BackendEngineService { + pub fn new(api: Arc>) -> Self { + Self { EngineAPI: api } + } +} + +pub fn mint_task_instance_id(node_id: Option<&str>) -> String { + let random_hex = druid::Druid::default().to_hex(); + match node_id.filter(|value| !value.is_empty()) { + Some(node_id) => format!("{node_id}@{random_hex}"), + None => random_hex, + } +} + +pub fn parse_owner_node(task_instance_id: &str) -> Option<(&str, &str)> { + let (node_id, local_id) = task_instance_id.split_once('@')?; + if node_id.is_empty() || local_id.is_empty() || local_id.contains('@') { + return None; + } + Some((node_id, local_id)) +} + +fn parse_task_key(task_id: &str) -> Result { + let Some((namespace, task)) = task_id.split_once(':') else { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + }; + + if namespace.is_empty() || task.is_empty() { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + } + + Ok(ID(namespace, task)) +} + +fn delete_task_from_collection( + collection: &mut HashMap<(String, String), Vec>, + id: &(String, String), + task_id: &str, + state_name: &str, + namespace: &str, + task: &str, + id_extractor: F, +) -> Result<(), Status> +where + F: Fn(&T) -> &str, +{ + match collection.get_mut(id) { + Some(query) => { + let orig_len = query.len(); + query.retain(|f| id_extractor(f) != task_id); + if query.len() == orig_len { + info!( + "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", + task_id, state_name, namespace, task + ); + return Err(Status::not_found(format!( + "Task with id {} not found in {} state", + task_id, state_name + ))); + } + Ok(()) + } + None => { + info!( + "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", + state_name, namespace, task + ); + Err(Status::not_found(format!( + "No tasks found in {} state for given namespace and task", + state_name + ))) + } + } +} + +#[tonic::async_trait] +impl Engine for BackendEngineService { + async fn check_auth( + &self, + request: tonic::Request, + ) -> Result, Status> { + let challenge = get_auth(&request); + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); + if !output { + warn!("Auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid Auth")); + } + Ok(Response::new(proto::Empty {})) + } + + async fn delete_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let data = request.get_ref(); + let challenge = get_auth(&request); + let db = api.db.clone(); + let id = ID(&data.namespace, &data.task); + + let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); + if !output { + warn!("Auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid Auth")); + } + + let result = match data.state() { + TaskState::Processing => delete_task_from_collection( + &mut api.executing_tasks.tasks, + &id, + &data.id, + "Processing", + &data.namespace, + &data.task, + |f| &f.id, + ), + TaskState::Solved => delete_task_from_collection( + &mut api.solved_tasks.tasks, + &id, + &data.id, + "Solved", + &data.namespace, + &data.task, + |f| &f.id, + ), + TaskState::Queued => delete_task_from_collection( + &mut api.task_queue.tasks, + &id, + &data.id, + "Queued", + &data.namespace, + &data.task, + |f| &f.id, + ), + }; + + result?; + EngineAPI::sync_db(&mut api); + info!( + "DeleteTask: Successfully deleted task with id {} in state {:?} for namespace: {}, task: {}", + data.id, + data.state(), + data.namespace, + data.task + ); + Ok(Response::new(proto::Empty {})) + } + + async fn get_tasks( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let db = api.db.clone(); + if !Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db) { + info!("GetTask denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + + let data = request.get_ref(); + let q: Vec = match data.state() { + TaskState::Processing => match api + .executing_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut task_refs: Vec<_> = tasks.iter().collect(); + task_refs.sort_by_key(|f| &f.id); + task_refs + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + TaskState::Queued => match api + .task_queue + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + TaskState::Solved => match api + .solved_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + }; + + let index = data.page.saturating_mul(data.page_size as u64) as usize; + let limit = api.cfg.config_toml.pagination_limit.min(data.page_size) as usize; + let final_vec: Vec<_> = q.iter().skip(index).take(limit).cloned().collect(); + Ok(Response::new(proto::TaskPage { + namespace: data.namespace.clone(), + task: data.task.clone(), + page: data.page, + page_size: data.page_size, + state: data.state, + tasks: final_vec, + })) + } + + async fn cgrpc( + &self, + request: tonic::Request, + ) -> Result, Status> { + info!( + "CGRPC request received for handler: {}:{}", + request.get_ref().handler_mod_id, + request.get_ref().handler_id + ); + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let db = api.db.clone(); + debug!("Checking admin authentication for CGRPC request"); + let output = Events::CheckAdminAuth( + &mut api, + challenge, + ( + request.get_ref().handler_mod_id.clone(), + request.get_ref().handler_id.clone(), + ), + db, + ); + if !output { + warn!("CGRPC auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid CGRPC Auth")); + } + let out = Arc::new(StdRwLock::new(Vec::new())); + debug!("Dispatching CGRPC event to handler"); + Events::CgrpcEvent( + &mut api, + ID("engine_core", "grpc"), + request.get_ref().event_payload.clone(), + out.clone(), + ); + let mut res = request.get_ref().clone(); + res.event_payload = match out.read() { + Ok(g) => g.clone(), + Err(_) => { + warn!("CGRPC response lock poisoned, returning empty payload"); + Vec::new() + } + }; + info!("CGRPC request processed successfully"); + Ok(Response::new(res)) + } + + async fn aquire_task_reg( + &self, + request: tonic::Request, + ) -> Result, Status> { + let uid = get_uid(&request); + let challenge = get_auth(&request); + info!("Task registry request received from user: {}", uid); + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + + debug!("Validating authentication for task registry request"); + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!( + "Task registry request denied - invalid authentication for user: {}", + uid + ); + return Err(Status::permission_denied("Invalid authentication")); + } + let mut tasks: Vec = api + .task_registry + .tasks + .keys() + .map(|(namespace, task)| format!("{namespace}:{task}")) + .collect(); + tasks.sort(); + info!("Returning task registry with {} tasks", tasks.len()); + Ok(Response::new(proto::TaskRegistry { tasks })) + } + + async fn aquire_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let challenge = get_auth(&request); + let input = request.get_ref(); + let task_id = input.task_id.clone(); + let uid = get_uid(&request); + info!( + "Task acquisition request received from user: {} for task: {}", + uid, task_id + ); + + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + debug!("Validating authentication for task acquisition"); + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!( + "Task acquisition denied - invalid authentication for user: {}", + uid + ); + return Err(Status::permission_denied("Invalid authentication")); + } + + let key = parse_task_key(&task_id)?; + if api.task_registry.get(&key).is_none() { + warn!( + "Task acquisition failed - task does not exist: {}:{}", + key.0, key.1 + ); + return Err(Status::invalid_argument("Task Does not Exist")); + } + + let mut map = match api.task_queue.tasks.get(&key) { + Some(v) if !v.is_empty() => v.clone(), + _ => { + info!("No queued tasks for {}:{}", key.0, key.1); + return Err(Status::not_found("No queued tasks available")); + } + }; + + let ttask = map.remove(0); + let task_payload = ttask.bytes.clone(); + api.task_queue.tasks.insert(key.clone(), map); + match postcard::to_allocvec(&api.task_queue.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => { + return Err(Status::internal(format!("Serialization error: {}", e))); + } + } + + let mut exec_tsks = api + .executing_tasks + .tasks + .get(&key) + .cloned() + .unwrap_or_default(); + exec_tsks.push(StoredExecutingTask { + bytes: task_payload.clone(), + user_id: uid, + given_at: Utc::now(), + id: ttask.id.clone(), + }); + api.executing_tasks.tasks.insert(key.clone(), exec_tsks); + match postcard::to_allocvec(&api.executing_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("executing_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => { + return Err(Status::internal(format!("Serialization error: {}", e))); + } + } + + Ok(Response::new(proto::Task { + id: ttask.id, + task_id, + task_payload, + payload: Vec::new(), + })) + } + + async fn publish_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let uid = get_uid(&request); + let db = api.db.clone(); + + let task = request.get_ref(); + let key = parse_task_key(&task.task_id)?; + let instance_id = task.id.clone(); + + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!("Aquire Task denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + if !api.task_registry.tasks.contains_key(&key) { + warn!( + "Task acquisition failed - task does not exist: {}:{}", + key.0, key.1 + ); + return Err(Status::invalid_argument("Task Does not Exist")); + } + + let mem_tsk = api + .executing_tasks + .tasks + .get(&key) + .cloned() + .unwrap_or_default(); + let executing_task = mem_tsk + .iter() + .find(|f| f.id == instance_id && f.user_id == uid) + .cloned(); + let Some(executing_task) = executing_task else { + return Err(Status::not_found("Invalid taskid or userid")); + }; + + let reg_tsk = match api.task_registry.get(&key) { + Some(r) => r, + None => { + warn!("Task registry missing for {}:{}", key.0, key.1); + return Err(Status::invalid_argument("Task Does not Exist")); + } + }; + if !reg_tsk.verify(task.task_payload.clone()) { + info!("Failed to parse task"); + return Err(Status::invalid_argument("Failed to parse given task bytes")); + } + + let mut nmem_tsk = mem_tsk.clone(); + nmem_tsk.retain(|f| !(f.id == instance_id && f.user_id == uid)); + api.executing_tasks.tasks.insert(key.clone(), nmem_tsk); + match postcard::to_allocvec(&api.executing_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("executing_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + + let mut mem_solv = api + .solved_tasks + .tasks + .get(&key) + .cloned() + .unwrap_or_default(); + mem_solv.push(StoredTask { + bytes: task.task_payload.clone(), + id: executing_task.id, + }); + api.solved_tasks.tasks.insert(key.clone(), mem_solv); + match postcard::to_allocvec(&api.solved_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("solved_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + + info!("Task published successfully: {} by user: {}", task.id, uid); + Ok(Response::new(proto::Empty {})) + } + + async fn create_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let uid = get_uid(&request); + let db = api.db.clone(); + if !Events::CheckAuth(&mut api, uid, challenge, db) { + info!("Create Task denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + let task = request.get_ref(); + let task_id = task.task_id.clone(); + let id = parse_task_key(&task_id)?; + let tsk_reg = api.task_registry.get(&id); + if let Some(tsk_reg) = tsk_reg { + if !tsk_reg.verify(task.task_payload.clone()) { + warn!("Failed to parse given task bytes"); + return Err(Status::invalid_argument("Failed to parse given task bytes")); + } + let stored_task = StoredTask { + bytes: task.task_payload.clone(), + id: mint_task_instance_id(api.cfg.config_toml.node_id.as_deref()), + }; + let mut mem_task_queue = api.task_queue.clone(); + let mut mem_tasks = mem_task_queue.tasks.get(&id).cloned().unwrap_or_default(); + mem_tasks.push(stored_task.clone()); + mem_task_queue.tasks.insert(id.clone(), mem_tasks); + api.task_queue = mem_task_queue; + match postcard::to_allocvec(&api.task_queue.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + return Ok(Response::new(proto::Task { + id: stored_task.id.clone(), + task_id, + payload: Vec::new(), + task_payload: stored_task.bytes.clone(), + })); + } + Err(Status::aborted("Error")) + } +} diff --git a/engine/src/service/mod.rs b/engine/src/service/mod.rs new file mode 100644 index 0000000..cfea61a --- /dev/null +++ b/engine/src/service/mod.rs @@ -0,0 +1,2 @@ +pub mod backend; +pub mod proxy; diff --git a/engine/src/service/proxy.rs b/engine/src/service/proxy.rs new file mode 100644 index 0000000..9c12b5e --- /dev/null +++ b/engine/src/service/proxy.rs @@ -0,0 +1,500 @@ +use std::{ + cmp::min, + sync::{Arc, atomic::Ordering}, + time::Duration, +}; + +use crate::{ + copy_metadata, get_auth, + proto::{self, cluster_server::Cluster, engine_client::EngineClient, engine_server::Engine}, + routing::{NodeState, ProxyState, now_unix}, + service::backend::parse_owner_node, +}; +use rand::{seq::SliceRandom, thread_rng}; +use tokio::{task::JoinHandle, time::sleep}; +use tonic::{Code, Request, Response, Status}; + +pub struct ProxyService { + state: Arc, +} + +impl ProxyService { + pub fn new(state: Arc) -> Self { + Self { state } + } +} + +pub fn spawn_reaper(state: Arc) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + sleep(Duration::from_secs(1)).await; + state.reap_stale_nodes(now_unix()).await; + } + }) +} + +struct InFlightCreateGuard { + node: Arc, +} + +impl InFlightCreateGuard { + fn new(node: Arc) -> Self { + node.in_flight_create.fetch_add(1, Ordering::SeqCst); + Self { node } + } +} + +impl Drop for InFlightCreateGuard { + fn drop(&mut self) { + self.node.in_flight_create.fetch_sub(1, Ordering::SeqCst); + } +} + +fn require_cluster_auth(request: &Request, cluster_token: &str) -> Result<(), Status> { + let expected = format!("Bearer {cluster_token}"); + if get_auth(request) != expected { + return Err(Status::permission_denied("Invalid cluster authorization")); + } + Ok(()) +} + +fn parse_task_key(task_id: &str) -> Result { + let Some((namespace, task)) = task_id.split_once(':') else { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + }; + if namespace.is_empty() || task.is_empty() { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + } + Ok(format!("{namespace}:{task}")) +} + +fn build_request(request: &Request, message: T) -> Request { + let mut outbound = Request::new(message); + copy_metadata(request.metadata(), outbound.metadata_mut()); + outbound +} + +async fn forward_create( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .create_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_acquire( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .aquire_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_publish( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .publish_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_delete( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .delete_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_get_tasks( + node: Arc, + request: &Request, + page: u64, + page_size: u32, +) -> Result { + let mut payload = request.get_ref().clone(); + payload.page = page; + payload.page_size = page_size; + let mut client = EngineClient::new(node.channel.clone()); + Ok(client + .get_tasks(build_request(request, payload)) + .await? + .into_inner()) +} + +async fn forward_task_registry_probe( + node: Arc, + request: &Request, +) -> Result<(), Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .aquire_task_reg(build_request(request, request.get_ref().clone())) + .await?; + Ok(()) +} + +async fn forward_check_auth( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .check_auth(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_cgrpc( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .cgrpc(build_request(request, request.get_ref().clone())) + .await +} + +async fn broadcast_publish( + nodes: Vec>, + request: &Request, +) -> Result, Status> { + let mut last_not_found = None; + for node in nodes { + match forward_publish(node, request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => { + last_not_found = Some(status); + } + Err(status) => return Err(status), + } + } + + Err(last_not_found.unwrap_or_else(|| Status::not_found("Task not found"))) +} + +async fn broadcast_delete( + nodes: Vec>, + request: &Request, +) -> Result, Status> { + let mut last_not_found = None; + for node in nodes { + match forward_delete(node, request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => { + last_not_found = Some(status); + } + Err(status) => return Err(status), + } + } + + Err(last_not_found.unwrap_or_else(|| Status::not_found("Task not found"))) +} + +fn unavailable() -> Status { + Status::unavailable("No healthy backend nodes available") +} + +fn select_create_candidate(mut candidates: Vec>) -> Option> { + if candidates.is_empty() { + return None; + } + let mut rng = thread_rng(); + candidates.shuffle(&mut rng); + if candidates.len() == 1 { + return candidates.into_iter().next(); + } + + let left = candidates[0].clone(); + let right = candidates[1].clone(); + let left_load = left.in_flight_create.load(Ordering::SeqCst); + let right_load = right.in_flight_create.load(Ordering::SeqCst); + if left_load < right_load { + Some(left) + } else if right_load < left_load { + Some(right) + } else if left.node_id <= right.node_id { + Some(left) + } else { + Some(right) + } +} + +#[tonic::async_trait] +impl Engine for ProxyService { + async fn aquire_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key(&request.get_ref().task_id)?; + let mut candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + + candidates.shuffle(&mut thread_rng()); + let hops = min( + self.state.config.max_acquire_hops as usize, + candidates.len(), + ); + for node in candidates.into_iter().take(hops) { + match forward_acquire(node, &request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => continue, + Err(status) => return Err(status), + } + } + + Err(Status::not_found("No queued tasks available")) + } + + async fn aquire_task_reg( + &self, + request: Request, + ) -> Result, Status> { + let nodes = self.state.healthy_nodes().await; + let Some(probe_node) = nodes.first().cloned() else { + return Err(unavailable()); + }; + forward_task_registry_probe(probe_node, &request).await?; + + let mut tasks: Vec = nodes + .iter() + .flat_map(|node| node.tasks.iter().cloned()) + .collect(); + tasks.sort(); + tasks.dedup(); + Ok(Response::new(proto::TaskRegistry { tasks })) + } + + async fn publish_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key(&request.get_ref().task_id)?; + if let Some((owner_node, _)) = parse_owner_node(&request.get_ref().id) { + if let Some(node) = self.state.node_by_id(owner_node).await { + return forward_publish(node, &request).await; + } + } + + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + broadcast_publish(candidates, &request).await + } + + async fn cgrpc( + &self, + request: Request, + ) -> Result, Status> { + let Some(node) = self + .state + .preferred_nodes_by_tag("control") + .await + .into_iter() + .next() + else { + return Err(unavailable()); + }; + forward_cgrpc(node, &request).await + } + + async fn create_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key(&request.get_ref().task_id)?; + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + let Some(node) = select_create_candidate(candidates) else { + return Err(unavailable()); + }; + + let _guard = InFlightCreateGuard::new(node.clone()); + forward_create(node, &request).await + } + + async fn delete_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = format!("{}:{}", request.get_ref().namespace, request.get_ref().task); + if let Some((owner_node, _)) = parse_owner_node(&request.get_ref().id) { + if let Some(node) = self.state.node_by_id(owner_node).await { + return forward_delete(node, &request).await; + } + } + + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + broadcast_delete(candidates, &request).await + } + + async fn get_tasks( + &self, + request: Request, + ) -> Result, Status> { + let task_key = format!("{}:{}", request.get_ref().namespace, request.get_ref().task); + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + + let requested_end = request + .get_ref() + .page + .saturating_add(1) + .saturating_mul(request.get_ref().page_size as u64); + let fanout_limit = min(requested_end, self.state.config.admin_fanout_limit as u64) as u32; + + let mut tasks = Vec::new(); + let mut join_set = tokio::task::JoinSet::new(); + for node in candidates { + let outbound_request = build_request(&request, request.get_ref().clone()); + join_set.spawn(async move { + forward_get_tasks(node, &outbound_request, 0, fanout_limit).await + }); + } + + while let Some(joined) = join_set.join_next().await { + match joined { + Ok(Ok(page)) => tasks.extend(page.tasks), + Ok(Err(status)) => return Err(status), + Err(err) => { + return Err(Status::internal(format!( + "proxy get_tasks join error: {err}" + ))); + } + } + } + + tasks.sort_by(|left, right| left.id.cmp(&right.id)); + let start = request + .get_ref() + .page + .saturating_mul(request.get_ref().page_size as u64) as usize; + let end = start.saturating_add(request.get_ref().page_size as usize); + let page_tasks = tasks + .into_iter() + .skip(start) + .take(end.saturating_sub(start)) + .collect(); + + Ok(Response::new(proto::TaskPage { + namespace: request.get_ref().namespace.clone(), + task: request.get_ref().task.clone(), + page: request.get_ref().page, + page_size: request.get_ref().page_size, + state: request.get_ref().state, + tasks: page_tasks, + })) + } + + async fn check_auth( + &self, + request: Request, + ) -> Result, Status> { + let Some(node) = self + .state + .preferred_nodes_by_tag("auth") + .await + .into_iter() + .next() + else { + return Err(unavailable()); + }; + forward_check_auth(node, &request).await + } +} + +#[tonic::async_trait] +impl Cluster for ProxyService { + async fn register_node( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let node = request.get_ref(); + if node.node_id.trim().is_empty() { + return Err(Status::invalid_argument("node_id must not be empty")); + } + if node.node_id.contains('@') { + return Err(Status::invalid_argument("node_id must not contain '@'")); + } + + let session_id = druid::Druid::default().to_hex(); + let state = NodeState::new( + node.node_id.clone(), + node.advertise_addr.clone(), + node.tags.clone(), + node.tasks.clone(), + session_id.clone(), + now_unix(), + ) + .map_err(Status::invalid_argument)?; + self.state.upsert_node(state).await; + + Ok(Response::new(proto::NodeRegisterResponse { + session_id, + heartbeat_interval_seconds: self.state.config.heartbeat_interval_seconds, + ttl_seconds: self.state.config.node_ttl_seconds, + })) + } + + async fn heartbeat( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let heartbeat = request.get_ref(); + let Some(node) = self.state.node_by_id(&heartbeat.node_id).await else { + return Err(Status::not_found("Unknown node")); + }; + if node.session_id != heartbeat.session_id { + return Err(Status::permission_denied("Invalid cluster session")); + } + + node.last_seen_unix.store(now_unix(), Ordering::Relaxed); + Ok(Response::new(proto::NodeHeartbeatAck { + server_time_unix: now_unix(), + })) + } + + async fn list_nodes( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let nodes = self + .state + .healthy_nodes() + .await + .into_iter() + .map(|node| proto::NodeInfo { + node_id: node.node_id.clone(), + advertise_addr: node.advertise_addr.clone(), + tags: node.tags.clone(), + tasks: node.tasks.clone(), + last_seen_unix: node.last_seen_unix(), + healthy: true, + }) + .collect(); + Ok(Response::new(proto::NodeList { nodes })) + } +} diff --git a/engine/tests/common/mod.rs b/engine/tests/common/mod.rs new file mode 100644 index 0000000..36c128c --- /dev/null +++ b/engine/tests/common/mod.rs @@ -0,0 +1,291 @@ +use std::{error::Error, sync::Arc, time::Duration}; + +use engine::{ + cluster_client::{registration_from_api, spawn_registration}, + proto::{self, cluster_client::ClusterClient, engine_client::EngineClient}, + proxy_config::{ProxyConfigToml, RouteRuleToml}, + routing::ProxyState, + service::{ + backend::BackendEngineService, + proxy::{ProxyService, spawn_reaper}, + }, +}; +use enginelib::{ + Registry, + api::{EngineAPI, postcard}, + event::register_inventory_handlers, + events::ID, + task::{Task, Verifiable}, +}; +use serde::{Deserialize, Serialize}; +use tokio::{ + net::TcpListener, + sync::{RwLock, oneshot}, + task::JoinHandle, + time::{sleep, timeout}, +}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::{Request, transport::Server}; + +pub type BoxError = Box; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TestTask { + pub value: u32, + pub id: (String, String), +} + +impl Verifiable for TestTask { + fn verify(&self, bytes: Vec) -> bool { + postcard::from_bytes::(&bytes).is_ok() + } +} + +impl Task for TestTask { + fn get_id(&self) -> (String, String) { + self.id.clone() + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn to_bytes(&self) -> Vec { + postcard::to_allocvec(self).unwrap() + } + + fn from_bytes(&self, bytes: &[u8]) -> Box { + Box::new(postcard::from_bytes::(bytes).unwrap()) + } + + fn from_toml(&self, _d: String) -> Box { + Box::new(self.clone()) + } + + fn to_toml(&self) -> String { + String::new() + } +} + +pub struct TestBackend { + pub api: Arc>, + shutdown: Option>, + server_task: JoinHandle>, + registration_task: Option>, +} + +impl TestBackend { + pub async fn shutdown(mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + if let Some(task) = self.registration_task.take() { + task.abort(); + } + let _ = self.server_task.await; + } +} + +pub struct TestProxy { + pub addr: String, + shutdown: Option>, + server_task: JoinHandle>, + reaper_task: JoinHandle<()>, +} + +impl TestProxy { + pub async fn shutdown(mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + self.reaper_task.abort(); + let _ = self.server_task.await; + } +} + +pub async fn spawn_proxy() -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let config = ProxyConfigToml { + listen: addr.to_string(), + cluster_token: "cluster-secret".into(), + node_ttl_seconds: 2, + heartbeat_interval_seconds: 1, + max_acquire_hops: 3, + admin_fanout_limit: 5000, + rules: vec![ + RouteRuleToml { + r#match: "ml:*".into(), + require_tags: vec!["gpu".into()], + }, + RouteRuleToml { + r#match: "*:*".into(), + require_tags: Vec::new(), + }, + ], + }; + let state = Arc::new(ProxyState::new(config)?); + let reaper_task = spawn_reaper(state.clone()); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let cluster_service = ProxyService::new(state.clone()); + let engine_service = ProxyService::new(state); + let server_task = tokio::spawn(async move { + Server::builder() + .add_service(proto::cluster_server::ClusterServer::new(cluster_service)) + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + }); + + Ok(TestProxy { + addr: format!("http://{addr}"), + shutdown: Some(shutdown_tx), + server_task, + reaper_task, + }) +} + +pub async fn spawn_backend( + node_id: &str, + tags: &[&str], + proxy_addr: &str, +) -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let mut api = EngineAPI::test_default(); + register_inventory_handlers(&mut api); + api.cfg.config_toml.host = addr.to_string(); + api.cfg.config_toml.node_id = Some(node_id.to_string()); + api.cfg.config_toml.advertise_addr = Some(format!("http://{addr}")); + api.cfg.config_toml.cluster_proxy_addr = Some(proxy_addr.to_string()); + api.cfg.config_toml.cluster_token = Some("cluster-secret".into()); + api.cfg.config_toml.node_tags = Some(tags.iter().map(|tag| (*tag).to_string()).collect()); + + register_task(&mut api, "dist", "work"); + register_task(&mut api, "ml", "train"); + register_task(&mut api, "node", node_id); + + let api = Arc::new(RwLock::new(api)); + let registration_task = { + let api_guard = api.read().await; + registration_from_api(&api_guard)? + } + .map(spawn_registration); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let engine_service = BackendEngineService::new(api.clone()); + let server_task = tokio::spawn(async move { + Server::builder() + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + }); + + wait_for_node(proxy_addr, node_id).await?; + + Ok(TestBackend { + api, + shutdown: Some(shutdown_tx), + server_task, + registration_task, + }) +} + +pub fn task_bytes(value: u32, namespace: &str, task: &str) -> Vec { + TestTask { + value, + id: ID(namespace, task), + } + .to_bytes() +} + +pub async fn engine_client( + addr: &str, +) -> Result, tonic::transport::Error> { + EngineClient::connect(addr.to_string()).await +} + +pub async fn cluster_client( + addr: &str, +) -> Result, tonic::transport::Error> { + ClusterClient::connect(addr.to_string()).await +} + +pub fn worker_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request.metadata_mut().insert("uid", "worker-1".parse()?); + request + .metadata_mut() + .insert("authorization", "worker-token".parse()?); + Ok(request) +} + +pub fn admin_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request + .metadata_mut() + .insert("authorization", "admin-token".parse()?); + Ok(request) +} + +pub fn cluster_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request + .metadata_mut() + .insert("authorization", "Bearer cluster-secret".parse()?); + Ok(request) +} + +pub async fn wait_for_node(proxy_addr: &str, node_id: &str) -> Result<(), BoxError> { + timeout(Duration::from_secs(5), async move { + loop { + let mut client = cluster_client(proxy_addr).await?; + let response = client.list_nodes(cluster_request(proto::Empty {})?).await?; + if response + .get_ref() + .nodes + .iter() + .any(|node| node.node_id == node_id) + { + return Ok::<(), BoxError>(()); + } + sleep(Duration::from_millis(50)).await; + } + }) + .await??; + Ok(()) +} + +pub async fn wait_for_node_count(proxy_addr: &str, expected: usize) -> Result<(), BoxError> { + timeout(Duration::from_secs(5), async move { + loop { + let mut client = cluster_client(proxy_addr).await?; + let response = client.list_nodes(cluster_request(proto::Empty {})?).await?; + if response.get_ref().nodes.len() == expected { + return Ok::<(), BoxError>(()); + } + sleep(Duration::from_millis(50)).await; + } + }) + .await??; + Ok(()) +} + +fn register_task(api: &mut EngineAPI, namespace: &str, task: &str) { + let id = ID(namespace, task); + api.task_registry.register( + Arc::new(TestTask { + value: 0, + id: id.clone(), + }), + id.clone(), + ); + api.task_queue.tasks.entry(id.clone()).or_default(); + api.executing_tasks.tasks.entry(id.clone()).or_default(); + api.solved_tasks.tasks.entry(id).or_default(); +} diff --git a/engine/tests/proxy_cluster.rs b/engine/tests/proxy_cluster.rs new file mode 100644 index 0000000..0a55bdc --- /dev/null +++ b/engine/tests/proxy_cluster.rs @@ -0,0 +1,507 @@ +mod common; + +use common::{ + BoxError, admin_request, cluster_client, cluster_request, engine_client, spawn_backend, + spawn_proxy, task_bytes, wait_for_node_count, worker_request, +}; +use engine::proto; +use enginelib::{ + chrono::Utc, + events::ID, + task::{StoredExecutingTask, StoredTask}, +}; +use tokio::task::JoinSet; +use tonic::Code; + +#[tokio::test] +async fn create_task_distributes_across_nodes() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut join_set = JoinSet::new(); + for value in 0..32 { + let proxy_addr = proxy.addr.clone(); + join_set.spawn(async move { + let mut client = engine_client(&proxy_addr).await.unwrap(); + client + .create_task( + worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(value, "dist", "work"), + payload: Vec::new(), + }) + .unwrap(), + ) + .await + }); + } + while let Some(result) = join_set.join_next().await { + result??; + } + + let queued_a = backend_a + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued_b = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + assert!(queued_a > 0, "expected node-a to receive tasks"); + assert!(queued_b > 0, "expected node-b to receive tasks"); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn aquire_task_retries_until_a_node_has_work() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredTask { + id: "node-b@queued-1".into(), + bytes: task_bytes(1, "dist", "work"), + }); + + let mut client = engine_client(&proxy.addr).await?; + let response = client + .aquire_task(worker_request(proto::TaskRequest { + task_id: "dist:work".into(), + })?) + .await? + .into_inner(); + + assert_eq!(response.id, "node-b@queued-1"); + assert_eq!( + backend_b + .api + .read() + .await + .executing_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(), + 1 + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn publish_task_routes_by_owner_prefix() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let created = client + .create_task(worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(7, "dist", "work"), + payload: Vec::new(), + })?) + .await? + .into_inner(); + let acquired = client + .aquire_task(worker_request(proto::TaskRequest { + task_id: "dist:work".into(), + })?) + .await? + .into_inner(); + assert_eq!(created.id, acquired.id); + + client + .publish_task(worker_request(proto::Task { + id: acquired.id.clone(), + task_id: acquired.task_id.clone(), + task_payload: task_bytes(99, "dist", "work"), + payload: Vec::new(), + })?) + .await?; + + let owner = acquired.id.split('@').next().unwrap(); + let solved_a = backend_a + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let solved_b = backend_b + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + match owner { + "node-a" => { + assert_eq!(solved_a, 1); + assert_eq!(solved_b, 0); + } + "node-b" => { + assert_eq!(solved_a, 0); + assert_eq!(solved_b, 1); + } + _ => panic!("unexpected owner {owner}"), + } + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn delete_task_routes_by_owner_prefix() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let created = client + .create_task(worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(8, "dist", "work"), + payload: Vec::new(), + })?) + .await? + .into_inner(); + + client + .delete_task(admin_request(proto::TaskSelector { + state: proto::TaskState::Queued as i32, + namespace: "dist".into(), + task: "work".into(), + id: created.id.clone(), + })?) + .await?; + + let owner = created.id.split('@').next().unwrap(); + let queued_a = backend_a + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued_b = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + match owner { + "node-a" => assert_eq!(queued_a + queued_b, 0), + "node-b" => assert_eq!(queued_a + queued_b, 0), + _ => panic!("unexpected owner {owner}"), + } + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn aquire_task_reg_returns_union_of_registered_tasks() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let response = client + .aquire_task_reg(worker_request(proto::Empty {})?) + .await? + .into_inner(); + + assert_eq!( + response.tasks, + vec![ + "dist:work".to_string(), + "ml:train".to_string(), + "node:node-a".to_string(), + "node:node-b".to_string(), + ] + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn get_tasks_fanout_merges_and_paginates() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_a + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .extend([ + StoredTask { + id: "node-a@0001".into(), + bytes: task_bytes(1, "dist", "work"), + }, + StoredTask { + id: "node-a@0003".into(), + bytes: task_bytes(3, "dist", "work"), + }, + ]); + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .extend([ + StoredTask { + id: "node-b@0002".into(), + bytes: task_bytes(2, "dist", "work"), + }, + StoredTask { + id: "node-b@0004".into(), + bytes: task_bytes(4, "dist", "work"), + }, + ]); + + let mut client = engine_client(&proxy.addr).await?; + let page0 = client + .get_tasks(admin_request(proto::TaskPageRequest { + namespace: "dist".into(), + task: "work".into(), + page: 0, + page_size: 2, + state: proto::TaskState::Queued as i32, + })?) + .await? + .into_inner(); + let page1 = client + .get_tasks(admin_request(proto::TaskPageRequest { + namespace: "dist".into(), + task: "work".into(), + page: 1, + page_size: 2, + state: proto::TaskState::Queued as i32, + })?) + .await? + .into_inner(); + + assert_eq!( + page0 + .tasks + .iter() + .map(|task| task.id.clone()) + .collect::>(), + vec!["node-a@0001".to_string(), "node-a@0003".to_string()] + ); + assert_eq!( + page1 + .tasks + .iter() + .map(|task| task.id.clone()) + .collect::>(), + vec!["node-b@0002".to_string(), "node-b@0004".to_string()] + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn cluster_membership_expires_after_heartbeats_stop() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + wait_for_node_count(&proxy.addr, 2).await?; + backend_b.shutdown().await; + wait_for_node_count(&proxy.addr, 1).await?; + + let mut client = cluster_client(&proxy.addr).await?; + let nodes = client + .list_nodes(cluster_request(proto::Empty {})?) + .await? + .into_inner() + .nodes; + assert_eq!(nodes.len(), 1); + assert_eq!(nodes[0].node_id, "node-a"); + + backend_a.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn reregister_replaces_prior_session_id() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let mut client = cluster_client(&proxy.addr).await?; + + let first = client + .register_node(cluster_request(proto::NodeRegisterRequest { + node_id: "node-z".into(), + advertise_addr: "http://127.0.0.1:59999".into(), + tags: vec!["gpu".into()], + tasks: vec!["dist:work".into()], + })?) + .await? + .into_inner(); + let second = client + .register_node(cluster_request(proto::NodeRegisterRequest { + node_id: "node-z".into(), + advertise_addr: "http://127.0.0.1:59999".into(), + tags: vec!["gpu".into()], + tasks: vec!["dist:work".into()], + })?) + .await? + .into_inner(); + + assert_ne!(first.session_id, second.session_id); + + let old_session = client + .heartbeat(cluster_request(proto::NodeHeartbeat { + node_id: "node-z".into(), + session_id: first.session_id, + })?) + .await + .unwrap_err(); + assert_eq!(old_session.code(), Code::PermissionDenied); + + client + .heartbeat(cluster_request(proto::NodeHeartbeat { + node_id: "node-z".into(), + session_id: second.session_id, + })?) + .await?; + + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn legacy_ids_use_broadcast_fallback_for_publish_and_delete() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_b + .api + .write() + .await + .executing_tasks + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredExecutingTask { + bytes: task_bytes(11, "dist", "work"), + id: "legacy-1".into(), + user_id: "worker-1".into(), + given_at: Utc::now(), + }); + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredTask { + id: "legacy-2".into(), + bytes: task_bytes(12, "dist", "work"), + }); + + let mut client = engine_client(&proxy.addr).await?; + client + .publish_task(worker_request(proto::Task { + id: "legacy-1".into(), + task_id: "dist:work".into(), + task_payload: task_bytes(111, "dist", "work"), + payload: Vec::new(), + })?) + .await?; + client + .delete_task(admin_request(proto::TaskSelector { + state: proto::TaskState::Queued as i32, + namespace: "dist".into(), + task: "work".into(), + id: "legacy-2".into(), + })?) + .await?; + + let solved = backend_b + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + assert_eq!(solved, 1); + assert_eq!(queued, 0); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} diff --git a/enginelib/src/config.rs b/enginelib/src/config.rs index 4e2a18a..89156e9 100644 --- a/enginelib/src/config.rs +++ b/enginelib/src/config.rs @@ -25,6 +25,16 @@ pub struct ConfigTomlServer { pub clean_tasks: u64, #[serde(default = "default_pagination_limit")] pub pagination_limit: u32, + #[serde(default)] + pub node_id: Option, + #[serde(default)] + pub advertise_addr: Option, + #[serde(default)] + pub cluster_proxy_addr: Option, + #[serde(default)] + pub cluster_token: Option, + #[serde(default)] + pub node_tags: Option>, } impl Default for ConfigTomlServer { fn default() -> Self { @@ -33,6 +43,11 @@ impl Default for ConfigTomlServer { cgrpc_token: None, clean_tasks: 60, pagination_limit: u32::MAX, + node_id: None, + advertise_addr: None, + cluster_proxy_addr: None, + cluster_token: None, + node_tags: None, } } } diff --git a/rfc/rfc1004.md b/rfc/rfc1004.md new file mode 100644 index 0000000..632c245 --- /dev/null +++ b/rfc/rfc1004.md @@ -0,0 +1,157 @@ +# RFC 1004: Proxy and Clustering + +## Summary + +Add a gRPC proxy in front of engine backends plus a lightweight cluster membership protocol so: + +- clients and workers talk to one stable endpoint, +- tasks can be routed to capability-specific nodes, +- task instances can be distributed across multiple nodes, +- `PublishTask` and `DeleteTask` can route back to the owning node without sticky proxy state. + +The design keeps backend storage shard-local. Each backend keeps its own sled DB and dynamically registers itself with one or more stateless proxies. + +## Motivation + +The current engine exposes a single backend directly. That creates three constraints: + +- all workers must know about a specific backend, +- there is no built-in way to spread the same `namespace:task` across multiple nodes, +- follow-up RPCs like `PublishTask` depend on whichever backend originally created the task. + +By making `Task.id` self-routing and adding in-memory cluster membership, the proxy can stay stateless while still routing task lifecycle RPCs correctly. + +## Design + +### Task instance IDs + +`Task.id` is now the task instance identifier and, when clustering is enabled, is minted as: + +```text +{node_id}@{random_hex} +``` + +Legacy ids without `@` remain supported. The proxy falls back to broadcast routing for `PublishTask` and `DeleteTask` when it sees a legacy id. + +### Cluster service + +`engine/proto/engine.proto` now contains a `Cluster` gRPC service with: + +- `RegisterNode` +- `Heartbeat` +- `ListNodes` + +These RPCs are used only between the proxy and backends. All three require `authorization: Bearer `. + +### Backend registration + +On startup, a backend: + +1. initializes `EngineAPI`, +2. loads tasks and modules, +3. snapshots its local task registry into `namespace:task` strings, +4. registers with the proxy, +5. sends heartbeats on the server-provided interval. + +If heartbeats fail, the backend reconnects and re-registers to get a fresh session id. + +### Proxy routing + +The proxy stores node membership entirely in memory and applies first-match routing rules from `proxy.toml`. + +Task-specific RPCs: + +- filter to healthy nodes, +- require the node to advertise the task key, +- require the node to satisfy the matched rule tags. + +Routing strategies: + +- `CreateTask`: power-of-two choices on in-flight create counts +- `AquireTask`: random candidate order, retry on `not_found` +- `PublishTask`: owner-node route when `Task.id` is prefixed, otherwise broadcast fallback +- `DeleteTask`: owner-node route when selector `id` is prefixed, otherwise broadcast fallback +- `GetTasks`: fan-out, merge by lexicographic task id, paginate after merge +- `AquireTaskReg`: sorted unique union of registered tasks +- `CheckAuth`: prefer nodes tagged `auth` +- `cgrpc`: prefer nodes tagged `control` + +### Config + +Backend `config.toml` adds: + +- `node_id` +- `advertise_addr` +- `cluster_proxy_addr` +- `cluster_token` +- `node_tags` + +Proxy `proxy.toml` contains: + +- `listen` +- `cluster_token` +- `node_ttl_seconds` +- `heartbeat_interval_seconds` +- `max_acquire_hops` +- `admin_fanout_limit` +- ordered `rules` + +Example proxy config: + +```toml +listen = "0.0.0.0:50052" +cluster_token = "cluster-secret" +node_ttl_seconds = 15 +heartbeat_interval_seconds = 5 +max_acquire_hops = 3 +admin_fanout_limit = 5000 + +[[rules]] +match = "ml:*" +require_tags = ["gpu"] + +[[rules]] +match = "*:*" +require_tags = [] +``` + +Example backend config additions: + +```toml +host = "127.0.0.1:50051" +node_id = "gpu-1" +advertise_addr = "http://127.0.0.1:50051" +cluster_proxy_addr = "http://127.0.0.1:50052" +cluster_token = "cluster-secret" +node_tags = ["gpu", "control"] +``` + +## Backwards compatibility + +This RFC is additive for clients: + +- the existing `Engine` service surface is unchanged, +- task ids remain opaque to workers, +- legacy ids without `@` still work through proxy broadcast fallback. + +Proxy membership is intentionally in-memory only. If a backend disappears, its local tasks are unavailable until it returns. + +## Alternatives + +1. Keep clients pinned to specific backends. + This avoids a proxy, but it does not solve distribution or owner routing. + +2. Store proxy membership in shared durable storage. + This would support persistent coordination, but it is not required for the current shard-local design. + +3. Add distributed task migration or failover. + This would improve recovery, but it is explicitly out of scope for this RFC. + +## Implementation plan + +1. Refactor backend `Engine` RPC handling into reusable library services. +2. Add proxy config loading, in-memory node membership, and cluster-authenticated registration. +3. Introduce a proxy binary implementing both `Engine` and `Cluster`. +4. Update backend startup to validate cluster config and start registration plus heartbeats. +5. Change backend task lifecycle handling so `Task.id` is always the instance id. +6. Add integration tests for distribution, routing, pagination, membership expiry, re-registration, and legacy fallback. From 89e3332ce4de66f0dec5a465d8e50ca0cbd0d2aa Mon Sep 17 00:00:00 2001 From: IGN-Styly Date: Thu, 12 Mar 2026 10:30:17 +0000 Subject: [PATCH 2/3] Harden proxy routing and background task shutdown - Add shared task ID parsing utilities and use them across backend/proxy services - Track and react to reaper/registration task lifecycle with cancellation-aware loops - Improve cluster test stability (retry node-list polling, deterministic task list ordering) - Redact sensitive config tokens in debug output --- Cargo.lock | 5 ++-- engine/Cargo.toml | 1 + engine/src/bin/proxy.rs | 21 ++++++++++---- engine/src/bin/server.rs | 5 +++- engine/src/cluster_client.rs | 24 +++++++++++++--- engine/src/lib.rs | 1 + engine/src/routing.rs | 9 +----- engine/src/service/backend.rs | 21 ++------------ engine/src/service/proxy.rs | 21 +++----------- engine/src/task_id.rs | 23 +++++++++++++++ engine/tests/common/mod.rs | 53 +++++++++++++++++++++++++++-------- engine/tests/proxy_cluster.rs | 8 ++++-- enginelib/macros/src/lib.rs | 15 ++++++---- enginelib/src/api.rs | 24 ++++++++-------- enginelib/src/config.rs | 27 ++++++++++++++++-- enginelib/src/events/mod.rs | 7 ++++- rfc/rfc1004.md | 2 +- 17 files changed, 177 insertions(+), 90 deletions(-) create mode 100644 engine/src/task_id.rs diff --git a/Cargo.lock b/Cargo.lock index 1b13ca9..312984f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -595,6 +595,7 @@ dependencies = [ "serde", "tokio", "tokio-stream", + "tokio-util", "toml", "tonic", "tonic-build", @@ -2832,9 +2833,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.16" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", diff --git a/engine/Cargo.toml b/engine/Cargo.toml index eca766c..3b8f47b 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -22,6 +22,7 @@ serde = { workspace = true } # serde = "1.0.219" tokio = { version = "1.50.0", features = ["rt-multi-thread", "macros"] } tokio-stream = { version = "0.1.17", features = ["net"] } +tokio-util = "0.7.17" toml = { workspace = true } # toml = "0.8.19" tonic = "0.14" diff --git a/engine/src/bin/proxy.rs b/engine/src/bin/proxy.rs index 996dcd3..90aaf53 100644 --- a/engine/src/bin/proxy.rs +++ b/engine/src/bin/proxy.rs @@ -15,7 +15,7 @@ async fn main() -> Result<(), Box> { let config = ProxyConfigToml::load()?; let listener = TcpListener::bind(config.listen.as_str()).await?; let state = Arc::new(ProxyState::new(config)?); - let _reaper = spawn_reaper(state.clone()); + let reaper_handle = spawn_reaper(state.clone()); let cluster_service = ProxyService::new(state.clone()); let engine_service = ProxyService::new(state); @@ -24,13 +24,24 @@ async fn main() -> Result<(), Box> { .build_v1() .map_err(|e| Box::new(e) as Box)?; - Server::builder() + let server = Server::builder() .add_service(reflection_service) .add_service(proto::cluster_server::ClusterServer::new(cluster_service)) .add_service(proto::engine_server::EngineServer::new(engine_service)) - .serve_with_incoming(TcpListenerStream::new(listener)) - .await - .map_err(|e| Box::new(e) as Box)?; + .serve_with_incoming(TcpListenerStream::new(listener)); + tokio::pin!(server); + + tokio::select! { + result = &mut server => { + result.map_err(|e| Box::new(e) as Box)?; + } + result = reaper_handle => { + match result { + Ok(()) => return Err("proxy reaper exited unexpectedly".into()), + Err(err) => return Err(format!("proxy reaper failed: {err}").into()), + } + } + } Ok(()) } diff --git a/engine/src/bin/server.rs b/engine/src/bin/server.rs index 0c11598..bba46b2 100644 --- a/engine/src/bin/server.rs +++ b/engine/src/bin/server.rs @@ -10,6 +10,7 @@ use std::{ }; use tokio::{net::TcpListener, sync::RwLock}; use tokio_stream::wrappers::TcpListenerStream; +use tokio_util::sync::CancellationToken; use tonic::transport::Server; #[tokio::main] @@ -35,7 +36,9 @@ async fn main() -> Result<(), Box> { let api_guard = api.read().await; registration_from_api(&api_guard)? }; - let _registration_task = registration.map(spawn_registration); + let _registration_shutdown = CancellationToken::new(); + let _registration_task = registration + .map(|registration| spawn_registration(registration, _registration_shutdown.clone())); EngineAPI::init_chron(api.clone()); let engine = BackendEngineService::new(api); diff --git a/engine/src/cluster_client.rs b/engine/src/cluster_client.rs index 146bf20..33348ea 100644 --- a/engine/src/cluster_client.rs +++ b/engine/src/cluster_client.rs @@ -2,6 +2,7 @@ use std::time::Duration; use enginelib::api::EngineAPI; use tokio::{task::JoinHandle, time::sleep}; +use tokio_util::sync::CancellationToken; use tonic::{Request, transport::Endpoint}; use tracing::{error, info, warn}; @@ -83,9 +84,15 @@ pub fn normalize_advertise_addr( } } -pub fn spawn_registration(registration: NodeRegistration) -> JoinHandle<()> { +pub fn spawn_registration( + registration: NodeRegistration, + shutdown: CancellationToken, +) -> JoinHandle<()> { tokio::spawn(async move { loop { + if shutdown.is_cancelled() { + return; + } match ClusterClient::connect(registration.cluster_proxy_addr.clone()).await { Ok(mut client) => { let register_response = match register_node(&mut client, ®istration).await { @@ -95,7 +102,10 @@ pub fn spawn_registration(registration: NodeRegistration) -> JoinHandle<()> { "cluster registration failed for {}: {}", registration.node_id, err ); - sleep(Duration::from_secs(1)).await; + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(1)) => {} + } continue; } }; @@ -109,7 +119,10 @@ pub fn spawn_registration(registration: NodeRegistration) -> JoinHandle<()> { let interval_seconds = register_response.heartbeat_interval_seconds.max(1); loop { - sleep(Duration::from_secs(interval_seconds)).await; + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(interval_seconds)) => {} + } match heartbeat(&mut client, ®istration, &session_id).await { Ok(_) => {} Err(err) => { @@ -127,7 +140,10 @@ pub fn spawn_registration(registration: NodeRegistration) -> JoinHandle<()> { "failed to connect backend node {} to cluster proxy {}: {}", registration.node_id, registration.cluster_proxy_addr, err ); - sleep(Duration::from_secs(1)).await; + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(1)) => {} + } } } } diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 9d09966..62121ed 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -8,6 +8,7 @@ pub mod proto; pub mod proxy_config; pub mod routing; pub mod service; +pub mod task_id; pub fn get_uid(req: &Request) -> String { req.metadata() diff --git a/engine/src/routing.rs b/engine/src/routing.rs index 4100071..d40a7eb 100644 --- a/engine/src/routing.rs +++ b/engine/src/routing.rs @@ -119,7 +119,7 @@ pub struct ProxyState { impl ProxyState { pub fn new(config: ProxyConfigToml) -> Result { - let mut rules = if config.rules.is_empty() { + let rules = if config.rules.is_empty() { vec![RouteRule::from_toml(RouteRuleToml { r#match: "*:*".into(), require_tags: Vec::new(), @@ -132,13 +132,6 @@ impl ProxyState { rules }; - if rules.is_empty() { - rules.push(RouteRule::from_toml(RouteRuleToml { - r#match: "*:*".into(), - require_tags: Vec::new(), - })?); - } - Ok(Self { config, rules, diff --git a/engine/src/service/backend.rs b/engine/src/service/backend.rs index ba139dc..1325169 100644 --- a/engine/src/service/backend.rs +++ b/engine/src/service/backend.rs @@ -5,7 +5,7 @@ use std::{ use enginelib::api::postcard; use enginelib::{ - Identifier, RawIdentier, Registry, + RawIdentier, Registry, api::EngineAPI, chrono::Utc, events::{Events, ID}, @@ -18,6 +18,7 @@ use tracing::{debug, info, warn}; use crate::{ get_auth, get_uid, proto::{self, TaskState, engine_server::Engine}, + task_id::parse_task_key, }; #[allow(non_snake_case)] @@ -47,22 +48,6 @@ pub fn parse_owner_node(task_instance_id: &str) -> Option<(&str, &str)> { Some((node_id, local_id)) } -fn parse_task_key(task_id: &str) -> Result { - let Some((namespace, task)) = task_id.split_once(':') else { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - }; - - if namespace.is_empty() || task.is_empty() { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - } - - Ok(ID(namespace, task)) -} - fn delete_task_from_collection( collection: &mut HashMap<(String, String), Vec>, id: &(String, String), @@ -444,7 +429,7 @@ impl Engine for BackendEngineService { let instance_id = task.id.clone(); if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!("Aquire Task denied due to Invalid Auth"); + info!("Publish Task denied due to Invalid Auth"); return Err(Status::permission_denied("Invalid authentication")); } if !api.task_registry.tasks.contains_key(&key) { diff --git a/engine/src/service/proxy.rs b/engine/src/service/proxy.rs index 9c12b5e..5c9f21f 100644 --- a/engine/src/service/proxy.rs +++ b/engine/src/service/proxy.rs @@ -9,6 +9,7 @@ use crate::{ proto::{self, cluster_server::Cluster, engine_client::EngineClient, engine_server::Engine}, routing::{NodeState, ProxyState, now_unix}, service::backend::parse_owner_node, + task_id::parse_task_key_string, }; use rand::{seq::SliceRandom, thread_rng}; use tokio::{task::JoinHandle, time::sleep}; @@ -58,20 +59,6 @@ fn require_cluster_auth(request: &Request, cluster_token: &str) -> Result< Ok(()) } -fn parse_task_key(task_id: &str) -> Result { - let Some((namespace, task)) = task_id.split_once(':') else { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - }; - if namespace.is_empty() || task.is_empty() { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - } - Ok(format!("{namespace}:{task}")) -} - fn build_request(request: &Request, message: T) -> Request { let mut outbound = Request::new(message); copy_metadata(request.metadata(), outbound.metadata_mut()); @@ -236,7 +223,7 @@ impl Engine for ProxyService { &self, request: Request, ) -> Result, Status> { - let task_key = parse_task_key(&request.get_ref().task_id)?; + let task_key = parse_task_key_string(&request.get_ref().task_id)?; let mut candidates = self.state.candidate_nodes_for_task(&task_key).await; if candidates.is_empty() { return Err(unavailable()); @@ -281,7 +268,7 @@ impl Engine for ProxyService { &self, request: Request, ) -> Result, Status> { - let task_key = parse_task_key(&request.get_ref().task_id)?; + let task_key = parse_task_key_string(&request.get_ref().task_id)?; if let Some((owner_node, _)) = parse_owner_node(&request.get_ref().id) { if let Some(node) = self.state.node_by_id(owner_node).await { return forward_publish(node, &request).await; @@ -315,7 +302,7 @@ impl Engine for ProxyService { &self, request: Request, ) -> Result, Status> { - let task_key = parse_task_key(&request.get_ref().task_id)?; + let task_key = parse_task_key_string(&request.get_ref().task_id)?; let candidates = self.state.candidate_nodes_for_task(&task_key).await; let Some(node) = select_create_candidate(candidates) else { return Err(unavailable()); diff --git a/engine/src/task_id.rs b/engine/src/task_id.rs new file mode 100644 index 0000000..3ec09eb --- /dev/null +++ b/engine/src/task_id.rs @@ -0,0 +1,23 @@ +use enginelib::Identifier; +use tonic::Status; + +pub fn parse_task_key(task_id: &str) -> Result { + let Some((namespace, task)) = task_id.split_once(':') else { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + }; + + if namespace.is_empty() || task.is_empty() { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + } + + Ok((namespace.to_string(), task.to_string())) +} + +pub fn parse_task_key_string(task_id: &str) -> Result { + let (namespace, task) = parse_task_key(task_id)?; + Ok(format!("{namespace}:{task}")) +} diff --git a/engine/tests/common/mod.rs b/engine/tests/common/mod.rs index 36c128c..3641f4c 100644 --- a/engine/tests/common/mod.rs +++ b/engine/tests/common/mod.rs @@ -25,6 +25,7 @@ use tokio::{ time::{sleep, timeout}, }; use tokio_stream::wrappers::TcpListenerStream; +use tokio_util::sync::CancellationToken; use tonic::{Request, transport::Server}; pub type BoxError = Box; @@ -71,6 +72,7 @@ pub struct TestBackend { pub api: Arc>, shutdown: Option>, server_task: JoinHandle>, + registration_shutdown: Option, registration_task: Option>, } @@ -79,8 +81,11 @@ impl TestBackend { if let Some(shutdown) = self.shutdown.take() { let _ = shutdown.send(()); } + if let Some(shutdown) = self.registration_shutdown.take() { + shutdown.cancel(); + } if let Some(task) = self.registration_task.take() { - task.abort(); + let _ = task.await; } let _ = self.server_task.await; } @@ -169,12 +174,6 @@ pub async fn spawn_backend( register_task(&mut api, "node", node_id); let api = Arc::new(RwLock::new(api)); - let registration_task = { - let api_guard = api.read().await; - registration_from_api(&api_guard)? - } - .map(spawn_registration); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); let engine_service = BackendEngineService::new(api.clone()); let server_task = tokio::spawn(async move { @@ -186,12 +185,20 @@ pub async fn spawn_backend( .await }); + let registration_shutdown = CancellationToken::new(); + let registration_task = { + let api_guard = api.read().await; + registration_from_api(&api_guard)? + } + .map(|registration| spawn_registration(registration, registration_shutdown.clone())); + wait_for_node(proxy_addr, node_id).await?; Ok(TestBackend { api, shutdown: Some(shutdown_tx), server_task, + registration_shutdown: Some(registration_shutdown), registration_task, }) } @@ -244,8 +251,20 @@ pub fn cluster_request(message: T) -> Result, BoxError> { pub async fn wait_for_node(proxy_addr: &str, node_id: &str) -> Result<(), BoxError> { timeout(Duration::from_secs(5), async move { loop { - let mut client = cluster_client(proxy_addr).await?; - let response = client.list_nodes(cluster_request(proto::Empty {})?).await?; + let mut client = match cluster_client(proxy_addr).await { + Ok(client) => client, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + let response = match client.list_nodes(cluster_request(proto::Empty {})?).await { + Ok(response) => response, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; if response .get_ref() .nodes @@ -264,8 +283,20 @@ pub async fn wait_for_node(proxy_addr: &str, node_id: &str) -> Result<(), BoxErr pub async fn wait_for_node_count(proxy_addr: &str, expected: usize) -> Result<(), BoxError> { timeout(Duration::from_secs(5), async move { loop { - let mut client = cluster_client(proxy_addr).await?; - let response = client.list_nodes(cluster_request(proto::Empty {})?).await?; + let mut client = match cluster_client(proxy_addr).await { + Ok(client) => client, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + let response = match client.list_nodes(cluster_request(proto::Empty {})?).await { + Ok(response) => response, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; if response.get_ref().nodes.len() == expected { return Ok::<(), BoxError>(()); } diff --git a/engine/tests/proxy_cluster.rs b/engine/tests/proxy_cluster.rs index 0a55bdc..e17d7b0 100644 --- a/engine/tests/proxy_cluster.rs +++ b/engine/tests/proxy_cluster.rs @@ -249,13 +249,15 @@ async fn aquire_task_reg_returns_union_of_registered_tasks() -> Result<(), BoxEr let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; let mut client = engine_client(&proxy.addr).await?; - let response = client + let mut tasks = client .aquire_task_reg(worker_request(proto::Empty {})?) .await? - .into_inner(); + .into_inner() + .tasks; + tasks.sort(); assert_eq!( - response.tasks, + tasks, vec![ "dist:work".to_string(), "ml:train".to_string(), diff --git a/enginelib/macros/src/lib.rs b/enginelib/macros/src/lib.rs index 5730298..bf2f01a 100644 --- a/enginelib/macros/src/lib.rs +++ b/enginelib/macros/src/lib.rs @@ -47,11 +47,13 @@ pub fn module(_attr: TokenStream, item: TokenStream) -> TokenStream { env!("CARGO_PKG_NAME"), );), ); + item_fn.block.stmts.insert( + 0, + parse_quote!(::enginelib::api::EngineAPI::setup_logger();), + ); item_fn - .block - .stmts - .insert(0, parse_quote!(::enginelib::api::EngineAPI::setup_logger();)); - item_fn.attrs.push(parse_quote!(#[unsafe(export_name="run")])); + .attrs + .push(parse_quote!(#[unsafe(export_name="run")])); quote!(#item_fn).into() } @@ -181,7 +183,10 @@ impl Parse for EventHandlerArgs { } let namespace = namespace.ok_or_else(|| { - syn::Error::new(proc_macro2::Span::call_site(), "missing `namespace = \"...\"`") + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `namespace = \"...\"`", + ) })?; let name = name.ok_or_else(|| { syn::Error::new(proc_macro2::Span::call_site(), "missing `name = \"...\"`") diff --git a/enginelib/src/api.rs b/enginelib/src/api.rs index 368eebd..0f83ba2 100644 --- a/enginelib/src/api.rs +++ b/enginelib/src/api.rs @@ -153,20 +153,20 @@ impl EngineAPI { static INIT: OnceLock<()> = OnceLock::new(); INIT.get_or_init(|| { - #[cfg(debug_assertions)] + #[cfg(debug_assertions)] let _ = tracing_subscriber::FmtSubscriber::builder() - // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) - // will be written to stdout. - .with_max_level(Level::DEBUG) - // builds the subscriber. - .try_init(); - #[cfg(not(debug_assertions))] + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::DEBUG) + // builds the subscriber. + .try_init(); + #[cfg(not(debug_assertions))] let _ = tracing_subscriber::FmtSubscriber::builder() - // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) - // will be written to stdout. - .with_max_level(Level::INFO) - // builds the subscriber. - .try_init(); + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::INFO) + // builds the subscriber. + .try_init(); }); } } diff --git a/enginelib/src/config.rs b/enginelib/src/config.rs index 89156e9..b5963bc 100644 --- a/enginelib/src/config.rs +++ b/enginelib/src/config.rs @@ -1,4 +1,4 @@ -use std::{fs, io::Error, u32}; +use std::{fmt, fs, io::Error, u32}; use serde::{Deserialize, Serialize}; use tracing::{error, instrument}; @@ -15,7 +15,7 @@ fn default_pagination_limit() -> u32 { u32::MAX } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone)] pub struct ConfigTomlServer { #[serde(default)] pub cgrpc_token: Option, // Administrator Token, used to invoke cgrpc reqs. If not preset will default to no protection. @@ -36,6 +36,29 @@ pub struct ConfigTomlServer { #[serde(default)] pub node_tags: Option>, } + +impl fmt::Debug for ConfigTomlServer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConfigTomlServer") + .field( + "cgrpc_token", + &self.cgrpc_token.as_ref().map(|_| ""), + ) + .field("host", &self.host) + .field("clean_tasks", &self.clean_tasks) + .field("pagination_limit", &self.pagination_limit) + .field("node_id", &self.node_id) + .field("advertise_addr", &self.advertise_addr) + .field("cluster_proxy_addr", &self.cluster_proxy_addr) + .field( + "cluster_token", + &self.cluster_token.as_ref().map(|_| ""), + ) + .field("node_tags", &self.node_tags) + .finish() + } +} + impl Default for ConfigTomlServer { fn default() -> Self { Self { diff --git a/enginelib/src/events/mod.rs b/enginelib/src/events/mod.rs index dbc9f54..acec1a9 100644 --- a/enginelib/src/events/mod.rs +++ b/enginelib/src/events/mod.rs @@ -22,7 +22,12 @@ impl Events { auth_event::AuthEvent::check(api, uid, challenge, db) } - pub fn CheckAdminAuth(api: &mut EngineAPI, payload: String, target: Identifier, db: Db) -> bool { + pub fn CheckAdminAuth( + api: &mut EngineAPI, + payload: String, + target: Identifier, + db: Db, + ) -> bool { admin_auth_event::AdminAuthEvent::check(api, payload, target, db) } diff --git a/rfc/rfc1004.md b/rfc/rfc1004.md index 632c245..0979483 100644 --- a/rfc/rfc1004.md +++ b/rfc/rfc1004.md @@ -17,7 +17,7 @@ The current engine exposes a single backend directly. That creates three constra - all workers must know about a specific backend, - there is no built-in way to spread the same `namespace:task` across multiple nodes, -- follow-up RPCs like `PublishTask` depend on whichever backend originally created the task. +- follow-up RPCs like `PublishTask` depend on whichever backend created the task. By making `Task.id` self-routing and adding in-memory cluster membership, the proxy can stay stateless while still routing task lifecycle RPCs correctly. From 1926381adfca94ebf79a80926bc0a74ca1d9c91d Mon Sep 17 00:00:00 2001 From: IGN-Styly Date: Thu, 12 Mar 2026 11:10:32 +0000 Subject: [PATCH 3/3] Inline backend gRPC service into server binary - Move `BackendEngineService` implementation from `service/backend.rs` into `bin/server.rs` - Switch task creation to `task_id::mint_task_instance_id` and keep task key parsing centralized - Update service module/proxy/tests to match the new backend service layout --- engine/src/bin/server.rs | 537 +++++++++++++++++++++++++++++++-- engine/src/service/backend.rs | 552 ---------------------------------- engine/src/service/mod.rs | 1 - engine/src/service/proxy.rs | 3 +- engine/src/task_id.rs | 16 + engine/tests/common/mod.rs | 10 +- 6 files changed, 539 insertions(+), 580 deletions(-) delete mode 100644 engine/src/service/backend.rs diff --git a/engine/src/bin/server.rs b/engine/src/bin/server.rs index bba46b2..622a2ef 100644 --- a/engine/src/bin/server.rs +++ b/engine/src/bin/server.rs @@ -1,17 +1,524 @@ use engine::{ cluster_client::{registration_from_api, spawn_registration}, - proto, - service::backend::BackendEngineService, + get_auth, get_uid, proto, + task_id::{mint_task_instance_id, parse_task_key}, +}; +use enginelib::api::postcard; +use enginelib::{ + RawIdentier, Registry, + api::EngineAPI, + chrono::Utc, + events::{Events, ID}, + task::{StoredExecutingTask, StoredTask}, }; -use enginelib::{api::EngineAPI, events::Events}; use std::{ - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, + collections::HashMap, + sync::{Arc, RwLock as StdRwLock}, }; use tokio::{net::TcpListener, sync::RwLock}; use tokio_stream::wrappers::TcpListenerStream; use tokio_util::sync::CancellationToken; -use tonic::transport::Server; +use tonic::{Response, Status, transport::Server}; +use tracing::{debug, info, warn}; + +#[allow(non_snake_case)] +pub struct BackendEngineService { + pub EngineAPI: Arc>, +} + +impl BackendEngineService { + pub fn new(api: Arc>) -> Self { + Self { EngineAPI: api } + } +} + +fn delete_task_from_collection( + collection: &mut HashMap<(String, String), Vec>, + id: &(String, String), + task_id: &str, + state_name: &str, + namespace: &str, + task: &str, + id_extractor: F, +) -> Result<(), Status> +where + F: Fn(&T) -> &str, +{ + match collection.get_mut(id) { + Some(query) => { + let orig_len = query.len(); + query.retain(|f| id_extractor(f) != task_id); + if query.len() == orig_len { + info!( + "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", + task_id, state_name, namespace, task + ); + return Err(Status::not_found(format!( + "Task with id {} not found in {} state", + task_id, state_name + ))); + } + Ok(()) + } + None => { + info!( + "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", + state_name, namespace, task + ); + Err(Status::not_found(format!( + "No tasks found in {} state for given namespace and task", + state_name + ))) + } + } +} + +#[tonic::async_trait] +impl proto::engine_server::Engine for BackendEngineService { + async fn check_auth( + &self, + request: tonic::Request, + ) -> Result, Status> { + let challenge = get_auth(&request); + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); + if !output { + warn!("Auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid Auth")); + } + Ok(Response::new(proto::Empty {})) + } + + async fn delete_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let data = request.get_ref(); + let challenge = get_auth(&request); + let db = api.db.clone(); + let id = ID(&data.namespace, &data.task); + + let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); + if !output { + warn!("Auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid Auth")); + } + + let result = match data.state() { + proto::TaskState::Processing => delete_task_from_collection( + &mut api.executing_tasks.tasks, + &id, + &data.id, + "Processing", + &data.namespace, + &data.task, + |f| &f.id, + ), + proto::TaskState::Solved => delete_task_from_collection( + &mut api.solved_tasks.tasks, + &id, + &data.id, + "Solved", + &data.namespace, + &data.task, + |f| &f.id, + ), + proto::TaskState::Queued => delete_task_from_collection( + &mut api.task_queue.tasks, + &id, + &data.id, + "Queued", + &data.namespace, + &data.task, + |f| &f.id, + ), + }; + + result?; + EngineAPI::sync_db(&mut api); + info!( + "DeleteTask: Successfully deleted task with id {} in state {:?} for namespace: {}, task: {}", + data.id, + data.state(), + data.namespace, + data.task + ); + Ok(Response::new(proto::Empty {})) + } + + async fn get_tasks( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let db = api.db.clone(); + if !Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db) { + info!("GetTask denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + + let data = request.get_ref(); + let q: Vec = match data.state() { + proto::TaskState::Processing => match api + .executing_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut task_refs: Vec<_> = tasks.iter().collect(); + task_refs.sort_by_key(|f| &f.id); + task_refs + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + proto::TaskState::Queued => match api + .task_queue + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + proto::TaskState::Solved => match api + .solved_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() + } + None => Vec::new(), + }, + }; + + let index = data.page.saturating_mul(data.page_size as u64) as usize; + let limit = api.cfg.config_toml.pagination_limit.min(data.page_size) as usize; + let final_vec: Vec<_> = q.iter().skip(index).take(limit).cloned().collect(); + Ok(Response::new(proto::TaskPage { + namespace: data.namespace.clone(), + task: data.task.clone(), + page: data.page, + page_size: data.page_size, + state: data.state, + tasks: final_vec, + })) + } + + async fn cgrpc( + &self, + request: tonic::Request, + ) -> Result, Status> { + info!( + "CGRPC request received for handler: {}:{}", + request.get_ref().handler_mod_id, + request.get_ref().handler_id + ); + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let db = api.db.clone(); + debug!("Checking admin authentication for CGRPC request"); + let output = Events::CheckAdminAuth( + &mut api, + challenge, + ( + request.get_ref().handler_mod_id.clone(), + request.get_ref().handler_id.clone(), + ), + db, + ); + if !output { + warn!("CGRPC auth check failed - permission denied"); + return Err(Status::permission_denied("Invalid CGRPC Auth")); + } + let out = Arc::new(StdRwLock::new(Vec::new())); + debug!("Dispatching CGRPC event to handler"); + Events::CgrpcEvent( + &mut api, + ID("engine_core", "grpc"), + request.get_ref().event_payload.clone(), + out.clone(), + ); + let mut res = request.get_ref().clone(); + res.event_payload = match out.read() { + Ok(g) => g.clone(), + Err(_) => { + warn!("CGRPC response lock poisoned, returning empty payload"); + Vec::new() + } + }; + info!("CGRPC request processed successfully"); + Ok(Response::new(res)) + } + + async fn aquire_task_reg( + &self, + request: tonic::Request, + ) -> Result, Status> { + let uid = get_uid(&request); + let challenge = get_auth(&request); + info!("Task registry request received from user: {}", uid); + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + + debug!("Validating authentication for task registry request"); + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!( + "Task registry request denied - invalid authentication for user: {}", + uid + ); + return Err(Status::permission_denied("Invalid authentication")); + } + let mut tasks: Vec = api + .task_registry + .tasks + .keys() + .map(|(namespace, task)| format!("{namespace}:{task}")) + .collect(); + tasks.sort(); + info!("Returning task registry with {} tasks", tasks.len()); + Ok(Response::new(proto::TaskRegistry { tasks })) + } + + async fn aquire_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let challenge = get_auth(&request); + let input = request.get_ref(); + let task_id = input.task_id.clone(); + let uid = get_uid(&request); + info!( + "Task acquisition request received from user: {} for task: {}", + uid, task_id + ); + + let mut api = self.EngineAPI.write().await; + let db = api.db.clone(); + debug!("Validating authentication for task acquisition"); + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!( + "Task acquisition denied - invalid authentication for user: {}", + uid + ); + return Err(Status::permission_denied("Invalid authentication")); + } + + let key = parse_task_key(&task_id)?; + if api.task_registry.get(&key).is_none() { + warn!( + "Task acquisition failed - task does not exist: {}:{}", + key.0, key.1 + ); + return Err(Status::invalid_argument("Task Does not Exist")); + } + + let mut map = match api.task_queue.tasks.get(&key) { + Some(v) if !v.is_empty() => v.clone(), + _ => { + info!("No queued tasks for {}:{}", key.0, key.1); + return Err(Status::not_found("No queued tasks available")); + } + }; + + let ttask = map.remove(0); + let task_payload = ttask.bytes.clone(); + api.task_queue.tasks.insert(key.clone(), map); + match postcard::to_allocvec(&api.task_queue.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => { + return Err(Status::internal(format!("Serialization error: {}", e))); + } + } + + let mut exec_tsks = api.executing_tasks.tasks.get(&key).cloned().unwrap_or_default(); + exec_tsks.push(StoredExecutingTask { + bytes: task_payload.clone(), + user_id: uid, + given_at: Utc::now(), + id: ttask.id.clone(), + }); + api.executing_tasks.tasks.insert(key.clone(), exec_tsks); + match postcard::to_allocvec(&api.executing_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("executing_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => { + return Err(Status::internal(format!("Serialization error: {}", e))); + } + } + + Ok(Response::new(proto::Task { + id: ttask.id, + task_id, + task_payload, + payload: Vec::new(), + })) + } + + async fn publish_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let uid = get_uid(&request); + let db = api.db.clone(); + + let task = request.get_ref(); + let key = parse_task_key(&task.task_id)?; + let instance_id = task.id.clone(); + + if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { + info!("Publish Task denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + if !api.task_registry.tasks.contains_key(&key) { + warn!( + "Task acquisition failed - task does not exist: {}:{}", + key.0, key.1 + ); + return Err(Status::invalid_argument("Task Does not Exist")); + } + + let mem_tsk = api.executing_tasks.tasks.get(&key).cloned().unwrap_or_default(); + let executing_task = mem_tsk + .iter() + .find(|f| f.id == instance_id && f.user_id == uid) + .cloned(); + let Some(executing_task) = executing_task else { + return Err(Status::not_found("Invalid taskid or userid")); + }; + + let reg_tsk = match api.task_registry.get(&key) { + Some(r) => r, + None => { + warn!("Task registry missing for {}:{}", key.0, key.1); + return Err(Status::invalid_argument("Task Does not Exist")); + } + }; + if !reg_tsk.verify(task.task_payload.clone()) { + info!("Failed to parse task"); + return Err(Status::invalid_argument("Failed to parse given task bytes")); + } + + let mut nmem_tsk = mem_tsk.clone(); + nmem_tsk.retain(|f| !(f.id == instance_id && f.user_id == uid)); + api.executing_tasks.tasks.insert(key.clone(), nmem_tsk); + match postcard::to_allocvec(&api.executing_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("executing_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + + let mut mem_solv = api.solved_tasks.tasks.get(&key).cloned().unwrap_or_default(); + mem_solv.push(StoredTask { + bytes: task.task_payload.clone(), + id: executing_task.id, + }); + api.solved_tasks.tasks.insert(key.clone(), mem_solv); + match postcard::to_allocvec(&api.solved_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("solved_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + + info!("Task published successfully: {} by user: {}", task.id, uid); + Ok(Response::new(proto::Empty {})) + } + + async fn create_task( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut api = self.EngineAPI.write().await; + let challenge = get_auth(&request); + let uid = get_uid(&request); + let db = api.db.clone(); + if !Events::CheckAuth(&mut api, uid, challenge, db) { + info!("Create Task denied due to Invalid Auth"); + return Err(Status::permission_denied("Invalid authentication")); + } + let task = request.get_ref(); + let task_id = task.task_id.clone(); + let id = parse_task_key(&task_id)?; + let tsk_reg = api.task_registry.get(&id); + if let Some(tsk_reg) = tsk_reg { + if !tsk_reg.verify(task.task_payload.clone()) { + warn!("Failed to parse given task bytes"); + return Err(Status::invalid_argument("Failed to parse given task bytes")); + } + let stored_task = StoredTask { + bytes: task.task_payload.clone(), + id: mint_task_instance_id(api.cfg.config_toml.node_id.as_deref()), + }; + let mut mem_task_queue = api.task_queue.clone(); + let mut mem_tasks = mem_task_queue.tasks.get(&id).cloned().unwrap_or_default(); + mem_tasks.push(stored_task.clone()); + mem_task_queue.tasks.insert(id.clone(), mem_tasks); + api.task_queue = mem_task_queue; + match postcard::to_allocvec(&api.task_queue.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); + } + } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + return Ok(Response::new(proto::Task { + id: stored_task.id.clone(), + task_id, + payload: Vec::new(), + task_payload: stored_task.bytes.clone(), + })); + } + Err(Status::aborted("Error")) + } +} #[tokio::main] async fn main() -> Result<(), Box> { @@ -20,25 +527,15 @@ async fn main() -> Result<(), Box> { Events::init_auth(&mut api); Events::StartEvent(&mut api); - let addr = api - .cfg - .config_toml - .host - .parse() - .unwrap_or(SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new(127, 0, 0, 1), - 50051, - ))); - let listener = TcpListener::bind(addr).await?; - + let listener = TcpListener::bind(api.cfg.config_toml.host.as_str()).await?; let api = Arc::new(RwLock::new(api)); let registration = { let api_guard = api.read().await; registration_from_api(&api_guard)? }; - let _registration_shutdown = CancellationToken::new(); - let _registration_task = registration - .map(|registration| spawn_registration(registration, _registration_shutdown.clone())); + let registration_shutdown = CancellationToken::new(); + let _registration_task = + registration.map(|registration| spawn_registration(registration, registration_shutdown)); EngineAPI::init_chron(api.clone()); let engine = BackendEngineService::new(api); diff --git a/engine/src/service/backend.rs b/engine/src/service/backend.rs deleted file mode 100644 index 1325169..0000000 --- a/engine/src/service/backend.rs +++ /dev/null @@ -1,552 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock as StdRwLock}, -}; - -use enginelib::api::postcard; -use enginelib::{ - RawIdentier, Registry, - api::EngineAPI, - chrono::Utc, - events::{Events, ID}, - task::{StoredExecutingTask, StoredTask}, -}; -use tokio::sync::RwLock; -use tonic::{Response, Status}; -use tracing::{debug, info, warn}; - -use crate::{ - get_auth, get_uid, - proto::{self, TaskState, engine_server::Engine}, - task_id::parse_task_key, -}; - -#[allow(non_snake_case)] -pub struct BackendEngineService { - pub EngineAPI: Arc>, -} - -impl BackendEngineService { - pub fn new(api: Arc>) -> Self { - Self { EngineAPI: api } - } -} - -pub fn mint_task_instance_id(node_id: Option<&str>) -> String { - let random_hex = druid::Druid::default().to_hex(); - match node_id.filter(|value| !value.is_empty()) { - Some(node_id) => format!("{node_id}@{random_hex}"), - None => random_hex, - } -} - -pub fn parse_owner_node(task_instance_id: &str) -> Option<(&str, &str)> { - let (node_id, local_id) = task_instance_id.split_once('@')?; - if node_id.is_empty() || local_id.is_empty() || local_id.contains('@') { - return None; - } - Some((node_id, local_id)) -} - -fn delete_task_from_collection( - collection: &mut HashMap<(String, String), Vec>, - id: &(String, String), - task_id: &str, - state_name: &str, - namespace: &str, - task: &str, - id_extractor: F, -) -> Result<(), Status> -where - F: Fn(&T) -> &str, -{ - match collection.get_mut(id) { - Some(query) => { - let orig_len = query.len(); - query.retain(|f| id_extractor(f) != task_id); - if query.len() == orig_len { - info!( - "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", - task_id, state_name, namespace, task - ); - return Err(Status::not_found(format!( - "Task with id {} not found in {} state", - task_id, state_name - ))); - } - Ok(()) - } - None => { - info!( - "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", - state_name, namespace, task - ); - Err(Status::not_found(format!( - "No tasks found in {} state for given namespace and task", - state_name - ))) - } - } -} - -#[tonic::async_trait] -impl Engine for BackendEngineService { - async fn check_auth( - &self, - request: tonic::Request, - ) -> Result, Status> { - let challenge = get_auth(&request); - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); - if !output { - warn!("Auth check failed - permission denied"); - return Err(Status::permission_denied("Invalid Auth")); - } - Ok(Response::new(proto::Empty {})) - } - - async fn delete_task( - &self, - request: tonic::Request, - ) -> Result, Status> { - let mut api = self.EngineAPI.write().await; - let data = request.get_ref(); - let challenge = get_auth(&request); - let db = api.db.clone(); - let id = ID(&data.namespace, &data.task); - - let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); - if !output { - warn!("Auth check failed - permission denied"); - return Err(Status::permission_denied("Invalid Auth")); - } - - let result = match data.state() { - TaskState::Processing => delete_task_from_collection( - &mut api.executing_tasks.tasks, - &id, - &data.id, - "Processing", - &data.namespace, - &data.task, - |f| &f.id, - ), - TaskState::Solved => delete_task_from_collection( - &mut api.solved_tasks.tasks, - &id, - &data.id, - "Solved", - &data.namespace, - &data.task, - |f| &f.id, - ), - TaskState::Queued => delete_task_from_collection( - &mut api.task_queue.tasks, - &id, - &data.id, - "Queued", - &data.namespace, - &data.task, - |f| &f.id, - ), - }; - - result?; - EngineAPI::sync_db(&mut api); - info!( - "DeleteTask: Successfully deleted task with id {} in state {:?} for namespace: {}, task: {}", - data.id, - data.state(), - data.namespace, - data.task - ); - Ok(Response::new(proto::Empty {})) - } - - async fn get_tasks( - &self, - request: tonic::Request, - ) -> Result, Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let db = api.db.clone(); - if !Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db) { - info!("GetTask denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - } - - let data = request.get_ref(); - let q: Vec = match data.state() { - TaskState::Processing => match api - .executing_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut task_refs: Vec<_> = tasks.iter().collect(); - task_refs.sort_by_key(|f| &f.id); - task_refs - .iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => Vec::new(), - }, - TaskState::Queued => match api - .task_queue - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut tasks = tasks.clone(); - tasks.sort_by_key(|f| f.id.clone()); - tasks - .iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => Vec::new(), - }, - TaskState::Solved => match api - .solved_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut tasks = tasks.clone(); - tasks.sort_by_key(|f| f.id.clone()); - tasks - .iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => Vec::new(), - }, - }; - - let index = data.page.saturating_mul(data.page_size as u64) as usize; - let limit = api.cfg.config_toml.pagination_limit.min(data.page_size) as usize; - let final_vec: Vec<_> = q.iter().skip(index).take(limit).cloned().collect(); - Ok(Response::new(proto::TaskPage { - namespace: data.namespace.clone(), - task: data.task.clone(), - page: data.page, - page_size: data.page_size, - state: data.state, - tasks: final_vec, - })) - } - - async fn cgrpc( - &self, - request: tonic::Request, - ) -> Result, Status> { - info!( - "CGRPC request received for handler: {}:{}", - request.get_ref().handler_mod_id, - request.get_ref().handler_id - ); - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let db = api.db.clone(); - debug!("Checking admin authentication for CGRPC request"); - let output = Events::CheckAdminAuth( - &mut api, - challenge, - ( - request.get_ref().handler_mod_id.clone(), - request.get_ref().handler_id.clone(), - ), - db, - ); - if !output { - warn!("CGRPC auth check failed - permission denied"); - return Err(Status::permission_denied("Invalid CGRPC Auth")); - } - let out = Arc::new(StdRwLock::new(Vec::new())); - debug!("Dispatching CGRPC event to handler"); - Events::CgrpcEvent( - &mut api, - ID("engine_core", "grpc"), - request.get_ref().event_payload.clone(), - out.clone(), - ); - let mut res = request.get_ref().clone(); - res.event_payload = match out.read() { - Ok(g) => g.clone(), - Err(_) => { - warn!("CGRPC response lock poisoned, returning empty payload"); - Vec::new() - } - }; - info!("CGRPC request processed successfully"); - Ok(Response::new(res)) - } - - async fn aquire_task_reg( - &self, - request: tonic::Request, - ) -> Result, Status> { - let uid = get_uid(&request); - let challenge = get_auth(&request); - info!("Task registry request received from user: {}", uid); - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - - debug!("Validating authentication for task registry request"); - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!( - "Task registry request denied - invalid authentication for user: {}", - uid - ); - return Err(Status::permission_denied("Invalid authentication")); - } - let mut tasks: Vec = api - .task_registry - .tasks - .keys() - .map(|(namespace, task)| format!("{namespace}:{task}")) - .collect(); - tasks.sort(); - info!("Returning task registry with {} tasks", tasks.len()); - Ok(Response::new(proto::TaskRegistry { tasks })) - } - - async fn aquire_task( - &self, - request: tonic::Request, - ) -> Result, Status> { - let challenge = get_auth(&request); - let input = request.get_ref(); - let task_id = input.task_id.clone(); - let uid = get_uid(&request); - info!( - "Task acquisition request received from user: {} for task: {}", - uid, task_id - ); - - let mut api = self.EngineAPI.write().await; - let db = api.db.clone(); - debug!("Validating authentication for task acquisition"); - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!( - "Task acquisition denied - invalid authentication for user: {}", - uid - ); - return Err(Status::permission_denied("Invalid authentication")); - } - - let key = parse_task_key(&task_id)?; - if api.task_registry.get(&key).is_none() { - warn!( - "Task acquisition failed - task does not exist: {}:{}", - key.0, key.1 - ); - return Err(Status::invalid_argument("Task Does not Exist")); - } - - let mut map = match api.task_queue.tasks.get(&key) { - Some(v) if !v.is_empty() => v.clone(), - _ => { - info!("No queued tasks for {}:{}", key.0, key.1); - return Err(Status::not_found("No queued tasks available")); - } - }; - - let ttask = map.remove(0); - let task_payload = ttask.bytes.clone(); - api.task_queue.tasks.insert(key.clone(), map); - match postcard::to_allocvec(&api.task_queue.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => { - return Err(Status::internal(format!("Serialization error: {}", e))); - } - } - - let mut exec_tsks = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - exec_tsks.push(StoredExecutingTask { - bytes: task_payload.clone(), - user_id: uid, - given_at: Utc::now(), - id: ttask.id.clone(), - }); - api.executing_tasks.tasks.insert(key.clone(), exec_tsks); - match postcard::to_allocvec(&api.executing_tasks.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("executing_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => { - return Err(Status::internal(format!("Serialization error: {}", e))); - } - } - - Ok(Response::new(proto::Task { - id: ttask.id, - task_id, - task_payload, - payload: Vec::new(), - })) - } - - async fn publish_task( - &self, - request: tonic::Request, - ) -> Result, Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let uid = get_uid(&request); - let db = api.db.clone(); - - let task = request.get_ref(); - let key = parse_task_key(&task.task_id)?; - let instance_id = task.id.clone(); - - if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!("Publish Task denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - } - if !api.task_registry.tasks.contains_key(&key) { - warn!( - "Task acquisition failed - task does not exist: {}:{}", - key.0, key.1 - ); - return Err(Status::invalid_argument("Task Does not Exist")); - } - - let mem_tsk = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - let executing_task = mem_tsk - .iter() - .find(|f| f.id == instance_id && f.user_id == uid) - .cloned(); - let Some(executing_task) = executing_task else { - return Err(Status::not_found("Invalid taskid or userid")); - }; - - let reg_tsk = match api.task_registry.get(&key) { - Some(r) => r, - None => { - warn!("Task registry missing for {}:{}", key.0, key.1); - return Err(Status::invalid_argument("Task Does not Exist")); - } - }; - if !reg_tsk.verify(task.task_payload.clone()) { - info!("Failed to parse task"); - return Err(Status::invalid_argument("Failed to parse given task bytes")); - } - - let mut nmem_tsk = mem_tsk.clone(); - nmem_tsk.retain(|f| !(f.id == instance_id && f.user_id == uid)); - api.executing_tasks.tasks.insert(key.clone(), nmem_tsk); - match postcard::to_allocvec(&api.executing_tasks.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("executing_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - - let mut mem_solv = api - .solved_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - mem_solv.push(StoredTask { - bytes: task.task_payload.clone(), - id: executing_task.id, - }); - api.solved_tasks.tasks.insert(key.clone(), mem_solv); - match postcard::to_allocvec(&api.solved_tasks.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("solved_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - - info!("Task published successfully: {} by user: {}", task.id, uid); - Ok(Response::new(proto::Empty {})) - } - - async fn create_task( - &self, - request: tonic::Request, - ) -> Result, Status> { - let mut api = self.EngineAPI.write().await; - let challenge = get_auth(&request); - let uid = get_uid(&request); - let db = api.db.clone(); - if !Events::CheckAuth(&mut api, uid, challenge, db) { - info!("Create Task denied due to Invalid Auth"); - return Err(Status::permission_denied("Invalid authentication")); - } - let task = request.get_ref(); - let task_id = task.task_id.clone(); - let id = parse_task_key(&task_id)?; - let tsk_reg = api.task_registry.get(&id); - if let Some(tsk_reg) = tsk_reg { - if !tsk_reg.verify(task.task_payload.clone()) { - warn!("Failed to parse given task bytes"); - return Err(Status::invalid_argument("Failed to parse given task bytes")); - } - let stored_task = StoredTask { - bytes: task.task_payload.clone(), - id: mint_task_instance_id(api.cfg.config_toml.node_id.as_deref()), - }; - let mut mem_task_queue = api.task_queue.clone(); - let mut mem_tasks = mem_task_queue.tasks.get(&id).cloned().unwrap_or_default(); - mem_tasks.push(stored_task.clone()); - mem_task_queue.tasks.insert(id.clone(), mem_tasks); - api.task_queue = mem_task_queue; - match postcard::to_allocvec(&api.task_queue.clone()) { - Ok(store) => { - if let Err(e) = api.db.insert("tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } - } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), - } - return Ok(Response::new(proto::Task { - id: stored_task.id.clone(), - task_id, - payload: Vec::new(), - task_payload: stored_task.bytes.clone(), - })); - } - Err(Status::aborted("Error")) - } -} diff --git a/engine/src/service/mod.rs b/engine/src/service/mod.rs index cfea61a..44dcc92 100644 --- a/engine/src/service/mod.rs +++ b/engine/src/service/mod.rs @@ -1,2 +1 @@ -pub mod backend; pub mod proxy; diff --git a/engine/src/service/proxy.rs b/engine/src/service/proxy.rs index 5c9f21f..9ec0705 100644 --- a/engine/src/service/proxy.rs +++ b/engine/src/service/proxy.rs @@ -8,8 +8,7 @@ use crate::{ copy_metadata, get_auth, proto::{self, cluster_server::Cluster, engine_client::EngineClient, engine_server::Engine}, routing::{NodeState, ProxyState, now_unix}, - service::backend::parse_owner_node, - task_id::parse_task_key_string, + task_id::{parse_owner_node, parse_task_key_string}, }; use rand::{seq::SliceRandom, thread_rng}; use tokio::{task::JoinHandle, time::sleep}; diff --git a/engine/src/task_id.rs b/engine/src/task_id.rs index 3ec09eb..d43d3b0 100644 --- a/engine/src/task_id.rs +++ b/engine/src/task_id.rs @@ -1,6 +1,22 @@ use enginelib::Identifier; use tonic::Status; +pub fn mint_task_instance_id(node_id: Option<&str>) -> String { + let random_hex = druid::Druid::default().to_hex(); + match node_id.filter(|value| !value.is_empty()) { + Some(node_id) => format!("{node_id}@{random_hex}"), + None => random_hex, + } +} + +pub fn parse_owner_node(task_instance_id: &str) -> Option<(&str, &str)> { + let (node_id, local_id) = task_instance_id.split_once('@')?; + if node_id.is_empty() || local_id.is_empty() || local_id.contains('@') { + return None; + } + Some((node_id, local_id)) +} + pub fn parse_task_key(task_id: &str) -> Result { let Some((namespace, task)) = task_id.split_once(':') else { return Err(Status::invalid_argument( diff --git a/engine/tests/common/mod.rs b/engine/tests/common/mod.rs index 3641f4c..2307d9a 100644 --- a/engine/tests/common/mod.rs +++ b/engine/tests/common/mod.rs @@ -1,14 +1,14 @@ use std::{error::Error, sync::Arc, time::Duration}; +#[path = "../../src/bin/server.rs"] +mod server_bin; + use engine::{ cluster_client::{registration_from_api, spawn_registration}, proto::{self, cluster_client::ClusterClient, engine_client::EngineClient}, proxy_config::{ProxyConfigToml, RouteRuleToml}, routing::ProxyState, - service::{ - backend::BackendEngineService, - proxy::{ProxyService, spawn_reaper}, - }, + service::proxy::{ProxyService, spawn_reaper}, }; use enginelib::{ Registry, @@ -175,7 +175,7 @@ pub async fn spawn_backend( let api = Arc::new(RwLock::new(api)); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let engine_service = BackendEngineService::new(api.clone()); + let engine_service = server_bin::BackendEngineService::new(api.clone()); let server_task = tokio::spawn(async move { Server::builder() .add_service(proto::engine_server::EngineServer::new(engine_service))