diff --git a/.github/trigger_files/beam_PostCommit_Python_Versions.json b/.github/trigger_files/beam_PostCommit_Python_Versions.json index 541dc4ea8e87..8ed972c9f579 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Versions.json +++ b/.github/trigger_files/beam_PostCommit_Python_Versions.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "revision": 2 + "revision": 3 } diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index c89974cb6ea5..9357515f36c2 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -185,6 +185,7 @@ def sickbayTests = [ def createPrismValidatesRunnerTask = { name, environmentType -> Task vrTask = tasks.create(name: name, type: Test, group: "Verification") { description "PrismRunner Java $environmentType ValidatesRunner suite" + outputs.upToDateWhen { false } classpath = configurations.validatesRunner var prismBuildTask = dependsOn(':runners:prism:build') diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index ab93dcba51b1..e65de99f6197 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -136,9 +136,13 @@ func (m *DataChannelManager) Open(ctx context.Context, port exec.Port) (*DataCha default: log.Warnf(ctx, "forcing DataChannel[%v] reconnection on port %v due to %v", id, port, err) } - m.mu.Lock() - delete(m.ports, port.URL) - m.mu.Unlock() + go func() { + m.mu.Lock() + defer m.mu.Unlock() + if curr, ok := m.ports[port.URL]; ok && curr == ch { + delete(m.ports, port.URL) + } + }() } m.ports[port.URL] = ch return ch, nil diff --git a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go index 269ded372998..f1671d56b748 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go @@ -619,9 +619,13 @@ func (m *StateChannelManager) Open(ctx context.Context, port exec.Port) (*StateC default: log.Warnf(ctx, "forcing StateChannel[%v] reconnection on port %v due to %v", id, port, err) } - m.mu.Lock() - delete(m.ports, port.URL) - m.mu.Unlock() + go func() { + m.mu.Lock() + defer m.mu.Unlock() + if curr, ok := m.ports[port.URL]; ok && curr == ch { + delete(m.ports, port.URL) + } + }() } m.ports[port.URL] = ch return ch, nil diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index f720be20e375..862f596feea4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -240,8 +240,10 @@ type ElementManager struct { } func (em *ElementManager) addPending(v int) { + prev := em.livePending.Load() em.livePending.Add(int64(v)) em.pendingElements.Add(v) + slog.Info("em.addPending", "delta", v, "prev", prev, "current", em.livePending.Load()) } // LinkID represents a fully qualified input or output. @@ -530,7 +532,7 @@ func (em *ElementManager) DumpStages() string { stageState = append(stageState, fmt.Sprintf("TestStreamHandler: completed %v, curIndex %v of %v events: %+v, processingTime %v, %v, ptEvents %v \n", em.testStreamHandler.completed, em.testStreamHandler.nextEventIndex, len(em.testStreamHandler.events), em.testStreamHandler.events, em.testStreamHandler.processingTime, mtime.FromTime(em.testStreamHandler.processingTime), em.processTimeEvents)) } else { - stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v\n", em.processingTimeNow(), em.processTimeEvents.events, em.injectedBundles)) + stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v livePending: %v\n", em.processingTimeNow(), em.processTimeEvents.events, em.injectedBundles, em.livePending.Load())) } sort.Strings(ids) for _, id := range ids { @@ -1091,18 +1093,25 @@ func (em *ElementManager) FailBundle(rb RunBundle) { em.markChangedAndClearBundle(rb.StageID, rb.BundleID, nil) } -// ReturnResiduals is called after a successful split, so the remaining work -// can be re-assigned to a new bundle. func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, inputInfo PColInfo, residuals Residuals) { stage := em.stages[rb.StageID] + slog.Info("ElementManager.ReturnResiduals start", "bundle", rb, "firstRsIndex", firstRsIndex) + + stage.mu.Lock() + completed := stage.inprogress[rb.BundleID] + originalRemainingCount := len(completed.es) - firstRsIndex + stage.mu.Unlock() + stage.splitBundle(rb, firstRsIndex, em) unprocessedElements := reElementResiduals(residuals.Data, inputInfo, rb) - if len(unprocessedElements) > 0 { - slog.Debug("ReturnResiduals: unprocessed elements", "bundle", rb, "count", len(unprocessedElements)) - count := stage.AddPending(em, unprocessedElements) + if len(unprocessedElements) > originalRemainingCount { + newResiduals := unprocessedElements[originalRemainingCount:] + slog.Info("ReturnResiduals: new residuals added back", "bundle", rb, "count", len(newResiduals)) + count := stage.AddPending(em, newResiduals) em.addPending(count) } + slog.Info("ElementManager.ReturnResiduals end", "bundle", rb, "unprocessedCount", len(unprocessedElements), "livePending", em.livePending.Load()) em.markStagesAsChanged(singleSet(rb.StageID)) } @@ -2187,7 +2196,7 @@ func (ss *stageState) splitBundle(rb RunBundle, firstResidual int, em *ElementMa defer ss.mu.Unlock() es := ss.inprogress[rb.BundleID] - slog.Debug("split elements", "bundle", rb, "elem count", len(es.es), "res", firstResidual) + slog.Info("splitBundle start", "bundle", rb, "elem count", len(es.es), "firstResidual", firstResidual, "livePending", em.livePending.Load()) prim := es.es[:firstResidual] res := es.es[firstResidual:] @@ -2207,6 +2216,7 @@ func (ss *stageState) splitBundle(rb RunBundle, firstResidual int, em *ElementMa // we don't need to increment pending count in em, since it is already pending ss.kind.addPending(ss, em, res) ss.inprogress[rb.BundleID] = es + slog.Info("splitBundle completed", "bundle", rb, "primaryCount", len(prim), "residualCount", len(res), "livePending", em.livePending.Load()) } // minimumPendingTimestamp returns the minimum pending timestamp from all pending elements, diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go index 0d7da5ea163f..f7376915de88 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go @@ -684,3 +684,73 @@ func TestElementManager_OnWindowExpiration(t *testing.T) { validateSideBundles(t, singleSet("\u0004key5")) // still exist.. }) } + +func TestElementManager_ReturnResidualsPendingCount(t *testing.T) { + tests := []struct { + name string + firstRsIndex int + wantFinalPending int64 + }{ + { + name: "ChannelSplit", + firstRsIndex: 0, + wantFinalPending: 1, + }, + { + name: "SDFCheckpoint", + firstRsIndex: 1, + wantFinalPending: 2, // Incremented by 1 because the active portion (index 0) is still in progress and will be completed/decremented in PersistBundle. + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + em := NewElementManager(Config{}) + em.AddStage("impulse", nil, []string{"input"}, nil) + em.AddStage("dofn", []string{"input"}, nil, nil) + em.Impulse("impulse") + + stage := em.stages["dofn"] + info := PColInfo{ + GlobalID: "generic_info", + WDec: exec.MakeWindowDecoder(coder.NewGlobalWindow()), + WEnc: exec.MakeWindowEncoder(coder.NewGlobalWindow()), + EDec: func(r io.Reader) []byte { + b, _ := io.ReadAll(r) + return b + }, + } + + // Initial state should have 1 pending element from impulse + if got, want := em.livePending.Load(), int64(1); got != want { + t.Fatalf("initial livePending = %v, want %v", got, want) + } + + // Start a bundle + bundID, ok, _, _ := stage.startEventTimeBundle(mtime.MaxTimestamp, func() string { return "inst0" }) + if !ok { + t.Fatalf("failed to start bundle") + } + + // Waitgroup/livePending shouldn't change on starting a bundle (it's still pending) + if got, want := em.livePending.Load(), int64(1); got != want { + t.Fatalf("livePending after startEventTimeBundle = %v, want %v", got, want) + } + + // Prepare residuals + residBytes := []byte{127, 223, 59, 100, 90, 28, 172, 9, 0, 0, 0, 1, 15, 3, 65, 66, 67} // windowed value header + ABC + residuals := Residuals{ + Data: []Residual{{Element: residBytes}}, + } + + rb := RunBundle{StageID: "dofn", BundleID: bundID} + + // Return residuals (Simulates splitting) + em.ReturnResiduals(rb, test.firstRsIndex, info, residuals) + + if got, want := em.livePending.Load(), test.wantFinalPending; got != want { + t.Errorf("livePending after ReturnResiduals = %v, want %v", got, want) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index 1f852e0862f1..59dea7c71cd2 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -160,9 +160,12 @@ func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *wo // Previous context cancelled so we need a new one // for this request. - pool.StopWorker(bgContext, &fnpb.StopWorkerRequest{ + _, err = pool.StopWorker(bgContext, &fnpb.StopWorkerRequest{ WorkerId: wk.ID, }) + if err != nil { + slog.Warn("StopWorker failed", "worker", wk, "error", err) + } wk.Stop() } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index f6e148f9f3f6..7bc63646fdb7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -391,6 +391,16 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic // Log a heartbeat every 60 seconds case <-ticker.C: j.Logger.Info("pipeline is running", slog.String("job", j.String())) + j.Logger.Info("pipeline stages state", slog.String("stages", em.DumpStages())) + for envID, wk := range wks { + if wk != nil && wk.Connected() && !wk.Stopped() { + j.Logger.Info("worker status", + slog.String("workerID", wk.ID), + slog.String("envID", envID), + slog.Duration("uptime", wk.Uptime()), + slog.Any("active_bundles", wk.ActiveBundles())) + } + } } } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index 3ac0d98850df..56e57bed336c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -210,7 +210,7 @@ func (h *runner) handleReshuffle(tid string, t *pipepb.PTransform, comps *pipepb } // And all the sub transforms. - toRemove = append(toRemove, t.GetSubtransforms()...) + toRemove = append(toRemove, removeSubTransforms(comps, t.GetSubtransforms())...) // Return the new components which is the transforms consumer return prepareResult{ diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index b46c9c2fd5b1..8ebeb76f9e54 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -184,6 +184,8 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c panic(err) } + bundleStart := time.Now() + // Progress + split loop. previousIndex := int64(-2) previousTotalCount := int64(-2) // Total count of all pcollection elements. @@ -232,7 +234,11 @@ progress: md := wk.MonitoringMetadata(ctx, unknownIDs) j.AddMetricShortIDs(md) } - slog.Debug("progress report", "bundle", rb, "index", index, "prevIndex", previousIndex) + runningFor := time.Since(bundleStart) + slog.Debug("progress report", "bundle", rb, "runningFor", runningFor, "index", index, "prevIndex", previousIndex) + if runningFor > 5*time.Minute { + slog.Warn("Bundle has been running for a long time", "bundle", rb, "runningFor", runningFor, "worker", wk.ID) + } // Check if there has been any measurable progress by the input, or all output pcollections since last report. slow := previousIndex == index["index"] && previousTotalCount == index["totalCount"] diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 33c8c3a7de5f..6a0ede344f5e 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -70,6 +70,7 @@ type W struct { // These are the ID sources inst uint64 connected, stopped atomic.Bool + StartTime time.Time StoppedChan chan struct{} // Channel to Broadcast stopped state. InstReqs chan *fnpb.InstructionRequest @@ -292,11 +293,37 @@ func (wk *W) Stopped() bool { return wk.stopped.Load() } +// Uptime returns how long the worker has been connected. +func (wk *W) Uptime() time.Duration { + wk.mu.Lock() + defer wk.mu.Unlock() + if wk.StartTime.IsZero() { + return 0 + } + return time.Since(wk.StartTime) +} + +// ActiveBundles returns a list of active bundles currently processing on this worker. +func (wk *W) ActiveBundles() []string { + wk.mu.Lock() + defer wk.mu.Unlock() + var bundles []string + for id, responder := range wk.activeInstructions { + if b, ok := responder.(*B); ok { + bundles = append(bundles, fmt.Sprintf("%s (%s)", id, b.PBDID)) + } + } + return bundles +} + // Control relays instructions to SDKs and back again, coordinated via unique instructionIDs. // // Requests come from the runner, and are sent to the client in the SDK. func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { wk.connected.Store(true) + wk.mu.Lock() + wk.StartTime = time.Now() + wk.mu.Unlock() done := make(chan error, 1) go func() { for { diff --git a/sdks/python/apache_beam/runners/portability/expansion_service_main.py b/sdks/python/apache_beam/runners/portability/expansion_service_main.py index f2d03e0e898c..269d02b3efbd 100644 --- a/sdks/python/apache_beam/runners/portability/expansion_service_main.py +++ b/sdks/python/apache_beam/runners/portability/expansion_service_main.py @@ -55,9 +55,7 @@ def main(argv): with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter( known_args.fully_qualified_name_glob): - # Bind to localhost instead of 0.0.0.0 to ensure compatibility with loopback - # connections on dual-stack (IPv4/IPv6) systems. - address = 'localhost:{}'.format(known_args.port) + address = '0.0.0.0:{}'.format(known_args.port) server = grpc.server(thread_pool_executor.shared_unbounded_instance()) if known_args.serve_loopback_worker: beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server( @@ -73,15 +71,9 @@ def main(argv): artifact_service.ArtifactRetrievalService( artifact_service.BeamFilesystemHandler(None).file_reader), server) - # Ensure gRPC server successfully binds. If this fails (e.g., due to port collision), - # add_insecure_port returns 0. We raise an error to crash the subprocess immediately, - # allowing the parent process to detect it and fail fast rather than hanging. - bound_port = server.add_insecure_port(address) - if not bound_port: - raise RuntimeError( - "Failed to bind expansion service to {}".format(address)) + server.add_insecure_port(address) server.start() - _LOGGER.info('Listening for expansion requests at %d', bound_port) + _LOGGER.info('Listening for expansion requests at %d', known_args.port) def cleanup(unused_signum, unused_frame): _LOGGER.info('Shutting down expansion service.') diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index e862fde4efef..8fab55410fd5 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -732,9 +732,11 @@ def _get_platform_for_default_sdk_container(): # addressed, download wheel based on glibc version in Beam's Python # Base image pip_version = distribution('pip').version - if version.parse(pip_version) >= version.parse('19.3'): - # pip can only recognize manylinux2014_x86_64 wheels - # from version 19.3. + # See more information about manylinux at + # https://github.com/pypa/manylinux + if version.parse(pip_version) >= version.parse('20.3'): + return 'manylinux_2_28_x86_64' + elif version.parse(pip_version) >= version.parse('19.3'): return 'manylinux2014_x86_64' else: return 'manylinux2010_x86_64' @@ -795,7 +797,7 @@ def _populate_requirements_cache( platform_tag ]) _LOGGER.info('Executing command: %s', cmd_args) - processes.check_output(cmd_args, stderr=processes.STDOUT) + processes.check_call(cmd_args) # Get list of downloaded packages and copy them to the cache downloaded_packages = set() diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 3d625fb287ae..e04f4ad716ee 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -985,13 +985,13 @@ def test_populate_requirements_cache_uses_find_links(self): captured_cmd_args = [] - def mock_check_output(cmd_args, **kwargs): + def mock_check_call(cmd_args, **kwargs): captured_cmd_args.extend(cmd_args) - return b'' + return 0 with mock.patch( - 'apache_beam.runners.portability.stager.processes.check_output', - side_effect=mock_check_output): + 'apache_beam.runners.portability.stager.processes.check_call', + side_effect=mock_check_call): stager.Stager._populate_requirements_cache( requirements_file, requirements_cache_dir) diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index b22e6badb5e7..00c71125e5d8 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -72,7 +72,7 @@ class _SharedCache: def __init__(self, constructor, destructor): self._constructor = constructor self._destructor = destructor - self._live_owners = set() + self._live_owners = {} self._cache = {} self._lock = threading.Lock() self._counter = 0 @@ -82,10 +82,10 @@ def _next_id(self): self._counter += 1 return self._counter - def register(self): + def register(self, is_context=False): with self._lock: owner = self._next_id() - self._live_owners.add(owner) + self._live_owners[owner] = is_context return owner def purge(self, owner): @@ -97,7 +97,7 @@ def purge(self, owner): "shutdown, the subprocess was already cleaned up earlier.", owner) return - self._live_owners.remove(owner) + del self._live_owners[owner] for key, entry in list(self._cache.items()): if owner in entry.owners: entry.owners.remove(owner) @@ -108,16 +108,30 @@ def purge(self, owner): for value in to_delete: self._destructor(value) - def get(self, *key): + def get(self, *key, owner=None): if not self._live_owners: raise RuntimeError("At least one owner must be registered.") with self._lock: if key not in self._cache: self._cache[key] = _SharedCacheEntry(self._constructor(*key), set()) - for owner in self._live_owners: + if owner is not None: + if owner not in self._live_owners: + raise RuntimeError("The requesting owner must be registered.") self._cache[key].owners.add(owner) + for live_owner, is_context in self._live_owners.items(): + if is_context: + self._cache[key].owners.add(live_owner) + else: + for live_owner in self._live_owners: + self._cache[key].owners.add(live_owner) return self._cache[key].obj + def force_remove(self, *key): + with self._lock: + entry = self._cache.pop(key, None) + if entry is not None: + self._destructor(entry.obj) + class JavaHelper: @classmethod @@ -174,7 +188,7 @@ def cache_subprocesses(cls): These subprocesses may be shared with other contexts as well. """ try: - unique_id = cls._cache.register() + unique_id = cls._cache.register(is_context=True) yield finally: cls._cache.purge(unique_id) @@ -186,66 +200,59 @@ def __exit__(self, *unused_args): self.stop() def start(self): - max_attempts = 3 - for attempt in range(max_attempts): - try: - process, endpoint = self.start_process() - wait_secs = .1 - channel_options = [ - ("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1), - # Default: 20000ms (20s), increased to 10 minutes for stability - ("grpc.keepalive_timeout_ms", 600_000), - # Default: 2, set to 0 to allow unlimited pings without data - ("grpc.http2.max_pings_without_data", 0), - # Default: False, set to True to allow keepalive pings when no calls - ("grpc.keepalive_permit_without_calls", True), - # Default: 2, set to 0 to allow unlimited ping strikes - ("grpc.http2.max_ping_strikes", 0), - # Default: 0 (disabled), enable socket reuse for better handling - ("grpc.so_reuseport", 1), - ] - self._grpc_channel = grpc.insecure_channel( - endpoint, options=channel_options) - channel_ready = grpc.channel_ready_future(self._grpc_channel) - while True: - if process is not None and process.poll() is not None: - _LOGGER.error("Started job service with %s", process.args) - raise RuntimeError( - 'Service failed to start up with error %s' % process.poll()) - try: - channel_ready.result(timeout=wait_secs) - break - except (grpc.FutureTimeoutError, grpc.RpcError): - wait_secs *= 1.2 - logging.log( - logging.WARNING if wait_secs > 1 else logging.DEBUG, - 'Waiting for grpc channel to be ready at %s.', - endpoint) - return self._stub_class(self._grpc_channel) - except Exception as e: - _LOGGER.warning( - "Error bringing up service on attempt %d: %s", - attempt + 1, - e, - exc_info=True) - self.stop() - if attempt == max_attempts - 1: - raise - time.sleep(1) + try: + process, endpoint = self.start_process() + wait_secs = .1 + channel_options = [ + ("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1), + # Default: 20000ms (20s), increased to 10 minutes for stability + ("grpc.keepalive_timeout_ms", 600_000), + # Default: 2, set to 0 to allow unlimited pings without data + ("grpc.http2.max_pings_without_data", 0), + # Default: False, set to True to allow keepalive pings when no calls + ("grpc.keepalive_permit_without_calls", True), + # Default: 2, set to 0 to allow unlimited ping strikes + ("grpc.http2.max_ping_strikes", 0), + # Default: 0 (disabled), enable socket reuse for better handling + ("grpc.so_reuseport", 1), + ] + self._grpc_channel = grpc.insecure_channel( + endpoint, options=channel_options) + channel_ready = grpc.channel_ready_future(self._grpc_channel) + while True: + if process is not None and process.poll() is not None: + _LOGGER.error("Failed to start job service with %s", process.args) + raise RuntimeError( + 'Service failed to start up with error %s' % process.poll()) + try: + channel_ready.result(timeout=wait_secs) + break + except (grpc.FutureTimeoutError, grpc.RpcError): + wait_secs *= 1.2 + logging.log( + logging.WARNING if wait_secs > 1 else logging.DEBUG, + 'Waiting for grpc channel to be ready at %s.', + endpoint) + return self._stub_class(self._grpc_channel) + except: # pylint: disable=bare-except + _LOGGER.exception("Error bringing up service") + self.stop_force() + raise def start_process(self): if self._owner_id is not None: self._cache.purge(self._owner_id) - self._owner_id = self._cache.register() - return self._cache.get(tuple(self._cmd), self._port, self._logger) + self._owner_id = self._cache.register(is_context=False) + return self._cache.get( + tuple(self._cmd), self._port, self._logger, owner=self._owner_id) def _really_start_process(cmd, port, logger): if not port: port, = pick_port(None) cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable endpoint = 'localhost:%s' % port - _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'")) + _LOGGER.warning("Really starting service at %s with cmd: %s", endpoint, cmd) process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) @@ -255,7 +262,7 @@ def log_stdout(): while line: # The log obtained from stdout is bytes, decode it into string. # Remove newline via rstrip() to not print an empty line. - logger.info(line.decode(errors='backslashreplace').rstrip()) + _LOGGER.warning(line.decode(errors='backslashreplace').rstrip()) line = process.stdout.readline() t = threading.Thread(target=log_stdout) @@ -282,10 +289,26 @@ def stop_process(self): finally: self._grpc_channel = None + def stop_force(self): + try: + self._cache.force_remove(tuple(self._cmd), self._port, self._logger) + finally: + self._owner_id = None + if self._grpc_channel: + try: + self._grpc_channel.close() + except: # pylint: disable=bare-except + _LOGGER.error( + "Could not close the gRPC channel started with cmd %s", self._cmd) + finally: + self._grpc_channel = None + def _really_stop_process(process_and_endpoint): - process, _ = process_and_endpoint # pylint: disable=unpacking-non-sequence + process, endpoint = process_and_endpoint # pylint: disable=unpacking-non-sequence if not process: return + _LOGGER.warning( + "Really destroying service at %s with cmd: %s", endpoint, process.args) for _ in range(5): if process.poll() is not None: break diff --git a/sdks/python/apache_beam/utils/subprocess_server_test.py b/sdks/python/apache_beam/utils/subprocess_server_test.py index 073b8b3bcbe8..a008ae05c52d 100644 --- a/sdks/python/apache_beam/utils/subprocess_server_test.py +++ b/sdks/python/apache_beam/utils/subprocess_server_test.py @@ -402,16 +402,16 @@ def mock_unregister(cb): self.assertEqual(len(registered_callbacks), 1) def test_concurrent_purge_race_condition(self): - # Concurrent threads attempting to check memebership and call purge for the same owner. - # Here we explicitly define a synchronized set to mimic the behavior of _live_owners. - # This set will block two threads on __contains__, allowing us to test the race condition. + # Concurrent threads attempting to check membership and call purge for the same owner. + # Here we explicitly define a synchronized dict to mimic the behavior of _live_owners. + # This dict will block two threads on __contains__, allowing us to test the race condition. cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None) owner = cache.register() barrier = threading.Barrier(2) exceptions = [] - class SynchronizedSet(set): + class SynchronizedDict(dict): def __contains__(self, item): res = super().__contains__(item) try: @@ -421,7 +421,7 @@ def __contains__(self, item): pass return res - cache._live_owners = SynchronizedSet(cache._live_owners) + cache._live_owners = SynchronizedDict(cache._live_owners) def purge_worker(): try: @@ -464,6 +464,140 @@ def __init__(self): # without raising ValueError. server.stop_process() + def test_force_remove(self): + destructor_calls = [] + + def custom_destructor(obj): + destructor_calls.append(obj) + + cache = subprocess_server._SharedCache(self.with_prefix, custom_destructor) + + owner1 = cache.register() + owner2 = cache.register() + + # Get object 'a' under both active owners + a = cache.get('a') + self.assertEqual(a[0], 'a') + self.assertIn(('a', ), cache._cache) + + # force_remove on a non-existent key should be a safe no-op + cache.force_remove('non_existent') + + # Call force_remove, which should bypass the owners check and delete it immediately + cache.force_remove('a') + + # The cache entry should be gone + self.assertNotIn(('a', ), cache._cache) + + # Destructor must be called on 'a' + self.assertEqual(destructor_calls, [a]) + + # Retrieving 'a' again under the active owners should construct a new object + new_a = cache.get('a') + self.assertNotEqual(new_a, a) + self.assertEqual(new_a[0], 'a') + + # Clean up + cache.purge(owner1) + cache.purge(owner2) + + def test_subprocess_server_start_failed_no_leak(self): + destructor_calls = [] + + def custom_destructor(obj): + destructor_calls.append(obj) + + class DummyProcess: + def __init__(self): + self.args = ["dummy_cmd"] + + def poll(self): + return 1 # Simulate that process exited/failed + + dummy_process = DummyProcess() + cache = subprocess_server._SharedCache( + lambda *args: (dummy_process, "localhost:12345"), custom_destructor) + + # 1. Register an independent, unrelated owner in the cache first. + other_owner = cache.register() + + class CustomServer(subprocess_server.SubprocessServer): + _cache = cache + + def __init__(self): + super().__init__(lambda channel: None, ["dummy_cmd"], port=12345) + + server = CustomServer() + # Fetch the process using other_owner, creating the cache entry and registering other_owner on it. + cache.get(tuple(server._cmd), server._port, server._logger) + + cache_key = (tuple(server._cmd), server._port, server._logger) + self.assertIn(cache_key, cache._cache) + self.assertEqual(cache._cache[cache_key].owners, {other_owner}) + + # 2. Verify starting the server (which registers its own owner and retrieves from cache) raises RuntimeError + with self.assertRaises(RuntimeError): + server.start() + + # 3. Verify that the destructor was called on the process, meaning no leak (even though other_owner was still registered!) + self.assertEqual(destructor_calls, [(dummy_process, "localhost:12345")]) + + # 4. Verify that the server has cleaned up its owner_id + self.assertIsNone(server._owner_id) + + # 5. Verify the cache entry has been removed completely + self.assertNotIn(cache_key, cache._cache) + + # Clean up the other owner + cache.purge(other_owner) + + def test_non_context_owners_do_not_share_keys(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + # owner1 is a non-context owner (e.g., prism) + owner1 = cache.register(is_context=False) + a = cache.get('a', owner=owner1) + + # owner2 is another non-context owner (e.g., short-lived expansion service) + owner2 = cache.register(is_context=False) + b = cache.get('b', owner=owner2) + + # Verify that owner1 does not own 'b' + self.assertNotIn(owner1, cache._cache[('b', )].owners) + + # Verify that owner2 does not own 'a' + self.assertNotIn(owner2, cache._cache[('a', )].owners) + + # Purging owner2 should immediately destroy/remove 'b' + cache.purge(owner2) + self.assertNotIn(('b', ), cache._cache) + + # 'a' is still alive because owner1 is still registered + self.assertIn(('a', ), cache._cache) + + # Purging owner1 should destroy/remove 'a' + cache.purge(owner1) + self.assertNotIn(('a', ), cache._cache) + + def test_context_owner_owns_all_keys(self): + cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None) + # owner1 is a non-context owner (e.g., prism) + owner1 = cache.register(is_context=False) + + # owner2 is a context owner (e.g., cache_subprocesses) + owner2 = cache.register(is_context=True) + + # owner3 is another non-context owner (e.g., short-lived service) + owner3 = cache.register(is_context=False) + + # owner3 requests 'b' + b = cache.get('b', owner=owner3) + + # owner2 (context) should own 'b' + self.assertIn(owner2, cache._cache[('b', )].owners) + + # owner1 (non-context) should NOT own 'b' + self.assertNotIn(owner1, cache._cache[('b', )].owners) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index bbb60b185c01..192af63a9871 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -287,7 +287,7 @@ def test_csv_to_json(self): output_shard = all_output[0] result = pd.read_json( output_shard, orient='records', - lines=True).sort_values('rank').reindex() + lines=True).sort_values('rank').reset_index(drop=True) pd.testing.assert_frame_equal(data, result) def test_circular_reference_validation(self): diff --git a/sdks/python/conftest.py b/sdks/python/conftest.py index 92d90769fb15..9dc010cb7e34 100644 --- a/sdks/python/conftest.py +++ b/sdks/python/conftest.py @@ -56,7 +56,7 @@ def configure_beam_rpc_timeouts(): timeout_env_vars = { 'GRPC_ARG_KEEPALIVE_TIME_MS': '30000', - 'GRPC_ARG_KEEPALIVE_TIMEOUT_MS': '10000', + 'GRPC_ARG_KEEPALIVE_TIMEOUT_MS': '60000', 'GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA': '0', 'GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS': '1', 'GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS': '300000',