diff --git a/Cargo.lock b/Cargo.lock index 0add590..312984f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,8 +543,8 @@ version = "0.1.0" source = "git+https://github.com/GrandEngineering/druid.git#801365f097b6b437314d0deaa72fd85be6af0a8f" dependencies = [ "chrono", - "rand", - "rand_chacha", + "rand 0.9.2", + "rand_chacha 0.9.0", "uuid", ] @@ -591,14 +591,18 @@ dependencies = [ "druid", "enginelib", "prost", + "rand 0.8.5", "serde", "tokio", + "tokio-stream", + "tokio-util", "toml", "tonic", "tonic-build", "tonic-prost", "tonic-prost-build", "tonic-reflection", + "tracing", ] [[package]] @@ -2307,14 +2311,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -2324,7 +2349,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", ] [[package]] @@ -2799,9 +2833,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.16" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", diff --git a/engine/Cargo.toml b/engine/Cargo.toml index 9e82df7..3b8f47b 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -17,13 +17,17 @@ druid = { git = "https://github.com/GrandEngineering/druid.git" } enginelib = { path = "../enginelib" } # libloading = "0.8.6" prost = "0.14" +rand = "0.8.5" serde = { workspace = true } # serde = "1.0.219" tokio = { version = "1.50.0", features = ["rt-multi-thread", "macros"] } +tokio-stream = { version = "0.1.17", features = ["net"] } +tokio-util = "0.7.17" toml = { workspace = true } # toml = "0.8.19" tonic = "0.14" tonic-prost = "0.14.5" +tracing = "0.1.44" tonic-reflection = "0.14" [build-dependencies] diff --git a/engine/proto/engine.proto b/engine/proto/engine.proto index 188c371..e532eeb 100644 --- a/engine/proto/engine.proto +++ b/engine/proto/engine.proto @@ -10,6 +10,13 @@ service Engine { rpc GetTasks(TaskPageRequest) returns (TaskPage); rpc CheckAuth(empty) returns (empty); } + +// Cluster membership and discovery service, used between backend nodes and a proxy. +service Cluster { + rpc RegisterNode(NodeRegisterRequest) returns (NodeRegisterResponse); + rpc Heartbeat(NodeHeartbeat) returns (NodeHeartbeatAck); + rpc ListNodes(empty) returns (NodeList); +} message TaskSelector { TaskState state = 1; string namespace = 2; @@ -58,3 +65,38 @@ message Task { string task_id = 2; // namespace:task bytes payload = 3; } + +message NodeRegisterRequest { + string node_id = 1; + string advertise_addr = 2; // e.g. "http://10.0.0.12:50051" + repeated string tags = 3; // e.g. ["gpu", "control"] + repeated string tasks = 4; // ["ns:task", ...] from local task registry +} + +message NodeRegisterResponse { + string session_id = 1; + uint64 heartbeat_interval_seconds = 2; // default 5 + uint64 ttl_seconds = 3; // default 15 +} + +message NodeHeartbeat { + string node_id = 1; + string session_id = 2; +} + +message NodeHeartbeatAck { + uint64 server_time_unix = 1; +} + +message NodeInfo { + string node_id = 1; + string advertise_addr = 2; + repeated string tags = 3; + repeated string tasks = 4; + uint64 last_seen_unix = 5; + bool healthy = 6; +} + +message NodeList { + repeated NodeInfo nodes = 1; +} diff --git a/engine/src/bin/proxy.rs b/engine/src/bin/proxy.rs new file mode 100644 index 0000000..90aaf53 --- /dev/null +++ b/engine/src/bin/proxy.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use engine::{ + proto, + proxy_config::ProxyConfigToml, + routing::ProxyState, + service::proxy::{ProxyService, spawn_reaper}, +}; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::Server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = ProxyConfigToml::load()?; + let listener = TcpListener::bind(config.listen.as_str()).await?; + let state = Arc::new(ProxyState::new(config)?); + let reaper_handle = spawn_reaper(state.clone()); + let cluster_service = ProxyService::new(state.clone()); + let engine_service = ProxyService::new(state); + + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1() + .map_err(|e| Box::new(e) as Box)?; + + let server = Server::builder() + .add_service(reflection_service) + .add_service(proto::cluster_server::ClusterServer::new(cluster_service)) + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming(TcpListenerStream::new(listener)); + tokio::pin!(server); + + tokio::select! { + result = &mut server => { + result.map_err(|e| Box::new(e) as Box)?; + } + result = reaper_handle => { + match result { + Ok(()) => return Err("proxy reaper exited unexpectedly".into()), + Err(err) => return Err(format!("proxy reaper failed: {err}").into()), + } + } + } + + Ok(()) +} diff --git a/engine/src/bin/server.rs b/engine/src/bin/server.rs index 5d6679a..622a2ef 100644 --- a/engine/src/bin/server.rs +++ b/engine/src/bin/server.rs @@ -1,39 +1,80 @@ -use engine::{get_auth, get_uid}; +use engine::{ + cluster_client::{registration_from_api, spawn_registration}, + get_auth, get_uid, proto, + task_id::{mint_task_instance_id, parse_task_key}, +}; use enginelib::api::postcard; use enginelib::{ - Identifier, RawIdentier, Registry, + RawIdentier, Registry, api::EngineAPI, chrono::Utc, - event::{debug, info, warn}, - events::{self, Events, ID}, - plugin::LibraryManager, - task::{SolvedTasks, StoredExecutingTask, StoredTask, Task, TaskQueue}, -}; -use proto::{ - TaskState, - engine_server::{Engine, EngineServer}, + events::{Events, ID}, + task::{StoredExecutingTask, StoredTask}, }; use std::{ collections::HashMap, - env::consts::OS, - io::Read, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, - sync::{Arc, RwLock as RS_RwLock}, + sync::{Arc, RwLock as StdRwLock}, }; -use tokio::sync::RwLock; -use tonic::{Request, Response, Status, metadata::MetadataValue, transport::Server}; +use tokio::{net::TcpListener, sync::RwLock}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_util::sync::CancellationToken; +use tonic::{Response, Status, transport::Server}; +use tracing::{debug, info, warn}; -mod proto { - tonic::include_proto!("engine"); - pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = - tonic::include_file_descriptor_set!("engine_descriptor"); -} #[allow(non_snake_case)] -struct EngineService { +pub struct BackendEngineService { pub EngineAPI: Arc>, } + +impl BackendEngineService { + pub fn new(api: Arc>) -> Self { + Self { EngineAPI: api } + } +} + +fn delete_task_from_collection( + collection: &mut HashMap<(String, String), Vec>, + id: &(String, String), + task_id: &str, + state_name: &str, + namespace: &str, + task: &str, + id_extractor: F, +) -> Result<(), Status> +where + F: Fn(&T) -> &str, +{ + match collection.get_mut(id) { + Some(query) => { + let orig_len = query.len(); + query.retain(|f| id_extractor(f) != task_id); + if query.len() == orig_len { + info!( + "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", + task_id, state_name, namespace, task + ); + return Err(Status::not_found(format!( + "Task with id {} not found in {} state", + task_id, state_name + ))); + } + Ok(()) + } + None => { + info!( + "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", + state_name, namespace, task + ); + Err(Status::not_found(format!( + "No tasks found in {} state for given namespace and task", + state_name + ))) + } + } +} + #[tonic::async_trait] -impl Engine for EngineService { +impl proto::engine_server::Engine for BackendEngineService { async fn check_auth( &self, request: tonic::Request, @@ -44,10 +85,11 @@ impl Engine for EngineService { let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); if !output { warn!("Auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid Auth")); - }; - return Ok(tonic::Response::new(proto::Empty {})); + return Err(Status::permission_denied("Invalid Auth")); + } + Ok(Response::new(proto::Empty {})) } + async fn delete_task( &self, request: tonic::Request, @@ -61,53 +103,11 @@ impl Engine for EngineService { let output = Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db); if !output { warn!("Auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid Auth")); - }; - // Generic helper for removing a task by id from a collection, using an id extractor closure - fn delete_task_from_collection( - collection: &mut HashMap<(String, String), Vec>, - id: &(String, String), - task_id: &str, - state_name: &str, - namespace: &str, - task: &str, - id_extractor: F, - ) -> Result<(), Status> - where - F: Fn(&T) -> &str, - { - match collection.get_mut(id) { - Some(query) => { - let orig_len = query.len(); - query.retain(|f| id_extractor(f) != task_id); - if query.len() == orig_len { - info!( - "DeleteTask: Task with id {} not found in {} state for namespace: {}, task: {}", - task_id, state_name, namespace, task - ); - return Err(Status::not_found(format!( - "Task with id {} not found in {} state", - task_id, state_name - ))); - } - Ok(()) - } - None => { - info!( - "DeleteTask: No tasks found in {} state for namespace: {}, task: {}", - state_name, namespace, task - ); - Err(Status::not_found(format!( - "No tasks found in {} state for given namespace and task", - state_name - ))) - } - } + return Err(Status::permission_denied("Invalid Auth")); } - // Use the helper for each state let result = match data.state() { - TaskState::Processing => delete_task_from_collection( + proto::TaskState::Processing => delete_task_from_collection( &mut api.executing_tasks.tasks, &id, &data.id, @@ -116,7 +116,7 @@ impl Engine for EngineService { &data.task, |f| &f.id, ), - TaskState::Solved => delete_task_from_collection( + proto::TaskState::Solved => delete_task_from_collection( &mut api.solved_tasks.tasks, &id, &data.id, @@ -125,7 +125,7 @@ impl Engine for EngineService { &data.task, |f| &f.id, ), - TaskState::Queued => delete_task_from_collection( + proto::TaskState::Queued => delete_task_from_collection( &mut api.task_queue.tasks, &id, &data.id, @@ -136,11 +136,7 @@ impl Engine for EngineService { ), }; - if let Err(e) = result { - return Err(e); - } - - // Sync running memory into DB + result?; EngineAPI::sync_db(&mut api); info!( "DeleteTask: Successfully deleted task with id {} in state {:?} for namespace: {}, task: {}", @@ -149,170 +145,102 @@ impl Engine for EngineService { data.namespace, data.task ); - Ok(tonic::Response::new(proto::Empty {})) + Ok(Response::new(proto::Empty {})) } - /// Retrieves a paginated list of tasks filtered by namespace, task name, and state. - /// - /// Authenticates the request and, if authorized, returns tasks in the specified state - /// (`Processing`, `Queued`, or `Solved`) for the given namespace and task name. The results - /// are sorted by task ID and paginated according to the requested page and page size. - /// - /// Returns a `TaskPage` containing the filtered tasks and pagination metadata, or a - /// permission denied error if authentication fails. - /// - /// # Examples - /// - /// ``` - /// // Example usage within a tonic gRPC client context: - /// let request = proto::TaskPageRequest { - /// namespace: "example_ns".to_string(), - /// task: "example_task".to_string(), - /// state: proto::TaskState::Queued as i32, - /// page: 0, - /// page_size: 10, - /// }; - /// let response = engine_client.get_tasks(request).await?; - /// assert!(response.get_ref().tasks.len() <= 10); - /// ``` + async fn get_tasks( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status> { + ) -> Result, Status> { let mut api = self.EngineAPI.write().await; let challenge = get_auth(&request); - let db = api.db.clone(); if !Events::CheckAdminAuth(&mut api, challenge, ("".into(), "".into()), db) { info!("GetTask denied due to Invalid Auth"); return Err(Status::permission_denied("Invalid authentication")); - }; - let data = request.get_ref(); + } - let q: Vec = match data.clone().state() { - TaskState::Processing => { - match api - .executing_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut task_refs: Vec<_> = tasks.iter().collect(); - task_refs.sort_by_key(|f| &f.id); - task_refs - .iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Processing state", - data.namespace, data.task - ); - Vec::new() - } + let data = request.get_ref(); + let q: Vec = match data.state() { + proto::TaskState::Processing => match api + .executing_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut task_refs: Vec<_> = tasks.iter().collect(); + task_refs.sort_by_key(|f| &f.id); + task_refs + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() } - } - TaskState::Queued => { - match api - .task_queue - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut d = tasks.clone(); - d.sort_by_key(|f| f.id.clone()); - d.iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Queued state", - data.namespace, data.task - ); - Vec::new() - } + None => Vec::new(), + }, + proto::TaskState::Queued => match api + .task_queue + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() } - } - TaskState::Solved => { - match api - .solved_tasks - .tasks - .get(&(data.namespace.clone(), data.task.clone())) - { - Some(tasks) => { - let mut d = tasks.clone(); - d.sort_by_key(|f| f.id.clone()); - d.iter() - .map(|f| proto::Task { - id: f.id.clone(), - task_id: format!("{}:{}", data.namespace, data.task), - task_payload: f.bytes.clone(), - payload: Vec::new(), - }) - .collect() - } - None => { - info!( - "Namespace {:?} and task {:?} not found in Solved state", - data.namespace, data.task - ); - Vec::new() - } + None => Vec::new(), + }, + proto::TaskState::Solved => match api + .solved_tasks + .tasks + .get(&(data.namespace.clone(), data.task.clone())) + { + Some(tasks) => { + let mut tasks = tasks.clone(); + tasks.sort_by_key(|f| f.id.clone()); + tasks + .iter() + .map(|f| proto::Task { + id: f.id.clone(), + task_id: format!("{}:{}", data.namespace, data.task), + task_payload: f.bytes.clone(), + payload: Vec::new(), + }) + .collect() } - } + None => Vec::new(), + }, }; - let index = data.page * data.page_size as u64; - let end = index + (api.cfg.config_toml.pagination_limit.min(data.page_size) as u64); - let final_vec: Vec<_> = q - .iter() - .skip(index as usize) - .take(data.page_size as usize) - .cloned() - .collect(); - return Ok(tonic::Response::new(proto::TaskPage { + + let index = data.page.saturating_mul(data.page_size as u64) as usize; + let limit = api.cfg.config_toml.pagination_limit.min(data.page_size) as usize; + let final_vec: Vec<_> = q.iter().skip(index).take(limit).cloned().collect(); + Ok(Response::new(proto::TaskPage { namespace: data.namespace.clone(), task: data.task.clone(), page: data.page, page_size: data.page_size, state: data.state, tasks: final_vec, - })); + })) } - /// Handles custom gRPC messages with admin-level authentication. - /// - /// Processes a CGRPC request by verifying admin credentials and dispatching the event payload to the appropriate handler. Returns the processed event payload in the response. If authentication fails, returns a permission denied error. - /// - /// # Returns - /// A `Cgrpcmsg` response containing the processed event payload, or a permission denied gRPC status on failed authentication. - /// - /// # Examples - /// - /// ``` - /// // Example usage within a gRPC client context: - /// let request = proto::Cgrpcmsg { - /// handler_mod_id: "mod".to_string(), - /// handler_id: "handler".to_string(), - /// event_payload: vec![1, 2, 3], - /// // ... other fields ... - /// }; - /// let response = engine_service.cgrpc(tonic::Request::new(request)).await?; - /// assert_eq!(response.get_ref().handler_mod_id, "mod"); - /// ``` + async fn cgrpc( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status> { + ) -> Result, Status> { info!( "CGRPC request received for handler: {}:{}", request.get_ref().handler_mod_id, @@ -333,9 +261,9 @@ impl Engine for EngineService { ); if !output { warn!("CGRPC auth check failed - permission denied"); - return Err(tonic::Status::permission_denied("Invalid CGRPC Auth")); - }; - let out = Arc::new(std::sync::RwLock::new(Vec::new())); + return Err(Status::permission_denied("Invalid CGRPC Auth")); + } + let out = Arc::new(StdRwLock::new(Vec::new())); debug!("Dispatching CGRPC event to handler"); Events::CgrpcEvent( &mut api, @@ -352,12 +280,13 @@ impl Engine for EngineService { } }; info!("CGRPC request processed successfully"); - return Ok(tonic::Response::new(res)); + Ok(Response::new(res)) } + async fn aquire_task_reg( &self, request: tonic::Request, - ) -> Result, tonic::Status> { + ) -> Result, Status> { let uid = get_uid(&request); let challenge = get_auth(&request); info!("Task registry request received from user: {}", uid); @@ -371,22 +300,22 @@ impl Engine for EngineService { uid ); return Err(Status::permission_denied("Invalid authentication")); - }; - let mut tasks: Vec = Vec::new(); - for (k, v) in &api.task_registry.tasks { - let js: Vec = vec![k.0.clone(), k.1.clone()]; - let jstr = js.join(":"); - tasks.push(jstr); } + let mut tasks: Vec = api + .task_registry + .tasks + .keys() + .map(|(namespace, task)| format!("{namespace}:{task}")) + .collect(); + tasks.sort(); info!("Returning task registry with {} tasks", tasks.len()); - let response = proto::TaskRegistry { tasks }; - Ok(tonic::Response::new(response)) + Ok(Response::new(proto::TaskRegistry { tasks })) } async fn aquire_task( &self, request: tonic::Request, - ) -> Result, tonic::Status> { + ) -> Result, Status> { let challenge = get_auth(&request); let input = request.get_ref(); let task_id = input.task_id.clone(); @@ -405,40 +334,27 @@ impl Engine for EngineService { uid ); return Err(Status::permission_denied("Invalid authentication")); - }; - - // Todo: check for wrong input to not cause a Panic out of bounds. - let alen = &task_id.split(":").collect::>().len(); - if *alen != 2 { - info!("Invalid task ID format: {}", task_id); - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:name", - )); } - let namespace = &task_id.split(":").collect::>()[0]; - let task_name = &task_id.split(":").collect::>()[1]; - debug!("Looking up task definition for {}:{}", namespace, task_name); - let tsx = api - .task_registry - .get(&(namespace.to_string(), task_name.to_string())); - if tsx.is_none() { + + let key = parse_task_key(&task_id)?; + if api.task_registry.get(&key).is_none() { warn!( "Task acquisition failed - task does not exist: {}:{}", - namespace, task_name + key.0, key.1 ); return Err(Status::invalid_argument("Task Does not Exist")); } - let key = ID(namespace, task_name); + let mut map = match api.task_queue.tasks.get(&key) { Some(v) if !v.is_empty() => v.clone(), _ => { - info!("No queued tasks for {}:{}", namespace, task_name); + info!("No queued tasks for {}:{}", key.0, key.1); return Err(Status::not_found("No queued tasks available")); } }; + let ttask = map.remove(0); let task_payload = ttask.bytes.clone(); - // Get Task and remove it from queue api.task_queue.tasks.insert(key.clone(), map); match postcard::to_allocvec(&api.task_queue.clone()) { Ok(store) => { @@ -450,16 +366,11 @@ impl Engine for EngineService { return Err(Status::internal(format!("Serialization error: {}", e))); } } - // Move it to exec queue - let mut exec_tsks = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - exec_tsks.push(enginelib::task::StoredExecutingTask { + + let mut exec_tsks = api.executing_tasks.tasks.get(&key).cloned().unwrap_or_default(); + exec_tsks.push(StoredExecutingTask { bytes: task_payload.clone(), - user_id: uid.clone(), + user_id: uid, given_at: Utc::now(), id: ttask.id.clone(), }); @@ -474,147 +385,122 @@ impl Engine for EngineService { return Err(Status::internal(format!("Serialization error: {}", e))); } } - let response = proto::Task { + + Ok(Response::new(proto::Task { id: ttask.id, - task_id: input.task_id.clone(), + task_id, task_payload, payload: Vec::new(), - }; - Ok(tonic::Response::new(response)) + })) } + async fn publish_task( &self, request: tonic::Request, - ) -> Result, tonic::Status> { + ) -> Result, Status> { let mut api = self.EngineAPI.write().await; let challenge = get_auth(&request); let uid = get_uid(&request); let db = api.db.clone(); - let task_id = request.get_ref().task_id.clone(); - let alen = &task_id.split(":").collect::>().len(); - if *alen != 2 { - return Err(Status::invalid_argument("Invalid Params")); - } - let namespace = &task_id.split(":").collect::>()[0]; - let task_name = &task_id.split(":").collect::>()[1]; + let task = request.get_ref(); + let key = parse_task_key(&task.task_id)?; + let instance_id = task.id.clone(); if !Events::CheckAuth(&mut api, uid.clone(), challenge, db) { - info!("Aquire Task denied due to Invalid Auth"); + info!("Publish Task denied due to Invalid Auth"); return Err(Status::permission_denied("Invalid authentication")); - }; - if !api - .task_registry - .tasks - .contains_key(&ID(namespace, task_name)) - { + } + if !api.task_registry.tasks.contains_key(&key) { warn!( "Task acquisition failed - task does not exist: {}:{}", - namespace, task_name + key.0, key.1 ); return Err(Status::invalid_argument("Task Does not Exist")); } - let key = ID(namespace, task_name); - let mem_tsk = api - .executing_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - let tsk_opt = mem_tsk + + let mem_tsk = api.executing_tasks.tasks.get(&key).cloned().unwrap_or_default(); + let executing_task = mem_tsk .iter() - .find(|f| f.id == task_id.clone() && f.user_id == uid.clone()); - if let Some(tsk) = tsk_opt { - let reg_tsk = match api.task_registry.get(&key) { - Some(r) => r.clone(), - None => { - warn!("Task registry missing for {}:{}", namespace, task_name); - return Err(Status::invalid_argument("Task Does not Exist")); - } - }; - if !reg_tsk.verify(request.get_ref().task_payload.clone()) { - info!("Failed to parse task"); - return Err(Status::invalid_argument("Failed to parse given task bytes")); + .find(|f| f.id == instance_id && f.user_id == uid) + .cloned(); + let Some(executing_task) = executing_task else { + return Err(Status::not_found("Invalid taskid or userid")); + }; + + let reg_tsk = match api.task_registry.get(&key) { + Some(r) => r, + None => { + warn!("Task registry missing for {}:{}", key.0, key.1); + return Err(Status::invalid_argument("Task Does not Exist")); } - // Exec Tasks -> DB - let mut nmem_tsk = mem_tsk.clone(); - nmem_tsk.retain(|f| f.id != task_id.clone() && f.user_id != uid.clone()); - api.executing_tasks - .tasks - .insert(key.clone(), nmem_tsk.clone()); - let t_mem_execs = api.executing_tasks.clone(); - match postcard::to_allocvec(&t_mem_execs) { - Ok(store) => { - if let Err(e) = api.db.insert("executing_tasks", store) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } + }; + if !reg_tsk.verify(task.task_payload.clone()) { + info!("Failed to parse task"); + return Err(Status::invalid_argument("Failed to parse given task bytes")); + } + + let mut nmem_tsk = mem_tsk.clone(); + nmem_tsk.retain(|f| !(f.id == instance_id && f.user_id == uid)); + api.executing_tasks.tasks.insert(key.clone(), nmem_tsk); + match postcard::to_allocvec(&api.executing_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("executing_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), } - // tsk-> solved Tsks - let mut mem_solv = api - .solved_tasks - .tasks - .get(&key) - .cloned() - .unwrap_or_default(); - mem_solv.push(enginelib::task::StoredTask { - bytes: tsk.bytes.clone(), - id: tsk.id.clone(), - }); - api.solved_tasks.tasks.insert(key.clone(), mem_solv); - // Solved tsks -> DB - match postcard::to_allocvec(&api.solved_tasks.tasks) { - Ok(e_solv) => { - if let Err(e) = api.db.insert("solved_tasks", e_solv) { - return Err(Status::internal(format!("DB insert error: {}", e))); - } + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), + } + + let mut mem_solv = api.solved_tasks.tasks.get(&key).cloned().unwrap_or_default(); + mem_solv.push(StoredTask { + bytes: task.task_payload.clone(), + id: executing_task.id, + }); + api.solved_tasks.tasks.insert(key.clone(), mem_solv); + match postcard::to_allocvec(&api.solved_tasks.clone()) { + Ok(store) => { + if let Err(e) = api.db.insert("solved_tasks", store) { + return Err(Status::internal(format!("DB insert error: {}", e))); } - Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), } - info!("Task published successfully: {} by user: {}", task_id, uid); - return Ok(tonic::Response::new(proto::Empty {})); - } else { - return Err(tonic::Status::not_found("Invalid taskid or userid")); + Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), } + + info!("Task published successfully: {} by user: {}", task.id, uid); + Ok(Response::new(proto::Empty {})) } + async fn create_task( &self, request: tonic::Request, - ) -> Result, tonic::Status> { + ) -> Result, Status> { let mut api = self.EngineAPI.write().await; let challenge = get_auth(&request); let uid = get_uid(&request); let db = api.db.clone(); if !Events::CheckAuth(&mut api, uid, challenge, db) { - //TODO: change to AdminSpecific Auth info!("Create Task denied due to Invalid Auth"); return Err(Status::permission_denied("Invalid authentication")); - }; + } let task = request.get_ref(); let task_id = task.task_id.clone(); - let parts: Vec<&str> = task_id.splitn(2, ':').collect(); - if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { - return Err(Status::invalid_argument( - "Invalid task ID format, expected 'namespace:task'", - )); - } - let id: Identifier = (parts[0].to_string(), parts[1].to_string()); + let id = parse_task_key(&task_id)?; let tsk_reg = api.task_registry.get(&id); if let Some(tsk_reg) = tsk_reg { - if !tsk_reg.clone().verify(task.task_payload.clone()) { + if !tsk_reg.verify(task.task_payload.clone()) { warn!("Failed to parse given task bytes"); return Err(Status::invalid_argument("Failed to parse given task bytes")); } - let tbp_tsk = StoredTask { + let stored_task = StoredTask { bytes: task.task_payload.clone(), - id: druid::Druid::default().to_hex(), + id: mint_task_instance_id(api.cfg.config_toml.node_id.as_deref()), }; - let mut mem_tsks = api.task_queue.clone(); - let mut mem_tsk = mem_tsks.tasks.get(&id).cloned().unwrap_or_default(); - mem_tsk.push(tbp_tsk.clone()); - mem_tsks.tasks.insert(id.clone(), mem_tsk); - api.task_queue = mem_tsks; + let mut mem_task_queue = api.task_queue.clone(); + let mut mem_tasks = mem_task_queue.tasks.get(&id).cloned().unwrap_or_default(); + mem_tasks.push(stored_task.clone()); + mem_task_queue.tasks.insert(id.clone(), mem_tasks); + api.task_queue = mem_task_queue; match postcard::to_allocvec(&api.task_queue.clone()) { Ok(store) => { if let Err(e) = api.db.insert("tasks", store) { @@ -623,14 +509,14 @@ impl Engine for EngineService { } Err(e) => return Err(Status::internal(format!("Serialization error: {}", e))), } - return Ok(tonic::Response::new(proto::Task { - id: tbp_tsk.id.clone(), - task_id: task_id.clone(), + return Ok(Response::new(proto::Task { + id: stored_task.id.clone(), + task_id, payload: Vec::new(), - task_payload: tbp_tsk.bytes.clone(), + task_payload: stored_task.bytes.clone(), })); } - Err(tonic::Status::aborted("Error")) + Err(Status::aborted("Error")) } } @@ -640,30 +526,29 @@ async fn main() -> Result<(), Box> { EngineAPI::init(&mut api); Events::init_auth(&mut api); Events::StartEvent(&mut api); - let addr = api - .cfg - .config_toml - .host - .parse() - .unwrap_or(SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new(127, 0, 0, 1), - 50051, - ))); - let apii = Arc::new(RwLock::new(api)); - EngineAPI::init_chron(apii.clone()); - let engine = EngineService { EngineAPI: apii }; - - // Build reflection service, mapping its concrete error into Box + + let listener = TcpListener::bind(api.cfg.config_toml.host.as_str()).await?; + let api = Arc::new(RwLock::new(api)); + let registration = { + let api_guard = api.read().await; + registration_from_api(&api_guard)? + }; + let registration_shutdown = CancellationToken::new(); + let _registration_task = + registration.map(|registration| spawn_registration(registration, registration_shutdown)); + + EngineAPI::init_chron(api.clone()); + let engine = BackendEngineService::new(api); + let reflection_service = tonic_reflection::server::Builder::configure() .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) .build_v1() .map_err(|e| Box::new(e) as Box)?; - // Start server and map transport errors into Box so `?` works with our return type. Server::builder() .add_service(reflection_service) - .add_service(EngineServer::new(engine)) - .serve(addr) + .add_service(proto::engine_server::EngineServer::new(engine)) + .serve_with_incoming(TcpListenerStream::new(listener)) .await .map_err(|e| Box::new(e) as Box)?; diff --git a/engine/src/cluster_client.rs b/engine/src/cluster_client.rs new file mode 100644 index 0000000..33348ea --- /dev/null +++ b/engine/src/cluster_client.rs @@ -0,0 +1,195 @@ +use std::time::Duration; + +use enginelib::api::EngineAPI; +use tokio::{task::JoinHandle, time::sleep}; +use tokio_util::sync::CancellationToken; +use tonic::{Request, transport::Endpoint}; +use tracing::{error, info, warn}; + +use crate::proto::{self, cluster_client::ClusterClient}; + +#[derive(Debug, Clone)] +pub struct NodeRegistration { + pub node_id: String, + pub advertise_addr: String, + pub cluster_proxy_addr: String, + pub cluster_token: String, + pub node_tags: Vec, + pub tasks: Vec, +} + +pub fn registration_from_api(api: &EngineAPI) -> Result, String> { + let cfg = &api.cfg.config_toml; + let Some(cluster_proxy_addr) = cfg.cluster_proxy_addr.clone() else { + return Ok(None); + }; + + let node_id = cfg + .node_id + .clone() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| "clustered backend requires config.server.node_id".to_string())?; + if node_id.contains('@') { + return Err("clustered backend node_id must not contain '@'".into()); + } + + let cluster_token = cfg + .cluster_token + .clone() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| "clustered backend requires config.server.cluster_token".to_string())?; + + let advertise_addr = + normalize_advertise_addr(cfg.advertise_addr.as_deref(), cfg.host.as_str())?; + + validate_endpoint(&cluster_proxy_addr, "cluster_proxy_addr")?; + validate_endpoint(&advertise_addr, "advertise_addr")?; + + let mut tasks: Vec = api + .task_registry + .tasks + .keys() + .map(|(namespace, task)| format!("{namespace}:{task}")) + .collect(); + tasks.sort(); + tasks.dedup(); + + let mut node_tags = cfg.node_tags.clone().unwrap_or_default(); + node_tags.sort(); + node_tags.dedup(); + + Ok(Some(NodeRegistration { + node_id, + advertise_addr, + cluster_proxy_addr, + cluster_token, + node_tags, + tasks, + })) +} + +pub fn normalize_advertise_addr( + advertise_addr: Option<&str>, + host: &str, +) -> Result { + let value = advertise_addr.unwrap_or(host).trim(); + if value.is_empty() { + return Err("advertise address must not be empty".into()); + } + + if value.starts_with("http://") || value.starts_with("https://") { + Ok(value.to_string()) + } else { + Ok(format!("http://{value}")) + } +} + +pub fn spawn_registration( + registration: NodeRegistration, + shutdown: CancellationToken, +) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + if shutdown.is_cancelled() { + return; + } + match ClusterClient::connect(registration.cluster_proxy_addr.clone()).await { + Ok(mut client) => { + let register_response = match register_node(&mut client, ®istration).await { + Ok(response) => response, + Err(err) => { + warn!( + "cluster registration failed for {}: {}", + registration.node_id, err + ); + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(1)) => {} + } + continue; + } + }; + + info!( + "registered backend node {} with proxy {}", + registration.node_id, registration.cluster_proxy_addr + ); + + let session_id = register_response.session_id; + let interval_seconds = register_response.heartbeat_interval_seconds.max(1); + + loop { + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(interval_seconds)) => {} + } + match heartbeat(&mut client, ®istration, &session_id).await { + Ok(_) => {} + Err(err) => { + warn!( + "cluster heartbeat failed for {}: {}", + registration.node_id, err + ); + break; + } + } + } + } + Err(err) => { + error!( + "failed to connect backend node {} to cluster proxy {}: {}", + registration.node_id, registration.cluster_proxy_addr, err + ); + tokio::select! { + _ = shutdown.cancelled() => return, + _ = sleep(Duration::from_secs(1)) => {} + } + } + } + } + }) +} + +async fn register_node( + client: &mut ClusterClient, + registration: &NodeRegistration, +) -> Result { + let mut request = Request::new(proto::NodeRegisterRequest { + node_id: registration.node_id.clone(), + advertise_addr: registration.advertise_addr.clone(), + tags: registration.node_tags.clone(), + tasks: registration.tasks.clone(), + }); + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", registration.cluster_token) + .parse() + .map_err(|_| tonic::Status::internal("failed to encode cluster authorization"))?, + ); + Ok(client.register_node(request).await?.into_inner()) +} + +async fn heartbeat( + client: &mut ClusterClient, + registration: &NodeRegistration, + session_id: &str, +) -> Result<(), tonic::Status> { + let mut request = Request::new(proto::NodeHeartbeat { + node_id: registration.node_id.clone(), + session_id: session_id.to_string(), + }); + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", registration.cluster_token) + .parse() + .map_err(|_| tonic::Status::internal("failed to encode cluster authorization"))?, + ); + client.heartbeat(request).await?; + Ok(()) +} + +fn validate_endpoint(value: &str, field_name: &str) -> Result<(), String> { + Endpoint::from_shared(value.to_string()) + .map(|_| ()) + .map_err(|err| format!("invalid {field_name} '{value}': {err}")) +} diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 5b529f0..62121ed 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -1,4 +1,14 @@ -use tonic::Request; +use tonic::{ + Request, + metadata::{KeyAndValueRef, MetadataMap}, +}; + +pub mod cluster_client; +pub mod proto; +pub mod proxy_config; +pub mod routing; +pub mod service; +pub mod task_id; pub fn get_uid(req: &Request) -> String { req.metadata() @@ -8,6 +18,19 @@ pub fn get_uid(req: &Request) -> String { .unwrap_or_default() } +pub fn copy_metadata(source: &MetadataMap, target: &mut MetadataMap) { + for entry in source.iter() { + match entry { + KeyAndValueRef::Ascii(key, value) => { + target.insert(key, value.clone()); + } + KeyAndValueRef::Binary(key, value) => { + target.insert_bin(key, value.clone()); + } + } + } +} + pub fn get_auth(req: &Request) -> String { req.metadata() .get("authorization") diff --git a/engine/src/proto.rs b/engine/src/proto.rs new file mode 100644 index 0000000..86acc1e --- /dev/null +++ b/engine/src/proto.rs @@ -0,0 +1,4 @@ +// Shared gRPC/protobuf bindings for the engine crate (bins + tests). +tonic::include_proto!("engine"); + +pub const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("engine_descriptor"); diff --git a/engine/src/proxy_config.rs b/engine/src/proxy_config.rs new file mode 100644 index 0000000..0415760 --- /dev/null +++ b/engine/src/proxy_config.rs @@ -0,0 +1,101 @@ +use std::fs; + +use serde::Deserialize; + +fn default_listen() -> String { + "0.0.0.0:50052".into() +} + +fn default_node_ttl_seconds() -> u64 { + 15 +} + +fn default_heartbeat_interval_seconds() -> u64 { + 5 +} + +fn default_max_acquire_hops() -> u32 { + 3 +} + +fn default_admin_fanout_limit() -> u32 { + 5000 +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct RouteRuleToml { + pub r#match: String, + #[serde(default)] + pub require_tags: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProxyConfigToml { + #[serde(default = "default_listen")] + pub listen: String, + pub cluster_token: String, + #[serde(default = "default_node_ttl_seconds")] + pub node_ttl_seconds: u64, + #[serde(default = "default_heartbeat_interval_seconds")] + pub heartbeat_interval_seconds: u64, + #[serde(default = "default_max_acquire_hops")] + pub max_acquire_hops: u32, + #[serde(default = "default_admin_fanout_limit")] + pub admin_fanout_limit: u32, + #[serde(default)] + pub rules: Vec, +} + +impl Default for ProxyConfigToml { + fn default() -> Self { + Self { + listen: default_listen(), + cluster_token: String::new(), + node_ttl_seconds: default_node_ttl_seconds(), + heartbeat_interval_seconds: default_heartbeat_interval_seconds(), + max_acquire_hops: default_max_acquire_hops(), + admin_fanout_limit: default_admin_fanout_limit(), + rules: Vec::new(), + } + } +} + +impl ProxyConfigToml { + pub fn load() -> Result { + let content = match fs::read_to_string("proxy.toml") { + Ok(content) => content, + Err(err) if err.kind() == ErrorKind::NotFound => String::new(), + Err(err) => return Err(format!("Failed to read proxy.toml: {err}")), + }; + + let config = if content.trim().is_empty() { + Self::default() + } else { + toml::from_str(&content).map_err(|err| format!("Failed to parse proxy.toml: {err}"))? + }; + + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> Result<(), String> { + if self.cluster_token.trim().is_empty() { + return Err("proxy cluster_token must not be empty".into()); + } + if self.node_ttl_seconds == 0 { + return Err("proxy node_ttl_seconds must be greater than zero".into()); + } + if self.heartbeat_interval_seconds == 0 { + return Err("proxy heartbeat_interval_seconds must be greater than zero".into()); + } + if self.max_acquire_hops == 0 { + return Err("proxy max_acquire_hops must be greater than zero".into()); + } + if self.admin_fanout_limit == 0 { + return Err("proxy admin_fanout_limit must be greater than zero".into()); + } + Ok(()) + } +} + +use std::io::ErrorKind; diff --git a/engine/src/routing.rs b/engine/src/routing.rs new file mode 100644 index 0000000..d40a7eb --- /dev/null +++ b/engine/src/routing.rs @@ -0,0 +1,207 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::{SystemTime, UNIX_EPOCH}, +}; + +use tokio::sync::RwLock; +use tonic::transport::{Channel, Endpoint}; + +use crate::proxy_config::{ProxyConfigToml, RouteRuleToml}; + +#[derive(Debug, Clone)] +pub struct RouteRequirement { + pub require_tags: Vec, +} + +#[derive(Debug, Clone)] +pub struct RouteRule { + pub namespace_pattern: String, + pub task_pattern: String, + pub requirement: RouteRequirement, +} + +impl RouteRule { + pub fn from_toml(rule: RouteRuleToml) -> Result { + let Some((namespace_pattern, task_pattern)) = rule.r#match.split_once(':') else { + return Err(format!( + "invalid proxy rule match '{}', expected namespace:task", + rule.r#match + )); + }; + + Ok(Self { + namespace_pattern: namespace_pattern.to_string(), + task_pattern: task_pattern.to_string(), + requirement: RouteRequirement { + require_tags: rule.require_tags, + }, + }) + } + + pub fn matches_task(&self, task_key: &str) -> bool { + let Some((namespace, task)) = task_key.split_once(':') else { + return false; + }; + + matches_pattern(&self.namespace_pattern, namespace) + && matches_pattern(&self.task_pattern, task) + } +} + +fn matches_pattern(pattern: &str, value: &str) -> bool { + pattern == "*" || pattern == value +} + +#[derive(Debug)] +pub struct NodeState { + pub node_id: String, + pub advertise_addr: String, + pub tags: Vec, + pub tasks: Vec, + pub session_id: String, + pub channel: Channel, + pub last_seen_unix: AtomicU64, + pub in_flight_create: AtomicU64, + tag_set: BTreeSet, + task_set: BTreeSet, +} + +impl NodeState { + pub fn new( + node_id: String, + advertise_addr: String, + tags: Vec, + tasks: Vec, + session_id: String, + now_unix: u64, + ) -> Result { + let channel = Endpoint::from_shared(advertise_addr.clone()) + .map_err(|err| format!("invalid advertise_addr '{advertise_addr}': {err}"))? + .connect_lazy(); + + Ok(Self { + node_id, + advertise_addr, + tag_set: tags.iter().cloned().collect(), + task_set: tasks.iter().cloned().collect(), + tags, + tasks, + session_id, + channel, + last_seen_unix: AtomicU64::new(now_unix), + in_flight_create: AtomicU64::new(0), + }) + } + + pub fn has_task(&self, task_key: &str) -> bool { + self.task_set.contains(task_key) + } + + pub fn has_tags(&self, tags: &[String]) -> bool { + tags.iter().all(|tag| self.tag_set.contains(tag)) + } + + pub fn last_seen_unix(&self) -> u64 { + self.last_seen_unix.load(Ordering::Relaxed) + } +} + +#[derive(Debug)] +pub struct ProxyState { + pub config: ProxyConfigToml, + pub rules: Vec, + pub nodes: RwLock>>, +} + +impl ProxyState { + pub fn new(config: ProxyConfigToml) -> Result { + let rules = if config.rules.is_empty() { + vec![RouteRule::from_toml(RouteRuleToml { + r#match: "*:*".into(), + require_tags: Vec::new(), + })?] + } else { + let mut rules = Vec::with_capacity(config.rules.len()); + for rule in config.rules.iter().cloned() { + rules.push(RouteRule::from_toml(rule)?); + } + rules + }; + + Ok(Self { + config, + rules, + nodes: RwLock::new(HashMap::new()), + }) + } + + pub async fn upsert_node(&self, node: NodeState) -> Arc { + let node = Arc::new(node); + self.nodes + .write() + .await + .insert(node.node_id.clone(), node.clone()); + node + } + + pub async fn healthy_nodes(&self) -> Vec> { + let mut nodes: Vec<_> = self.nodes.read().await.values().cloned().collect(); + nodes.sort_by(|left, right| left.node_id.cmp(&right.node_id)); + nodes + } + + pub async fn candidate_nodes_for_task(&self, task_key: &str) -> Vec> { + let required_tags = self + .rules + .iter() + .find(|rule| rule.matches_task(task_key)) + .map(|rule| rule.requirement.require_tags.clone()) + .unwrap_or_default(); + + self.healthy_nodes() + .await + .into_iter() + .filter(|node| node.has_task(task_key) && node.has_tags(&required_tags)) + .collect() + } + + pub async fn preferred_nodes_by_tag(&self, tag: &str) -> Vec> { + let nodes = self.healthy_nodes().await; + let mut preferred = Vec::new(); + let mut fallback = Vec::new(); + + for node in nodes { + if node.has_tags(&[tag.to_string()]) { + preferred.push(node); + } else { + fallback.push(node); + } + } + + preferred.extend(fallback); + preferred + } + + pub async fn node_by_id(&self, node_id: &str) -> Option> { + self.nodes.read().await.get(node_id).cloned() + } + + pub async fn reap_stale_nodes(&self, now_unix: u64) { + let ttl_seconds = self.config.node_ttl_seconds; + self.nodes + .write() + .await + .retain(|_, node| now_unix.saturating_sub(node.last_seen_unix()) <= ttl_seconds); + } +} + +pub fn now_unix() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} diff --git a/engine/src/service/mod.rs b/engine/src/service/mod.rs new file mode 100644 index 0000000..44dcc92 --- /dev/null +++ b/engine/src/service/mod.rs @@ -0,0 +1 @@ +pub mod proxy; diff --git a/engine/src/service/proxy.rs b/engine/src/service/proxy.rs new file mode 100644 index 0000000..9ec0705 --- /dev/null +++ b/engine/src/service/proxy.rs @@ -0,0 +1,486 @@ +use std::{ + cmp::min, + sync::{Arc, atomic::Ordering}, + time::Duration, +}; + +use crate::{ + copy_metadata, get_auth, + proto::{self, cluster_server::Cluster, engine_client::EngineClient, engine_server::Engine}, + routing::{NodeState, ProxyState, now_unix}, + task_id::{parse_owner_node, parse_task_key_string}, +}; +use rand::{seq::SliceRandom, thread_rng}; +use tokio::{task::JoinHandle, time::sleep}; +use tonic::{Code, Request, Response, Status}; + +pub struct ProxyService { + state: Arc, +} + +impl ProxyService { + pub fn new(state: Arc) -> Self { + Self { state } + } +} + +pub fn spawn_reaper(state: Arc) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + sleep(Duration::from_secs(1)).await; + state.reap_stale_nodes(now_unix()).await; + } + }) +} + +struct InFlightCreateGuard { + node: Arc, +} + +impl InFlightCreateGuard { + fn new(node: Arc) -> Self { + node.in_flight_create.fetch_add(1, Ordering::SeqCst); + Self { node } + } +} + +impl Drop for InFlightCreateGuard { + fn drop(&mut self) { + self.node.in_flight_create.fetch_sub(1, Ordering::SeqCst); + } +} + +fn require_cluster_auth(request: &Request, cluster_token: &str) -> Result<(), Status> { + let expected = format!("Bearer {cluster_token}"); + if get_auth(request) != expected { + return Err(Status::permission_denied("Invalid cluster authorization")); + } + Ok(()) +} + +fn build_request(request: &Request, message: T) -> Request { + let mut outbound = Request::new(message); + copy_metadata(request.metadata(), outbound.metadata_mut()); + outbound +} + +async fn forward_create( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .create_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_acquire( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .aquire_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_publish( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .publish_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_delete( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .delete_task(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_get_tasks( + node: Arc, + request: &Request, + page: u64, + page_size: u32, +) -> Result { + let mut payload = request.get_ref().clone(); + payload.page = page; + payload.page_size = page_size; + let mut client = EngineClient::new(node.channel.clone()); + Ok(client + .get_tasks(build_request(request, payload)) + .await? + .into_inner()) +} + +async fn forward_task_registry_probe( + node: Arc, + request: &Request, +) -> Result<(), Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .aquire_task_reg(build_request(request, request.get_ref().clone())) + .await?; + Ok(()) +} + +async fn forward_check_auth( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .check_auth(build_request(request, request.get_ref().clone())) + .await +} + +async fn forward_cgrpc( + node: Arc, + request: &Request, +) -> Result, Status> { + let mut client = EngineClient::new(node.channel.clone()); + client + .cgrpc(build_request(request, request.get_ref().clone())) + .await +} + +async fn broadcast_publish( + nodes: Vec>, + request: &Request, +) -> Result, Status> { + let mut last_not_found = None; + for node in nodes { + match forward_publish(node, request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => { + last_not_found = Some(status); + } + Err(status) => return Err(status), + } + } + + Err(last_not_found.unwrap_or_else(|| Status::not_found("Task not found"))) +} + +async fn broadcast_delete( + nodes: Vec>, + request: &Request, +) -> Result, Status> { + let mut last_not_found = None; + for node in nodes { + match forward_delete(node, request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => { + last_not_found = Some(status); + } + Err(status) => return Err(status), + } + } + + Err(last_not_found.unwrap_or_else(|| Status::not_found("Task not found"))) +} + +fn unavailable() -> Status { + Status::unavailable("No healthy backend nodes available") +} + +fn select_create_candidate(mut candidates: Vec>) -> Option> { + if candidates.is_empty() { + return None; + } + let mut rng = thread_rng(); + candidates.shuffle(&mut rng); + if candidates.len() == 1 { + return candidates.into_iter().next(); + } + + let left = candidates[0].clone(); + let right = candidates[1].clone(); + let left_load = left.in_flight_create.load(Ordering::SeqCst); + let right_load = right.in_flight_create.load(Ordering::SeqCst); + if left_load < right_load { + Some(left) + } else if right_load < left_load { + Some(right) + } else if left.node_id <= right.node_id { + Some(left) + } else { + Some(right) + } +} + +#[tonic::async_trait] +impl Engine for ProxyService { + async fn aquire_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key_string(&request.get_ref().task_id)?; + let mut candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + + candidates.shuffle(&mut thread_rng()); + let hops = min( + self.state.config.max_acquire_hops as usize, + candidates.len(), + ); + for node in candidates.into_iter().take(hops) { + match forward_acquire(node, &request).await { + Ok(response) => return Ok(response), + Err(status) if status.code() == Code::NotFound => continue, + Err(status) => return Err(status), + } + } + + Err(Status::not_found("No queued tasks available")) + } + + async fn aquire_task_reg( + &self, + request: Request, + ) -> Result, Status> { + let nodes = self.state.healthy_nodes().await; + let Some(probe_node) = nodes.first().cloned() else { + return Err(unavailable()); + }; + forward_task_registry_probe(probe_node, &request).await?; + + let mut tasks: Vec = nodes + .iter() + .flat_map(|node| node.tasks.iter().cloned()) + .collect(); + tasks.sort(); + tasks.dedup(); + Ok(Response::new(proto::TaskRegistry { tasks })) + } + + async fn publish_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key_string(&request.get_ref().task_id)?; + if let Some((owner_node, _)) = parse_owner_node(&request.get_ref().id) { + if let Some(node) = self.state.node_by_id(owner_node).await { + return forward_publish(node, &request).await; + } + } + + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + broadcast_publish(candidates, &request).await + } + + async fn cgrpc( + &self, + request: Request, + ) -> Result, Status> { + let Some(node) = self + .state + .preferred_nodes_by_tag("control") + .await + .into_iter() + .next() + else { + return Err(unavailable()); + }; + forward_cgrpc(node, &request).await + } + + async fn create_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = parse_task_key_string(&request.get_ref().task_id)?; + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + let Some(node) = select_create_candidate(candidates) else { + return Err(unavailable()); + }; + + let _guard = InFlightCreateGuard::new(node.clone()); + forward_create(node, &request).await + } + + async fn delete_task( + &self, + request: Request, + ) -> Result, Status> { + let task_key = format!("{}:{}", request.get_ref().namespace, request.get_ref().task); + if let Some((owner_node, _)) = parse_owner_node(&request.get_ref().id) { + if let Some(node) = self.state.node_by_id(owner_node).await { + return forward_delete(node, &request).await; + } + } + + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + broadcast_delete(candidates, &request).await + } + + async fn get_tasks( + &self, + request: Request, + ) -> Result, Status> { + let task_key = format!("{}:{}", request.get_ref().namespace, request.get_ref().task); + let candidates = self.state.candidate_nodes_for_task(&task_key).await; + if candidates.is_empty() { + return Err(unavailable()); + } + + let requested_end = request + .get_ref() + .page + .saturating_add(1) + .saturating_mul(request.get_ref().page_size as u64); + let fanout_limit = min(requested_end, self.state.config.admin_fanout_limit as u64) as u32; + + let mut tasks = Vec::new(); + let mut join_set = tokio::task::JoinSet::new(); + for node in candidates { + let outbound_request = build_request(&request, request.get_ref().clone()); + join_set.spawn(async move { + forward_get_tasks(node, &outbound_request, 0, fanout_limit).await + }); + } + + while let Some(joined) = join_set.join_next().await { + match joined { + Ok(Ok(page)) => tasks.extend(page.tasks), + Ok(Err(status)) => return Err(status), + Err(err) => { + return Err(Status::internal(format!( + "proxy get_tasks join error: {err}" + ))); + } + } + } + + tasks.sort_by(|left, right| left.id.cmp(&right.id)); + let start = request + .get_ref() + .page + .saturating_mul(request.get_ref().page_size as u64) as usize; + let end = start.saturating_add(request.get_ref().page_size as usize); + let page_tasks = tasks + .into_iter() + .skip(start) + .take(end.saturating_sub(start)) + .collect(); + + Ok(Response::new(proto::TaskPage { + namespace: request.get_ref().namespace.clone(), + task: request.get_ref().task.clone(), + page: request.get_ref().page, + page_size: request.get_ref().page_size, + state: request.get_ref().state, + tasks: page_tasks, + })) + } + + async fn check_auth( + &self, + request: Request, + ) -> Result, Status> { + let Some(node) = self + .state + .preferred_nodes_by_tag("auth") + .await + .into_iter() + .next() + else { + return Err(unavailable()); + }; + forward_check_auth(node, &request).await + } +} + +#[tonic::async_trait] +impl Cluster for ProxyService { + async fn register_node( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let node = request.get_ref(); + if node.node_id.trim().is_empty() { + return Err(Status::invalid_argument("node_id must not be empty")); + } + if node.node_id.contains('@') { + return Err(Status::invalid_argument("node_id must not contain '@'")); + } + + let session_id = druid::Druid::default().to_hex(); + let state = NodeState::new( + node.node_id.clone(), + node.advertise_addr.clone(), + node.tags.clone(), + node.tasks.clone(), + session_id.clone(), + now_unix(), + ) + .map_err(Status::invalid_argument)?; + self.state.upsert_node(state).await; + + Ok(Response::new(proto::NodeRegisterResponse { + session_id, + heartbeat_interval_seconds: self.state.config.heartbeat_interval_seconds, + ttl_seconds: self.state.config.node_ttl_seconds, + })) + } + + async fn heartbeat( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let heartbeat = request.get_ref(); + let Some(node) = self.state.node_by_id(&heartbeat.node_id).await else { + return Err(Status::not_found("Unknown node")); + }; + if node.session_id != heartbeat.session_id { + return Err(Status::permission_denied("Invalid cluster session")); + } + + node.last_seen_unix.store(now_unix(), Ordering::Relaxed); + Ok(Response::new(proto::NodeHeartbeatAck { + server_time_unix: now_unix(), + })) + } + + async fn list_nodes( + &self, + request: Request, + ) -> Result, Status> { + require_cluster_auth(&request, &self.state.config.cluster_token)?; + + let nodes = self + .state + .healthy_nodes() + .await + .into_iter() + .map(|node| proto::NodeInfo { + node_id: node.node_id.clone(), + advertise_addr: node.advertise_addr.clone(), + tags: node.tags.clone(), + tasks: node.tasks.clone(), + last_seen_unix: node.last_seen_unix(), + healthy: true, + }) + .collect(); + Ok(Response::new(proto::NodeList { nodes })) + } +} diff --git a/engine/src/task_id.rs b/engine/src/task_id.rs new file mode 100644 index 0000000..d43d3b0 --- /dev/null +++ b/engine/src/task_id.rs @@ -0,0 +1,39 @@ +use enginelib::Identifier; +use tonic::Status; + +pub fn mint_task_instance_id(node_id: Option<&str>) -> String { + let random_hex = druid::Druid::default().to_hex(); + match node_id.filter(|value| !value.is_empty()) { + Some(node_id) => format!("{node_id}@{random_hex}"), + None => random_hex, + } +} + +pub fn parse_owner_node(task_instance_id: &str) -> Option<(&str, &str)> { + let (node_id, local_id) = task_instance_id.split_once('@')?; + if node_id.is_empty() || local_id.is_empty() || local_id.contains('@') { + return None; + } + Some((node_id, local_id)) +} + +pub fn parse_task_key(task_id: &str) -> Result { + let Some((namespace, task)) = task_id.split_once(':') else { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + }; + + if namespace.is_empty() || task.is_empty() { + return Err(Status::invalid_argument( + "Invalid task ID format, expected 'namespace:task'", + )); + } + + Ok((namespace.to_string(), task.to_string())) +} + +pub fn parse_task_key_string(task_id: &str) -> Result { + let (namespace, task) = parse_task_key(task_id)?; + Ok(format!("{namespace}:{task}")) +} diff --git a/engine/tests/common/mod.rs b/engine/tests/common/mod.rs new file mode 100644 index 0000000..2307d9a --- /dev/null +++ b/engine/tests/common/mod.rs @@ -0,0 +1,322 @@ +use std::{error::Error, sync::Arc, time::Duration}; + +#[path = "../../src/bin/server.rs"] +mod server_bin; + +use engine::{ + cluster_client::{registration_from_api, spawn_registration}, + proto::{self, cluster_client::ClusterClient, engine_client::EngineClient}, + proxy_config::{ProxyConfigToml, RouteRuleToml}, + routing::ProxyState, + service::proxy::{ProxyService, spawn_reaper}, +}; +use enginelib::{ + Registry, + api::{EngineAPI, postcard}, + event::register_inventory_handlers, + events::ID, + task::{Task, Verifiable}, +}; +use serde::{Deserialize, Serialize}; +use tokio::{ + net::TcpListener, + sync::{RwLock, oneshot}, + task::JoinHandle, + time::{sleep, timeout}, +}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_util::sync::CancellationToken; +use tonic::{Request, transport::Server}; + +pub type BoxError = Box; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TestTask { + pub value: u32, + pub id: (String, String), +} + +impl Verifiable for TestTask { + fn verify(&self, bytes: Vec) -> bool { + postcard::from_bytes::(&bytes).is_ok() + } +} + +impl Task for TestTask { + fn get_id(&self) -> (String, String) { + self.id.clone() + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn to_bytes(&self) -> Vec { + postcard::to_allocvec(self).unwrap() + } + + fn from_bytes(&self, bytes: &[u8]) -> Box { + Box::new(postcard::from_bytes::(bytes).unwrap()) + } + + fn from_toml(&self, _d: String) -> Box { + Box::new(self.clone()) + } + + fn to_toml(&self) -> String { + String::new() + } +} + +pub struct TestBackend { + pub api: Arc>, + shutdown: Option>, + server_task: JoinHandle>, + registration_shutdown: Option, + registration_task: Option>, +} + +impl TestBackend { + pub async fn shutdown(mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + if let Some(shutdown) = self.registration_shutdown.take() { + shutdown.cancel(); + } + if let Some(task) = self.registration_task.take() { + let _ = task.await; + } + let _ = self.server_task.await; + } +} + +pub struct TestProxy { + pub addr: String, + shutdown: Option>, + server_task: JoinHandle>, + reaper_task: JoinHandle<()>, +} + +impl TestProxy { + pub async fn shutdown(mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + self.reaper_task.abort(); + let _ = self.server_task.await; + } +} + +pub async fn spawn_proxy() -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let config = ProxyConfigToml { + listen: addr.to_string(), + cluster_token: "cluster-secret".into(), + node_ttl_seconds: 2, + heartbeat_interval_seconds: 1, + max_acquire_hops: 3, + admin_fanout_limit: 5000, + rules: vec![ + RouteRuleToml { + r#match: "ml:*".into(), + require_tags: vec!["gpu".into()], + }, + RouteRuleToml { + r#match: "*:*".into(), + require_tags: Vec::new(), + }, + ], + }; + let state = Arc::new(ProxyState::new(config)?); + let reaper_task = spawn_reaper(state.clone()); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let cluster_service = ProxyService::new(state.clone()); + let engine_service = ProxyService::new(state); + let server_task = tokio::spawn(async move { + Server::builder() + .add_service(proto::cluster_server::ClusterServer::new(cluster_service)) + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + }); + + Ok(TestProxy { + addr: format!("http://{addr}"), + shutdown: Some(shutdown_tx), + server_task, + reaper_task, + }) +} + +pub async fn spawn_backend( + node_id: &str, + tags: &[&str], + proxy_addr: &str, +) -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let mut api = EngineAPI::test_default(); + register_inventory_handlers(&mut api); + api.cfg.config_toml.host = addr.to_string(); + api.cfg.config_toml.node_id = Some(node_id.to_string()); + api.cfg.config_toml.advertise_addr = Some(format!("http://{addr}")); + api.cfg.config_toml.cluster_proxy_addr = Some(proxy_addr.to_string()); + api.cfg.config_toml.cluster_token = Some("cluster-secret".into()); + api.cfg.config_toml.node_tags = Some(tags.iter().map(|tag| (*tag).to_string()).collect()); + + register_task(&mut api, "dist", "work"); + register_task(&mut api, "ml", "train"); + register_task(&mut api, "node", node_id); + + let api = Arc::new(RwLock::new(api)); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let engine_service = server_bin::BackendEngineService::new(api.clone()); + let server_task = tokio::spawn(async move { + Server::builder() + .add_service(proto::engine_server::EngineServer::new(engine_service)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + }); + + let registration_shutdown = CancellationToken::new(); + let registration_task = { + let api_guard = api.read().await; + registration_from_api(&api_guard)? + } + .map(|registration| spawn_registration(registration, registration_shutdown.clone())); + + wait_for_node(proxy_addr, node_id).await?; + + Ok(TestBackend { + api, + shutdown: Some(shutdown_tx), + server_task, + registration_shutdown: Some(registration_shutdown), + registration_task, + }) +} + +pub fn task_bytes(value: u32, namespace: &str, task: &str) -> Vec { + TestTask { + value, + id: ID(namespace, task), + } + .to_bytes() +} + +pub async fn engine_client( + addr: &str, +) -> Result, tonic::transport::Error> { + EngineClient::connect(addr.to_string()).await +} + +pub async fn cluster_client( + addr: &str, +) -> Result, tonic::transport::Error> { + ClusterClient::connect(addr.to_string()).await +} + +pub fn worker_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request.metadata_mut().insert("uid", "worker-1".parse()?); + request + .metadata_mut() + .insert("authorization", "worker-token".parse()?); + Ok(request) +} + +pub fn admin_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request + .metadata_mut() + .insert("authorization", "admin-token".parse()?); + Ok(request) +} + +pub fn cluster_request(message: T) -> Result, BoxError> { + let mut request = Request::new(message); + request + .metadata_mut() + .insert("authorization", "Bearer cluster-secret".parse()?); + Ok(request) +} + +pub async fn wait_for_node(proxy_addr: &str, node_id: &str) -> Result<(), BoxError> { + timeout(Duration::from_secs(5), async move { + loop { + let mut client = match cluster_client(proxy_addr).await { + Ok(client) => client, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + let response = match client.list_nodes(cluster_request(proto::Empty {})?).await { + Ok(response) => response, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + if response + .get_ref() + .nodes + .iter() + .any(|node| node.node_id == node_id) + { + return Ok::<(), BoxError>(()); + } + sleep(Duration::from_millis(50)).await; + } + }) + .await??; + Ok(()) +} + +pub async fn wait_for_node_count(proxy_addr: &str, expected: usize) -> Result<(), BoxError> { + timeout(Duration::from_secs(5), async move { + loop { + let mut client = match cluster_client(proxy_addr).await { + Ok(client) => client, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + let response = match client.list_nodes(cluster_request(proto::Empty {})?).await { + Ok(response) => response, + Err(_) => { + sleep(Duration::from_millis(50)).await; + continue; + } + }; + if response.get_ref().nodes.len() == expected { + return Ok::<(), BoxError>(()); + } + sleep(Duration::from_millis(50)).await; + } + }) + .await??; + Ok(()) +} + +fn register_task(api: &mut EngineAPI, namespace: &str, task: &str) { + let id = ID(namespace, task); + api.task_registry.register( + Arc::new(TestTask { + value: 0, + id: id.clone(), + }), + id.clone(), + ); + api.task_queue.tasks.entry(id.clone()).or_default(); + api.executing_tasks.tasks.entry(id.clone()).or_default(); + api.solved_tasks.tasks.entry(id).or_default(); +} diff --git a/engine/tests/proxy_cluster.rs b/engine/tests/proxy_cluster.rs new file mode 100644 index 0000000..e17d7b0 --- /dev/null +++ b/engine/tests/proxy_cluster.rs @@ -0,0 +1,509 @@ +mod common; + +use common::{ + BoxError, admin_request, cluster_client, cluster_request, engine_client, spawn_backend, + spawn_proxy, task_bytes, wait_for_node_count, worker_request, +}; +use engine::proto; +use enginelib::{ + chrono::Utc, + events::ID, + task::{StoredExecutingTask, StoredTask}, +}; +use tokio::task::JoinSet; +use tonic::Code; + +#[tokio::test] +async fn create_task_distributes_across_nodes() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut join_set = JoinSet::new(); + for value in 0..32 { + let proxy_addr = proxy.addr.clone(); + join_set.spawn(async move { + let mut client = engine_client(&proxy_addr).await.unwrap(); + client + .create_task( + worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(value, "dist", "work"), + payload: Vec::new(), + }) + .unwrap(), + ) + .await + }); + } + while let Some(result) = join_set.join_next().await { + result??; + } + + let queued_a = backend_a + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued_b = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + assert!(queued_a > 0, "expected node-a to receive tasks"); + assert!(queued_b > 0, "expected node-b to receive tasks"); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn aquire_task_retries_until_a_node_has_work() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredTask { + id: "node-b@queued-1".into(), + bytes: task_bytes(1, "dist", "work"), + }); + + let mut client = engine_client(&proxy.addr).await?; + let response = client + .aquire_task(worker_request(proto::TaskRequest { + task_id: "dist:work".into(), + })?) + .await? + .into_inner(); + + assert_eq!(response.id, "node-b@queued-1"); + assert_eq!( + backend_b + .api + .read() + .await + .executing_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(), + 1 + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn publish_task_routes_by_owner_prefix() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let created = client + .create_task(worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(7, "dist", "work"), + payload: Vec::new(), + })?) + .await? + .into_inner(); + let acquired = client + .aquire_task(worker_request(proto::TaskRequest { + task_id: "dist:work".into(), + })?) + .await? + .into_inner(); + assert_eq!(created.id, acquired.id); + + client + .publish_task(worker_request(proto::Task { + id: acquired.id.clone(), + task_id: acquired.task_id.clone(), + task_payload: task_bytes(99, "dist", "work"), + payload: Vec::new(), + })?) + .await?; + + let owner = acquired.id.split('@').next().unwrap(); + let solved_a = backend_a + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let solved_b = backend_b + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + match owner { + "node-a" => { + assert_eq!(solved_a, 1); + assert_eq!(solved_b, 0); + } + "node-b" => { + assert_eq!(solved_a, 0); + assert_eq!(solved_b, 1); + } + _ => panic!("unexpected owner {owner}"), + } + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn delete_task_routes_by_owner_prefix() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let created = client + .create_task(worker_request(proto::Task { + id: String::new(), + task_id: "dist:work".into(), + task_payload: task_bytes(8, "dist", "work"), + payload: Vec::new(), + })?) + .await? + .into_inner(); + + client + .delete_task(admin_request(proto::TaskSelector { + state: proto::TaskState::Queued as i32, + namespace: "dist".into(), + task: "work".into(), + id: created.id.clone(), + })?) + .await?; + + let owner = created.id.split('@').next().unwrap(); + let queued_a = backend_a + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued_b = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + match owner { + "node-a" => assert_eq!(queued_a + queued_b, 0), + "node-b" => assert_eq!(queued_a + queued_b, 0), + _ => panic!("unexpected owner {owner}"), + } + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn aquire_task_reg_returns_union_of_registered_tasks() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + let mut client = engine_client(&proxy.addr).await?; + let mut tasks = client + .aquire_task_reg(worker_request(proto::Empty {})?) + .await? + .into_inner() + .tasks; + tasks.sort(); + + assert_eq!( + tasks, + vec![ + "dist:work".to_string(), + "ml:train".to_string(), + "node:node-a".to_string(), + "node:node-b".to_string(), + ] + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn get_tasks_fanout_merges_and_paginates() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_a + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .extend([ + StoredTask { + id: "node-a@0001".into(), + bytes: task_bytes(1, "dist", "work"), + }, + StoredTask { + id: "node-a@0003".into(), + bytes: task_bytes(3, "dist", "work"), + }, + ]); + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .extend([ + StoredTask { + id: "node-b@0002".into(), + bytes: task_bytes(2, "dist", "work"), + }, + StoredTask { + id: "node-b@0004".into(), + bytes: task_bytes(4, "dist", "work"), + }, + ]); + + let mut client = engine_client(&proxy.addr).await?; + let page0 = client + .get_tasks(admin_request(proto::TaskPageRequest { + namespace: "dist".into(), + task: "work".into(), + page: 0, + page_size: 2, + state: proto::TaskState::Queued as i32, + })?) + .await? + .into_inner(); + let page1 = client + .get_tasks(admin_request(proto::TaskPageRequest { + namespace: "dist".into(), + task: "work".into(), + page: 1, + page_size: 2, + state: proto::TaskState::Queued as i32, + })?) + .await? + .into_inner(); + + assert_eq!( + page0 + .tasks + .iter() + .map(|task| task.id.clone()) + .collect::>(), + vec!["node-a@0001".to_string(), "node-a@0003".to_string()] + ); + assert_eq!( + page1 + .tasks + .iter() + .map(|task| task.id.clone()) + .collect::>(), + vec!["node-b@0002".to_string(), "node-b@0004".to_string()] + ); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn cluster_membership_expires_after_heartbeats_stop() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + wait_for_node_count(&proxy.addr, 2).await?; + backend_b.shutdown().await; + wait_for_node_count(&proxy.addr, 1).await?; + + let mut client = cluster_client(&proxy.addr).await?; + let nodes = client + .list_nodes(cluster_request(proto::Empty {})?) + .await? + .into_inner() + .nodes; + assert_eq!(nodes.len(), 1); + assert_eq!(nodes[0].node_id, "node-a"); + + backend_a.shutdown().await; + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn reregister_replaces_prior_session_id() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let mut client = cluster_client(&proxy.addr).await?; + + let first = client + .register_node(cluster_request(proto::NodeRegisterRequest { + node_id: "node-z".into(), + advertise_addr: "http://127.0.0.1:59999".into(), + tags: vec!["gpu".into()], + tasks: vec!["dist:work".into()], + })?) + .await? + .into_inner(); + let second = client + .register_node(cluster_request(proto::NodeRegisterRequest { + node_id: "node-z".into(), + advertise_addr: "http://127.0.0.1:59999".into(), + tags: vec!["gpu".into()], + tasks: vec!["dist:work".into()], + })?) + .await? + .into_inner(); + + assert_ne!(first.session_id, second.session_id); + + let old_session = client + .heartbeat(cluster_request(proto::NodeHeartbeat { + node_id: "node-z".into(), + session_id: first.session_id, + })?) + .await + .unwrap_err(); + assert_eq!(old_session.code(), Code::PermissionDenied); + + client + .heartbeat(cluster_request(proto::NodeHeartbeat { + node_id: "node-z".into(), + session_id: second.session_id, + })?) + .await?; + + proxy.shutdown().await; + Ok(()) +} + +#[tokio::test] +async fn legacy_ids_use_broadcast_fallback_for_publish_and_delete() -> Result<(), BoxError> { + let proxy = spawn_proxy().await?; + let backend_a = spawn_backend("node-a", &["auth", "control"], &proxy.addr).await?; + let backend_b = spawn_backend("node-b", &["gpu"], &proxy.addr).await?; + + backend_b + .api + .write() + .await + .executing_tasks + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredExecutingTask { + bytes: task_bytes(11, "dist", "work"), + id: "legacy-1".into(), + user_id: "worker-1".into(), + given_at: Utc::now(), + }); + backend_b + .api + .write() + .await + .task_queue + .tasks + .get_mut(&ID("dist", "work")) + .unwrap() + .push(StoredTask { + id: "legacy-2".into(), + bytes: task_bytes(12, "dist", "work"), + }); + + let mut client = engine_client(&proxy.addr).await?; + client + .publish_task(worker_request(proto::Task { + id: "legacy-1".into(), + task_id: "dist:work".into(), + task_payload: task_bytes(111, "dist", "work"), + payload: Vec::new(), + })?) + .await?; + client + .delete_task(admin_request(proto::TaskSelector { + state: proto::TaskState::Queued as i32, + namespace: "dist".into(), + task: "work".into(), + id: "legacy-2".into(), + })?) + .await?; + + let solved = backend_b + .api + .read() + .await + .solved_tasks + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + let queued = backend_b + .api + .read() + .await + .task_queue + .tasks + .get(&ID("dist", "work")) + .unwrap() + .len(); + assert_eq!(solved, 1); + assert_eq!(queued, 0); + + backend_a.shutdown().await; + backend_b.shutdown().await; + proxy.shutdown().await; + Ok(()) +} diff --git a/enginelib/macros/src/lib.rs b/enginelib/macros/src/lib.rs index 5730298..bf2f01a 100644 --- a/enginelib/macros/src/lib.rs +++ b/enginelib/macros/src/lib.rs @@ -47,11 +47,13 @@ pub fn module(_attr: TokenStream, item: TokenStream) -> TokenStream { env!("CARGO_PKG_NAME"), );), ); + item_fn.block.stmts.insert( + 0, + parse_quote!(::enginelib::api::EngineAPI::setup_logger();), + ); item_fn - .block - .stmts - .insert(0, parse_quote!(::enginelib::api::EngineAPI::setup_logger();)); - item_fn.attrs.push(parse_quote!(#[unsafe(export_name="run")])); + .attrs + .push(parse_quote!(#[unsafe(export_name="run")])); quote!(#item_fn).into() } @@ -181,7 +183,10 @@ impl Parse for EventHandlerArgs { } let namespace = namespace.ok_or_else(|| { - syn::Error::new(proc_macro2::Span::call_site(), "missing `namespace = \"...\"`") + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `namespace = \"...\"`", + ) })?; let name = name.ok_or_else(|| { syn::Error::new(proc_macro2::Span::call_site(), "missing `name = \"...\"`") diff --git a/enginelib/src/api.rs b/enginelib/src/api.rs index 368eebd..0f83ba2 100644 --- a/enginelib/src/api.rs +++ b/enginelib/src/api.rs @@ -153,20 +153,20 @@ impl EngineAPI { static INIT: OnceLock<()> = OnceLock::new(); INIT.get_or_init(|| { - #[cfg(debug_assertions)] + #[cfg(debug_assertions)] let _ = tracing_subscriber::FmtSubscriber::builder() - // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) - // will be written to stdout. - .with_max_level(Level::DEBUG) - // builds the subscriber. - .try_init(); - #[cfg(not(debug_assertions))] + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::DEBUG) + // builds the subscriber. + .try_init(); + #[cfg(not(debug_assertions))] let _ = tracing_subscriber::FmtSubscriber::builder() - // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) - // will be written to stdout. - .with_max_level(Level::INFO) - // builds the subscriber. - .try_init(); + // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.) + // will be written to stdout. + .with_max_level(Level::INFO) + // builds the subscriber. + .try_init(); }); } } diff --git a/enginelib/src/config.rs b/enginelib/src/config.rs index 4e2a18a..b5963bc 100644 --- a/enginelib/src/config.rs +++ b/enginelib/src/config.rs @@ -1,4 +1,4 @@ -use std::{fs, io::Error, u32}; +use std::{fmt, fs, io::Error, u32}; use serde::{Deserialize, Serialize}; use tracing::{error, instrument}; @@ -15,7 +15,7 @@ fn default_pagination_limit() -> u32 { u32::MAX } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone)] pub struct ConfigTomlServer { #[serde(default)] pub cgrpc_token: Option, // Administrator Token, used to invoke cgrpc reqs. If not preset will default to no protection. @@ -25,7 +25,40 @@ pub struct ConfigTomlServer { pub clean_tasks: u64, #[serde(default = "default_pagination_limit")] pub pagination_limit: u32, + #[serde(default)] + pub node_id: Option, + #[serde(default)] + pub advertise_addr: Option, + #[serde(default)] + pub cluster_proxy_addr: Option, + #[serde(default)] + pub cluster_token: Option, + #[serde(default)] + pub node_tags: Option>, } + +impl fmt::Debug for ConfigTomlServer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConfigTomlServer") + .field( + "cgrpc_token", + &self.cgrpc_token.as_ref().map(|_| ""), + ) + .field("host", &self.host) + .field("clean_tasks", &self.clean_tasks) + .field("pagination_limit", &self.pagination_limit) + .field("node_id", &self.node_id) + .field("advertise_addr", &self.advertise_addr) + .field("cluster_proxy_addr", &self.cluster_proxy_addr) + .field( + "cluster_token", + &self.cluster_token.as_ref().map(|_| ""), + ) + .field("node_tags", &self.node_tags) + .finish() + } +} + impl Default for ConfigTomlServer { fn default() -> Self { Self { @@ -33,6 +66,11 @@ impl Default for ConfigTomlServer { cgrpc_token: None, clean_tasks: 60, pagination_limit: u32::MAX, + node_id: None, + advertise_addr: None, + cluster_proxy_addr: None, + cluster_token: None, + node_tags: None, } } } diff --git a/enginelib/src/events/mod.rs b/enginelib/src/events/mod.rs index dbc9f54..acec1a9 100644 --- a/enginelib/src/events/mod.rs +++ b/enginelib/src/events/mod.rs @@ -22,7 +22,12 @@ impl Events { auth_event::AuthEvent::check(api, uid, challenge, db) } - pub fn CheckAdminAuth(api: &mut EngineAPI, payload: String, target: Identifier, db: Db) -> bool { + pub fn CheckAdminAuth( + api: &mut EngineAPI, + payload: String, + target: Identifier, + db: Db, + ) -> bool { admin_auth_event::AdminAuthEvent::check(api, payload, target, db) } diff --git a/rfc/rfc1004.md b/rfc/rfc1004.md new file mode 100644 index 0000000..0979483 --- /dev/null +++ b/rfc/rfc1004.md @@ -0,0 +1,157 @@ +# RFC 1004: Proxy and Clustering + +## Summary + +Add a gRPC proxy in front of engine backends plus a lightweight cluster membership protocol so: + +- clients and workers talk to one stable endpoint, +- tasks can be routed to capability-specific nodes, +- task instances can be distributed across multiple nodes, +- `PublishTask` and `DeleteTask` can route back to the owning node without sticky proxy state. + +The design keeps backend storage shard-local. Each backend keeps its own sled DB and dynamically registers itself with one or more stateless proxies. + +## Motivation + +The current engine exposes a single backend directly. That creates three constraints: + +- all workers must know about a specific backend, +- there is no built-in way to spread the same `namespace:task` across multiple nodes, +- follow-up RPCs like `PublishTask` depend on whichever backend created the task. + +By making `Task.id` self-routing and adding in-memory cluster membership, the proxy can stay stateless while still routing task lifecycle RPCs correctly. + +## Design + +### Task instance IDs + +`Task.id` is now the task instance identifier and, when clustering is enabled, is minted as: + +```text +{node_id}@{random_hex} +``` + +Legacy ids without `@` remain supported. The proxy falls back to broadcast routing for `PublishTask` and `DeleteTask` when it sees a legacy id. + +### Cluster service + +`engine/proto/engine.proto` now contains a `Cluster` gRPC service with: + +- `RegisterNode` +- `Heartbeat` +- `ListNodes` + +These RPCs are used only between the proxy and backends. All three require `authorization: Bearer `. + +### Backend registration + +On startup, a backend: + +1. initializes `EngineAPI`, +2. loads tasks and modules, +3. snapshots its local task registry into `namespace:task` strings, +4. registers with the proxy, +5. sends heartbeats on the server-provided interval. + +If heartbeats fail, the backend reconnects and re-registers to get a fresh session id. + +### Proxy routing + +The proxy stores node membership entirely in memory and applies first-match routing rules from `proxy.toml`. + +Task-specific RPCs: + +- filter to healthy nodes, +- require the node to advertise the task key, +- require the node to satisfy the matched rule tags. + +Routing strategies: + +- `CreateTask`: power-of-two choices on in-flight create counts +- `AquireTask`: random candidate order, retry on `not_found` +- `PublishTask`: owner-node route when `Task.id` is prefixed, otherwise broadcast fallback +- `DeleteTask`: owner-node route when selector `id` is prefixed, otherwise broadcast fallback +- `GetTasks`: fan-out, merge by lexicographic task id, paginate after merge +- `AquireTaskReg`: sorted unique union of registered tasks +- `CheckAuth`: prefer nodes tagged `auth` +- `cgrpc`: prefer nodes tagged `control` + +### Config + +Backend `config.toml` adds: + +- `node_id` +- `advertise_addr` +- `cluster_proxy_addr` +- `cluster_token` +- `node_tags` + +Proxy `proxy.toml` contains: + +- `listen` +- `cluster_token` +- `node_ttl_seconds` +- `heartbeat_interval_seconds` +- `max_acquire_hops` +- `admin_fanout_limit` +- ordered `rules` + +Example proxy config: + +```toml +listen = "0.0.0.0:50052" +cluster_token = "cluster-secret" +node_ttl_seconds = 15 +heartbeat_interval_seconds = 5 +max_acquire_hops = 3 +admin_fanout_limit = 5000 + +[[rules]] +match = "ml:*" +require_tags = ["gpu"] + +[[rules]] +match = "*:*" +require_tags = [] +``` + +Example backend config additions: + +```toml +host = "127.0.0.1:50051" +node_id = "gpu-1" +advertise_addr = "http://127.0.0.1:50051" +cluster_proxy_addr = "http://127.0.0.1:50052" +cluster_token = "cluster-secret" +node_tags = ["gpu", "control"] +``` + +## Backwards compatibility + +This RFC is additive for clients: + +- the existing `Engine` service surface is unchanged, +- task ids remain opaque to workers, +- legacy ids without `@` still work through proxy broadcast fallback. + +Proxy membership is intentionally in-memory only. If a backend disappears, its local tasks are unavailable until it returns. + +## Alternatives + +1. Keep clients pinned to specific backends. + This avoids a proxy, but it does not solve distribution or owner routing. + +2. Store proxy membership in shared durable storage. + This would support persistent coordination, but it is not required for the current shard-local design. + +3. Add distributed task migration or failover. + This would improve recovery, but it is explicitly out of scope for this RFC. + +## Implementation plan + +1. Refactor backend `Engine` RPC handling into reusable library services. +2. Add proxy config loading, in-memory node membership, and cluster-authenticated registration. +3. Introduce a proxy binary implementing both `Engine` and `Cluster`. +4. Update backend startup to validate cluster config and start registration plus heartbeats. +5. Change backend task lifecycle handling so `Task.id` is always the instance id. +6. Add integration tests for distribution, routing, pagination, membership expiry, re-registration, and legacy fallback.