From 5cb96137f8cd1b5664a3eaaf32768b408e86e0d1 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 17 Mar 2026 19:08:45 +0100 Subject: [PATCH 01/11] Move away from Arrow Flight in favor of a custom worker gRPC service --- .gitignore | 2 + benchmarks/cdk/bin/worker.rs | 4 +- benchmarks/src/run.rs | 2 +- console/examples/cluster.rs | 2 +- console/examples/console_worker.rs | 2 +- docs/source/user-guide/channel-resolver.md | 18 +- docs/source/user-guide/concepts.md | 4 +- docs/source/user-guide/getting-started.md | 6 +- docs/source/user-guide/worker.md | 16 +- examples/in_memory_cluster.rs | 14 +- examples/localhost_worker.rs | 2 +- src/distributed_ext.rs | 12 +- .../benchmarks/shuffle_bench.rs | 14 +- src/execution_plans/distributed.rs | 32 +- src/execution_plans/network_broadcast.rs | 5 +- src/execution_plans/network_coalesce.rs | 5 +- src/execution_plans/network_shuffle.rs | 5 +- src/lib.rs | 17 +- src/metrics/task_metrics_collector.rs | 2 +- src/metrics/task_metrics_rewriter.rs | 32 +- src/networking/channel_resolver.rs | 62 ++- src/networking/mod.rs | 2 +- src/observability/gen/src/main.rs | 4 + src/observability/generated/observability.rs | 152 ++++--- src/observability/mod.rs | 4 +- src/observability/service.rs | 24 +- src/protobuf/app_metadata.rs | 2 +- src/protobuf/distributed_codec.rs | 27 +- src/protobuf/mod.rs | 1 - src/test_utils/in_memory_channel_resolver.rs | 14 +- src/test_utils/localhost.rs | 6 +- src/test_utils/plans.rs | 15 +- .../do_get.rs => worker/execute_task.rs} | 50 +-- src/worker/gen/Cargo.toml | 12 + src/worker/gen/regen.sh | 6 + src/worker/gen/src/main.rs | 29 ++ src/worker/generated/mod.rs | 1 + src/worker/generated/worker.rs | 411 ++++++++++++++++++ src/{flight_service => worker}/mod.rs | 12 +- .../session_builder.rs | 0 .../do_action.rs => worker/set_plan.rs} | 38 +- .../single_write_multi_read.rs | 0 .../spawn_select_all.rs | 0 .../test_utils/memory_worker.rs | 8 +- .../test_utils/mod.rs | 0 src/worker/worker.proto | 57 +++ .../worker_connection_pool.rs | 39 +- .../worker.rs => worker/worker_service.rs} | 126 ++---- 48 files changed, 839 insertions(+), 459 deletions(-) rename src/{flight_service/do_get.rs => worker/execute_task.rs} (87%) create mode 100644 src/worker/gen/Cargo.toml create mode 100755 src/worker/gen/regen.sh create mode 100644 src/worker/gen/src/main.rs create mode 100644 src/worker/generated/mod.rs create mode 100644 src/worker/generated/worker.rs rename src/{flight_service => worker}/mod.rs (75%) rename src/{flight_service => worker}/session_builder.rs (100%) rename src/{flight_service/do_action.rs => worker/set_plan.rs} (75%) rename src/{flight_service => worker}/single_write_multi_read.rs (100%) rename src/{flight_service => worker}/spawn_select_all.rs (100%) rename src/{flight_service => worker}/test_utils/memory_worker.rs (95%) rename src/{flight_service => worker}/test_utils/mod.rs (100%) create mode 100644 src/worker/worker.proto rename src/{flight_service => worker}/worker_connection_pool.rs (95%) rename src/{flight_service/worker.rs => worker/worker_service.rs} (60%) diff --git a/.gitignore b/.gitignore index 9aecd384..479b6815 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ testdata/clickbench/* !testdata/clickbench/queries src/observability/gen/target/* src/observability/gen/Cargo.lock +src/worker/gen/target/* +src/worker/gen/Cargo.lock diff --git a/benchmarks/cdk/bin/worker.rs b/benchmarks/cdk/bin/worker.rs index 5d019c5e..171aca82 100644 --- a/benchmarks/cdk/bin/worker.rs +++ b/benchmarks/cdk/bin/worker.rs @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box> { let mut errors = vec![]; for worker_url in worker_resolver.get_urls().map_err(err)? { if let Err(err) = channel_resolver - .get_flight_client_for_url(&worker_url) + .get_worker_client_for_url(&worker_url) .await { errors.push(err.to_string()) @@ -206,7 +206,7 @@ async fn main() -> Result<(), Box> { ), ); let grpc_server = Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(WORKER_ADDR.parse()?); info!("Started listener HTTP server in {LISTENER_ADDR}"); diff --git a/benchmarks/src/run.rs b/benchmarks/src/run.rs index 6b40c3ad..3b5c26d8 100644 --- a/benchmarks/src/run.rs +++ b/benchmarks/src/run.rs @@ -163,7 +163,7 @@ impl RunOpt { let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); Ok::<_, Box>( Server::builder() - .add_service(Worker::default().into_flight_server()) + .add_service(Worker::default().into_worker_server()) .serve_with_incoming(incoming) .await?, ) diff --git a/console/examples/cluster.rs b/console/examples/cluster.rs index 73260c05..de4b57cc 100644 --- a/console/examples/cluster.rs +++ b/console/examples/cluster.rs @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { Server::builder() .add_service(worker.with_observability_service()) - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await .expect("worker server failed"); diff --git a/console/examples/console_worker.rs b/console/examples/console_worker.rs index 189e8230..955d7034 100644 --- a/console/examples/console_worker.rs +++ b/console/examples/console_worker.rs @@ -22,7 +22,7 @@ async fn main() -> Result<(), Box> { Server::builder() .add_service(worker.with_observability_service()) - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) .await?; diff --git a/docs/source/user-guide/channel-resolver.md b/docs/source/user-guide/channel-resolver.md index 749cabf0..9fcdac18 100644 --- a/docs/source/user-guide/channel-resolver.md +++ b/docs/source/user-guide/channel-resolver.md @@ -2,10 +2,10 @@ This trait is optional—a sensible default implementation exists that handles most use cases. -The `ChannelResolver` trait controls how Distributed DataFusion builds Arrow Flight clients backed by +The `ChannelResolver` trait controls how Distributed DataFusion builds Worker gRPC clients backed by [Tonic](https://github.com/hyperium/tonic) channels for worker URLs. -The default implementation connects to each URL, builds an Arrow Flight client, and caches it for reuse on +The default implementation connects to each URL, builds a Worker client, and caches it for reuse on subsequent requests to the same URL. ## Providing your own ChannelResolver @@ -15,7 +15,7 @@ For providing your own implementation, you'll need to take into account the foll - You will need to provide your own implementation in two places: - in the `SessionContext` that first initiates and plans your queries. - while instantiating the `Worker` with the `from_session_builder()` constructor. -- If building from scratch, ensure Arrow Flight clients are reused across requests rather than recreated each time. +- If building from scratch, ensure Worker clients are reused across requests rather than recreated each time. - You can extend `DefaultChannelResolver` as a foundation for custom implementations. This automatically handles gRPC channel reuse. @@ -25,11 +25,11 @@ struct CustomChannelResolver; #[async_trait] impl ChannelResolver for CustomChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - // Build a custom FlightServiceClient wrapped with tower + ) -> Result, DataFusionError> { + // Build a custom WorkerServiceClient wrapped with tower // layers or something similar. todo!() } @@ -37,7 +37,7 @@ impl ChannelResolver for CustomChannelResolver { async fn main() { // Build a single instance for your application's lifetime - // to enable Arrow Flight client reuse across queries. + // to enable Worker client reuse across queries. let channel_resolver = CustomChannelResolver; let state = SessionStateBuilder::new() @@ -56,10 +56,10 @@ async fn main() { } }); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; Ok(()) } -``` \ No newline at end of file +``` diff --git a/docs/source/user-guide/concepts.md b/docs/source/user-guide/concepts.md index 88cbd0aa..d66235ad 100644 --- a/docs/source/user-guide/concepts.md +++ b/docs/source/user-guide/concepts.md @@ -32,9 +32,9 @@ a fully formed physical plan and injects the appropriate nodes to execute the qu It builds the distributed plan from bottom to top, injecting network boundaries at appropriate locations based on the nodes present in the original plan. -## [Worker](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/src/flight_service/worker.rs) +## [Worker](https://github.com/datafusion-contrib/datafusion-distributed/blob/main/src/worker/worker_service.rs) -Arrow Flight server implementation that integrates with the Tonic ecosystem and listens to serialized plans that get +gRPC server implementation that integrates with the Tonic ecosystem and listens to serialized plans that get executed over the wire. Users are expected to build these and spawn them in ports so that the network boundary nodes can reach them. diff --git a/docs/source/user-guide/getting-started.md b/docs/source/user-guide/getting-started.md index 33f5fa02..9b7dd089 100644 --- a/docs/source/user-guide/getting-started.md +++ b/docs/source/user-guide/getting-started.md @@ -11,7 +11,7 @@ Rather than imposing constraints on your infrastructure or query serving pattern allows you to plug in your own networking stack and spawn your own gRPC servers that act as workers in the cluster. This project heavily relies on the [Tonic](https://github.com/hyperium/tonic) ecosystem for the networking layer. -Users of this library are responsible for building their own Tonic server, adding the Arrow Flight distributed +Users of this library are responsible for building their own Tonic server, adding the distributed DataFusion service to it and spawning it on a port so that it can be reached by other workers in the cluster. A very basic example of this would be: @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { let worker = Worker::default(); Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -74,7 +74,7 @@ This will leave a DataFusion `SessionContext` ready for executing distributed qu Depending on your needs, your setup can get more complicated, for example: - You may want to resolve worker URLs dynamically using the Kubernetes API. -- You may want to wrap the Arrow Flight clients that connect workers with an observability layer. +- You may want to wrap the Worker clients that connect workers with an observability layer. - You may want to be able to execute your own custom ExecutionPlans in a distributed manner. - etc... diff --git a/docs/source/user-guide/worker.md b/docs/source/user-guide/worker.md index e57bc329..91fa3bf9 100644 --- a/docs/source/user-guide/worker.md +++ b/docs/source/user-guide/worker.md @@ -1,18 +1,18 @@ # Spawn a Worker -The `Worker` is a gRPC server implementing the Arrow Flight protocol for distributed query execution. Worker nodes +The `Worker` is a gRPC server that handles distributed query execution. Worker nodes run these endpoints to receive execution plans, execute them, and stream results back. ## Overview The `Worker` is the core worker component in Distributed DataFusion. It: -- Receives serialized execution plans via Arrow Flight's `do_get` method +- Receives serialized execution plans via gRPC - Deserializes plans using protobuf and user-provided codecs - Executes plans using the local DataFusion runtime -- Streams results back as Arrow record batches through the gRPC Arrow Flight interface +- Streams results back as Arrow record batches through the gRPC interface -## Launching the Arrow Flight server +## Launching the Worker server The default `Worker` implementation satisfies most basic use cases: @@ -23,7 +23,7 @@ async fn main() { let endpoint = Worker::default(); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -47,7 +47,7 @@ async fn main() { let endpoint = Worker::from_session_builder(build_sate); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)) .await?; @@ -88,10 +88,10 @@ async fn main() { let endpoint = Worker::default(); Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)) .await?; } ``` -The `into_flight_server()` method builds a `FlightServiceServer` ready to be added as a Tonic service. +The `into_worker_server()` method builds a `WorkerServiceServer` ready to be added as a Tonic service. diff --git a/examples/in_memory_cluster.rs b/examples/in_memory_cluster.rs index e5f02fdc..6ae97e28 100644 --- a/examples/in_memory_cluster.rs +++ b/examples/in_memory_cluster.rs @@ -1,12 +1,12 @@ use arrow::util::pretty::pretty_format_batches; -use arrow_flight::flight_service_client::FlightServiceClient; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_distributed::{ BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, Worker, - WorkerQueryContext, WorkerResolver, create_flight_client, display_plan_ascii, + WorkerQueryContext, WorkerResolver, WorkerServiceClient, create_worker_client, + display_plan_ascii, }; use futures::TryStreamExt; use hyper_util::rt::TokioIo; @@ -66,7 +66,7 @@ const DUMMY_URL: &str = "http://localhost:50051"; /// tokio duplex rather than a TCP connection. #[derive(Clone)] struct InMemoryChannelResolver { - channel: FlightServiceClient, + channel: WorkerServiceClient, } impl InMemoryChannelResolver { @@ -84,7 +84,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: create_flight_client(BoxCloneSyncChannel::new(channel)), + channel: create_worker_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -95,7 +95,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); @@ -106,10 +106,10 @@ impl InMemoryChannelResolver { #[async_trait] impl ChannelResolver for InMemoryChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, _: &url::Url, - ) -> Result, DataFusionError> { + ) -> Result, DataFusionError> { Ok(self.channel.clone()) } } diff --git a/examples/localhost_worker.rs b/examples/localhost_worker.rs index 58d329e0..2ed0d6d3 100644 --- a/examples/localhost_worker.rs +++ b/examples/localhost_worker.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), Box> { let args = Args::from_args(); Server::builder() - .add_service(Worker::default().into_flight_server()) + .add_service(Worker::default().into_worker_server()) .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port)) .await?; diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index f9a997a4..7c957926 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -191,7 +191,6 @@ pub trait DistributedExt: Sized { /// Example: /// /// ``` - /// # use arrow_flight::flight_service_client::FlightServiceClient; /// # use async_trait::async_trait; /// # use datafusion::common::DataFusionError; /// # use datafusion::execution::{SessionState, SessionStateBuilder}; @@ -228,29 +227,28 @@ pub trait DistributedExt: Sized { resolver: T, ); - /// This is what tells Distributed DataFusion how to build an Arrow Flight client out of a worker URL. + /// This is what tells Distributed DataFusion how to build a Worker gRPC client out of a worker URL. /// - /// There's a default implementation that caches the Arrow Flight client instances so that there's + /// There's a default implementation that caches the Worker client instances so that there's /// only one per URL, but users can decide to override that behavior in favor of their own solution. /// /// Example: /// /// ``` - /// # use arrow_flight::flight_service_client::FlightServiceClient; /// # use async_trait::async_trait; /// # use datafusion::common::DataFusionError; /// # use datafusion::execution::{SessionState, SessionStateBuilder}; /// # use datafusion::prelude::SessionConfig; /// # use url::Url; /// # use std::sync::Arc; - /// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, WorkerQueryContext}; + /// # use datafusion_distributed::{BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule, WorkerQueryContext, WorkerServiceClient}; /// /// struct CustomChannelResolver; /// /// #[async_trait] /// impl ChannelResolver for CustomChannelResolver { - /// async fn get_flight_client_for_url(&self, url: &Url) -> Result, DataFusionError> { - /// // Build a custom FlightServiceClient wrapped with tower layers or something similar. + /// async fn get_worker_client_for_url(&self, url: &Url) -> Result, DataFusionError> { + /// // Build a custom WorkerServiceClient wrapped with tower layers or something similar. /// todo!() /// } /// } diff --git a/src/execution_plans/benchmarks/shuffle_bench.rs b/src/execution_plans/benchmarks/shuffle_bench.rs index a7d5cf83..4e630533 100644 --- a/src/execution_plans/benchmarks/shuffle_bench.rs +++ b/src/execution_plans/benchmarks/shuffle_bench.rs @@ -1,16 +1,16 @@ use crate::common::task_ctx_with_extension; -use crate::flight_service::WorkerConnectionPool; -use crate::flight_service::test_utils::memory_worker::MemoryWorker; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; +use crate::worker::test_utils::memory_worker::MemoryWorker; use crate::{ BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedTaskContext, ExecutionTask, - NetworkShuffleExec, Stage, create_flight_client, + NetworkShuffleExec, Stage, create_worker_client, }; use arrow::datatypes::DataType::{ Boolean, Dictionary, Float64, Int32, Int64, List, Timestamp, UInt8, Utf8, }; use arrow::datatypes::{Field, Schema, TimeUnit}; use arrow::util::data_gen::create_random_batch; -use arrow_flight::flight_service_client::FlightServiceClient; use arrow_ipc::CompressionType; use datafusion::common::{Result, exec_err}; use datafusion::execution::SessionStateBuilder; @@ -33,14 +33,14 @@ pub struct InMemoryChannelsResolver { #[async_trait::async_trait] impl ChannelResolver for InMemoryChannelsResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result> { + ) -> Result> { let Some(port) = url.port() else { return exec_err!("Missing port in url {url}"); }; - Ok(create_flight_client(self.channels[port as usize].clone())) + Ok(create_worker_client(self.channels[port as usize].clone())) } } diff --git a/src/execution_plans/distributed.rs b/src/execution_plans/distributed.rs index 56752166..ff66d7d2 100644 --- a/src/execution_plans/distributed.rs +++ b/src/execution_plans/distributed.rs @@ -1,17 +1,14 @@ use crate::common::require_one_child; use crate::config_extension_ext::get_config_extension_propagation_headers; use crate::distributed_planner::NetworkBoundaryExt; -use crate::flight_service::{INIT_ACTION_TYPE, InitAction}; use crate::networking::get_distributed_worker_resolver; use crate::passthrough_headers::get_passthrough_headers; use crate::protobuf::{DistributedCodec, tonic_status_to_datafusion_error}; use crate::stage::{ExecutionTask, Stage}; +use crate::worker::generated::worker::{SetPlanRequest, StageKey}; use crate::{ - ChannelResolver, DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, StageKey, WorkerResolver, - get_distributed_channel_resolver, + DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, WorkerResolver, get_distributed_channel_resolver, }; -use arrow_flight::Action; -use bytes::Bytes; use datafusion::common::instant::Instant; use datafusion::common::runtime::JoinSet; use datafusion::common::tree_node::{Transformed, TreeNode}; @@ -124,10 +121,8 @@ impl DistributedExec { // This assumes the plan is the same for all the tasks within a stage. This is fine for // now, but it should be possible to send different versions of the subplan to the // different tasks. - let bytes: Bytes = - PhysicalPlanNode::try_from_physical_plan(Arc::clone(input_plan), &codec)? - .encode_to_vec() - .into(); + let bytes = PhysicalPlanNode::try_from_physical_plan(Arc::clone(input_plan), &codec)? + .encode_to_vec(); let tasks = stage .tasks @@ -138,10 +133,10 @@ impl DistributedExec { let execution_task = ExecutionTask { url: Some(url.clone()), }; - let action = InitAction { + let request = SetPlanRequest { plan_proto: bytes.clone(), stage_key: Some(StageKey { - query_id: stage.query_id.as_bytes().to_vec().into(), + query_id: stage.query_id.as_bytes().to_vec(), stage_id: stage.num as _, task_number: i as _, }), @@ -152,7 +147,7 @@ impl DistributedExec { // Spawns the task that feeds this subplan to this worker. There will be as // many as this spawned tasks as workers. join_set.spawn(async move { - send_plan_task(ctx, url, action).await?; + send_plan_task(ctx, url, request).await?; plan_send_latency.record(&start); Ok(()) }); @@ -258,24 +253,19 @@ impl ExecutionPlan for DistributedExec { } } -async fn send_plan_task(ctx: Arc, url: Url, init_action: InitAction) -> Result<()> { +async fn send_plan_task(ctx: Arc, url: Url, request: SetPlanRequest) -> Result<()> { let channel_resolver = get_distributed_channel_resolver(ctx.as_ref()); - let mut client = channel_resolver.get_flight_client_for_url(&url).await?; - - let body = init_action.encode_to_vec().into(); + let mut client = channel_resolver.get_worker_client_for_url(&url).await?; let mut headers = get_config_extension_propagation_headers(ctx.session_config())?; headers.extend(get_passthrough_headers(ctx.session_config())); let request = Request::from_parts( MetadataMap::from_headers(headers), Extensions::default(), - Action { - r#type: INIT_ACTION_TYPE.to_string(), - body, - }, + request, ); - client.do_action(request).await.map_err(|e| { + client.set_plan(request).await.map_err(|e| { tonic_status_to_datafusion_error(&e) .unwrap_or_else(|| exec_datafusion_err!("Error sending plan to worker {url}: {e}")) })?; diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs index 13511635..0d0243b6 100644 --- a/src/execution_plans/network_broadcast.rs +++ b/src/execution_plans/network_broadcast.rs @@ -1,10 +1,11 @@ use crate::DistributedTaskContext; use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; -use crate::flight_service::WorkerConnectionPool; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; +use crate::protobuf::AppMetadata; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker::StageKey; use dashmap::DashMap; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 96fcccd9..19908e33 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -1,10 +1,11 @@ use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; use crate::execution_plans::common::scale_partitioning_props; -use crate::flight_service::WorkerConnectionPool; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; +use crate::protobuf::AppMetadata; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker::StageKey; use crate::{DistributedTaskContext, ExecutionTask}; use dashmap::DashMap; use datafusion::common::{exec_err, plan_err}; diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 3bf515a4..2ebe6634 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,9 +1,10 @@ use crate::common::require_one_child; use crate::execution_plans::common::scale_partitioning; -use crate::flight_service::WorkerConnectionPool; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::{AppMetadata, StageKey}; +use crate::protobuf::AppMetadata; use crate::stage::Stage; +use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker::StageKey; use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary}; use dashmap::DashMap; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; diff --git a/src/lib.rs b/src/lib.rs index f56f0384..dc8ed777 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,10 @@ mod common; mod config_extension_ext; mod distributed_ext; mod execution_plans; -mod flight_service; mod metrics; mod passthrough_headers; mod stage; +mod worker; mod distributed_planner; mod networking; @@ -26,10 +26,6 @@ pub use execution_plans::{ BroadcastExec, DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, PartitionIsolatorExec, }; -pub use flight_service::{ - DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, - Worker, WorkerQueryContext, WorkerSessionBuilder, -}; pub use metrics::{ AvgLatencyMetric, BytesCounterMetric, BytesMetricExt, DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, DistributedMetricsFormat, FirstLatencyMetric, LatencyMetricExt, MaxLatencyMetric, @@ -38,20 +34,23 @@ pub use metrics::{ }; pub use networking::{ BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, WorkerResolver, - create_flight_client, get_distributed_channel_resolver, get_distributed_worker_resolver, + create_worker_client, get_distributed_channel_resolver, get_distributed_worker_resolver, }; +pub use worker::generated::worker::worker_service_client::WorkerServiceClient; pub use stage::{ DistributedTaskContext, ExecutionTask, Stage, display_plan_ascii, display_plan_graphviz, explain_analyze, }; +pub use worker::{ + DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, + Worker, WorkerQueryContext, WorkerSessionBuilder, +}; pub use observability::{ GetTaskProgressRequest, GetTaskProgressResponse, ObservabilityService, ObservabilityServiceClient, ObservabilityServiceImpl, ObservabilityServiceServer, PingRequest, - PingResponse, StageKey as ObservabilityStageKey, TaskProgress, TaskStatus, WorkerMetrics, + PingResponse, TaskProgress, TaskStatus, WorkerMetrics, }; -pub use protobuf::StageKey; - #[cfg(any(feature = "integration", test))] pub use execution_plans::benchmarks::ShuffleBench; diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index a566f809..5517bfce 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -2,7 +2,7 @@ use crate::NetworkBroadcastExec; use crate::execution_plans::NetworkCoalesceExec; use crate::execution_plans::NetworkShuffleExec; use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::StageKey; +use crate::worker::generated::worker::StageKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; use datafusion::common::tree_node::TreeNode; diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index 9c5e8b28..f06d7dc2 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -5,9 +5,8 @@ use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL; use crate::metrics::MetricsCollectorResult; use crate::metrics::TaskMetricsCollector; use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; -use crate::protobuf::StageKey; use crate::stage::Stage; -use bytes::Bytes; +use crate::worker::generated::worker::StageKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; use datafusion::common::tree_node::TreeNode; @@ -217,7 +216,11 @@ pub fn stage_metrics_rewriter( let mut stage_metrics = MetricsSet::new(); for task_id in 0..stage.tasks.len() { - let stage_key = StageKey::new(Bytes::from(stage.query_id.as_bytes().to_vec()), stage.num as u64, task_id as u64); + let stage_key = StageKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }; match metrics_collection.get(&stage_key) { Some(task_metrics) => { if node_idx >= task_metrics.len() { @@ -275,7 +278,6 @@ mod tests { annotate_metrics_set_with_task_id, stage_metrics_rewriter, }; use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; - use crate::protobuf::StageKey; use crate::test_utils::in_memory_channel_resolver::{ InMemoryChannelResolver, InMemoryWorkerResolver, }; @@ -283,7 +285,6 @@ mod tests { use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary; use crate::test_utils::session_context::register_temp_parquet_table; use crate::{DistributedExec, DistributedPhysicalOptimizerRule}; - use bytes::Bytes; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; @@ -298,6 +299,7 @@ mod tests { use crate::DistributedExt; use crate::metrics::task_metrics_rewriter::MetricsWrapperExec; + use crate::worker::generated::worker::StageKey; use datafusion::physical_plan::empty::EmptyExec; use datafusion::prelude::SessionConfig; use datafusion::prelude::SessionContext; @@ -434,11 +436,11 @@ mod tests { // Generate metrics for each task and store them in the map. let mut metrics_collection = HashMap::new(); for task_id in 0..stage.tasks.len() { - let stage_key = StageKey::new( - Bytes::from(stage.query_id.as_bytes().to_vec()), - stage.num as u64, - task_id as u64, - ); + let stage_key = StageKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }; let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan)) .map(|node_id| { make_test_metrics_set_proto_from_seed( @@ -475,11 +477,11 @@ mod tests { .enumerate() { let expected_task_node_metrics = metrics_collection - .get(&StageKey::new( - Bytes::from(stage.query_id.as_bytes().to_vec()), - stage.num as u64, - task_id as u64, - )) + .get(&StageKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: task_id as u64, + }) .unwrap()[node_id] .clone(); diff --git a/src/networking/channel_resolver.rs b/src/networking/channel_resolver.rs index 5e44546b..50b5fd93 100644 --- a/src/networking/channel_resolver.rs +++ b/src/networking/channel_resolver.rs @@ -1,6 +1,6 @@ use crate::DistributedConfig; use crate::config_extension_ext::set_distributed_option_extension; -use arrow_flight::flight_service_client::FlightServiceClient; +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; use async_trait::async_trait; use datafusion::common::{DataFusionError, config_datafusion_err, exec_datafusion_err}; use datafusion::execution::TaskContext; @@ -15,36 +15,35 @@ use tonic::transport::Channel; use tower::ServiceExt; use url::Url; -/// Allows users to customize the way Arrow Flight clients are created. A common use case is to +/// Allows users to customize the way Worker clients are created. A common use case is to /// wrap the client with tower layers or schedule it in an IO-specific tokio runtime. /// /// There is a default implementation of this trait that should be enough for the most common /// use-cases. /// /// # Implementation Notes -/// - This is called per Arrow Flight request, so implementors of this trait should make sure that -/// clients are reused across method calls instead of building a new Arrow Flight client -/// every time. +/// - This is called per gRPC request, so implementors of this trait should make sure that +/// clients are reused across method calls instead of building a new Worker client every time. /// -/// - When implementing `get_flight_client_for_url`, it is recommended to use the -/// [`create_flight_client`] helper function to ensure clients are configured with +/// - When implementing `get_worker_client_for_url`, it is recommended to use the +/// [`create_worker_client`] helper function to ensure clients are configured with /// appropriate message size limits for internal communication. This helps avoid message /// size errors when transferring large datasets. #[async_trait] pub trait ChannelResolver { - /// For a given URL, get an Arrow Flight client for communicating to it. + /// For a given URL, get a Worker gRPC client for communicating to it. /// - /// *WARNING*: This method is called for every Arrow Flight gRPC request, so to not create + /// *WARNING*: This method is called for every gRPC request, so to not create /// one client connection for each request, users are required to reuse generated clients. /// It's recommended to rely on [DefaultChannelResolver] either by delegating method calls /// to it or by copying the implementation. /// - /// Consider using [`create_flight_client`] to create the client with appropriate + /// Consider using [`create_worker_client`] to create the client with appropriate /// default message size limits. - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError>; + ) -> Result, DataFusionError>; } pub(crate) fn set_distributed_channel_resolver( @@ -109,7 +108,7 @@ pub(crate) struct ChannelResolverExtension(Option>, @@ -173,56 +172,55 @@ impl DefaultChannelResolver { #[async_trait] impl ChannelResolver for DefaultChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - self.get_channel(url).await.map(create_flight_client) + ) -> Result, DataFusionError> { + self.get_channel(url).await.map(create_worker_client) } } #[async_trait] impl ChannelResolver for Arc { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, url: &Url, - ) -> Result, DataFusionError> { - self.as_ref().get_flight_client_for_url(url).await + ) -> Result, DataFusionError> { + self.as_ref().get_worker_client_for_url(url).await } } -/// Creates a [`FlightServiceClient`] with high default message size limits. +/// Creates a [`WorkerServiceClient`] with high default message size limits. /// -/// This is a convenience function that wraps [`FlightServiceClient::new`] and configures +/// This is a convenience function that wraps [`WorkerServiceClient::new`] and configures /// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)` /// to avoid message size limitations for internal communication. /// /// Users implementing custom [`ChannelResolver`]s should use this function in their -/// `get_flight_client_for_url` implementations to ensure consistent behavior with built-in +/// `get_worker_client_for_url` implementations to ensure consistent behavior with built-in /// implementations. /// /// # Example /// /// ```rust,ignore -/// use datafusion_distributed::{create_flight_client, BoxCloneSyncChannel, ChannelResolver}; -/// use arrow_flight::flight_service_client::FlightServiceClient; -/// use tonic::transport::Channel; +/// use datafusion_distributed::{create_worker_client, BoxCloneSyncChannel, ChannelResolver}; +/// /// use tonic::transport::Channel; /// /// #[async_trait] /// impl ChannelResolver for MyResolver { -/// async fn get_flight_client_for_url( +/// async fn get_worker_client_for_url( /// &self, /// url: &Url, -/// ) -> Result, DataFusionError> { +/// ) -> Result, DataFusionError> { /// let channel = Channel::from_shared(url.to_string())?.connect().await?; -/// Ok(create_flight_client(BoxCloneSyncChannel::new(channel))) +/// Ok(create_worker_client(BoxCloneSyncChannel::new(channel))) /// } /// } /// ``` -pub fn create_flight_client( +pub fn create_worker_client( channel: BoxCloneSyncChannel, -) -> FlightServiceClient { - FlightServiceClient::new(channel) +) -> WorkerServiceClient { + WorkerServiceClient::new(channel) .max_decoding_message_size(usize::MAX) .max_encoding_message_size(usize::MAX) } @@ -285,7 +283,7 @@ mod tests { let worker = Worker::default(); let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); if let Err(err) = Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(incoming) .await { diff --git a/src/networking/mod.rs b/src/networking/mod.rs index 3da47b0d..6bab9ae0 100644 --- a/src/networking/mod.rs +++ b/src/networking/mod.rs @@ -2,7 +2,7 @@ mod channel_resolver; mod worker_resolver; pub use channel_resolver::{ - BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, create_flight_client, + BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, create_worker_client, get_distributed_channel_resolver, }; pub(crate) use channel_resolver::{ChannelResolverExtension, set_distributed_channel_resolver}; diff --git a/src/observability/gen/src/main.rs b/src/observability/gen/src/main.rs index 93aef11c..37488f89 100644 --- a/src/observability/gen/src/main.rs +++ b/src/observability/gen/src/main.rs @@ -19,6 +19,10 @@ fn main() -> Result<(), Box> { .build_server(true) .build_client(true) .out_dir(&out_dir) + .extern_path( + ".observability.StageKey", + "crate::worker::generated::worker::StageKey", + ) .compile_protos(&[proto_file], &[proto_dir])?; println!("Successfully generated observability proto code"); diff --git a/src/observability/generated/observability.rs b/src/observability/generated/observability.rs index a032eb07..b7a07f0e 100644 --- a/src/observability/generated/observability.rs +++ b/src/observability/generated/observability.rs @@ -8,20 +8,11 @@ pub struct PingResponse { } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct GetTaskProgressRequest {} -#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] -pub struct StageKey { - #[prost(bytes = "vec", tag = "1")] - pub query_id: ::prost::alloc::vec::Vec, - #[prost(uint64, tag = "2")] - pub stage_id: u64, - #[prost(uint64, tag = "3")] - pub task_number: u64, -} /// Progress information for a single task #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct TaskProgress { #[prost(message, optional, tag = "1")] - pub stage_key: ::core::option::Option, + pub stage_key: ::core::option::Option, #[prost(uint64, tag = "2")] pub total_partitions: u64, #[prost(uint64, tag = "3")] @@ -79,10 +70,10 @@ pub mod observability_service_client { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value + clippy::let_unit_value, )] - use tonic::codegen::http::Uri; use tonic::codegen::*; + use tonic::codegen::http::Uri; #[derive(Debug, Clone)] pub struct ObservabilityServiceClient { inner: tonic::client::Grpc, @@ -121,13 +112,14 @@ pub mod observability_service_client { F: tonic::service::Interceptor, T::ResponseBody: Default, T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, + http::Request, + Response = http::Response< + >::ResponseBody, >, - >>::Error: - Into + std::marker::Send + std::marker::Sync, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, { ObservabilityServiceClient::new(InterceptedService::new(inner, interceptor)) } @@ -166,36 +158,50 @@ pub mod observability_service_client { &mut self, request: impl tonic::IntoRequest, ) -> std::result::Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::unknown(format!("Service was not ready: {}", e.into())) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/observability.ObservabilityService/Ping"); + let path = http::uri::PathAndQuery::from_static( + "/observability.ObservabilityService/Ping", + ); let mut req = request.into_request(); - req.extensions_mut().insert(GrpcMethod::new( - "observability.ObservabilityService", - "Ping", - )); + req.extensions_mut() + .insert(GrpcMethod::new("observability.ObservabilityService", "Ping")); self.inner.unary(req, path, codec).await } pub async fn get_task_progress( &mut self, request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> - { - self.inner.ready().await.map_err(|e| { - tonic::Status::unknown(format!("Service was not ready: {}", e.into())) - })?; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/observability.ObservabilityService/GetTaskProgress", ); let mut req = request.into_request(); - req.extensions_mut().insert(GrpcMethod::new( - "observability.ObservabilityService", - "GetTaskProgress", - )); + req.extensions_mut() + .insert( + GrpcMethod::new( + "observability.ObservabilityService", + "GetTaskProgress", + ), + ); self.inner.unary(req, path, codec).await } } @@ -207,7 +213,7 @@ pub mod observability_service_server { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value + clippy::let_unit_value, )] use tonic::codegen::*; /// Generated trait containing gRPC methods that should be implemented for use with ObservabilityServiceServer. @@ -220,7 +226,10 @@ pub mod observability_service_server { async fn get_task_progress( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; } #[derive(Debug)] pub struct ObservabilityServiceServer { @@ -243,7 +252,10 @@ pub mod observability_service_server { max_encoding_message_size: None, } } - pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService where F: tonic::service::Interceptor, { @@ -278,7 +290,8 @@ pub mod observability_service_server { self } } - impl tonic::codegen::Service> for ObservabilityServiceServer + impl tonic::codegen::Service> + for ObservabilityServiceServer where T: ObservabilityService, B: Body + std::marker::Send + 'static, @@ -298,9 +311,14 @@ pub mod observability_service_server { "/observability.ObservabilityService/Ping" => { #[allow(non_camel_case_types)] struct PingSvc(pub Arc); - impl tonic::server::UnaryService for PingSvc { + impl< + T: ObservabilityService, + > tonic::server::UnaryService for PingSvc { type Response = super::PingResponse; - type Future = BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, @@ -337,19 +355,25 @@ pub mod observability_service_server { "/observability.ObservabilityService/GetTaskProgress" => { #[allow(non_camel_case_types)] struct GetTaskProgressSvc(pub Arc); - impl - tonic::server::UnaryService - for GetTaskProgressSvc - { + impl< + T: ObservabilityService, + > tonic::server::UnaryService + for GetTaskProgressSvc { type Response = super::GetTaskProgressResponse; - type Future = BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - ::get_task_progress(&inner, request) + ::get_task_progress( + &inner, + request, + ) .await }; Box::pin(fut) @@ -377,19 +401,25 @@ pub mod observability_service_server { }; Box::pin(fut) } - _ => Box::pin(async move { - let mut response = http::Response::new(tonic::body::Body::default()); - let headers = response.headers_mut(); - headers.insert( - tonic::Status::GRPC_STATUS, - (tonic::Code::Unimplemented as i32).into(), - ); - headers.insert( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ); - Ok(response) - }), + _ => { + Box::pin(async move { + let mut response = http::Response::new( + tonic::body::Body::default(), + ); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } } } } diff --git a/src/observability/mod.rs b/src/observability/mod.rs index 6c07b7b5..20eb775a 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -7,7 +7,7 @@ pub use generated::observability::observability_service_server::{ }; pub use generated::observability::{ - GetTaskProgressRequest, GetTaskProgressResponse, PingRequest, PingResponse, StageKey, - TaskProgress, TaskStatus, WorkerMetrics, + GetTaskProgressRequest, GetTaskProgressResponse, PingRequest, PingResponse, TaskProgress, + TaskStatus, WorkerMetrics, }; pub use service::ObservabilityServiceImpl; diff --git a/src/observability/service.rs b/src/observability/service.rs index 6a2834ad..a191b783 100644 --- a/src/observability/service.rs +++ b/src/observability/service.rs @@ -1,5 +1,9 @@ -use crate::flight_service::{SingleWriteMultiRead, TaskData}; -use crate::protobuf::StageKey; +use super::{ + GetTaskProgressResponse, ObservabilityService, TaskProgress, TaskStatus, WorkerMetrics, + generated::observability::{GetTaskProgressRequest, PingRequest, PingResponse}, +}; +use crate::worker::generated::worker::StageKey; +use crate::worker::{SingleWriteMultiRead, TaskData}; use datafusion::error::DataFusionError; use datafusion::physical_plan::ExecutionPlan; use moka::future::Cache; @@ -14,11 +18,6 @@ use sysinfo::{Pid, ProcessRefreshKind}; use tokio::sync::watch; use tonic::{Request, Response, Status}; -use super::{ - GetTaskProgressResponse, ObservabilityService, TaskProgress, TaskStatus, WorkerMetrics, - generated::observability::{GetTaskProgressRequest, PingRequest, PingResponse}, -}; - type ResultTaskData = Result>; pub struct ObservabilityServiceImpl { @@ -100,7 +99,7 @@ impl ObservabilityService for ObservabilityServiceImpl { let output_rows = output_rows_from_plan(&task_data.plan); tasks.push(TaskProgress { - stage_key: Some(convert_stage_key(&internal_key)), + stage_key: Some((*internal_key).clone()), total_partitions, completed_partitions, status: TaskStatus::Running as i32, @@ -130,15 +129,6 @@ impl ObservabilityServiceImpl { } } -/// Converts internal StageKey to observability proto StageKey -fn convert_stage_key(key: &StageKey) -> super::StageKey { - super::StageKey { - query_id: key.query_id.to_vec(), - stage_id: key.stage_id, - task_number: key.task_number, - } -} - /// Extracts output rows from the root plan node's metrics. fn output_rows_from_plan(plan: &Arc) -> u64 { plan.metrics().and_then(|m| m.output_rows()).unwrap_or(0) as u64 diff --git a/src/protobuf/app_metadata.rs b/src/protobuf/app_metadata.rs index 07d793dd..82077265 100644 --- a/src/protobuf/app_metadata.rs +++ b/src/protobuf/app_metadata.rs @@ -1,5 +1,5 @@ use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::distributed_codec::StageKey; +use crate::worker::generated::worker::StageKey; use std::time::{SystemTime, UNIX_EPOCH}; /// A collection of metrics for a set of tasks in an ExecutionPlan. each diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 05079c66..80fb9eba 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -2,8 +2,8 @@ use super::get_distributed_user_codecs; use crate::execution_plans::{ BroadcastExec, ChildrenIsolatorUnionExec, NetworkBroadcastExec, NetworkCoalesceExec, }; -use crate::flight_service::WorkerConnectionPool; use crate::stage::{ExecutionTask, Stage}; +use crate::worker::WorkerConnectionPool; use crate::{DistributedTaskContext, NetworkBoundary}; use crate::{NetworkShuffleExec, PartitionIsolatorExec}; use bytes::Bytes; @@ -324,31 +324,6 @@ impl PhysicalExtensionCodec for DistributedCodec { } } -/// A key that uniquely identifies a stage in a query. -#[derive(Clone, Hash, Eq, PartialEq, ::prost::Message)] -pub struct StageKey { - /// Our query id - #[prost(bytes, tag = "1")] - pub query_id: Bytes, - /// Our stage id - #[prost(uint64, tag = "2")] - pub stage_id: u64, - /// The task number within the stage - #[prost(uint64, tag = "3")] - pub task_number: u64, -} - -impl StageKey { - /// Creates a new `StageKey`. - pub fn new(query_id: Bytes, stage_id: u64, task_number: u64) -> StageKey { - Self { - query_id, - stage_id, - task_number, - } - } -} - #[derive(Clone, PartialEq, ::prost::Message)] pub struct StageProto { /// Our query id diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index aeecb8ac..9c908a97 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -5,7 +5,6 @@ mod user_codec; pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; pub(crate) use distributed_codec::DistributedCodec; -pub use distributed_codec::StageKey; pub(crate) use errors::{ datafusion_error_to_tonic_status, map_flight_to_datafusion_error, tonic_status_to_datafusion_error, diff --git a/src/test_utils/in_memory_channel_resolver.rs b/src/test_utils/in_memory_channel_resolver.rs index 96baef2d..85638248 100644 --- a/src/test_utils/in_memory_channel_resolver.rs +++ b/src/test_utils/in_memory_channel_resolver.rs @@ -1,9 +1,9 @@ +use crate::worker::generated::worker::worker_service_client::WorkerServiceClient; use crate::{ BoxCloneSyncChannel, ChannelResolver, DefaultSessionBuilder, DistributedExt, MappedWorkerSessionBuilderExt, Worker, WorkerResolver, WorkerSessionBuilder, - create_flight_client, + create_worker_client, }; -use arrow_flight::flight_service_client::FlightServiceClient; use async_trait::async_trait; use datafusion::common::DataFusionError; use hyper_util::rt::TokioIo; @@ -15,7 +15,7 @@ const DUMMY_URL: &str = "http://localhost:50051"; /// tokio duplex rather than a TCP connection. #[derive(Clone)] pub struct InMemoryChannelResolver { - channel: FlightServiceClient, + channel: WorkerServiceClient, } impl InMemoryChannelResolver { @@ -38,7 +38,7 @@ impl InMemoryChannelResolver { })); let this = Self { - channel: create_flight_client(BoxCloneSyncChannel::new(channel)), + channel: create_worker_client(BoxCloneSyncChannel::new(channel)), }; let this_clone = this.clone(); @@ -49,7 +49,7 @@ impl InMemoryChannelResolver { tokio::spawn(async move { Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); @@ -66,10 +66,10 @@ impl Default for InMemoryChannelResolver { #[async_trait] impl ChannelResolver for InMemoryChannelResolver { - async fn get_flight_client_for_url( + async fn get_worker_client_for_url( &self, _: &url::Url, - ) -> Result, DataFusionError> { + ) -> Result, DataFusionError> { Ok(self.channel.clone()) } } diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index 09a1ee12..2db20968 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -57,7 +57,7 @@ where join_set.spawn(async move { Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(incoming) .await .unwrap(); @@ -103,7 +103,7 @@ impl WorkerResolver for LocalHostWorkerResolver { } } -pub async fn spawn_flight_service( +pub async fn spawn_worker_service( session_builder: impl WorkerSessionBuilder + Send + Sync + 'static, incoming: TcpListener, ) -> Result<(), Box> { @@ -112,7 +112,7 @@ pub async fn spawn_flight_service( let incoming = tokio_stream::wrappers::TcpListenerStream::new(incoming); Ok(Server::builder() - .add_service(endpoint.into_flight_server()) + .add_service(endpoint.into_worker_server()) .serve_with_incoming(incoming) .await?) } diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs index 2291e00c..09f409af 100644 --- a/src/test_utils/plans.rs +++ b/src/test_utils/plans.rs @@ -1,9 +1,10 @@ +use super::parquet::register_parquet_tables; use crate::NetworkBoundaryExt; use crate::distributed_ext::DistributedExt; use crate::execution_plans::DistributedExec; -use crate::protobuf::StageKey; use crate::stage::Stage; use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver; +use crate::worker::generated::worker::StageKey; #[cfg(test)] use crate::{DistributedConfig, TaskEstimation, TaskEstimator}; #[cfg(test)] @@ -18,8 +19,6 @@ use datafusion::{ use itertools::Itertools; use std::sync::Arc; -use super::parquet::register_parquet_tables; - /// count_plan_nodes counts the number of execution plan nodes in a plan using BFS traversal. /// This does NOT traverse child stages, only the execution plan tree within this stage. /// Network boundary nodes are counted but their children (which belong to child stages) are not traversed. @@ -62,11 +61,11 @@ pub fn get_stages_and_stage_keys( // Add each task. for j in 0..stage.tasks.len() { - stage_keys.insert(StageKey::new( - stage.query_id.as_bytes().to_vec().into(), - stage.num as u64, - j as u64, - )); + stage_keys.insert(StageKey { + query_id: stage.query_id.as_bytes().to_vec(), + stage_id: stage.num as u64, + task_number: j as u64, + }); } // Add any child stages diff --git a/src/flight_service/do_get.rs b/src/worker/execute_task.rs similarity index 87% rename from src/flight_service/do_get.rs rename to src/worker/execute_task.rs index fb9c665a..b014f75f 100644 --- a/src/flight_service/do_get.rs +++ b/src/worker/execute_task.rs @@ -1,20 +1,20 @@ use crate::common::{map_last_stream, on_drop_stream, task_ctx_with_extension}; -use crate::flight_service::worker::Worker; use crate::metrics::TaskMetricsCollector; use crate::metrics::proto::df_metrics_set_to_proto; use crate::protobuf::{ - AppMetadata, FlightAppMetadata, MetricsCollection, StageKey, TaskMetrics, + AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics, datafusion_error_to_tonic_status, }; +use crate::worker::worker_service::Worker; use crate::{DistributedConfig, DistributedTaskContext}; -use arrow_flight::Ticket; use arrow_flight::encode::{DictionaryHandling, FlightDataEncoder, FlightDataEncoderBuilder}; use arrow_flight::error::FlightError; -use arrow_flight::flight_service_server::FlightService; use arrow_select::dictionary::garbage_collect_any_dictionary; use datafusion::arrow::array::{Array, AsArray, RecordBatch}; -use crate::flight_service::spawn_select_all::spawn_select_all; +use crate::worker::generated::worker::worker_service_server::WorkerService; +use crate::worker::generated::worker::{ExecuteTaskRequest, StageKey}; +use crate::worker::spawn_select_all::spawn_select_all; use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion::common::exec_datafusion_err; @@ -32,38 +32,14 @@ use tonic::{Request, Response, Status}; const RECORD_BATCH_BUFFER_SIZE: usize = 2; const WAIT_PLAN_TIMEOUT_SECS: u64 = 10; -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DoGet { - /// The index to the task within the stage that we want to execute - #[prost(uint64, tag = "2")] - pub target_task_index: u64, - #[prost(uint64, tag = "3")] - pub target_task_count: u64, - /// lower bound for the list of partitions to execute (inclusive). - #[prost(uint64, tag = "4")] - pub target_partition_start: u64, - /// upper bound for the list of partitions to execute (exclusive). - #[prost(uint64, tag = "5")] - pub target_partition_end: u64, - /// The stage key that identifies the stage. This is useful to keep - /// outside of the stage proto as it is used to store the stage - /// and we may not need to deserialize the entire stage proto - /// if we already have stored it - #[prost(message, optional, tag = "6")] - pub stage_key: Option, -} - impl Worker { - pub(super) async fn get( + pub(crate) async fn execute_task( &self, - request: Request, - ) -> Result::DoGetStream>, Status> { + request: Request, + ) -> Result::ExecuteTaskStream>, Status> { let body = request.into_inner(); - let doget = DoGet::decode(body.ticket).map_err(|err| { - Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) - })?; - let key = doget.stage_key.ok_or_else(missing("stage_key"))?; + let key = body.stage_key.ok_or_else(missing("stage_key"))?; let entry = self .task_data_entries .get_with(key.clone(), async { Default::default() }) @@ -94,8 +70,8 @@ impl Worker { let task_ctx = Arc::new(task_ctx_with_extension( &task_ctx, DistributedTaskContext { - task_index: doget.target_task_index as usize, - task_count: doget.target_task_count as usize, + task_index: body.target_task_index as usize, + task_count: body.target_task_count as usize, }, )); @@ -104,9 +80,9 @@ impl Worker { // Execute all the requested partitions at once, and collect all the streams so that they // can be merged into a single one at the end of this function. - let n_streams = doget.target_partition_end - doget.target_partition_start; + let n_streams = body.target_partition_end - body.target_partition_start; let mut streams = Vec::with_capacity(n_streams as usize); - for partition in doget.target_partition_start..doget.target_partition_end { + for partition in body.target_partition_start..body.target_partition_end { if partition >= partition_count as u64 { return Err(datafusion_error_to_tonic_status(exec_datafusion_err!( "partition {partition} not available. The head plan {plan_name} of the stage just has {partition_count} partitions" diff --git a/src/worker/gen/Cargo.toml b/src/worker/gen/Cargo.toml new file mode 100644 index 00000000..364b9cf3 --- /dev/null +++ b/src/worker/gen/Cargo.toml @@ -0,0 +1,12 @@ +[workspace] +# Empty workspace table to exclude this crate from parent workspace + +[package] +name = "worker-gen" +version = "0.1.0" +edition = "2024" +description = "Protobuf code generation for Worker service" +publish = false + +[dependencies] +tonic-prost-build = "0.14.2" diff --git a/src/worker/gen/regen.sh b/src/worker/gen/regen.sh new file mode 100755 index 00000000..39e04fdc --- /dev/null +++ b/src/worker/gen/regen.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +repo_root=$(git rev-parse --show-toplevel) +cd "$repo_root" && cargo run --manifest-path src/worker/gen/Cargo.toml diff --git a/src/worker/gen/src/main.rs b/src/worker/gen/src/main.rs new file mode 100644 index 00000000..1c982b17 --- /dev/null +++ b/src/worker/gen/src/main.rs @@ -0,0 +1,29 @@ +use std::env; +use std::fs; + +fn main() -> Result<(), Box> { + let repo_root = env::current_dir()?; + + let proto_dir = repo_root.join("src/worker"); + let proto_file = proto_dir.join("worker.proto"); + let out_dir = repo_root.join("src/worker/generated"); + + fs::create_dir_all(&out_dir)?; + + println!("Generating protobuf code..."); + println!("Proto dir: {proto_dir:?}"); + println!("Proto file: {proto_file:?}"); + println!("Output dir: {out_dir:?}"); + + tonic_prost_build::configure() + .build_server(true) + .build_client(true) + .out_dir(&out_dir) + .extern_path(".worker.FlightData", "::arrow_flight::FlightData") + .extern_path(".worker.FlightDescriptor", "::arrow_flight::FlightDescriptor") + .compile_protos(&[proto_file], &[proto_dir])?; + + println!("Successfully generated observability proto code"); + + Ok(()) +} diff --git a/src/worker/generated/mod.rs b/src/worker/generated/mod.rs new file mode 100644 index 00000000..844c269c --- /dev/null +++ b/src/worker/generated/mod.rs @@ -0,0 +1 @@ +pub(crate) mod worker; diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs new file mode 100644 index 00000000..57a78615 --- /dev/null +++ b/src/worker/generated/worker.rs @@ -0,0 +1,411 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct SetPlanRequest { + #[prost(bytes = "vec", tag = "1")] + pub plan_proto: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "2")] + pub stage_key: ::core::option::Option, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct SetPlanResponse {} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ExecuteTaskRequest { + #[prost(uint64, tag = "1")] + pub target_task_index: u64, + #[prost(uint64, tag = "2")] + pub target_task_count: u64, + #[prost(uint64, tag = "3")] + pub target_partition_start: u64, + #[prost(uint64, tag = "4")] + pub target_partition_end: u64, + #[prost(message, optional, tag = "5")] + pub stage_key: ::core::option::Option, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct StageKey { + #[prost(bytes = "vec", tag = "1")] + pub query_id: ::prost::alloc::vec::Vec, + #[prost(uint64, tag = "2")] + pub stage_id: u64, + #[prost(uint64, tag = "3")] + pub task_number: u64, +} +/// Generated client implementations. +pub mod worker_service_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct WorkerServiceClient { + inner: tonic::client::Grpc, + } + impl WorkerServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl WorkerServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> WorkerServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + WorkerServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn set_plan( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/worker.WorkerService/SetPlan", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("worker.WorkerService", "SetPlan")); + self.inner.unary(req, path, codec).await + } + pub async fn execute_task( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/worker.WorkerService/ExecuteTask", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("worker.WorkerService", "ExecuteTask")); + self.inner.server_streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod worker_service_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with WorkerServiceServer. + #[async_trait] + pub trait WorkerService: std::marker::Send + std::marker::Sync + 'static { + async fn set_plan( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ExecuteTask method. + type ExecuteTaskStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result<::arrow_flight::FlightData, tonic::Status>, + > + + std::marker::Send + + 'static; + async fn execute_task( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + } + #[derive(Debug)] + pub struct WorkerServiceServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl WorkerServiceServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for WorkerServiceServer + where + T: WorkerService, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/worker.WorkerService/SetPlan" => { + #[allow(non_camel_case_types)] + struct SetPlanSvc(pub Arc); + impl< + T: WorkerService, + > tonic::server::UnaryService + for SetPlanSvc { + type Response = super::SetPlanResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::set_plan(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = SetPlanSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/worker.WorkerService/ExecuteTask" => { + #[allow(non_camel_case_types)] + struct ExecuteTaskSvc(pub Arc); + impl< + T: WorkerService, + > tonic::server::ServerStreamingService + for ExecuteTaskSvc { + type Response = ::arrow_flight::FlightData; + type ResponseStream = T::ExecuteTaskStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::execute_task(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ExecuteTaskSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + let mut response = http::Response::new( + tonic::body::Body::default(), + ); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } + } + } + } + impl Clone for WorkerServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "worker.WorkerService"; + impl tonic::server::NamedService for WorkerServiceServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/src/flight_service/mod.rs b/src/worker/mod.rs similarity index 75% rename from src/flight_service/mod.rs rename to src/worker/mod.rs index 3e5b67b5..d25d7e85 100644 --- a/src/flight_service/mod.rs +++ b/src/worker/mod.rs @@ -1,14 +1,14 @@ -mod do_action; -mod do_get; +mod execute_task; +pub(crate) mod generated; mod session_builder; +mod set_plan; mod single_write_multi_read; mod spawn_select_all; #[cfg(any(test, feature = "integration"))] pub(crate) mod test_utils; -mod worker; mod worker_connection_pool; +mod worker_service; -pub(crate) use do_action::{INIT_ACTION_TYPE, InitAction}; pub(crate) use single_write_multi_read::SingleWriteMultiRead; pub(crate) use worker_connection_pool::WorkerConnectionPool; @@ -16,6 +16,6 @@ pub use session_builder::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, WorkerQueryContext, WorkerSessionBuilder, }; -pub use worker::Worker; +pub use worker_service::Worker; -pub use do_action::TaskData; +pub use set_plan::TaskData; diff --git a/src/flight_service/session_builder.rs b/src/worker/session_builder.rs similarity index 100% rename from src/flight_service/session_builder.rs rename to src/worker/session_builder.rs diff --git a/src/flight_service/do_action.rs b/src/worker/set_plan.rs similarity index 75% rename from src/flight_service/do_action.rs rename to src/worker/set_plan.rs index 14a8cc2a..719a6954 100644 --- a/src/flight_service/do_action.rs +++ b/src/worker/set_plan.rs @@ -1,36 +1,17 @@ use crate::config_extension_ext::set_distributed_option_extension_from_headers; use crate::protobuf::DistributedCodec; -use crate::{DistributedConfig, StageKey, Worker, WorkerQueryContext}; -use arrow_flight::Action; -use arrow_flight::flight_service_server::FlightService; -use bytes::Bytes; +use crate::worker::generated::worker::{SetPlanRequest, SetPlanResponse}; +use crate::{DistributedConfig, Worker, WorkerQueryContext}; use datafusion::error::DataFusionError; use datafusion::execution::{SessionStateBuilder, TaskContext}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; -use futures::StreamExt; -use prost::Message; use std::sync::Arc; use std::sync::atomic::AtomicUsize; use tonic::{Request, Response, Status}; -pub(crate) const INIT_ACTION_TYPE: &str = "init"; - -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InitAction { - /// The ExecutionPlan we are going to execute encoded as protobuf bytes. - #[prost(bytes, tag = "1")] - pub plan_proto: Bytes, - /// The stage key that identifies the stage. This is useful to keep - /// outside of the stage proto as it is used to store the stage - /// and we may not need to deserialize the entire stage proto - /// if we already have stored it - #[prost(message, optional, tag = "2")] - pub stage_key: Option, -} - #[derive(Clone, Debug)] /// TaskData stores state for a single task being executed by this Endpoint. It may be shared /// by concurrent requests for the same task which execute separate partitions. @@ -61,15 +42,12 @@ impl TaskData { } impl Worker { - pub(super) async fn init( + pub(crate) async fn set_plan( &self, - request: Request, - ) -> Result::DoActionStream>, Status> { + request: Request, + ) -> Result, Status> { let (metadata, _ext, body) = request.into_parts(); - let init = InitAction::decode(body.body).map_err(|err| { - Status::invalid_argument(format!("Cannot decode InitAction message: {err}")) - })?; - let key = init.stage_key.ok_or_else(missing("stage_key"))?; + let key = body.stage_key.ok_or_else(missing("stage_key"))?; let entry = self .task_data_entries @@ -94,7 +72,7 @@ impl Worker { let codec = DistributedCodec::new_combined_with_user(session_state.config()); let task_ctx = session_state.task_ctx(); - let proto_node = PhysicalPlanNode::try_decode(init.plan_proto.as_ref())?; + let proto_node = PhysicalPlanNode::try_decode(body.plan_proto.as_ref())?; let mut plan = proto_node.try_into_physical_plan(&task_ctx, &codec)?; for hook in self.hooks.on_plan.iter() { @@ -115,7 +93,7 @@ impl Worker { "Logic error while setting plan for Stage key {key:?}: the plan was set twice. This is a bug in datafusion-distributed, please report it." )) })?; - Ok(Response::new(futures::stream::empty().boxed())) + Ok(Response::new(SetPlanResponse {})) } } diff --git a/src/flight_service/single_write_multi_read.rs b/src/worker/single_write_multi_read.rs similarity index 100% rename from src/flight_service/single_write_multi_read.rs rename to src/worker/single_write_multi_read.rs diff --git a/src/flight_service/spawn_select_all.rs b/src/worker/spawn_select_all.rs similarity index 100% rename from src/flight_service/spawn_select_all.rs rename to src/worker/spawn_select_all.rs diff --git a/src/flight_service/test_utils/memory_worker.rs b/src/worker/test_utils/memory_worker.rs similarity index 95% rename from src/flight_service/test_utils/memory_worker.rs rename to src/worker/test_utils/memory_worker.rs index df9aabc2..de6b87e4 100644 --- a/src/flight_service/test_utils/memory_worker.rs +++ b/src/worker/test_utils/memory_worker.rs @@ -1,7 +1,7 @@ use crate::config_extension_ext::set_distributed_option_extension; -use crate::{BoxCloneSyncChannel, DistributedConfig, DistributedExt, StageKey, TaskData, Worker}; +use crate::worker::generated::worker::StageKey; +use crate::{BoxCloneSyncChannel, DistributedConfig, DistributedExt, TaskData, Worker}; use arrow_ipc::CompressionType; -use bytes::Bytes; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::memory::MemorySourceConfig; @@ -14,7 +14,7 @@ use uuid::Uuid; pub fn test_stage_key(task_number: u64) -> StageKey { StageKey { - query_id: Bytes::from(Uuid::from_u128(0).as_bytes().to_vec()), + query_id: Uuid::from_u128(0).as_bytes().to_vec(), stage_id: 0, task_number, } @@ -97,7 +97,7 @@ impl MemoryWorker { tokio::spawn(async move { Server::builder() - .add_service(worker.into_flight_server()) + .add_service(worker.into_worker_server()) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); diff --git a/src/flight_service/test_utils/mod.rs b/src/worker/test_utils/mod.rs similarity index 100% rename from src/flight_service/test_utils/mod.rs rename to src/worker/test_utils/mod.rs diff --git a/src/worker/worker.proto b/src/worker/worker.proto new file mode 100644 index 00000000..e752311d --- /dev/null +++ b/src/worker/worker.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; +package worker; + +service WorkerService { + rpc SetPlan(SetPlanRequest) returns (SetPlanResponse); + rpc ExecuteTask(ExecuteTaskRequest) returns (stream FlightData); +} + +message SetPlanRequest { + bytes plan_proto = 1; + StageKey stage_key = 2; +} + +message SetPlanResponse { + +} + +message ExecuteTaskRequest { + uint64 target_task_index = 1; + uint64 target_task_count = 2; + uint64 target_partition_start = 3; + uint64 target_partition_end = 4; + StageKey stage_key = 5; +} + +message StageKey { + bytes query_id = 1; + uint64 stage_id = 2; + uint64 task_number = 3; +} + +// Messages from https://github.com/apache/arrow/blob/main/format/Flight.proto. These +// will match the structs shipped by the `arrow-flight` crate, so that we can use its +// tools for dealing with FlightData streams. + +// Matches arrow.flight.protocol.FlightDescriptor from Flight.proto. +// Mapped to arrow_flight::FlightDescriptor via extern_path at codegen time. +message FlightDescriptor { + enum DescriptorType { + UNKNOWN = 0; + PATH = 1; + CMD = 2; + } + + DescriptorType type = 1; + bytes cmd = 2; + repeated string path = 3; +} + +// Matches arrow.flight.protocol.FlightData from Flight.proto. +// Mapped to arrow_flight::FlightData via extern_path at codegen time. +message FlightData { + FlightDescriptor flight_descriptor = 1; + bytes data_header = 2; + bytes app_metadata = 3; + bytes data_body = 1000; +} \ No newline at end of file diff --git a/src/flight_service/worker_connection_pool.rs b/src/worker/worker_connection_pool.rs similarity index 95% rename from src/flight_service/worker_connection_pool.rs rename to src/worker/worker_connection_pool.rs index 6a456ea0..aa34741c 100644 --- a/src/flight_service/worker_connection_pool.rs +++ b/src/worker/worker_connection_pool.rs @@ -1,16 +1,15 @@ use crate::common::on_drop_stream; -use crate::flight_service::do_get::DoGet; use crate::metrics::LatencyMetricExt; use crate::networking::get_distributed_channel_resolver; use crate::passthrough_headers::get_passthrough_headers; use crate::protobuf::{ - FlightAppMetadata, StageKey, datafusion_error_to_tonic_status, map_flight_to_datafusion_error, + FlightAppMetadata, datafusion_error_to_tonic_status, map_flight_to_datafusion_error, }; -use crate::{BytesMetricExt, Stage}; +use crate::worker::generated::worker::{ExecuteTaskRequest, StageKey}; +use crate::{BytesMetricExt, ChannelResolver, Stage}; +use arrow_flight::FlightData; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; -use arrow_flight::{FlightData, Ticket}; -use bytes::Bytes; use dashmap::DashMap; use datafusion::arrow::array::RecordBatch; use datafusion::common::instant::Instant; @@ -160,23 +159,19 @@ impl WorkerConnection { // Building the actual request that will be sent to the worker. let headers = get_passthrough_headers(ctx.session_config()); - let ticket = Request::from_parts( + let request = Request::from_parts( MetadataMap::from_headers(headers), Extensions::default(), - Ticket { - ticket: DoGet { - target_partition_start: target_partition_range.start as u64, - target_partition_end: target_partition_range.end as u64, - stage_key: Some(StageKey::new( - Bytes::from(input_stage.query_id.as_bytes().to_vec()), - input_stage.num as u64, - target_task as u64, - )), - target_task_index: target_task as u64, - target_task_count: input_stage.tasks.len() as u64, - } - .encode_to_vec() - .into(), + ExecuteTaskRequest { + target_partition_start: target_partition_range.start as u64, + target_partition_end: target_partition_range.end as u64, + stage_key: Some(StageKey { + query_id: input_stage.query_id.as_bytes().to_vec(), + stage_id: input_stage.num as u64, + task_number: target_task as u64, + }), + target_task_index: target_task as u64, + target_task_count: input_stage.tasks.len() as u64, }, ); @@ -212,14 +207,14 @@ impl WorkerConnection { // fan them out to the appropriate `per_partition_rx` based on the "partition" declared // in each individual record batch flight metadata. let task = SpawnedTask::spawn(async move { - let mut client = match channel_resolver.get_flight_client_for_url(&url).await { + let mut client = match channel_resolver.get_worker_client_for_url(&url).await { Ok(v) => v, Err(err) => { return fanout(&per_partition_tx, datafusion_error_to_tonic_status(&err)); } }; - let mut interleaved_stream = match client.do_get(ticket).await { + let mut interleaved_stream = match client.execute_task(request).await { Ok(v) => v.into_inner(), Err(err) => return fanout(&per_partition_tx, err), }; diff --git a/src/flight_service/worker.rs b/src/worker/worker_service.rs similarity index 60% rename from src/flight_service/worker.rs rename to src/worker/worker_service.rs index b277b400..ac270bcb 100644 --- a/src/flight_service/worker.rs +++ b/src/worker/worker_service.rs @@ -1,22 +1,21 @@ -use crate::flight_service::WorkerSessionBuilder; -use crate::flight_service::do_action::{INIT_ACTION_TYPE, TaskData}; -use crate::flight_service::single_write_multi_read::SingleWriteMultiRead; -use crate::protobuf::StageKey; -use crate::{DefaultSessionBuilder, ObservabilityServiceImpl, ObservabilityServiceServer}; -use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; -use arrow_flight::{ - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, +use crate::worker::WorkerSessionBuilder; +use crate::worker::generated::worker::worker_service_server::{WorkerService, WorkerServiceServer}; +use crate::worker::generated::worker::{ + ExecuteTaskRequest, SetPlanRequest, SetPlanResponse, StageKey, }; +use crate::worker::set_plan::TaskData; +use crate::worker::single_write_multi_read::SingleWriteMultiRead; +use crate::{DefaultSessionBuilder, ObservabilityServiceImpl, ObservabilityServiceServer}; +use arrow_flight::FlightData; use async_trait::async_trait; use datafusion::common::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::ExecutionPlan; -use futures::stream::BoxStream; use moka::future::Cache; use std::sync::Arc; use std::time::Duration; -use tonic::{Request, Response, Status, Streaming}; +use tonic::codegen::BoxStream; +use tonic::{Request, Response, Status}; #[allow(clippy::type_complexity)] #[derive(Clone, Default)] @@ -100,9 +99,9 @@ impl Worker { self } - /// Converts this [Worker] into a [`FlightServiceServer`] with high default message size limits. + /// Converts this [Worker] into a [`WorkerServiceServer`] with high default message size limits. /// - /// This is a convenience method that wraps the endpoint in a [`FlightServiceServer`] and + /// This is a convenience method that wraps the endpoint in a [`WorkerServiceServer`] and /// configures it with `max_decoding_message_size(usize::MAX)` and /// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal /// communication. @@ -118,17 +117,17 @@ impl Worker { /// # async fn f() { /// /// let worker = Worker::default(); - /// let server = worker.into_flight_server(); + /// let server = worker.into_worker_server(); /// /// Server::builder() - /// .add_service(Worker::default().into_flight_server()) + /// .add_service(Worker::default().into_worker_server()) /// .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)) /// .await; /// /// # } /// ``` - pub fn into_flight_server(self) -> FlightServiceServer { - FlightServiceServer::new(self) + pub fn into_worker_server(self) -> WorkerServiceServer { + WorkerServiceServer::new(self) .max_decoding_message_size(usize::MAX) .max_encoding_message_size(usize::MAX) } @@ -152,93 +151,20 @@ impl Worker { } #[async_trait] -impl FlightService for Worker { - type HandshakeStream = BoxStream<'static, Result>; - - async fn handshake( +impl WorkerService for Worker { + async fn set_plan( &self, - _: Request>, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - type ListFlightsStream = BoxStream<'static, Result>; - - async fn list_flights( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - async fn get_flight_info( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - async fn poll_flight_info( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - async fn get_schema( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - type DoGetStream = BoxStream<'static, Result>; - - async fn do_get( - &self, - request: Request, - ) -> Result, Status> { - self.get(request).await - } - - type DoPutStream = BoxStream<'static, Result>; - - async fn do_put( - &self, - _: Request>, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - type DoExchangeStream = BoxStream<'static, Result>; - - async fn do_exchange( - &self, - _: Request>, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) - } - - type DoActionStream = BoxStream<'static, Result>; - - async fn do_action( - &self, - request: Request, - ) -> Result, Status> { - match request.get_ref().r#type.as_str() { - INIT_ACTION_TYPE => self.init(request).await, - v => Err(Status::unimplemented(format!( - "Action {v} not yet implemented" - ))), - } + request: Request, + ) -> Result, Status> { + self.set_plan(request).await } - type ListActionsStream = BoxStream<'static, Result>; + type ExecuteTaskStream = BoxStream; - async fn list_actions( + async fn execute_task( &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + request: Request, + ) -> Result, Status> { + self.execute_task(request).await } } From 31ac44c5e7c265a094650f0acf1e4a23a945fc3f Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Tue, 17 Mar 2026 19:31:01 +0100 Subject: [PATCH 02/11] Take more manually generated protobuf messages and generate them automatically --- src/execution_plans/network_broadcast.rs | 8 +- src/execution_plans/network_coalesce.rs | 8 +- src/execution_plans/network_shuffle.rs | 8 +- src/lib.rs | 2 +- src/metrics/proto.rs | 422 ++++--------------- src/metrics/task_metrics_collector.rs | 6 +- src/metrics/task_metrics_rewriter.rs | 12 +- src/observability/generated/observability.rs | 141 +++---- src/protobuf/app_metadata.rs | 67 --- src/protobuf/mod.rs | 2 - src/test_utils/metrics.rs | 27 +- src/worker/execute_task.rs | 29 +- src/worker/generated/worker.rs | 369 ++++++++++++---- src/worker/worker.proto | 167 +++++++- src/worker/worker_connection_pool.rs | 5 +- 15 files changed, 648 insertions(+), 625 deletions(-) delete mode 100644 src/protobuf/app_metadata.rs diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs index 0d0243b6..dfac5704 100644 --- a/src/execution_plans/network_broadcast.rs +++ b/src/execution_plans/network_broadcast.rs @@ -1,11 +1,11 @@ use crate::DistributedTaskContext; use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::AppMetadata; use crate::stage::Stage; use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; use crate::worker::generated::worker::StageKey; +use crate::worker::generated::worker::flight_app_metadata; use dashmap::DashMap; use datafusion::common::internal_datafusion_err; use datafusion::error::DataFusionError; @@ -124,7 +124,7 @@ pub struct NetworkBroadcastExec { pub(crate) properties: PlanProperties, pub(crate) input_stage: Stage, pub(crate) worker_connections: WorkerConnectionPool, - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkBroadcastExec { @@ -248,7 +248,7 @@ impl ExecutionPlan for NetworkBroadcastExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(off + partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { if let Some(stage_key) = task_metrics.stage_key { metrics_collection.insert(stage_key, task_metrics.metrics); diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 19908e33..e4c691de 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -1,11 +1,11 @@ use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; use crate::execution_plans::common::scale_partitioning_props; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::AppMetadata; use crate::stage::Stage; use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; use crate::worker::generated::worker::StageKey; +use crate::worker::generated::worker::flight_app_metadata; use crate::{DistributedTaskContext, ExecutionTask}; use dashmap::DashMap; use datafusion::common::{exec_err, plan_err}; @@ -89,7 +89,7 @@ pub struct NetworkCoalesceExec { /// the stage it is reading from. This is because, by convention, the Worker sends metrics for /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this /// instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkCoalesceExec { @@ -251,7 +251,7 @@ impl ExecutionPlan for NetworkCoalesceExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(target_partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { if let Some(stage_key) = task_metrics.stage_key { metrics_collection.insert(stage_key, task_metrics.metrics); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 2ebe6634..7b472428 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,10 +1,10 @@ use crate::common::require_one_child; use crate::execution_plans::common::scale_partitioning; -use crate::metrics::proto::MetricsSetProto; -use crate::protobuf::AppMetadata; use crate::stage::Stage; use crate::worker::WorkerConnectionPool; +use crate::worker::generated::worker as pb; use crate::worker::generated::worker::StageKey; +use crate::worker::generated::worker::flight_app_metadata; use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary}; use dashmap::DashMap; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; @@ -117,7 +117,7 @@ pub struct NetworkShuffleExec { /// the stage it is reading from. This is because, by convention, the Worker sends metrics for /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this /// instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) metrics_collection: Arc>>, } impl NetworkShuffleExec { @@ -243,7 +243,7 @@ impl ExecutionPlan for NetworkShuffleExec { let metrics_collection = Arc::clone(&self.metrics_collection); let stream = worker_connection.stream_partition(off + partition, move |meta| { - if let Some(AppMetadata::MetricsCollection(m)) = meta.content { + if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content { for task_metrics in m.tasks { if let Some(stage_key) = task_metrics.stage_key { metrics_collection.insert(stage_key, task_metrics.metrics); diff --git a/src/lib.rs b/src/lib.rs index dc8ed777..735803e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,11 +36,11 @@ pub use networking::{ BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, WorkerResolver, create_worker_client, get_distributed_channel_resolver, get_distributed_worker_resolver, }; -pub use worker::generated::worker::worker_service_client::WorkerServiceClient; pub use stage::{ DistributedTaskContext, ExecutionTask, Stage, display_plan_ascii, display_plan_graphviz, explain_analyze, }; +pub use worker::generated::worker::worker_service_client::WorkerServiceClient; pub use worker::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, Worker, WorkerQueryContext, WorkerSessionBuilder, diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index dab28627..e87d49c0 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -13,264 +13,14 @@ use super::latency_metric::{ AvgLatencyMetric, FirstLatencyMetric, MaxLatencyMetric, MinLatencyMetric, P50LatencyMetric, P75LatencyMetric, P95LatencyMetric, P99LatencyMetric, }; +use crate::worker::generated::worker as pb; -/// A MetricProto is a protobuf mirror of [datafusion::physical_plan::metrics::Metric]. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricProto { - #[prost(message, repeated, tag = "1")] - pub labels: Vec, - #[prost(uint64, optional, tag = "2")] - pub partition: Option, - #[prost( - oneof = "MetricValueProto", - tags = "10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33" - )] - // This field is *always* set. It is marked optional due to protobuf "oneof" requirements. - pub metric: Option, -} - -/// A MetricsSetProto is a protobuf mirror of [datafusion::physical_plan::metrics::MetricsSet]. It represents -/// a collection of metrics for one `ExecutionPlan` node. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricsSetProto { - #[prost(message, repeated, tag = "1")] - pub metrics: Vec, -} - -impl MetricsSetProto { - pub fn new() -> Self { - Self { - metrics: Vec::new(), - } - } - - pub fn push(&mut self, metric: MetricProto) { - self.metrics.push(metric) - } -} - -/// MetricValueProto is a protobuf mirror of the [datafusion::physical_plan::metrics::MetricValue] enum. -#[derive(Clone, PartialEq, Eq, ::prost::Oneof)] -pub enum MetricValueProto { - #[prost(message, tag = "10")] - OutputRows(OutputRows), - #[prost(message, tag = "11")] - ElapsedCompute(ElapsedCompute), - #[prost(message, tag = "12")] - SpillCount(SpillCount), - #[prost(message, tag = "13")] - SpilledBytes(SpilledBytes), - #[prost(message, tag = "14")] - SpilledRows(SpilledRows), - #[prost(message, tag = "15")] - CurrentMemoryUsage(CurrentMemoryUsage), - #[prost(message, tag = "16")] - Count(NamedCount), - #[prost(message, tag = "17")] - Gauge(NamedGauge), - #[prost(message, tag = "18")] - Time(NamedTime), - #[prost(message, tag = "19")] - StartTimestamp(StartTimestamp), - #[prost(message, tag = "20")] - EndTimestamp(EndTimestamp), - #[prost(message, tag = "21")] - OutputBytes(OutputBytes), - #[prost(message, tag = "22")] - OutputBatches(OutputBatches), - #[prost(message, tag = "23")] - PruningMetrics(NamedPruningMetrics), - #[prost(message, tag = "24")] - Ratio(NamedRatio), - #[prost(message, tag = "25")] - CustomMinLatency(MinLatency), - #[prost(message, tag = "26")] - CustomMaxLatency(MaxLatency), - #[prost(message, tag = "27")] - CustomAvgLatency(AvgLatency), - #[prost(message, tag = "28")] - CustomFirstLatency(FirstLatency), - #[prost(message, tag = "29")] - CustomBytesCount(BytesCount), - #[prost(message, tag = "30")] - CustomP50Latency(PercentileLatency), - #[prost(message, tag = "31")] - CustomP75Latency(PercentileLatency), - #[prost(message, tag = "32")] - CustomP95Latency(PercentileLatency), - #[prost(message, tag = "33")] - CustomP99Latency(PercentileLatency), -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputRows { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct ElapsedCompute { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpillCount { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpilledBytes { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct SpilledRows { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct CurrentMemoryUsage { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedCount { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedGauge { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedTime { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct StartTimestamp { - #[prost(int64, optional, tag = "1")] - pub value: Option, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct EndTimestamp { - #[prost(int64, optional, tag = "1")] - pub value: Option, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputBytes { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct OutputBatches { - #[prost(uint64, tag = "1")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedPruningMetrics { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub pruned: u64, - #[prost(uint64, tag = "3")] - pub matched: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct NamedRatio { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub part: u64, - #[prost(uint64, tag = "3")] - pub total: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct BytesCount { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct MinLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct MaxLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct AvgLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub nanos_sum: u64, - #[prost(uint64, tag = "3")] - pub count: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct FirstLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(uint64, tag = "2")] - pub value: u64, -} - -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct PercentileLatency { - #[prost(string, tag = "1")] - pub name: String, - #[prost(bytes = "vec", tag = "4")] - pub sketch_bytes: Vec, -} - -/// A ProtoLabel mirrors [datafusion::physical_plan::metrics::Label]. -#[derive(Clone, PartialEq, Eq, ::prost::Message)] -pub struct ProtoLabel { - #[prost(string, tag = "1")] - pub name: String, - #[prost(string, tag = "2")] - pub value: String, -} - -/// df_metrics_set_to_proto converts a [datafusion::physical_plan::metrics::MetricsSet] to a [MetricsSetProto]. +/// df_metrics_set_to_proto converts a [datafusion::physical_plan::metrics::MetricsSet] to a [pb::MetricsSet]. /// Custom metrics are filtered out, but any other errors are returned. /// TODO(#140): Support custom metrics. pub fn df_metrics_set_to_proto( metrics_set: &MetricsSet, -) -> Result { +) -> Result { let mut metrics = Vec::new(); for metric in metrics_set.iter() { @@ -290,12 +40,12 @@ pub fn df_metrics_set_to_proto( } } - Ok(MetricsSetProto { metrics }) + Ok(pb::MetricsSet { metrics }) } -/// metrics_set_proto_to_df converts a [MetricsSetProto] to a [datafusion::physical_plan::metrics::MetricsSet]. +/// metrics_set_proto_to_df converts a [pb::MetricsSet] to a [datafusion::physical_plan::metrics::MetricsSet]. pub fn metrics_set_proto_to_df( - metrics_set_proto: &MetricsSetProto, + metrics_set_proto: &pb::MetricsSet, ) -> Result { let mut metrics_set = MetricsSet::new(); metrics_set_proto.metrics.iter().try_for_each(|metric| { @@ -313,75 +63,75 @@ const CUSTOM_METRICS_NOT_SUPPORTED: &str = /// New DataFusion metrics that are not yet supported in proto conversion. const UNSUPPORTED_METRICS: &str = "metric type not supported in proto conversion"; -/// df_metric_to_proto converts a `datafusion::physical_plan::metrics::Metric` to a `MetricProto`. It does not consume the Arc. -pub fn df_metric_to_proto(metric: Arc) -> Result { +/// df_metric_to_proto converts a `datafusion::physical_plan::metrics::Metric` to a `pb::Metric`. It does not consume the Arc. +pub fn df_metric_to_proto(metric: Arc) -> Result { let partition = metric.partition().map(|p| p as u64); let labels = metric .labels() .iter() - .map(|label| ProtoLabel { + .map(|label| pb::Label { name: label.name().to_string(), value: label.value().to_string(), }) .collect(); match metric.value() { - MetricValue::OutputRows(rows) => Ok(MetricProto { - metric: Some(MetricValueProto::OutputRows(OutputRows { value: rows.value() as u64 })), + MetricValue::OutputRows(rows) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputRows(pb::OutputRows { value: rows.value() as u64 })), partition, labels, }), - MetricValue::ElapsedCompute(time) => Ok(MetricProto { - metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { value: time.value() as u64 })), + MetricValue::ElapsedCompute(time) => Ok(pb::Metric { + value: Some(pb::metric::Value::ElapsedCompute(pb::ElapsedCompute { value: time.value() as u64 })), partition, labels, }), - MetricValue::SpillCount(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpillCount(SpillCount { value: count.value() as u64 })), + MetricValue::SpillCount(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpillCount(pb::SpillCount { value: count.value() as u64 })), partition, labels, }), - MetricValue::SpilledBytes(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpilledBytes(SpilledBytes { value: count.value() as u64 })), + MetricValue::SpilledBytes(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpilledBytes(pb::SpilledBytes { value: count.value() as u64 })), partition, labels, }), - MetricValue::SpilledRows(count) => Ok(MetricProto { - metric: Some(MetricValueProto::SpilledRows(SpilledRows { value: count.value() as u64 })), + MetricValue::SpilledRows(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::SpilledRows(pb::SpilledRows { value: count.value() as u64 })), partition, labels, }), - MetricValue::CurrentMemoryUsage(gauge) => Ok(MetricProto { - metric: Some(MetricValueProto::CurrentMemoryUsage(CurrentMemoryUsage { value: gauge.value() as u64 })), + MetricValue::CurrentMemoryUsage(gauge) => Ok(pb::Metric { + value: Some(pb::metric::Value::CurrentMemoryUsage(pb::CurrentMemoryUsage { value: gauge.value() as u64 })), partition, labels, }), - MetricValue::Count { name, count } => Ok(MetricProto { - metric: Some(MetricValueProto::Count(NamedCount { + MetricValue::Count { name, count } => Ok(pb::Metric { + value: Some(pb::metric::Value::Count(pb::NamedCount { name: name.to_string(), value: count.value() as u64 })), partition, labels, }), - MetricValue::Gauge { name, gauge } => Ok(MetricProto { - metric: Some(MetricValueProto::Gauge(NamedGauge { + MetricValue::Gauge { name, gauge } => Ok(pb::Metric { + value: Some(pb::metric::Value::Gauge(pb::NamedGauge { name: name.to_string(), value: gauge.value() as u64 })), partition, labels, }), - MetricValue::Time { name, time } => Ok(MetricProto { - metric: Some(MetricValueProto::Time(NamedTime { + MetricValue::Time { name, time } => Ok(pb::Metric { + value: Some(pb::metric::Value::Time(pb::NamedTime { name: name.to_string(), value: time.value() as u64 })), partition, labels, }), - MetricValue::StartTimestamp(timestamp) => Ok(MetricProto { - metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + MetricValue::StartTimestamp(timestamp) => Ok(pb::Metric { + value: Some(pb::metric::Value::StartTimestamp(pb::StartTimestamp { value: match timestamp.value() { Some(dt) => Some( dt.timestamp_nanos_opt().ok_or(DataFusionError::Internal( @@ -393,8 +143,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + MetricValue::EndTimestamp(timestamp) => Ok(pb::Metric { + value: Some(pb::metric::Value::EndTimestamp(pb::EndTimestamp { value: match timestamp.value() { Some(dt) => Some( dt.timestamp_nanos_opt().ok_or(DataFusionError::Internal( @@ -408,8 +158,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result { if let Some(min) = value.as_any().downcast_ref::() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomMinLatency(MinLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomMinLatency(pb::MinLatency { name: name.to_string(), value: min.value() as u64, })), @@ -417,8 +167,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomMaxLatency(MaxLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomMaxLatency(pb::MaxLatency { name: name.to_string(), value: max.value() as u64, })), @@ -426,8 +176,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomAvgLatency(AvgLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomAvgLatency(pb::AvgLatency { name: name.to_string(), nanos_sum: avg.nanos_sum() as u64, count: avg.count() as u64, @@ -436,8 +186,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomFirstLatency(FirstLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomFirstLatency(pb::FirstLatency { name: name.to_string(), value: first.value() as u64, })), @@ -445,8 +195,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomBytesCount(BytesCount { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomBytesCount(pb::BytesCount { name: name.to_string(), value: bytes.value() as u64, })), @@ -454,8 +204,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP50Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP50Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p50.serialize_sketch()?, })), @@ -463,8 +213,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP75Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP75Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p75.serialize_sketch()?, })), @@ -472,8 +222,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP95Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP95Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p95.serialize_sketch()?, })), @@ -481,8 +231,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result() { - Ok(MetricProto { - metric: Some(MetricValueProto::CustomP99Latency(PercentileLatency { + Ok(pb::Metric { + value: Some(pb::metric::Value::CustomP99Latency(pb::PercentileLatency { name: name.to_string(), sketch_bytes: p99.serialize_sketch()?, })), @@ -493,18 +243,18 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::OutputBytes(OutputBytes { value: count.value() as u64 })), + MetricValue::OutputBytes(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputBytes(pb::OutputBytes { value: count.value() as u64 })), partition, labels, }), - MetricValue::OutputBatches(count) => Ok(MetricProto { - metric: Some(MetricValueProto::OutputBatches(OutputBatches { value: count.value() as u64 })), + MetricValue::OutputBatches(count) => Ok(pb::Metric { + value: Some(pb::metric::Value::OutputBatches(pb::OutputBatches { value: count.value() as u64 })), partition, labels, }), - MetricValue::PruningMetrics { name, pruning_metrics } => Ok(MetricProto { - metric: Some(MetricValueProto::PruningMetrics(NamedPruningMetrics { + MetricValue::PruningMetrics { name, pruning_metrics } => Ok(pb::Metric { + value: Some(pb::metric::Value::PruningMetrics(pb::NamedPruningMetrics { name: name.to_string(), pruned: pruning_metrics.pruned() as u64, matched: pruning_metrics.matched() as u64, @@ -512,8 +262,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Ok(MetricProto { - metric: Some(MetricValueProto::Ratio(NamedRatio { + MetricValue::Ratio { name, ratio_metrics } => Ok(pb::Metric { + value: Some(pb::metric::Value::Ratio(pb::NamedRatio { name: name.to_string(), part: ratio_metrics.part() as u64, total: ratio_metrics.total() as u64, @@ -524,8 +274,8 @@ pub fn df_metric_to_proto(metric: Arc) -> Result Result, DataFusionError> { +/// metric_proto_to_df converts a `pb::Metric` to a `datafusion::physical_plan::metrics::Metric`. It consumes the pb::Metric. +pub fn metric_proto_to_df(metric: pb::Metric) -> Result, DataFusionError> { let partition = metric.partition.map(|p| p as usize); let labels = metric .labels @@ -533,8 +283,8 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion .map(|proto_label| Label::new(proto_label.name, proto_label.value)) .collect(); - match metric.metric { - Some(MetricValueProto::OutputRows(rows)) => { + match metric.value { + Some(pb::metric::Value::OutputRows(rows)) => { let count = Count::new(); count.add(rows.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -543,7 +293,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::ElapsedCompute(elapsed)) => { + Some(pb::metric::Value::ElapsedCompute(elapsed)) => { let time = Time::new(); time.add_duration(std::time::Duration::from_nanos(elapsed.value)); Ok(Arc::new(Metric::new_with_labels( @@ -552,7 +302,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpillCount(spill_count)) => { + Some(pb::metric::Value::SpillCount(spill_count)) => { let count = Count::new(); count.add(spill_count.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -561,7 +311,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpilledBytes(spilled_bytes)) => { + Some(pb::metric::Value::SpilledBytes(spilled_bytes)) => { let count = Count::new(); count.add(spilled_bytes.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -570,7 +320,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::SpilledRows(spilled_rows)) => { + Some(pb::metric::Value::SpilledRows(spilled_rows)) => { let count = Count::new(); count.add(spilled_rows.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -579,7 +329,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CurrentMemoryUsage(memory)) => { + Some(pb::metric::Value::CurrentMemoryUsage(memory)) => { let gauge = Gauge::new(); gauge.set(memory.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -588,7 +338,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Count(named_count)) => { + Some(pb::metric::Value::Count(named_count)) => { let count = Count::new(); count.add(named_count.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -600,7 +350,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Gauge(named_gauge)) => { + Some(pb::metric::Value::Gauge(named_gauge)) => { let gauge = Gauge::new(); gauge.set(named_gauge.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -612,7 +362,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Time(named_time)) => { + Some(pb::metric::Value::Time(named_time)) => { let time = Time::new(); time.add_duration(std::time::Duration::from_nanos(named_time.value)); Ok(Arc::new(Metric::new_with_labels( @@ -624,7 +374,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::StartTimestamp(start_ts)) => { + Some(pb::metric::Value::StartTimestamp(start_ts)) => { let timestamp = Timestamp::new(); if let Some(value) = start_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); @@ -635,7 +385,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::EndTimestamp(end_ts)) => { + Some(pb::metric::Value::EndTimestamp(end_ts)) => { let timestamp = Timestamp::new(); if let Some(value) = end_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); @@ -646,7 +396,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::OutputBytes(output_bytes)) => { + Some(pb::metric::Value::OutputBytes(output_bytes)) => { let count = Count::new(); count.add(output_bytes.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -655,7 +405,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::OutputBatches(output_batches)) => { + Some(pb::metric::Value::OutputBatches(output_batches)) => { let count = Count::new(); count.add(output_batches.value as usize); Ok(Arc::new(Metric::new_with_labels( @@ -664,7 +414,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::PruningMetrics(named_pruning)) => { + Some(pb::metric::Value::PruningMetrics(named_pruning)) => { let pruning_metrics = DfPruningMetrics::new(); pruning_metrics.add_pruned(named_pruning.pruned as usize); pruning_metrics.add_matched(named_pruning.matched as usize); @@ -677,7 +427,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::Ratio(named_ratio)) => { + Some(pb::metric::Value::Ratio(named_ratio)) => { let ratio_metrics = RatioMetrics::new(); ratio_metrics.set_part(named_ratio.part as usize); ratio_metrics.set_total(named_ratio.total as usize); @@ -690,7 +440,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomMinLatency(min_latency)) => { + Some(pb::metric::Value::CustomMinLatency(min_latency)) => { let value = MinLatencyMetric::from_nanos(min_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -701,7 +451,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomMaxLatency(max_latency)) => { + Some(pb::metric::Value::CustomMaxLatency(max_latency)) => { let value = MaxLatencyMetric::from_nanos(max_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -712,7 +462,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomAvgLatency(avg_latency)) => { + Some(pb::metric::Value::CustomAvgLatency(avg_latency)) => { let value = AvgLatencyMetric::from_raw( avg_latency.nanos_sum as usize, avg_latency.count as usize, @@ -726,7 +476,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomFirstLatency(first_latency)) => { + Some(pb::metric::Value::CustomFirstLatency(first_latency)) => { let value = FirstLatencyMetric::from_nanos(first_latency.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -737,7 +487,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomBytesCount(bytes_count)) => { + Some(pb::metric::Value::CustomBytesCount(bytes_count)) => { let value = BytesCounterMetric::from_value(bytes_count.value as usize); Ok(Arc::new(Metric::new_with_labels( MetricValue::Custom { @@ -748,7 +498,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP50Latency(p)) => { + Some(pb::metric::Value::CustomP50Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -762,7 +512,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP75Latency(p)) => { + Some(pb::metric::Value::CustomP75Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -776,7 +526,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP95Latency(p)) => { + Some(pb::metric::Value::CustomP95Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -790,7 +540,7 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::CustomP99Latency(p)) => { + Some(pb::metric::Value::CustomP99Latency(p)) => { let sketch: DDSketch = bincode::deserialize(&p.sketch_bytes).map_err(|e| { DataFusionError::Internal(format!("failed to deserialize DDSketch: {e}")) })?; @@ -1258,8 +1008,8 @@ mod tests { let remaining_metric = &metrics_set_proto.metrics[0]; assert!(matches!( - remaining_metric.metric, - Some(MetricValueProto::OutputRows(_)) + remaining_metric.value, + Some(pb::metric::Value::OutputRows(_)) )); } diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index 5517bfce..b295db12 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -1,7 +1,7 @@ use crate::NetworkBroadcastExec; use crate::execution_plans::NetworkCoalesceExec; use crate::execution_plans::NetworkShuffleExec; -use crate::metrics::proto::MetricsSetProto; +use crate::worker::generated::worker as pb; use crate::worker::generated::worker::StageKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; @@ -23,7 +23,7 @@ pub struct TaskMetricsCollector { task_metrics: Vec, /// input_task_metrics contains metrics for tasks from child [StageExec]s if they were /// collected. - input_task_metrics: HashMap>, + input_task_metrics: HashMap>, } /// MetricsCollectorResult is the result of collecting metrics from a task. @@ -31,7 +31,7 @@ pub struct MetricsCollectorResult { // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. pub task_metrics: Vec, // input_task_metrics contains metrics for child tasks if they were collected. - pub input_task_metrics: HashMap>, + pub input_task_metrics: HashMap>, } impl TreeNodeRewriter for TaskMetricsCollector { diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index f06d7dc2..f8bd755a 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -4,8 +4,9 @@ use crate::execution_plans::MetricsWrapperExec; use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL; use crate::metrics::MetricsCollectorResult; use crate::metrics::TaskMetricsCollector; -use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; +use crate::metrics::proto::metrics_set_proto_to_df; use crate::stage::Stage; +use crate::worker::generated::worker as pb; use crate::worker::generated::worker::StageKey; use datafusion::common::HashMap; use datafusion::common::tree_node::Transformed; @@ -202,7 +203,7 @@ pub fn rewrite_local_plan_with_metrics( /// Note: Metrics may be aggregated by name (ex. output_rows) automatically by various datafusion utils. pub fn stage_metrics_rewriter( stage: &Stage, - metrics_collection: Arc>>, + metrics_collection: Arc>>, format: DistributedMetricsFormat, ) -> Result> { let mut node_idx = 0; @@ -271,9 +272,7 @@ pub fn stage_metrics_rewriter( mod tests { use crate::Stage; use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL; - use crate::metrics::proto::{ - MetricsSetProto, df_metrics_set_to_proto, metrics_set_proto_to_df, - }; + use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df}; use crate::metrics::task_metrics_rewriter::{ annotate_metrics_set_with_task_id, stage_metrics_rewriter, }; @@ -284,6 +283,7 @@ mod tests { use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary; use crate::test_utils::session_context::register_temp_parquet_table; + use crate::worker::generated::worker as pb; use crate::{DistributedExec, DistributedPhysicalOptimizerRule}; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; @@ -448,7 +448,7 @@ mod tests { num_metrics_per_task_per_node, ) }) - .collect::>(); + .collect::>(); metrics_collection.insert(stage_key, metrics); } diff --git a/src/observability/generated/observability.rs b/src/observability/generated/observability.rs index b7a07f0e..b36f99c8 100644 --- a/src/observability/generated/observability.rs +++ b/src/observability/generated/observability.rs @@ -70,10 +70,10 @@ pub mod observability_service_client { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value, + clippy::let_unit_value )] - use tonic::codegen::*; use tonic::codegen::http::Uri; + use tonic::codegen::*; #[derive(Debug, Clone)] pub struct ObservabilityServiceClient { inner: tonic::client::Grpc, @@ -112,14 +112,13 @@ pub mod observability_service_client { F: tonic::service::Interceptor, T::ResponseBody: Default, T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, + http::Request, + Response = http::Response< + >::ResponseBody, + >, >, - >, - , - >>::Error: Into + std::marker::Send + std::marker::Sync, + >>::Error: + Into + std::marker::Send + std::marker::Sync, { ObservabilityServiceClient::new(InterceptedService::new(inner, interceptor)) } @@ -158,50 +157,36 @@ pub mod observability_service_client { &mut self, request: impl tonic::IntoRequest, ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/observability.ObservabilityService/Ping", - ); + let path = + http::uri::PathAndQuery::from_static("/observability.ObservabilityService/Ping"); let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("observability.ObservabilityService", "Ping")); + req.extensions_mut().insert(GrpcMethod::new( + "observability.ObservabilityService", + "Ping", + )); self.inner.unary(req, path, codec).await } pub async fn get_task_progress( &mut self, request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/observability.ObservabilityService/GetTaskProgress", ); let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "observability.ObservabilityService", - "GetTaskProgress", - ), - ); + req.extensions_mut().insert(GrpcMethod::new( + "observability.ObservabilityService", + "GetTaskProgress", + )); self.inner.unary(req, path, codec).await } } @@ -213,7 +198,7 @@ pub mod observability_service_server { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value, + clippy::let_unit_value )] use tonic::codegen::*; /// Generated trait containing gRPC methods that should be implemented for use with ObservabilityServiceServer. @@ -226,10 +211,7 @@ pub mod observability_service_server { async fn get_task_progress( &self, request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; + ) -> std::result::Result, tonic::Status>; } #[derive(Debug)] pub struct ObservabilityServiceServer { @@ -252,10 +234,7 @@ pub mod observability_service_server { max_encoding_message_size: None, } } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService where F: tonic::service::Interceptor, { @@ -290,8 +269,7 @@ pub mod observability_service_server { self } } - impl tonic::codegen::Service> - for ObservabilityServiceServer + impl tonic::codegen::Service> for ObservabilityServiceServer where T: ObservabilityService, B: Body + std::marker::Send + 'static, @@ -311,14 +289,9 @@ pub mod observability_service_server { "/observability.ObservabilityService/Ping" => { #[allow(non_camel_case_types)] struct PingSvc(pub Arc); - impl< - T: ObservabilityService, - > tonic::server::UnaryService for PingSvc { + impl tonic::server::UnaryService for PingSvc { type Response = super::PingResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -355,25 +328,19 @@ pub mod observability_service_server { "/observability.ObservabilityService/GetTaskProgress" => { #[allow(non_camel_case_types)] struct GetTaskProgressSvc(pub Arc); - impl< - T: ObservabilityService, - > tonic::server::UnaryService - for GetTaskProgressSvc { + impl + tonic::server::UnaryService + for GetTaskProgressSvc + { type Response = super::GetTaskProgressResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - ::get_task_progress( - &inner, - request, - ) + ::get_task_progress(&inner, request) .await }; Box::pin(fut) @@ -401,25 +368,19 @@ pub mod observability_service_server { }; Box::pin(fut) } - _ => { - Box::pin(async move { - let mut response = http::Response::new( - tonic::body::Body::default(), - ); - let headers = response.headers_mut(); - headers - .insert( - tonic::Status::GRPC_STATUS, - (tonic::Code::Unimplemented as i32).into(), - ); - headers - .insert( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ); - Ok(response) - }) - } + _ => Box::pin(async move { + let mut response = http::Response::new(tonic::body::Body::default()); + let headers = response.headers_mut(); + headers.insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers.insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }), } } } diff --git a/src/protobuf/app_metadata.rs b/src/protobuf/app_metadata.rs deleted file mode 100644 index 82077265..00000000 --- a/src/protobuf/app_metadata.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::metrics::proto::MetricsSetProto; -use crate::worker::generated::worker::StageKey; -use std::time::{SystemTime, UNIX_EPOCH}; - -/// A collection of metrics for a set of tasks in an ExecutionPlan. each -/// entry should have a distinct [StageKey]. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MetricsCollection { - #[prost(message, repeated, tag = "1")] - pub tasks: Vec, -} - -/// TaskMetrics represents the metrics for a single task. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TaskMetrics { - /// stage_key uniquely identifies this task. - /// - /// This field is always present. It's marked optional due to protobuf rules. - #[prost(message, optional, tag = "1")] - pub stage_key: Option, - /// metrics[i] is the set of metrics for plan node `i` where plan nodes are in pre-order - /// traversal order. - #[prost(message, repeated, tag = "2")] - pub metrics: Vec, -} - -// FlightAppMetadata represents all types of app_metadata which we use in the distributed execution. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FlightAppMetadata { - #[prost(uint64, tag = "1")] - pub partition: u64, - // Unix timestamp in nanoseconds at which this message was created. - #[prost(uint64, tag = "2")] - pub created_timestamp_unix_nanos: u64, - // content should always be Some, but it is optional due to protobuf rules. - #[prost(oneof = "AppMetadata", tags = "10")] - pub content: Option, -} - -impl FlightAppMetadata { - pub fn new(partition: u64) -> Self { - Self { - partition, - created_timestamp_unix_nanos: current_unix_timestamp_nanos(), - content: None, - } - } - - pub fn set_content(&mut self, content: AppMetadata) { - self.content = Some(content); - } -} - -fn current_unix_timestamp_nanos() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_nanos() as u64) - .unwrap_or(0) -} - -#[derive(Clone, PartialEq, ::prost::Oneof)] -pub enum AppMetadata { - #[prost(message, tag = "10")] - MetricsCollection(MetricsCollection), - // Note: For every additional enum variant, ensure to add tags to [FlightAppMetadata]. ex. `#[prost(oneof = "AppMetadata", tags = "1,2,3")]` etc. - // If you don't the proto will compile but you may encounter errors during serialization/deserialization. -} diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index 9c908a97..4518dbb8 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -1,9 +1,7 @@ -mod app_metadata; mod distributed_codec; mod errors; mod user_codec; -pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; pub(crate) use distributed_codec::DistributedCodec; pub(crate) use errors::{ datafusion_error_to_tonic_status, map_flight_to_datafusion_error, diff --git a/src/test_utils/metrics.rs b/src/test_utils/metrics.rs index a768379d..e744f1f1 100644 --- a/src/test_utils/metrics.rs +++ b/src/test_utils/metrics.rs @@ -1,35 +1,36 @@ -use crate::metrics::proto::{ElapsedCompute, EndTimestamp, OutputRows, StartTimestamp}; -use crate::metrics::proto::{MetricProto, MetricValueProto, MetricsSetProto}; +use crate::worker::generated::worker as pb; /// creates a "distinct" set of metrics from the provided seed -pub fn make_test_metrics_set_proto_from_seed(seed: u64, num_metrics: usize) -> MetricsSetProto { +pub fn make_test_metrics_set_proto_from_seed(seed: u64, num_metrics: usize) -> pb::MetricsSet { const TEST_TIMESTAMP: i64 = 1758200400000000000; // 2025-09-18 13:00:00 UTC - let mut result = MetricsSetProto { metrics: vec![] }; + let mut result = pb::MetricsSet { metrics: vec![] }; for i in 0..num_metrics { let value = seed + i as u64; - result.push(match i % 4 { - 0 => MetricProto { - metric: Some(MetricValueProto::OutputRows(OutputRows { value })), + result.metrics.push(match i % 4 { + 0 => pb::Metric { + value: Some(pb::metric::Value::OutputRows(pb::OutputRows { value })), labels: vec![], partition: None, }, - 1 => MetricProto { - metric: Some(MetricValueProto::ElapsedCompute(ElapsedCompute { value })), + 1 => pb::Metric { + value: Some(pb::metric::Value::ElapsedCompute(pb::ElapsedCompute { + value, + })), labels: vec![], partition: None, }, - 2 => MetricProto { - metric: Some(MetricValueProto::StartTimestamp(StartTimestamp { + 2 => pb::Metric { + value: Some(pb::metric::Value::StartTimestamp(pb::StartTimestamp { value: Some(TEST_TIMESTAMP + (value as i64 * 1_000_000_000)), })), labels: vec![], partition: None, }, - 3 => MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { + 3 => pb::Metric { + value: Some(pb::metric::Value::EndTimestamp(pb::EndTimestamp { value: Some(TEST_TIMESTAMP + (value as i64 * 1_000_000_000)), })), labels: vec![], diff --git a/src/worker/execute_task.rs b/src/worker/execute_task.rs index b014f75f..a8027d76 100644 --- a/src/worker/execute_task.rs +++ b/src/worker/execute_task.rs @@ -1,9 +1,9 @@ use crate::common::{map_last_stream, on_drop_stream, task_ctx_with_extension}; use crate::metrics::TaskMetricsCollector; use crate::metrics::proto::df_metrics_set_to_proto; -use crate::protobuf::{ - AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics, - datafusion_error_to_tonic_status, +use crate::protobuf::datafusion_error_to_tonic_status; +use crate::worker::generated::worker::{ + FlightAppMetadata, MetricsCollection, TaskMetrics, flight_app_metadata, }; use crate::worker::worker_service::Worker; use crate::{DistributedConfig, DistributedTaskContext}; @@ -25,7 +25,7 @@ use futures::TryStreamExt; use prost::Message; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::Duration; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tonic::{Request, Response, Status}; /// How many record batches to buffer from the plan execution. @@ -108,7 +108,14 @@ impl Worker { // partition. This stream will be merged with several others from other partitions, // so marking it with the original partition allows it to be deconstructed into // the original per-partition streams in later steps. - let mut flight_data = FlightAppMetadata::new(partition); + let mut flight_data = FlightAppMetadata { + partition, + created_timestamp_unix_nanos: SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_nanos() as u64) + .unwrap_or(0), + content: None, + }; if last_msg_in_stream { // If it's the last message from the last partition, clean up the entry from @@ -122,7 +129,7 @@ impl Worker { if send_metrics { // Last message of the last partition. This is the moment to send // the metrics back. - flight_data.set_content(collect_and_create_metrics_flight_data( + flight_data.content = Some(collect_and_create_metrics_flight_data( key.clone(), plan.clone(), )?); @@ -216,7 +223,7 @@ fn missing(field: &'static str) -> impl FnOnce() -> Status { fn collect_and_create_metrics_flight_data( stage_key: StageKey, plan: Arc, -) -> Result { +) -> Result { // Get the metrics for the task executed on this worker + child tasks. let mut result = TaskMetricsCollector::new() .collect(plan) @@ -245,9 +252,11 @@ fn collect_and_create_metrics_flight_data( }); } - Ok(AppMetadata::MetricsCollection(MetricsCollection { - tasks: task_metrics_set, - })) + Ok(flight_app_metadata::Content::MetricsCollection( + MetricsCollection { + tasks: task_metrics_set, + }, + )) } /// Garbage collects values sub-arrays. diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index 57a78615..f1399e21 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -30,6 +30,252 @@ pub struct StageKey { #[prost(uint64, tag = "3")] pub task_number: u64, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FlightAppMetadata { + #[prost(uint64, tag = "1")] + pub partition: u64, + #[prost(uint64, tag = "2")] + pub created_timestamp_unix_nanos: u64, + #[prost(oneof = "flight_app_metadata::Content", tags = "10")] + pub content: ::core::option::Option, +} +/// Nested message and enum types in `FlightAppMetadata`. +pub mod flight_app_metadata { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Content { + #[prost(message, tag = "10")] + MetricsCollection(super::MetricsCollection), + } +} +/// A collection of metrics for a set of tasks in an ExecutionPlan. +/// Each entry should have a distinct StageKey. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MetricsCollection { + #[prost(message, repeated, tag = "1")] + pub tasks: ::prost::alloc::vec::Vec, +} +/// Metrics for a single task. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TaskMetrics { + #[prost(message, optional, tag = "1")] + pub stage_key: ::core::option::Option, + /// metrics\[i\] is the set of metrics for plan node i where plan nodes are + /// in pre-order traversal order. + #[prost(message, repeated, tag = "2")] + pub metrics: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Label { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Metric { + #[prost(message, repeated, tag = "1")] + pub labels: ::prost::alloc::vec::Vec