diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1524ff455ec75..896ab6e3b19df 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -213,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( conf.get(PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE) protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false + protected val tracebackWithLocals: Boolean = false protected def runnerConf: Map[String, String] = Map.empty protected def evalConf: Map[String, String] = Map.empty @@ -302,6 +303,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } + if (tracebackWithLocals) { + envVars.put("SPARK_TRACEBACK_WITH_LOCALS", "1") + } // SPARK-30299 this could be wrong with standalone mode when executor // cores might not be correct because it defaults to all cores on the box. val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index ce7a42469e18b..5701fb6dcda97 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -1232,6 +1232,20 @@ def test_udf(a): with self.assertRaisesRegex(PythonException, "StopIteration"): self.spark.range(10).select(test_udf(col("id"))).show() + def test_udf_traceback_with_locals(self): + with self.sql_conf({"spark.sql.execution.pyspark.udf.tracebackWithLocals.enabled": True}): + + @udf("int") + def test_udf(a): + local_marker = a + 1 + if local_marker: + raise ValueError("boom") + return local_marker + + # The captured locals should include the local variable and its value. + with self.assertRaisesRegex(PythonException, "local_marker = 1"): + self.spark.range(1).select(test_udf(col("id"))).collect() + def test_python_udf_segfault(self): with self.sql_conf({"spark.sql.execution.pyspark.udf.faulthandler.enabled": True}): with self.assertRaisesRegex(Exception, "Segmentation fault"): diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index 9fd0135eff4eb..55412b3914fdd 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -176,7 +176,9 @@ def run_handle_worker_exception(self, hide_traceback=None): from pyspark.util import handle_worker_exception try: - raise ValueError("test_message") + local_marker = "marker_value_42" + if local_marker: + raise ValueError("test_message") except Exception as e: with io.BytesIO() as stream: handle_worker_exception(e, stream, hide_traceback) @@ -206,6 +208,22 @@ def test_hide_traceback(self): self.assertIn(self.exception_bytes, result) self.assertNotIn(self.traceback_bytes, result) + @patch.dict(os.environ, {"SPARK_TRACEBACK_WITH_LOCALS": "1"}) + def test_env_traceback_with_locals(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertIn(self.traceback_bytes, result) + # The local variable's value should be captured in the traceback. + self.assertIn(b"marker_value_42", result) + + @patch.dict(os.environ, {"SPARK_TRACEBACK_WITH_LOCALS": ""}) + def test_env_no_traceback_with_locals(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertIn(self.traceback_bytes, result) + # Without the environment variable, locals must not be captured. + self.assertNotIn(b"marker_value_42", result) + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 4d3c64e6a1f51..e66f6283ead18 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -527,12 +527,19 @@ def handle_worker_exception( def format_exception() -> str: if hide_traceback: return "".join(traceback.format_exception_only(type(e), e)) + tb = sys.exc_info()[-1] if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): - tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type] - if tb is not None: + simplified_tb = try_simplify_traceback(tb) # type: ignore[arg-type] + if simplified_tb is not None: + tb = simplified_tb e.__cause__ = None - return "".join(traceback.format_exception(type(e), e, tb)) - return traceback.format_exc() + # We only set SPARK_TRACEBACK_WITH_LOCALS=1 for now. This equivalent to a + # check for the existence of the environment variable. + capture_locals = bool(os.environ.get("SPARK_TRACEBACK_WITH_LOCALS", False)) + te = traceback.TracebackException( + type(e), e, tb, compact=True, capture_locals=capture_locals + ) + return "".join(te.format()) try: exc_info = format_exception() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a21653a011b34..a1a91bbb22a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4768,6 +4768,17 @@ object SQLConf { // show full stacktrace in tests but hide in production by default. .createWithDefault(!Utils.isTesting) + val PYSPARK_TRACEBACK_WITH_LOCALS = + buildConf("spark.sql.execution.pyspark.udf.tracebackWithLocals.enabled") + .doc( + "When true, include the local variables in the traceback from Python UDFs. " + + "Note that this config will print the value of every local variable in the call stack, " + + "including sensitive data like passwords or API keys. Please use this config with caution.") + .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val PYSPARK_ARROW_VALIDATE_SCHEMA = buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled") .doc( @@ -8387,6 +8398,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK) + def pysparkTracebackWithLocals: Boolean = getConf(PYSPARK_TRACEBACK_WITH_LOCALS) + def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA) def pandasGroupedMapAssignColumnsByName: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index df9b0e241b748..75b8465e2607a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -67,6 +67,7 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef]( override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index a3c9f2f459266..59e78ae81e3bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -93,6 +93,7 @@ class ArrowPythonUDTFRunner( override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index df108187c9f03..940cab4a50696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -80,6 +80,7 @@ class CoGroupedArrowPythonRunner( override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals private val maxRecordsPerBatch: Int = { val v = SQLConf.get.arrowMaxRecordsPerBatch diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala index 5311f34f41c07..80887dd469d69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala @@ -63,6 +63,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) extends Logging { val killWorkerOnFlushFailure: Boolean = SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index e59b942ab1b41..b2eaec1674b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -60,6 +60,7 @@ abstract class BasePythonUDFRunner( override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala index 786c5fc408dcd..70d23ff89c355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala @@ -90,6 +90,7 @@ class ApplyInPandasWithStatePythonRunner( override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = sqlConf.pysparkTracebackWithLocals override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala index cc7745210a4d3..ebcd49931da4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala @@ -111,6 +111,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals } }