diff --git a/.gitignore b/.gitignore index 9aecd384..479b6815 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ testdata/clickbench/* !testdata/clickbench/queries src/observability/gen/target/* src/observability/gen/Cargo.lock +src/worker/gen/target/* +src/worker/gen/Cargo.lock diff --git a/benchmarks/cdk/bin/worker.rs b/benchmarks/cdk/bin/worker.rs index 978522de..8ab21c1a 100644 --- a/benchmarks/cdk/bin/worker.rs +++ b/benchmarks/cdk/bin/worker.rs @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box> { let mut errors = vec![]; for worker_url in worker_resolver.get_urls().map_err(err)? { if let Err(err) = channel_resolver - .get_flight_client_for_url(&worker_url) + .get_worker_client_for_url(&worker_url) .await { errors.push(err.to_string()) @@ -208,7 +208,7 @@ async fn main() -> Result<(), Box> { let ec2_worker_resolver = Arc::new(Ec2WorkerResolver::new()); let grpc_server = Server::builder() .add_service(worker.with_observability_service(ec2_worker_resolver)) - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(WORKER_ADDR.parse()?); info!("Started listener HTTP server in {LISTENER_ADDR}"); diff --git a/benchmarks/src/run.rs b/benchmarks/src/run.rs index 6b40c3ad..3b5c26d8 100644 --- a/benchmarks/src/run.rs +++ b/benchmarks/src/run.rs @@ -163,7 +163,7 @@ impl RunOpt { let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); Ok::<_, Box>( Server::builder() - .add_service(Worker::default().into_flight_server()) + .add_service(Worker::default().into_worker_server()) .serve_with_incoming(incoming) .await?, ) diff --git a/console/examples/cluster.rs b/console/examples/cluster.rs index cb541240..821a24db 100644 --- a/console/examples/cluster.rs +++ b/console/examples/cluster.rs @@ -56,7 +56,7 @@ async fn main() -> Result<(), Box> { Server::builder() .add_service(worker.with_observability_service(resolver)) - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await .expect("worker server failed"); diff --git a/console/examples/console_worker.rs b/console/examples/console_worker.rs index 80178e4d..01cc1dca 100644 --- a/console/examples/console_worker.rs +++ b/console/examples/console_worker.rs @@ -33,7 +33,7 @@ async fn main() -> Result<(), Box> { Server::builder() .add_service(worker.with_observability_service(localhost_resolver)) - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) .await?; diff --git a/console/src/app.rs b/console/src/app.rs index 04ec2b19..e12cb24b 100644 --- a/console/src/app.rs +++ b/console/src/app.rs @@ -156,7 +156,7 @@ impl App { let mut query_ids = HashSet::new(); for worker in &self.workers { for task in &worker.tasks { - if let Some(sk) = &task.stage_key { + if let Some(sk) = &task.task_key { query_ids.insert(&sk.query_id); } } diff --git a/console/src/ui/worker.rs b/console/src/ui/worker.rs index 9bee0d9e..d22c9c48 100644 --- a/console/src/ui/worker.rs +++ b/console/src/ui/worker.rs @@ -94,12 +94,12 @@ fn render_active_tasks(frame: &mut Frame, area: Rect, app: &mut App, idx: usize) let mut task_indices: Vec = (0..worker.tasks.len()).collect(); task_indices.sort_by(|&a, &b| { let dur_a = worker.tasks[a] - .stage_key + .task_key .as_ref() .map(|sk| worker.task_duration(&sk.query_id, sk.stage_id, sk.task_number)) .unwrap_or_default(); let dur_b = worker.tasks[b] - .stage_key + .task_key .as_ref() .map(|sk| worker.task_duration(&sk.query_id, sk.stage_id, sk.task_number)) .unwrap_or_default(); @@ -110,7 +110,7 @@ fn render_active_tasks(frame: &mut Frame, area: Rect, app: &mut App, idx: usize) .iter() .map(|&i| { let task = &worker.tasks[i]; - if let Some(sk) = &task.stage_key { + if let Some(sk) = &task.task_key { let query_hex = hex_prefix(&sk.query_id, 8); let duration = worker.task_duration(&sk.query_id, sk.stage_id, sk.task_number); let dur_str = format_duration(duration); diff --git a/console/src/worker.rs b/console/src/worker.rs index ffc4c14e..be39315e 100644 --- a/console/src/worker.rs +++ b/console/src/worker.rs @@ -167,7 +167,7 @@ impl WorkerConn { let new_task_keys: HashSet = new_tasks .iter() .filter_map(|t| { - t.stage_key + t.task_key .as_ref() .map(|sk| (sk.query_id.clone(), sk.stage_id, sk.task_number)) }) @@ -176,7 +176,7 @@ impl WorkerConn { // Detect completed tasks: tasks that were running but disappeared for old_task in &self.tasks { if old_task.status == TaskStatus::Running as i32 { - if let Some(sk) = &old_task.stage_key { + if let Some(sk) = &old_task.task_key { let key = (sk.query_id.clone(), sk.stage_id, sk.task_number); if !new_task_keys.contains(&key) { // Task disappeared — assume completed @@ -208,7 +208,7 @@ impl WorkerConn { // Track first_seen for new tasks let now = Instant::now(); for task in &new_tasks { - if let Some(sk) = &task.stage_key { + if let Some(sk) = &task.task_key { let key = (sk.query_id.clone(), sk.stage_id, sk.task_number); self.task_first_seen.entry(key).or_insert(now); } @@ -226,7 +226,7 @@ impl WorkerConn { let mut has_running = false; for task in &self.tasks { - if let Some(sk) = &task.stage_key { + if let Some(sk) = &task.task_key { current_query_ids.insert(sk.query_id.clone()); if task.status == TaskStatus::Running as i32 { has_running = true; @@ -338,7 +338,7 @@ impl WorkerConn { let ids: HashSet<_> = self .tasks .iter() - .filter_map(|t| t.stage_key.as_ref().map(|sk| &sk.query_id)) + .filter_map(|t| t.task_key.as_ref().map(|sk| &sk.query_id)) .collect(); ids.len() } diff --git a/docs/source/user-guide/channel-resolver.md b/docs/source/user-guide/channel-resolver.md index 749cabf0..9fcdac18 100644 --- a/docs/source/user-guide/channel-resolver.md +++ b/docs/source/user-guide/channel-resolver.md @@ -2,10 +2,10 @@ This trait is optional—a sensible default implementation exists that handles most use cases. -The `ChannelResolver` trait controls how Distributed DataFusion builds Arrow Flight clients backed by +The `ChannelResolver` trait controls how Distributed DataFusion builds Worker gRPC clients backed by [Tonic](https://github.com/hyperium/tonic) channels for worker URLs. -The default implementation connects to each URL, builds an Arrow Flight client, and caches it for reuse on +The default implementation connects to each URL, builds a Worker client, and caches it for reuse on subsequent requests to the same URL. ## Providing your own ChannelResolver @@ -15,7 +15,7 @@ For providing your own implementation, you'll need to take into account the foll - You will need to provide your own implementation in two places: - in the `SessionContext` that first initiates and plans your queries. - while instantiating the `Worker` with the `from_session_builder()` constructor. -- If building from scratch, ensure Arrow Flight clients are reused across requests rather than recreated each time. +- If building from scratch, ensure Worker clients are reused across requests rather than recreated each time. - You can extend `DefaultChannelResolver` as a foundation for custom implementations. This automatically handles gRPC channel reuse. @@ -25,11 +25,11 @@ struct CustomChannelResolver; #[async_trait] impl ChannelResolver for CustomChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - // Build a custom FlightServiceClient wrapped with tower + ) -> Result, DataFusionError> { + // Build a custom WorkerServiceClient wrapped with tower // layers or something similar. todo!() } @@ -37,7 +37,7 @@ impl ChannelResolver for CustomChannelResolver { async fn main() { // Build a single instance for your application's lifetime - // to enable Arrow Flight client reuse across queries. + // to enable Worker client reuse across queries. let channel_resolver = CustomChannelResolver; let state = SessionStateBuilder::new() @@ -56,10 +56,10 @@ async fn main() { } }); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; Ok(()) } -``` \ No newline at end of file +``` diff --git a/docs/source/user-guide/concepts.md b/docs/source/user-guide/concepts.md index 88cbd0aa..d66235ad 100644 --- a/docs/source/user-guide/concepts.md +++ b/docs/source/user-guide/concepts.md @@ -32,9 +32,9 @@ a fully formed physical plan and injects the appropriate nodes to execute the qu It builds the distributed plan from bottom to top, injecting network boundaries at appropriate locations based on the nodes present in the original plan. -## [Worker](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/src/flight_service/worker.rs) +## [Worker](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/src/worker/worker_service.rs) -Arrow Flight server implementation that integrates with the Tonic ecosystem and listens to serialized plans that get +gRPC server implementation that integrates with the Tonic ecosystem and listens to serialized plans that get executed over the wire. Users are expected to build these and spawn them in ports so that the network boundary nodes can reach them. diff --git a/docs/source/user-guide/getting-started.md b/docs/source/user-guide/getting-started.md index 33f5fa02..9b7dd089 100644 --- a/docs/source/user-guide/getting-started.md +++ b/docs/source/user-guide/getting-started.md @@ -11,7 +11,7 @@ Rather than imposing constraints on your infrastructure or query serving pattern allows you to plug in your own networking stack and spawn your own gRPC servers that act as workers in the cluster. This project heavily relies on the [Tonic](https://github.com/hyperium/tonic) ecosystem for the networking layer. -Users of this library are responsible for building their own Tonic server, adding the Arrow Flight distributed +Users of this library are responsible for building their own Tonic server, adding the distributed DataFusion service to it and spawning it on a port so that it can be reached by other workers in the cluster. A very basic example of this would be: @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { let worker = Worker::default(); Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -74,7 +74,7 @@ This will leave a DataFusion `SessionContext` ready for executing distributed qu Depending on your needs, your setup can get more complicated, for example: - You may want to resolve worker URLs dynamically using the Kubernetes API. -- You may want to wrap the Arrow Flight clients that connect workers with an observability layer. +- You may want to wrap the Worker clients that connect workers with an observability layer. - You may want to be able to execute your own custom ExecutionPlans in a distributed manner. - etc... diff --git a/docs/source/user-guide/worker.md b/docs/source/user-guide/worker.md index e57bc329..91fa3bf9 100644 --- a/docs/source/user-guide/worker.md +++ b/docs/source/user-guide/worker.md @@ -1,18 +1,18 @@ # Spawn a Worker -The `Worker` is a gRPC server implementing the Arrow Flight protocol for distributed query execution. Worker nodes +The `Worker` is a gRPC server that handles distributed query execution. Worker nodes run these endpoints to receive execution plans, execute them, and stream results back. ## Overview The `Worker` is the core worker component in Distributed DataFusion. It: -- Receives serialized execution plans via Arrow Flight's `do_get` method +- Receives serialized execution plans via gRPC - Deserializes plans using protobuf and user-provided codecs - Executes plans using the local DataFusion runtime -- Streams results back as Arrow record batches through the gRPC Arrow Flight interface +- Streams results back as Arrow record batches through the gRPC interface -## Launching the Arrow Flight server +## Launching the Worker server The default `Worker` implementation satisfies most basic use cases: @@ -23,7 +23,7 @@ async fn main() { let endpoint = Worker::default(); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -47,7 +47,7 @@ async fn main() { let endpoint = Worker::from_session_builder(build_sate); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -88,10 +88,10 @@ async fn main() { let endpoint = Worker::default(); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)) .await?; } ``` -The `into_flight_server()` method builds a `FlightServiceServer` ready to be added as a Tonic service. +The `into_worker_server()` method builds a `WorkerServiceServer` ready to be added as a Tonic service. diff --git a/examples/in_memory_cluster.rs b/examples/in_memory_cluster.rs index e5f02fdc..6ae97e28 100644 --- a/examples/in_memory_cluster.rs +++ b/examples/in_memory_cluster.rs @@ -1,12 +1,12 @@ use arrow::util::pretty::pretty_format_batches; -use arrow_flight::flight_service_client::FlightServiceClient; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_distributed::{ BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, Worker, - WorkerQueryContext, WorkerResolver, create_flight_client, display_plan_ascii, + WorkerQueryContext, WorkerResolver, WorkerServiceClient, create_worker_client, + display_plan_ascii, }; use futures::TryStreamExt; use hyper_util::rt::TokioIo; @@ -66,7 +66,7 @@ const DUMMY_URL: &str = "http://localhost:50051"; /// tokio duplex rather than a TCP connection. #[derive(Clone)] struct InMemoryChannelResolver { - channel: FlightServiceClient, + channel: WorkerServiceClient, } impl InMemoryChannelResolver { @@ -84,7 +84,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: create_flight_client(BoxCloneSyncChannel::new(channel)), + channel: create_worker_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -95,7 +95,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); @@ -106,10 +106,10 @@ impl InMemoryChannelResolver { #[async_trait] impl ChannelResolver for InMemoryChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, _: &url::Url, - ) -> Result, DataFusionError> { + ) -> Result, DataFusionError> { Ok(self.channel.clone()) } } diff --git a/examples/localhost_worker.rs b/examples/localhost_worker.rs index 58d329e0..2ed0d6d3 100644 --- a/examples/localhost_worker.rs +++ b/examples/localhost_worker.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), Box> { let args = Args::from_args(); Server::builder() - .add_service(Worker::default().into_flight_server()) + .add_service(Worker::default().into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) .await?; diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index f9a997a4..7c957926 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -191,7 +191,6 @@ pub trait DistributedExt: Sized { /// Example: /// /// ``` - /// # use arrow_flight::flight_service_client::FlightServiceClient; /// # use async_trait::async_trait; /// # use datafusion::common::DataFusionError; /// # use datafusion::execution::{SessionState, SessionStateBuilder}; @@ -228,29 +227,28 @@ pub trait DistributedExt: Sized { resolver: T, ); - /// This is what tells Distributed DataFusion how to build an Arrow Flight client out of a worker URL. + /// This is what tells Distributed DataFusion how to build a Worker gRPC client out of a worker URL. /// - /// There's a default implementation that caches the Arrow Flight client instances so that there's + /// There's a default implementation that caches the Worker client instances so that there's /// only one per URL, but users can decide to override that behavior in favor of their own solution. /// /// Example: /// /// ``` - /// # use arrow_flight::flight_service_client::FlightServiceClient; /// # use async_trait::async_trait; /// # use datafusion::common::DataFusionError; /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; /// # use url::Url; /// # use std::sync::Arc; - /// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, WorkerQueryContext}; + /// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, WorkerQueryContext, WorkerServiceClient}; /// /// struct CustomChannelResolver; /// /// #[async_trait] /// impl ChannelResolver for CustomChannelResolver { - /// async fn get_flight_client_for_url(&self, url: &Url) -> Result, DataFusionError> { - /// // Build a custom FlightServiceClient wrapped with tower layers or something similar. + /// async fn get_worker_client_for_url(&self, url: &Url) -> Result, DataFusionError> { + /// // Build a custom WorkerServiceClient wrapped with tower layers or something similar. /// todo!() /// } /// } diff --git a/src/execution_plans/benchmarks/shuffle_bench.rs b/src/execution_plans/benchmarks/shuffle_bench.rs index a7d5cf83..4e630533 100644 --- a/src/execution_plans/benchmarks/shuffle_bench.rs +++ b/src/execution_plans/benchmarks/shuffle_bench.rs @@ -1,16 +1,16 @@ use crate::common::task_ctx_with_extension; -use crate::flight_service::WorkerConnectionPool; -use crate::flight_service::test_utils::memory_worker::MemoryWorker; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; +use crate::worker::test_utils::memory_worker::MemoryWorker; use crate::{ BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedTaskContext, ExecutionTask, - NetworkShuffleExec, Stage, create_flight_client, + NetworkShuffleExec, Stage, create_worker_client, }; use arrow::datatypes::DataType::{ Boolean, Dictionary, Float64, Int32, Int64, List, Timestamp, UInt8, Utf8, }; use arrow::datatypes::{Field, Schema, TimeUnit}; use arrow::util::data_gen::create_random_batch; -use arrow_flight::flight_service_client::FlightServiceClient; use arrow_ipc::CompressionType; use datafusion::common::{Result, exec_err}; use datafusion::execution::SessionStateBuilder; @@ -33,14 +33,14 @@ pub struct InMemoryChannelsResolver { #[async_trait::async_trait] impl ChannelResolver for InMemoryChannelsResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result> { + ) -> Result> { let Some(port) = url.port() else { return exec_err!("Missing port in url {url}"); }; - Ok(create_flight_client(self.channels[port as usize].clone())) + Ok(create_worker_client(self.channels[port as usize].clone())) } } diff --git a/src/execution_plans/distributed.rs b/src/execution_plans/distributed.rs index 56752166..220d5394 100644 --- a/src/execution_plans/distributed.rs +++ b/src/execution_plans/distributed.rs @@ -1,17 +1,16 @@ use crate::common::require_one_child; use crate::config_extension_ext::get_config_extension_propagation_headers; use crate::distributed_planner::NetworkBoundaryExt; -use crate::flight_service::{INIT_ACTION_TYPE, InitAction}; use crate::networking::get_distributed_worker_resolver; use crate::passthrough_headers::get_passthrough_headers; use crate::protobuf::{DistributedCodec, tonic_status_to_datafusion_error}; use crate::stage::{ExecutionTask, Stage}; +use crate::worker::generated::worker::{ + CoordinatorToWorkerMsg, SetPlanRequest, TaskKey, coordinator_to_worker_msg::Inner, +}; use crate::{ - ChannelResolver, DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, StageKey, WorkerResolver, - get_distributed_channel_resolver, + DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, WorkerResolver, get_distributed_channel_resolver, }; -use arrow_flight::Action; -use bytes::Bytes; use datafusion::common::instant::Instant; use datafusion::common::runtime::JoinSet; use datafusion::common::tree_node::{Transformed, TreeNode}; @@ -124,10 +123,8 @@ impl DistributedExec { // This assumes the plan is the same for all the tasks within a stage. This is fine for // now, but it should be possible to send different versions of the subplan to the // different tasks. - let bytes: Bytes = - PhysicalPlanNode::try_from_physical_plan(Arc::clone(input_plan), &codec)? - .encode_to_vec() - .into(); + let bytes = PhysicalPlanNode::try_from_physical_plan(Arc::clone(input_plan), &codec)? + .encode_to_vec(); let tasks = stage .tasks @@ -138,10 +135,11 @@ impl DistributedExec { let execution_task = ExecutionTask { url: Some(url.clone()), }; - let action = InitAction { + let request = SetPlanRequest { plan_proto: bytes.clone(), - stage_key: Some(StageKey { - query_id: stage.query_id.as_bytes().to_vec().into(), + task_count: stage.tasks.len() as _, + task_key: Some(TaskKey { + query_id: stage.query_id.as_bytes().to_vec(), stage_id: stage.num as _, task_number: i as _, }), @@ -152,7 +150,7 @@ impl DistributedExec { // Spawns the task that feeds this subplan to this worker. There will be as // many as this spawned tasks as workers. join_set.spawn(async move { - send_plan_task(ctx, url, action).await?; + send_plan_task(ctx, url, request).await?; plan_send_latency.record(&start); Ok(()) }); @@ -258,24 +256,23 @@ impl ExecutionPlan for DistributedExec { } } -async fn send_plan_task(ctx: Arc, url: Url, init_action: InitAction) -> Result<()> { +async fn send_plan_task(ctx: Arc, url: Url, request: SetPlanRequest) -> Result<()> { let channel_resolver = get_distributed_channel_resolver(ctx.as_ref()); - let mut client = channel_resolver.get_flight_client_for_url(&url).await?; - - let body = init_action.encode_to_vec().into(); + let mut client = channel_resolver.get_worker_client_for_url(&url).await?; let mut headers = get_config_extension_propagation_headers(ctx.session_config())?; headers.extend(get_passthrough_headers(ctx.session_config())); + + let msg = CoordinatorToWorkerMsg { + inner: Some(Inner::SetPlanRequest(request)), + }; let request = Request::from_parts( MetadataMap::from_headers(headers), Extensions::default(), - Action { - r#type: INIT_ACTION_TYPE.to_string(), - body, - }, + futures::stream::once(async { msg }), ); - client.do_action(request).await.map_err(|e| { + client.coordinator_channel(request).await.map_err(|e| { tonic_status_to_datafusion_error(&e) .unwrap_or_else(|| exec_datafusion_err!("Error sending plan to worker {url}: {e}")) })?; diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs index 13511635..585a1fb8 100644 --- a/src/execution_plans/network_broadcast.rs +++ b/src/execution_plans/network_broadcast.rs @@ -1,10 +1,11 @@ use crate::DistributedTaskContext; use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; -use crate::flight_service::WorkerConnectionPool; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; +use crate::worker::generated::worker::TaskKey; +use crate::worker::generated::worker::flight_app_metadata; use dashmap::DashMap; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; @@ -123,7 +124,7 @@ pub struct NetworkBroadcastExec { pub(crate) properties: PlanProperties, pub(crate) input_stage: Stage, pub(crate) worker_connections: WorkerConnectionPool, - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkBroadcastExec { @@ -247,10 +248,10 @@ impl ExecutionPlan for NetworkBroadcastExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(off + partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { - if let Some(stage_key) = task_metrics.stage_key { - metrics_collection.insert(stage_key, task_metrics.metrics); + if let Some(task_key) = task_metrics.task_key { + metrics_collection.insert(task_key, task_metrics.metrics); }; } } diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 96fcccd9..c1b2b5b7 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -1,10 +1,11 @@ use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; use crate::execution_plans::common::scale_partitioning_props; -use crate::flight_service::WorkerConnectionPool; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; +use crate::worker::generated::worker::TaskKey; +use crate::worker::generated::worker::flight_app_metadata; use crate::{DistributedTaskContext, ExecutionTask}; use dashmap::DashMap; use datafusion::common::{exec_err, plan_err}; @@ -88,7 +89,7 @@ pub struct NetworkCoalesceExec { /// the stage it is reading from. This is because, by convention, the Worker sends metrics for /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this /// instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkCoalesceExec { @@ -250,10 +251,10 @@ impl ExecutionPlan for NetworkCoalesceExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(target_partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { - if let Some(stage_key) = task_metrics.stage_key { - metrics_collection.insert(stage_key, task_metrics.metrics); + if let Some(task_key) = task_metrics.task_key { + metrics_collection.insert(task_key, task_metrics.metrics); }; } } diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 3bf515a4..14687376 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,9 +1,10 @@ use crate::common::require_one_child; use crate::execution_plans::common::scale_partitioning; -use crate::flight_service::WorkerConnectionPool; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; +use crate::worker::generated::worker::TaskKey; +use crate::worker::generated::worker::flight_app_metadata; use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary}; use dashmap::DashMap; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; @@ -116,7 +117,7 @@ pub struct NetworkShuffleExec { /// the stage it is reading from. This is because, by convention, the Worker sends metrics for /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this /// instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkShuffleExec { @@ -242,10 +243,10 @@ impl ExecutionPlan for NetworkShuffleExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(off + partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { - if let Some(stage_key) = task_metrics.stage_key { - metrics_collection.insert(stage_key, task_metrics.metrics); + if let Some(task_key) = task_metrics.task_key { + metrics_collection.insert(task_key, task_metrics.metrics); }; } } diff --git a/src/lib.rs b/src/lib.rs index f14a32e3..e1771d3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,10 @@ mod common; mod config_extension_ext; mod distributed_ext; mod execution_plans; -mod flight_service; mod metrics; mod passthrough_headers; mod stage; +mod worker; mod distributed_planner; mod networking; @@ -26,10 +26,6 @@ pub use execution_plans::{ BroadcastExec, DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, PartitionIsolatorExec, }; -pub use flight_service::{ - DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, - Worker, WorkerQueryContext, WorkerSessionBuilder, -}; pub use metrics::{ AvgLatencyMetric, BytesCounterMetric, BytesMetricExt, DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, DistributedMetricsFormat, FirstLatencyMetric, LatencyMetricExt, MaxLatencyMetric, @@ -38,21 +34,25 @@ pub use metrics::{ }; pub use networking::{ BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, WorkerResolver, - create_flight_client, get_distributed_channel_resolver, get_distributed_worker_resolver, + create_worker_client, get_distributed_channel_resolver, get_distributed_worker_resolver, }; pub use stage::{ DistributedTaskContext, ExecutionTask, Stage, display_plan_ascii, display_plan_graphviz, explain_analyze, }; +pub use worker::generated::worker::TaskKey; +pub use worker::generated::worker::worker_service_client::WorkerServiceClient; +pub use worker::{ + DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, + Worker, WorkerQueryContext, WorkerSessionBuilder, +}; pub use observability::{ GetClusterWorkersRequest, GetClusterWorkersResponse, GetTaskProgressRequest, GetTaskProgressResponse, ObservabilityService, ObservabilityServiceClient, - ObservabilityServiceImpl, ObservabilityServiceServer, PingRequest, PingResponse, - StageKey as ObservabilityStageKey, TaskProgress, TaskStatus, WorkerMetrics, + ObservabilityServiceImpl, ObservabilityServiceServer, PingRequest, PingResponse, TaskProgress, + TaskStatus, WorkerMetrics, }; -pub use protobuf::StageKey; - #[cfg(any(feature = "integration", test))] pub use execution_plans::benchmarks::ShuffleBench; diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index dab28627..e87d49c0 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -13,264 +13,14 @@ use super::latency_metric::{ AvgLatencyMetric, FirstLatencyMetric, MaxLatencyMetric, MinLatencyMetric, P50LatencyMetric, P75LatencyMetric, P95LatencyMetric, P99LatencyMetric, }; +use crate::worker::generated::worker as pb; -/// A MetricProto is a protobuf mirror of [datafusion::physical_plan::metrics::Metric]. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricProto { - #[prost(message, repeated, tag = "1")] - pub labels: Vec, - #[prost(uint64, optional, tag = "2")] - pub partition: Option, - #[prost( - oneof = "MetricValueProto", - tags = "10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33" - )] - // This field is *always* set. It is marked optional due to protobuf "oneof" requirements. - pub metric: Option, -} - -/// A MetricsSetProto is a protobuf mirror of [datafusion::physical_plan::metrics::MetricsSet]. It represents -/// a collection of metrics for one `ExecutionPlan` node. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricsSetProto { - #[prost(message, repeated, tag = "1")] - pub metrics: Vec, -} - -impl MetricsSetProto { - pub fn new() -> Self { - Self { - metrics: Vec::new(), - } - } - - pub fn push(&mut self, metric: MetricProto) { - self.metrics.push(metric) - } -} - -/// MetricValueProto is a protobuf mirror of the [datafusion::physical_plan::metrics::MetricValue] enum. -#[derive(Clone, PartialEq, Eq, ::prost::Oneof)] -pub enum MetricValueProto { - #[prost(message, tag = "10")] - OutputRows(OutputRows), - #[prost(message, tag = "11")] - ElapsedCompute(ElapsedCompute), - #[prost(message, tag = "12")] - SpillCount(SpillCount), - #[prost(message, tag = "13")] - SpilledBytes(SpilledBytes), - #[prost(message, tag = "14")] - SpilledRows(SpilledRows), - #[prost(message, tag = "15")] - CurrentMemoryUsage(CurrentMemoryUsage), - #[prost(message, tag = "16")] - Count(NamedCount), - #[prost(message, tag = "17")] - Gauge(NamedGauge), - #[prost(message, tag = "18")] - Time(NamedTime), - #[prost(message, tag = "19")] - StartTimestamp(StartTimestamp), - #[prost(message, tag = "20")] - EndTimestamp(EndTimestamp), - #[prost(message, tag = "21")] - OutputBytes(OutputBytes), - #[prost(message, tag = "22")] - OutputBatches(OutputBatches), - #[prost(message, tag = "23")] - PruningMetrics(NamedPruningMetrics), - #[prost(message, tag = "24")] - Ratio(NamedRatio), - #[prost(message, tag = "25")] - CustomMinLatency(MinLatency), - #[prost(message, tag = "26")] - CustomMaxLatency(MaxLatency), - #[prost(message, tag = "27")] - CustomAvgLatency(AvgLatency), - #[prost(message, tag = "28")] - CustomFirstLatency(FirstLatency), - #[prost(message, tag = "29")] - CustomBytesCount(BytesCount), - #[prost(message, tag = "30")] - CustomP50Latency(PercentileLatency), - #[prost(message, tag = "31")] - CustomP75Latency(PercentileLatency), - #[prost(message, tag = "32")] - CustomP95Latency(PercentileLatency), - #[prost(message, tag = "33")] - CustomP99Latency(PercentileLatency), -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputRows { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct ElapsedCompute { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpillCount { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpilledBytes { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpilledRows { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct CurrentMemoryUsage { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedCount { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedGauge { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedTime { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct StartTimestamp { - #[prost(int64, optional, tag = "1")] - pub value: Option, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct EndTimestamp { - #[prost(int64, optional, tag = "1")] - pub value: Option, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputBytes { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputBatches { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedPruningMetrics { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub pruned: u64, - #[prost(uint64, tag = "3")] - pub matched: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedRatio { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub part: u64, - #[prost(uint64, tag = "3")] - pub total: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct BytesCount { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct MinLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct MaxLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct AvgLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub nanos_sum: u64, - #[prost(uint64, tag = "3")] - pub count: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct FirstLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct PercentileLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(bytes = "vec", tag = "4")] - pub sketch_bytes: Vec, -} - -/// A ProtoLabel mirrors [datafusion::physical_plan::metrics::Label]. -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct ProtoLabel { - #[prost(string, tag = "1")] - pub name: String, - #[prost(string, tag = "2")] - pub value: String, -} - -/// df_metrics_set_to_proto converts a [datafusion::physical_plan::metrics::MetricsSet] to a [MetricsSetProto]. +/// df_metrics_set_to_proto converts a [datafusion::physical_plan::metrics::MetricsSet] to a [pb::MetricsSet]. /// Custom metrics are filtered out, but any other errors are returned. /// TODO(#140): Support custom metrics. pub fn df_metrics_set_to_proto( metrics_set: &MetricsSet, -) -> Result { +) -> Result { let mut metrics = Vec::new(); for metric in metrics_set.iter() { @@ -290,12 +40,12 @@ pub fn df_metrics_set_to_proto( } } - Ok(MetricsSetProto { metrics }) + Ok(pb::MetricsSet { metrics }) } -/// metrics_set_proto_to_df converts a [MetricsSetProto] to a [datafusion::physical_plan::metrics::MetricsSet]. +/// metrics_set_proto_to_df converts a [pb::MetricsSet] to a [datafusion::physical_plan::metrics::MetricsSet]. pub fn metrics_set_proto_to_df( - metrics_set_proto: &MetricsSetProto, + metrics_set_proto: &pb::MetricsSet, ) -> Result { let mut metrics_set = MetricsSet::new(); metrics_set_proto.metrics.iter().try_for_each(|metric| { @@ -313,75 +63,75 @@ const CUSTOM_METRICS_NOT_SUPPORTED: &str = /// New DataFusion metrics that are not yet supported in proto conversion. const UNSUPPORTED_METRICS: &str = "metric type not supported in proto conversion"; -/// df_metric_to_proto converts a `datafusion::physical_plan::metrics::Metric` to a `MetricProto`. It does not consume the Arc. -pub fn df_metric_to_proto(metric: Arc) -> Result { +/// df_metric_to_proto converts a `datafusion::physical_plan::metrics::Metric` to a `pb::Metric`. It does not consume the Arc. +pub fn df_metric_to_proto(metric: Arc) -> Result { let partition = metric.partition().map(|p| p as u64); let labels = metric .labels() .iter() - .map(|label| ProtoLabel { + .map(|label| pb::Label { name: label.name().to_string(), value: label.value().to_string(), }) .collect(); match metric.value() { - MetricValue::OutputRows(rows) => Ok(MetricProto { - metric: Some(MetricValueProto::OutputRows(OutputRows { value: rows.value() as u64 })), + MetricValue::OutputRows(rows) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputRows(pb::OutputRows { value: rows.value() as u64 })), partition, labels, }), - MetricValue::ElapsedCompute(time) => Ok(MetricProto { - metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { value: time.value() as u64 })), + MetricValue::ElapsedCompute(time) => Ok(pb::Metric { + value: Some(pb::metric::Value::ElapsedCompute(pb::ElapsedCompute { value: time.value() as u64 })), partition, labels, }), - MetricValue::SpillCount(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpillCount(SpillCount { value: count.value() as u64 })), + MetricValue::SpillCount(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpillCount(pb::SpillCount { value: count.value() as u64 })), partition, labels, }), - MetricValue::SpilledBytes(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpilledBytes(SpilledBytes { value: count.value() as u64 })), + MetricValue::SpilledBytes(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpilledBytes(pb::SpilledBytes { value: count.value() as u64 })), partition, labels, }), - MetricValue::SpilledRows(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpilledRows(SpilledRows { value: count.value() as u64 })), + MetricValue::SpilledRows(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpilledRows(pb::SpilledRows { value: count.value() as u64 })), partition, labels, }), - MetricValue::CurrentMemoryUsage(gauge) => Ok(MetricProto { - metric: Some(MetricValueProto::CurrentMemoryUsage(CurrentMemoryUsage { value: gauge.value() as u64 })), + MetricValue::CurrentMemoryUsage(gauge) => Ok(pb::Metric { + value: Some(pb::metric::Value::CurrentMemoryUsage(pb::CurrentMemoryUsage { value: gauge.value() as u64 })), partition, labels, }), - MetricValue::Count { name, count } => Ok(MetricProto { - metric: Some(MetricValueProto::Count(NamedCount { + MetricValue::Count { name, count } => Ok(pb::Metric { + value: Some(pb::metric::Value::Count(pb::NamedCount { name: name.to_string(), value: count.value() as u64 })), partition, labels, }), - MetricValue::Gauge { name, gauge } => Ok(MetricProto { - metric: Some(MetricValueProto::Gauge(NamedGauge { + MetricValue::Gauge { name, gauge } => Ok(pb::Metric { + value: Some(pb::metric::Value::Gauge(pb::NamedGauge { name: name.to_string(), value: gauge.value() as u64 })), partition, labels, }), - MetricValue::Time { name, time } => Ok(MetricProto { - metric: Some(MetricValueProto::Time(NamedTime { + MetricValue::Time { name, time } => Ok(pb::Metric { + value: Some(pb::metric::Value::Time(pb::NamedTime { name: name.to_string(), value: time.value() as u64 })), partition, labels, }), - MetricValue::StartTimestamp(timestamp) => Ok(MetricProto { - metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + MetricValue::StartTimestamp(timestamp) => Ok(pb::Metric { + value: Some(pb::metric::Value::StartTimestamp(pb::StartTimestamp { value: match timestamp.value() { Some(dt) => Some( dt.timestamp_nanos_opt().ok_or(DataFusionError::Internal( @@ -393,8 +143,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + MetricValue::EndTimestamp(timestamp) => Ok(pb::Metric { + value: Some(pb::metric::Value::EndTimestamp(pb::EndTimestamp { value: match timestamp.value() { Some(dt) => Some( dt.timestamp_nanos_opt().ok_or(DataFusionError::Internal( @@ -408,8 +158,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result { if let Some(min) = value.as_any().downcast_ref::() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomMinLatency(MinLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomMinLatency(pb::MinLatency { name: name.to_string(), value: min.value() as u64, })), @@ -417,8 +167,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomMaxLatency(MaxLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomMaxLatency(pb::MaxLatency { name: name.to_string(), value: max.value() as u64, })), @@ -426,8 +176,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomAvgLatency(AvgLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomAvgLatency(pb::AvgLatency { name: name.to_string(), nanos_sum: avg.nanos_sum() as u64, count: avg.count() as u64, @@ -436,8 +186,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomFirstLatency(FirstLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomFirstLatency(pb::FirstLatency { name: name.to_string(), value: first.value() as u64, })), @@ -445,8 +195,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomBytesCount(BytesCount { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomBytesCount(pb::BytesCount { name: name.to_string(), value: bytes.value() as u64, })), @@ -454,8 +204,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP50Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP50Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p50.serialize_sketch()?, })), @@ -463,8 +213,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP75Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP75Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p75.serialize_sketch()?, })), @@ -472,8 +222,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP95Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP95Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p95.serialize_sketch()?, })), @@ -481,8 +231,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP99Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP99Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p99.serialize_sketch()?, })), @@ -493,18 +243,18 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::OutputBytes(OutputBytes { value: count.value() as u64 })), + MetricValue::OutputBytes(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputBytes(pb::OutputBytes { value: count.value() as u64 })), partition, labels, }), - MetricValue::OutputBatches(count) => Ok(MetricProto { - metric: Some(MetricValueProto::OutputBatches(OutputBatches { value: count.value() as u64 })), + MetricValue::OutputBatches(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputBatches(pb::OutputBatches { value: count.value() as u64 })), partition, labels, }), - MetricValue::PruningMetrics { name, pruning_metrics } => Ok(MetricProto { - metric: Some(MetricValueProto::PruningMetrics(NamedPruningMetrics { + MetricValue::PruningMetrics { name, pruning_metrics } => Ok(pb::Metric { + value: Some(pb::metric::Value::PruningMetrics(pb::NamedPruningMetrics { name: name.to_string(), pruned: pruning_metrics.pruned() as u64, matched: pruning_metrics.matched() as u64, @@ -512,8 +262,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::Ratio(NamedRatio { + MetricValue::Ratio { name, ratio_metrics } => Ok(pb::Metric { + value: Some(pb::metric::Value::Ratio(pb::NamedRatio { name: name.to_string(), part: ratio_metrics.part() as u64, total: ratio_metrics.total() as u64, @@ -524,8 +274,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Result, DataFusionError> { +/// metric_proto_to_df converts a `pb::Metric` to a `datafusion::physical_plan::metrics::Metric`. It consumes the pb::Metric. +pub fn metric_proto_to_df(metric: pb::Metric) -> Result, DataFusionError> { let partition = metric.partition.map(|p| p as usize); let labels = metric .labels @@ -533,8 +283,8 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion .map(|proto_label| Label::new(proto_label.name, proto_label.value)) .collect(); - match metric.metric { - Some(MetricValueProto::OutputRows(rows)) => { + match metric.value { + Some(pb::metric::Value::OutputRows(rows)) => { let count = Count::new(); count.add(rows.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -543,7 +293,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::ElapsedCompute(elapsed)) => { + Some(pb::metric::Value::ElapsedCompute(elapsed)) => { let time = Time::new(); time.add_duration(std::time::Duration::from_nanos(elapsed.value)); Ok(Arc::new(Metric::new_with_labels( @@ -552,7 +302,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpillCount(spill_count)) => { + Some(pb::metric::Value::SpillCount(spill_count)) => { let count = Count::new(); count.add(spill_count.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -561,7 +311,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpilledBytes(spilled_bytes)) => { + Some(pb::metric::Value::SpilledBytes(spilled_bytes)) => { let count = Count::new(); count.add(spilled_bytes.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -570,7 +320,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpilledRows(spilled_rows)) => { + Some(pb::metric::Value::SpilledRows(spilled_rows)) => { let count = Count::new(); count.add(spilled_rows.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -579,7 +329,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CurrentMemoryUsage(memory)) => { + Some(pb::metric::Value::CurrentMemoryUsage(memory)) => { let gauge = Gauge::new(); gauge.set(memory.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -588,7 +338,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Count(named_count)) => { + Some(pb::metric::Value::Count(named_count)) => { let count = Count::new(); count.add(named_count.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -600,7 +350,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Gauge(named_gauge)) => { + Some(pb::metric::Value::Gauge(named_gauge)) => { let gauge = Gauge::new(); gauge.set(named_gauge.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -612,7 +362,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Time(named_time)) => { + Some(pb::metric::Value::Time(named_time)) => { let time = Time::new(); time.add_duration(std::time::Duration::from_nanos(named_time.value)); Ok(Arc::new(Metric::new_with_labels( @@ -624,7 +374,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::StartTimestamp(start_ts)) => { + Some(pb::metric::Value::StartTimestamp(start_ts)) => { let timestamp = Timestamp::new(); if let Some(value) = start_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); @@ -635,7 +385,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::EndTimestamp(end_ts)) => { + Some(pb::metric::Value::EndTimestamp(end_ts)) => { let timestamp = Timestamp::new(); if let Some(value) = end_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); @@ -646,7 +396,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::OutputBytes(output_bytes)) => { + Some(pb::metric::Value::OutputBytes(output_bytes)) => { let count = Count::new(); count.add(output_bytes.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -655,7 +405,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::OutputBatches(output_batches)) => { + Some(pb::metric::Value::OutputBatches(output_batches)) => { let count = Count::new(); count.add(output_batches.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -664,7 +414,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::PruningMetrics(named_pruning)) => { + Some(pb::metric::Value::PruningMetrics(named_pruning)) => { let pruning_metrics = DfPruningMetrics::new(); pruning_metrics.add_pruned(named_pruning.pruned as usize); pruning_metrics.add_matched(named_pruning.matched as usize); @@ -677,7 +427,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Ratio(named_ratio)) => { + Some(pb::metric::Value::Ratio(named_ratio)) => { let ratio_metrics = RatioMetrics::new(); ratio_metrics.set_part(named_ratio.part as usize); ratio_metrics.set_total(named_ratio.total as usize); @@ -690,7 +440,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomMinLatency(min_latency)) => { + Some(pb::metric::Value::CustomMinLatency(min_latency)) => { let value = MinLatencyMetric::from_nanos(min_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -701,7 +451,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomMaxLatency(max_latency)) => { + Some(pb::metric::Value::CustomMaxLatency(max_latency)) => { let value = MaxLatencyMetric::from_nanos(max_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -712,7 +462,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomAvgLatency(avg_latency)) => { + Some(pb::metric::Value::CustomAvgLatency(avg_latency)) => { let value = AvgLatencyMetric::from_raw( avg_latency.nanos_sum as usize, avg_latency.count as usize, @@ -726,7 +476,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomFirstLatency(first_latency)) => { + Some(pb::metric::Value::CustomFirstLatency(first_latency)) => { let value = FirstLatencyMetric::from_nanos(first_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -737,7 +487,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomBytesCount(bytes_count)) => { + Some(pb::metric::Value::CustomBytesCount(bytes_count)) => { let value = BytesCounterMetric::from_value(bytes_count.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -748,7 +498,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP50Latency(p)) => { + Some(pb::metric::Value::CustomP50Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -762,7 +512,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP75Latency(p)) => { + Some(pb::metric::Value::CustomP75Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -776,7 +526,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP95Latency(p)) => { + Some(pb::metric::Value::CustomP95Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -790,7 +540,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP99Latency(p)) => { + Some(pb::metric::Value::CustomP99Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -1258,8 +1008,8 @@ mod tests { let remaining_metric = &metrics_set_proto.metrics[0]; assert!(matches!( - remaining_metric.metric, - Some(MetricValueProto::OutputRows(_)) + remaining_metric.value, + Some(pb::metric::Value::OutputRows(_)) )); } diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index a566f809..d524d158 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -1,8 +1,8 @@ use crate::NetworkBroadcastExec; use crate::execution_plans::NetworkCoalesceExec; use crate::execution_plans::NetworkShuffleExec; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::StageKey; +use crate::worker::generated::worker as pb; +use crate::worker::generated::worker::TaskKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; use datafusion::common::tree_node::TreeNode; @@ -23,7 +23,7 @@ pub struct TaskMetricsCollector { task_metrics: Vec, /// input_task_metrics contains metrics for tasks from child [StageExec]s if they were /// collected. - input_task_metrics: HashMap>, + input_task_metrics: HashMap>, } /// MetricsCollectorResult is the result of collecting metrics from a task. @@ -31,7 +31,7 @@ pub struct MetricsCollectorResult { // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. pub task_metrics: Vec, // input_task_metrics contains metrics for child tasks if they were collected. - pub input_task_metrics: HashMap>, + pub input_task_metrics: HashMap>, } impl TreeNodeRewriter for TaskMetricsCollector { @@ -62,21 +62,21 @@ impl TreeNodeRewriter for TaskMetricsCollector { if let Some(metrics_collection) = metrics_collection { for mut entry in metrics_collection.iter_mut() { - let stage_key = entry.key().clone(); + let task_key = entry.key().clone(); let task_metrics = std::mem::take(entry.value_mut()); // Avoid copy. - match self.input_task_metrics.get(&stage_key) { - // There should never be two NetworkShuffleExec with metrics for the same stage_key. + match self.input_task_metrics.get(&task_key) { + // There should never be two NetworkShuffleExec with metrics for the same task_key. // By convention, the NetworkShuffleExec which runs the last partition in a task should be // sent metrics (the NetworkShuffleExec tracks it for us). Some(_) => { return internal_err!( "duplicate task metrics for key {:?} during metrics collection", - stage_key + task_key ); } None => { self.input_task_metrics - .insert(stage_key.clone(), task_metrics); + .insert(task_key.clone(), task_metrics); } } } @@ -127,7 +127,7 @@ mod tests { }; use crate::test_utils::parquet::register_parquet_tables; use crate::test_utils::plans::{ - count_plan_nodes_up_to_network_boundary, get_stages_and_stage_keys, + count_plan_nodes_up_to_network_boundary, get_stages_and_task_keys, }; use crate::test_utils::session_context::register_temp_parquet_table; use crate::{DistributedExt, DistributedPhysicalOptimizerRule}; @@ -259,10 +259,10 @@ mod tests { .expect("expected DistributedExec"); // Assert to ensure the distributed test case is sufficiently complex. - let (stages, expected_stage_keys) = get_stages_and_stage_keys(dist_exec); + let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec); assert!( - expected_stage_keys.len() > 1, - "expected more than 1 stage key in test. the plan was not distributed):\n{}", + expected_task_keys.len() > 1, + "expected more than 1 task key in test. the plan was not distributed):\n{}", DisplayableExecutionPlan::new(plan.as_ref()).indent(true) ); @@ -272,20 +272,20 @@ mod tests { let result = collector.collect(dist_exec.plan.clone()).unwrap(); // Ensure that there's metrics for each node for each task for each stage. - for expected_stage_key in expected_stage_keys { + for expected_task_key in expected_task_keys { // Get the collected metrics for this task. - let actual_metrics = result.input_task_metrics.get(&expected_stage_key).unwrap(); + let actual_metrics = result.input_task_metrics.get(&expected_task_key).unwrap(); // Verify that metrics were collected for all nodes. Some nodes may legitimately have // empty metrics (e.g., custom execution plans without metrics), which is fine - we // just verify that a metrics set exists for each node. The count assertion above // ensures all nodes are included in the metrics collection. - let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap(); let stage_plan = stage.plan.as_ref().unwrap(); assert_eq!( actual_metrics.len(), count_plan_nodes_up_to_network_boundary(stage_plan), - "Mismatch between collected metrics and actual nodes for {expected_stage_key:?}" + "Mismatch between collected metrics and actual nodes for {expected_task_key:?}" ); } } @@ -349,7 +349,7 @@ mod tests { /// Issue: https://github.com/datafusion-contrib/datafusion-distributed/issues/187 /// /// Metrics are piggybacked on the last FlightData message of the last partition stream - /// (see `do_get.rs`). If a LIMIT causes the client-side stream to be dropped before the + /// (see `impl_execute_task.rs`). If a LIMIT causes the client-side stream to be dropped before the /// worker finishes, the last message (carrying metrics) is never received. /// /// This uses the `flights_1m` dataset (1M rows) so the worker is still producing data @@ -375,10 +375,10 @@ mod tests { .downcast_ref::() .expect("expected DistributedExec"); - let (stages, expected_stage_keys) = get_stages_and_stage_keys(dist_exec); + let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec); assert!( - expected_stage_keys.len() > 1, - "expected more than 1 stage key. Plan was not distributed:\n{}", + expected_task_keys.len() > 1, + "expected more than 1 task key. Plan was not distributed:\n{}", DisplayableExecutionPlan::new(plan.as_ref()).indent(true) ); @@ -387,24 +387,24 @@ mod tests { let collector = TaskMetricsCollector::new(); let result = collector.collect(dist_exec.plan.clone()).unwrap(); - for expected_stage_key in expected_stage_keys { + for expected_task_key in expected_task_keys { let actual_metrics = result .input_task_metrics - .get(&expected_stage_key) + .get(&expected_task_key) .unwrap_or_else(|| { panic!( - "Missing metrics for stage key {expected_stage_key:?}. \ + "Missing metrics for task key {expected_task_key:?}. \ The LIMIT caused the stream to be dropped before the worker \ sent the last FlightData message with metrics." ) }); - let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap(); let stage_plan = stage.plan.as_ref().unwrap(); assert_eq!( actual_metrics.len(), count_plan_nodes_up_to_network_boundary(stage_plan), - "Mismatch between collected metrics and actual nodes for {expected_stage_key:?}" + "Mismatch between collected metrics and actual nodes for {expected_task_key:?}" ); } } @@ -438,14 +438,14 @@ mod tests { .downcast_ref::() .expect("expected DistributedExec"); - let (stages, expected_stage_keys) = get_stages_and_stage_keys(dist_exec); + let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec); let collector = TaskMetricsCollector::new(); let result = collector.collect(dist_exec.plan.clone()).unwrap(); // Verify all nodes (including PartitionIsolatorExec) are preserved in metrics collection - for expected_stage_key in expected_stage_keys { - let actual_metrics = result.input_task_metrics.get(&expected_stage_key).unwrap(); - let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + for expected_task_key in expected_task_keys { + let actual_metrics = result.input_task_metrics.get(&expected_task_key).unwrap(); + let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap(); let stage_plan = stage.plan.as_ref().unwrap(); // Verify metrics count matches - this ensures all nodes are included in metrics collection @@ -453,7 +453,7 @@ mod tests { assert_eq!( actual_metrics.len(), count_plan_nodes_up_to_network_boundary(stage_plan), - "Metrics count must match plan nodes for stage {expected_stage_key:?}" + "Metrics count must match plan nodes for stage {expected_task_key:?}" ); } } diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index 9c5e8b28..5e71a9fe 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -4,10 +4,10 @@ use crate::execution_plans::MetricsWrapperExec; use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL; use crate::metrics::MetricsCollectorResult; use crate::metrics::TaskMetricsCollector; -use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; -use crate::protobuf::StageKey; +use crate::metrics::proto::metrics_set_proto_to_df; use crate::stage::Stage; -use bytes::Bytes; +use crate::worker::generated::worker as pb; +use crate::worker::generated::worker::TaskKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; use datafusion::common::tree_node::TreeNode; @@ -203,7 +203,7 @@ pub fn rewrite_local_plan_with_metrics( /// Note: Metrics may be aggregated by name (ex. output_rows) automatically by various datafusion utils. pub fn stage_metrics_rewriter( stage: &Stage, - metrics_collection: Arc>>, + metrics_collection: Arc>>, format: DistributedMetricsFormat, ) -> Result> { let mut node_idx = 0; @@ -217,8 +217,12 @@ pub fn stage_metrics_rewriter( let mut stage_metrics = MetricsSet::new(); for task_id in 0..stage.tasks.len() { - let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, task_id as u64); - match metrics_collection.get(&stage_key) { + let task_key = TaskKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }; + match metrics_collection.get(&task_key) { Some(task_metrics) => { if node_idx >= task_metrics.len() { return internal_err!( @@ -268,22 +272,19 @@ pub fn stage_metrics_rewriter( mod tests { use crate::Stage; use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL; - use crate::metrics::proto::{ - MetricsSetProto, df_metrics_set_to_proto, metrics_set_proto_to_df, - }; + use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df}; use crate::metrics::task_metrics_rewriter::{ annotate_metrics_set_with_task_id, stage_metrics_rewriter, }; use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; - use crate::protobuf::StageKey; use crate::test_utils::in_memory_channel_resolver::{ InMemoryChannelResolver, InMemoryWorkerResolver, }; use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary; use crate::test_utils::session_context::register_temp_parquet_table; + use crate::worker::generated::worker as pb; use crate::{DistributedExec, DistributedPhysicalOptimizerRule}; - use bytes::Bytes; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; @@ -298,6 +299,7 @@ mod tests { use crate::DistributedExt; use crate::metrics::task_metrics_rewriter::MetricsWrapperExec; + use crate::worker::generated::worker::TaskKey; use datafusion::physical_plan::empty::EmptyExec; use datafusion::prelude::SessionConfig; use datafusion::prelude::SessionContext; @@ -434,11 +436,11 @@ mod tests { // Generate metrics for each task and store them in the map. let mut metrics_collection = HashMap::new(); for task_id in 0..stage.tasks.len() { - let stage_key = StageKey::new( - Bytes::from(stage.query_id.as_bytes().to_vec()), - stage.num as u64, - task_id as u64, - ); + let task_key = TaskKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }; let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan)) .map(|node_id| { make_test_metrics_set_proto_from_seed( @@ -446,9 +448,9 @@ mod tests { num_metrics_per_task_per_node, ) }) - .collect::>(); + .collect::>(); - metrics_collection.insert(stage_key, metrics); + metrics_collection.insert(task_key, metrics); } let metrics_collection = Arc::new(metrics_collection); @@ -475,11 +477,11 @@ mod tests { .enumerate() { let expected_task_node_metrics = metrics_collection - .get(&StageKey::new( - Bytes::from(stage.query_id.as_bytes().to_vec()), - stage.num as u64, - task_id as u64, - )) + .get(&TaskKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }) .unwrap()[node_id] .clone(); diff --git a/src/networking/channel_resolver.rs b/src/networking/channel_resolver.rs index 5e44546b..50b5fd93 100644 --- a/src/networking/channel_resolver.rs +++ b/src/networking/channel_resolver.rs @@ -1,6 +1,6 @@ use crate::DistributedConfig; use crate::config_extension_ext::set_distributed_option_extension; -use arrow_flight::flight_service_client::FlightServiceClient; +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; use async_trait::async_trait; use datafusion::common::{DataFusionError, config_datafusion_err, exec_datafusion_err}; use datafusion::execution::TaskContext; @@ -15,36 +15,35 @@ use tonic::transport::Channel; use tower::ServiceExt; use url::Url; -/// Allows users to customize the way Arrow Flight clients are created. A common use case is to +/// Allows users to customize the way Worker clients are created. A common use case is to /// wrap the client with tower layers or schedule it in an IO-specific tokio runtime. /// /// There is a default implementation of this trait that should be enough for the most common /// use-cases. /// /// # Implementation Notes -/// - This is called per Arrow Flight request, so implementors of this trait should make sure that -/// clients are reused across method calls instead of building a new Arrow Flight client -/// every time. +/// - This is called per gRPC request, so implementors of this trait should make sure that +/// clients are reused across method calls instead of building a new Worker client every time. /// -/// - When implementing `get_flight_client_for_url`, it is recommended to use the -/// [`create_flight_client`] helper function to ensure clients are configured with +/// - When implementing `get_worker_client_for_url`, it is recommended to use the +/// [`create_worker_client`] helper function to ensure clients are configured with /// appropriate message size limits for internal communication. This helps avoid message /// size errors when transferring large datasets. #[async_trait] pub trait ChannelResolver { - /// For a given URL, get an Arrow Flight client for communicating to it. + /// For a given URL, get a Worker gRPC client for communicating to it. /// - /// *WARNING*: This method is called for every Arrow Flight gRPC request, so to not create + /// *WARNING*: This method is called for every gRPC request, so to not create /// one client connection for each request, users are required to reuse generated clients. /// It's recommended to rely on [DefaultChannelResolver] either by delegating method calls /// to it or by copying the implementation. /// - /// Consider using [`create_flight_client`] to create the client with appropriate + /// Consider using [`create_worker_client`] to create the client with appropriate /// default message size limits. - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError>; + ) -> Result, DataFusionError>; } pub(crate) fn set_distributed_channel_resolver( @@ -109,7 +108,7 @@ pub(crate) struct ChannelResolverExtension(Option>, @@ -173,56 +172,55 @@ impl DefaultChannelResolver { #[async_trait] impl ChannelResolver for DefaultChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - self.get_channel(url).await.map(create_flight_client) + ) -> Result, DataFusionError> { + self.get_channel(url).await.map(create_worker_client) } } #[async_trait] impl ChannelResolver for Arc { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - self.as_ref().get_flight_client_for_url(url).await + ) -> Result, DataFusionError> { + self.as_ref().get_worker_client_for_url(url).await } } -/// Creates a [`FlightServiceClient`] with high default message size limits. +/// Creates a [`WorkerServiceClient`] with high default message size limits. /// -/// This is a convenience function that wraps [`FlightServiceClient::new`] and configures +/// This is a convenience function that wraps [`WorkerServiceClient::new`] and configures /// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)` /// to avoid message size limitations for internal communication. /// /// Users implementing custom [`ChannelResolver`]s should use this function in their -/// `get_flight_client_for_url` implementations to ensure consistent behavior with built-in +/// `get_worker_client_for_url` implementations to ensure consistent behavior with built-in /// implementations. /// /// # Example /// /// ```rust,ignore -/// use datafusion_distributed::{create_flight_client, BoxCloneSyncChannel, ChannelResolver}; -/// use arrow_flight::flight_service_client::FlightServiceClient; -/// use tonic::transport::Channel; +/// use datafusion_distributed::{create_worker_client, BoxCloneSyncChannel, ChannelResolver}; +/// /// use tonic::transport::Channel; /// /// #[async_trait] /// impl ChannelResolver for MyResolver { -/// async fn get_flight_client_for_url( +/// async fn get_worker_client_for_url( /// &self, /// url: &Url, -/// ) -> Result, DataFusionError> { +/// ) -> Result, DataFusionError> { /// let channel = Channel::from_shared(url.to_string())?.connect().await?; -/// Ok(create_flight_client(BoxCloneSyncChannel::new(channel))) +/// Ok(create_worker_client(BoxCloneSyncChannel::new(channel))) /// } /// } /// ``` -pub fn create_flight_client( +pub fn create_worker_client( channel: BoxCloneSyncChannel, -) -> FlightServiceClient { - FlightServiceClient::new(channel) +) -> WorkerServiceClient { + WorkerServiceClient::new(channel) .max_decoding_message_size(usize::MAX) .max_encoding_message_size(usize::MAX) } @@ -285,7 +283,7 @@ mod tests { let worker = Worker::default(); let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); if let Err(err) = Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(incoming) .await { diff --git a/src/networking/mod.rs b/src/networking/mod.rs index 3da47b0d..6bab9ae0 100644 --- a/src/networking/mod.rs +++ b/src/networking/mod.rs @@ -2,7 +2,7 @@ mod channel_resolver; mod worker_resolver; pub use channel_resolver::{ - BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, create_flight_client, + BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, create_worker_client, get_distributed_channel_resolver, }; pub(crate) use channel_resolver::{ChannelResolverExtension, set_distributed_channel_resolver}; diff --git a/src/observability/gen/src/main.rs b/src/observability/gen/src/main.rs index 93aef11c..d414fc53 100644 --- a/src/observability/gen/src/main.rs +++ b/src/observability/gen/src/main.rs @@ -19,6 +19,10 @@ fn main() -> Result<(), Box> { .build_server(true) .build_client(true) .out_dir(&out_dir) + .extern_path( + ".observability.TaskKey", + "crate::worker::generated::worker::TaskKey", + ) .compile_protos(&[proto_file], &[proto_dir])?; println!("Successfully generated observability proto code"); diff --git a/src/observability/generated/observability.rs b/src/observability/generated/observability.rs index cc7fe513..c8da3372 100644 --- a/src/observability/generated/observability.rs +++ b/src/observability/generated/observability.rs @@ -8,20 +8,11 @@ pub struct PingResponse { } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct GetTaskProgressRequest {} -#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] -pub struct StageKey { - #[prost(bytes = "vec", tag = "1")] - pub query_id: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "2")] - pub stage_id: u64, - #[prost(uint64, tag = "3")] - pub task_number: u64, -} /// Progress information for a single task #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct TaskProgress { #[prost(message, optional, tag = "1")] - pub stage_key: ::core::option::Option, + pub task_key: ::core::option::Option, #[prost(uint64, tag = "2")] pub total_partitions: u64, #[prost(uint64, tag = "3")] diff --git a/src/observability/mod.rs b/src/observability/mod.rs index c48906ae..555c6577 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -8,7 +8,6 @@ pub use generated::observability::observability_service_server::{ pub use generated::observability::{ GetClusterWorkersRequest, GetClusterWorkersResponse, GetTaskProgressRequest, - GetTaskProgressResponse, PingRequest, PingResponse, StageKey, TaskProgress, TaskStatus, - WorkerMetrics, + GetTaskProgressResponse, PingRequest, PingResponse, TaskProgress, TaskStatus, WorkerMetrics, }; pub use service::ObservabilityServiceImpl; diff --git a/src/observability/proto/observability.proto b/src/observability/proto/observability.proto index 6c0c403d..07925d54 100644 --- a/src/observability/proto/observability.proto +++ b/src/observability/proto/observability.proto @@ -15,7 +15,7 @@ message PingResponse { message GetTaskProgressRequest {} -message StageKey { +message TaskKey { bytes query_id = 1; uint64 stage_id = 2; uint64 task_number = 3; @@ -23,7 +23,7 @@ message StageKey { // Progress information for a single task message TaskProgress { - StageKey stage_key = 1; + TaskKey task_key = 1; uint64 total_partitions = 2; uint64 completed_partitions = 3; TaskStatus status = 4; diff --git a/src/observability/service.rs b/src/observability/service.rs index 00672677..14caaa17 100644 --- a/src/observability/service.rs +++ b/src/observability/service.rs @@ -1,6 +1,10 @@ -use crate::flight_service::{SingleWriteMultiRead, TaskData}; -use crate::networking::WorkerResolver; -use crate::protobuf::StageKey; +use super::{ + GetTaskProgressResponse, ObservabilityService, TaskProgress, TaskStatus, WorkerMetrics, + generated::observability::{GetTaskProgressRequest, PingRequest, PingResponse}, +}; +use crate::worker::generated::worker::TaskKey; +use crate::worker::{SingleWriteMultiRead, TaskData}; +use crate::{GetClusterWorkersRequest, GetClusterWorkersResponse, WorkerResolver}; use datafusion::error::DataFusionError; use datafusion::physical_plan::ExecutionPlan; use moka::future::Cache; @@ -13,18 +17,10 @@ use sysinfo::{Pid, ProcessRefreshKind}; use tokio::sync::watch; use tonic::{Request, Response, Status}; -use super::{ - GetClusterWorkersResponse, GetTaskProgressResponse, ObservabilityService, TaskProgress, - TaskStatus, WorkerMetrics, - generated::observability::{ - GetClusterWorkersRequest, GetTaskProgressRequest, PingRequest, PingResponse, - }, -}; - type ResultTaskData = Result>; pub struct ObservabilityServiceImpl { - task_data_entries: Arc>>>, + task_data_entries: Arc>>>, worker_resolver: Arc, #[cfg(feature = "system-metrics")] system: watch::Receiver, @@ -32,7 +28,7 @@ pub struct ObservabilityServiceImpl { impl ObservabilityServiceImpl { pub fn new( - task_data_entries: Arc>>>, + task_data_entries: Arc>>>, worker_resolver: Arc, ) -> Self { #[cfg(feature = "system-metrics")] @@ -102,7 +98,7 @@ impl ObservabilityService for ObservabilityServiceImpl { let output_rows = output_rows_from_plan(&task_data.plan); tasks.push(TaskProgress { - stage_key: Some(convert_stage_key(&internal_key)), + task_key: Some((*internal_key).clone()), total_partitions, completed_partitions, status: TaskStatus::Running as i32, @@ -146,15 +142,6 @@ impl ObservabilityServiceImpl { } } -/// Converts internal StageKey to observability proto StageKey -fn convert_stage_key(key: &StageKey) -> super::StageKey { - super::StageKey { - query_id: key.query_id.to_vec(), - stage_id: key.stage_id, - task_number: key.task_number, - } -} - /// Extracts output rows from the root plan node's metrics. fn output_rows_from_plan(plan: &Arc) -> u64 { plan.metrics().and_then(|m| m.output_rows()).unwrap_or(0) as u64 diff --git a/src/protobuf/app_metadata.rs b/src/protobuf/app_metadata.rs deleted file mode 100644 index 07d793dd..00000000 --- a/src/protobuf/app_metadata.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::distributed_codec::StageKey; -use std::time::{SystemTime, UNIX_EPOCH}; - -/// A collection of metrics for a set of tasks in an ExecutionPlan. each -/// entry should have a distinct [StageKey]. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricsCollection { - #[prost(message, repeated, tag = "1")] - pub tasks: Vec, -} - -/// TaskMetrics represents the metrics for a single task. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TaskMetrics { - /// stage_key uniquely identifies this task. - /// - /// This field is always present. It's marked optional due to protobuf rules. - #[prost(message, optional, tag = "1")] - pub stage_key: Option, - /// metrics[i] is the set of metrics for plan node `i` where plan nodes are in pre-order - /// traversal order. - #[prost(message, repeated, tag = "2")] - pub metrics: Vec, -} - -// FlightAppMetadata represents all types of app_metadata which we use in the distributed execution. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FlightAppMetadata { - #[prost(uint64, tag = "1")] - pub partition: u64, - // Unix timestamp in nanoseconds at which this message was created. - #[prost(uint64, tag = "2")] - pub created_timestamp_unix_nanos: u64, - // content should always be Some, but it is optional due to protobuf rules. - #[prost(oneof = "AppMetadata", tags = "10")] - pub content: Option, -} - -impl FlightAppMetadata { - pub fn new(partition: u64) -> Self { - Self { - partition, - created_timestamp_unix_nanos: current_unix_timestamp_nanos(), - content: None, - } - } - - pub fn set_content(&mut self, content: AppMetadata) { - self.content = Some(content); - } -} - -fn current_unix_timestamp_nanos() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_nanos() as u64) - .unwrap_or(0) -} - -#[derive(Clone, PartialEq, ::prost::Oneof)] -pub enum AppMetadata { - #[prost(message, tag = "10")] - MetricsCollection(MetricsCollection), - // Note: For every additional enum variant, ensure to add tags to [FlightAppMetadata]. ex. `#[prost(oneof = "AppMetadata", tags = "1,2,3")]` etc. - // If you don't the proto will compile but you may encounter errors during serialization/deserialization. -} diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 05079c66..80fb9eba 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -2,8 +2,8 @@ use super::get_distributed_user_codecs; use crate::execution_plans::{ BroadcastExec, ChildrenIsolatorUnionExec, NetworkBroadcastExec, NetworkCoalesceExec, }; -use crate::flight_service::WorkerConnectionPool; use crate::stage::{ExecutionTask, Stage}; +use crate::worker::WorkerConnectionPool; use crate::{DistributedTaskContext, NetworkBoundary}; use crate::{NetworkShuffleExec, PartitionIsolatorExec}; use bytes::Bytes; @@ -324,31 +324,6 @@ impl PhysicalExtensionCodec for DistributedCodec { } } -/// A key that uniquely identifies a stage in a query. -#[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] -pub struct StageKey { - /// Our query id - #[prost(bytes, tag = "1")] - pub query_id: Bytes, - /// Our stage id - #[prost(uint64, tag = "2")] - pub stage_id: u64, - /// The task number within the stage - #[prost(uint64, tag = "3")] - pub task_number: u64, -} - -impl StageKey { - /// Creates a new `StageKey`. - pub fn new(query_id: Bytes, stage_id: u64, task_number: u64) -> StageKey { - Self { - query_id, - stage_id, - task_number, - } - } -} - #[derive(Clone, PartialEq, ::prost::Message)] pub struct StageProto { /// Our query id diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index aeecb8ac..4518dbb8 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -1,11 +1,8 @@ -mod app_metadata; mod distributed_codec; mod errors; mod user_codec; -pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; pub(crate) use distributed_codec::DistributedCodec; -pub use distributed_codec::StageKey; pub(crate) use errors::{ datafusion_error_to_tonic_status, map_flight_to_datafusion_error, tonic_status_to_datafusion_error, diff --git a/src/test_utils/in_memory_channel_resolver.rs b/src/test_utils/in_memory_channel_resolver.rs index 96baef2d..85638248 100644 --- a/src/test_utils/in_memory_channel_resolver.rs +++ b/src/test_utils/in_memory_channel_resolver.rs @@ -1,9 +1,9 @@ +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; use crate::{ BoxCloneSyncChannel, ChannelResolver, DefaultSessionBuilder, DistributedExt, MappedWorkerSessionBuilderExt, Worker, WorkerResolver, WorkerSessionBuilder, - create_flight_client, + create_worker_client, }; -use arrow_flight::flight_service_client::FlightServiceClient; use async_trait::async_trait; use datafusion::common::DataFusionError; use hyper_util::rt::TokioIo; @@ -15,7 +15,7 @@ const DUMMY_URL: &str = "http://localhost:50051"; /// tokio duplex rather than a TCP connection. #[derive(Clone)] pub struct InMemoryChannelResolver { - channel: FlightServiceClient, + channel: WorkerServiceClient, } impl InMemoryChannelResolver { @@ -38,7 +38,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: create_flight_client(BoxCloneSyncChannel::new(channel)), + channel: create_worker_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -49,7 +49,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); @@ -66,10 +66,10 @@ impl Default for InMemoryChannelResolver { #[async_trait] impl ChannelResolver for InMemoryChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, _: &url::Url, - ) -> Result, DataFusionError> { + ) -> Result, DataFusionError> { Ok(self.channel.clone()) } } diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 09a1ee12..2db20968 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -57,7 +57,7 @@ where join_set.spawn(async move { Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(incoming) .await .unwrap(); @@ -103,7 +103,7 @@ impl WorkerResolver for LocalHostWorkerResolver { } } -pub async fn spawn_flight_service( +pub async fn spawn_worker_service( session_builder: impl WorkerSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { @@ -112,7 +112,7 @@ pub async fn spawn_flight_service( let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming); Ok(Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(incoming) .await?) } diff --git a/src/test_utils/metrics.rs b/src/test_utils/metrics.rs index a768379d..e744f1f1 100644 --- a/src/test_utils/metrics.rs +++ b/src/test_utils/metrics.rs @@ -1,35 +1,36 @@ -use crate::metrics::proto::{ElapsedCompute, EndTimestamp, OutputRows, StartTimestamp}; -use crate::metrics::proto::{MetricProto, MetricValueProto, MetricsSetProto}; +use crate::worker::generated::worker as pb; /// creates a "distinct" set of metrics from the provided seed -pub fn make_test_metrics_set_proto_from_seed(seed: u64, num_metrics: usize) -> MetricsSetProto { +pub fn make_test_metrics_set_proto_from_seed(seed: u64, num_metrics: usize) -> pb::MetricsSet { const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC - let mut result = MetricsSetProto { metrics: vec![] }; + let mut result = pb::MetricsSet { metrics: vec![] }; for i in 0..num_metrics { let value = seed + i as u64; - result.push(match i % 4 { - 0 => MetricProto { - metric: Some(MetricValueProto::OutputRows(OutputRows { value })), + result.metrics.push(match i % 4 { + 0 => pb::Metric { + value: Some(pb::metric::Value::OutputRows(pb::OutputRows { value })), labels: vec![], partition: None, }, - 1 => MetricProto { - metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { value })), + 1 => pb::Metric { + value: Some(pb::metric::Value::ElapsedCompute(pb::ElapsedCompute { + value, + })), labels: vec![], partition: None, }, - 2 => MetricProto { - metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + 2 => pb::Metric { + value: Some(pb::metric::Value::StartTimestamp(pb::StartTimestamp { value: Some(TEST_TIMESTAMP + (value as i64 * 1_000_000_000)), })), labels: vec![], partition: None, }, - 3 => MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + 3 => pb::Metric { + value: Some(pb::metric::Value::EndTimestamp(pb::EndTimestamp { value: Some(TEST_TIMESTAMP + (value as i64 * 1_000_000_000)), })), labels: vec![], diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs index 2291e00c..635bed47 100644 --- a/src/test_utils/plans.rs +++ b/src/test_utils/plans.rs @@ -1,9 +1,10 @@ +use super::parquet::register_parquet_tables; use crate::NetworkBoundaryExt; use crate::distributed_ext::DistributedExt; use crate::execution_plans::DistributedExec; -use crate::protobuf::StageKey; use crate::stage::Stage; use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver; +use crate::worker::generated::worker::TaskKey; #[cfg(test)] use crate::{DistributedConfig, TaskEstimation, TaskEstimator}; #[cfg(test)] @@ -18,8 +19,6 @@ use datafusion::{ use itertools::Itertools; use std::sync::Arc; -use super::parquet::register_parquet_tables; - /// count_plan_nodes counts the number of execution plan nodes in a plan using BFS traversal. /// This does NOT traverse child stages, only the execution plan tree within this stage. /// Network boundary nodes are counted but their children (which belong to child stages) are not traversed. @@ -46,13 +45,13 @@ pub fn count_plan_nodes_up_to_network_boundary(plan: &Arc) -> /// Returns /// - a map of all stages -/// - a set of all the stage keys (one per task) -pub fn get_stages_and_stage_keys( +/// - a set of all the task keys (one per task) +pub fn get_stages_and_task_keys( stage: &DistributedExec, -) -> (HashMap, HashSet) { +) -> (HashMap, HashSet) { let mut i = 0; let mut queue = find_input_stages(stage); - let mut stage_keys = HashSet::new(); + let mut task_keys = HashSet::new(); let mut stages_map = HashMap::new(); while i < queue.len() { @@ -62,17 +61,17 @@ pub fn get_stages_and_stage_keys( // Add each task. for j in 0..stage.tasks.len() { - stage_keys.insert(StageKey::new( - stage.query_id.as_bytes().to_vec().into(), - stage.num as u64, - j as u64, - )); + task_keys.insert(TaskKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: j as u64, + }); } // Add any child stages queue.extend(find_input_stages(stage.plan.as_ref().unwrap().as_ref())); } - (stages_map, stage_keys) + (stages_map, task_keys) } fn find_input_stages(plan: &dyn ExecutionPlan) -> Vec<&Stage> { diff --git a/src/worker/gen/Cargo.toml b/src/worker/gen/Cargo.toml new file mode 100644 index 00000000..364b9cf3 --- /dev/null +++ b/src/worker/gen/Cargo.toml @@ -0,0 +1,12 @@ +[workspace] +# Empty workspace table to exclude this crate from parent workspace + +[package] +name = "worker-gen" +version = "0.1.0" +edition = "2024" +description = "Protobuf code generation for Worker service" +publish = false + +[dependencies] +tonic-prost-build = "0.14.2" diff --git a/src/worker/gen/regen.sh b/src/worker/gen/regen.sh new file mode 100755 index 00000000..39e04fdc --- /dev/null +++ b/src/worker/gen/regen.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +repo_root=$(git rev-parse --show-toplevel) +cd "$repo_root" && cargo run --manifest-path src/worker/gen/Cargo.toml diff --git a/src/worker/gen/src/main.rs b/src/worker/gen/src/main.rs new file mode 100644 index 00000000..1c982b17 --- /dev/null +++ b/src/worker/gen/src/main.rs @@ -0,0 +1,29 @@ +use std::env; +use std::fs; + +fn main() -> Result<(), Box> { + let repo_root = env::current_dir()?; + + let proto_dir = repo_root.join("src/worker"); + let proto_file = proto_dir.join("worker.proto"); + let out_dir = repo_root.join("src/worker/generated"); + + fs::create_dir_all(&out_dir)?; + + println!("Generating protobuf code..."); + println!("Proto dir: {proto_dir:?}"); + println!("Proto file: {proto_file:?}"); + println!("Output dir: {out_dir:?}"); + + tonic_prost_build::configure() + .build_server(true) + .build_client(true) + .out_dir(&out_dir) + .extern_path(".worker.FlightData", "::arrow_flight::FlightData") + .extern_path(".worker.FlightDescriptor", "::arrow_flight::FlightDescriptor") + .compile_protos(&[proto_file], &[proto_dir])?; + + println!("Successfully generated observability proto code"); + + Ok(()) +} diff --git a/src/worker/generated/mod.rs b/src/worker/generated/mod.rs new file mode 100644 index 00000000..844c269c --- /dev/null +++ b/src/worker/generated/mod.rs @@ -0,0 +1 @@ +pub(crate) mod worker; diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs new file mode 100644 index 00000000..e659236b --- /dev/null +++ b/src/worker/generated/worker.rs @@ -0,0 +1,675 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct CoordinatorToWorkerMsg { + #[prost(oneof = "coordinator_to_worker_msg::Inner", tags = "1")] + pub inner: ::core::option::Option, +} +/// Nested message and enum types in `CoordinatorToWorkerMsg`. +pub mod coordinator_to_worker_msg { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Inner { + /// Sends a subplan to a worker so that a future ExecuteTask call can actually execute it. + /// The plan is identified by a TaskKey. + #[prost(message, tag = "1")] + SetPlanRequest(super::SetPlanRequest), + } +} +/// For now, there are no messages that can flow back from worker to coordinator. +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct WorkerToCoordinatorMsg {} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct SetPlanRequest { + /// The unique identifier of the task to which the subplan belongs to. + #[prost(message, optional, tag = "1")] + pub task_key: ::core::option::Option, + /// The amount of tasks that share the same subplan. Necessary for building the DistributedTaskContext during execution. + #[prost(uint64, tag = "2")] + pub task_count: u64, + /// The serialized subplan the worker is expected to execute on an ExecuteTask gRPC call. + #[prost(bytes = "vec", tag = "3")] + pub plan_proto: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ExecuteTaskRequest { + /// The unique identifier of the task that is going to get executed. + #[prost(message, optional, tag = "1")] + pub task_key: ::core::option::Option, + /// The start of the partition range of the specified task that is going to be executed. + #[prost(uint64, tag = "2")] + pub target_partition_start: u64, + /// The end of the partition range of the specified task that is going to be executed. + #[prost(uint64, tag = "3")] + pub target_partition_end: u64, +} +/// A key that uniquely identifies a task in a query. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct TaskKey { + /// Our query id. + #[prost(bytes = "vec", tag = "1")] + pub query_id: ::prost::alloc::vec::Vec, + /// Our stage id. + #[prost(uint64, tag = "2")] + pub stage_id: u64, + /// The task number within the stage. + #[prost(uint64, tag = "3")] + pub task_number: u64, +} +/// FlightAppMetadata represents all types of app_metadata which we use in the distributed execution. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FlightAppMetadata { + #[prost(uint64, tag = "1")] + pub partition: u64, + /// Unix timestamp in nanoseconds at which this message was created. + #[prost(uint64, tag = "2")] + pub created_timestamp_unix_nanos: u64, + /// content should always be Some, but it is optional due to protobuf rules. + #[prost(oneof = "flight_app_metadata::Content", tags = "10")] + pub content: ::core::option::Option, +} +/// Nested message and enum types in `FlightAppMetadata`. +pub mod flight_app_metadata { + /// content should always be Some, but it is optional due to protobuf rules. + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Content { + #[prost(message, tag = "10")] + MetricsCollection(super::MetricsCollection), + } +} +/// A collection of metrics for a set of tasks in an ExecutionPlan. Each +/// entry should have a distinct TaskKey. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MetricsCollection { + #[prost(message, repeated, tag = "1")] + pub tasks: ::prost::alloc::vec::Vec, +} +/// TaskMetrics represents the metrics for a single task. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TaskMetrics { + /// task_key uniquely identifies this task. + /// This field is always present. It's marked optional due to protobuf rules. + #[prost(message, optional, tag = "1")] + pub task_key: ::core::option::Option, + /// metrics\[i\] is the set of metrics for plan node i where plan nodes are + /// in pre-order traversal order. + #[prost(message, repeated, tag = "2")] + pub metrics: ::prost::alloc::vec::Vec, +} +/// A Label mirrors datafusion::physical_plan::metrics::Label. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Label { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +/// A Metric is a protobuf mirror of datafusion::physical_plan::metrics::Metric. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Metric { + #[prost(message, repeated, tag = "1")] + pub labels: ::prost::alloc::vec::Vec