Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
conf.get(PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
protected val tracebackWithLocals: Boolean = false

protected def runnerConf: Map[String, String] = Map.empty
protected def evalConf: Map[String, String] = Map.empty
Expand Down Expand Up @@ -302,6 +303,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
if (tracebackWithLocals) {
envVars.put("SPARK_TRACEBACK_WITH_LOCALS", "1")
}
// SPARK-30299 this could be wrong with standalone mode when executor
// cores might not be correct because it defaults to all cores on the box.
val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES))
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,20 @@ def test_udf(a):
with self.assertRaisesRegex(PythonException, "StopIteration"):
self.spark.range(10).select(test_udf(col("id"))).show()

def test_udf_traceback_with_locals(self):
with self.sql_conf({"spark.sql.execution.pyspark.udf.tracebackWithLocals.enabled": True}):

@udf("int")
def test_udf(a):
local_marker = a + 1
if local_marker:
raise ValueError("boom")
return local_marker

# The captured locals should include the local variable and its value.
with self.assertRaisesRegex(PythonException, "local_marker = 1"):
self.spark.range(1).select(test_udf(col("id"))).collect()

def test_python_udf_segfault(self):
with self.sql_conf({"spark.sql.execution.pyspark.udf.faulthandler.enabled": True}):
with self.assertRaisesRegex(Exception, "Segmentation fault"):
Expand Down
20 changes: 19 additions & 1 deletion python/pyspark/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def run_handle_worker_exception(self, hide_traceback=None):
from pyspark.util import handle_worker_exception

try:
raise ValueError("test_message")
local_marker = "marker_value_42"
if local_marker:
raise ValueError("test_message")
except Exception as e:
with io.BytesIO() as stream:
handle_worker_exception(e, stream, hide_traceback)
Expand Down Expand Up @@ -206,6 +208,22 @@ def test_hide_traceback(self):
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)

@patch.dict(os.environ, {"SPARK_TRACEBACK_WITH_LOCALS": "1"})
def test_env_traceback_with_locals(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)
# The local variable's value should be captured in the traceback.
self.assertIn(b"marker_value_42", result)

@patch.dict(os.environ, {"SPARK_TRACEBACK_WITH_LOCALS": ""})
def test_env_no_traceback_with_locals(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)
# Without the environment variable, locals must not be captured.
self.assertNotIn(b"marker_value_42", result)


if __name__ == "__main__":
from pyspark.testing import main
Expand Down
15 changes: 11 additions & 4 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,19 @@ def handle_worker_exception(
def format_exception() -> str:
if hide_traceback:
return "".join(traceback.format_exception_only(type(e), e))
tb = sys.exc_info()[-1]
if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type]
if tb is not None:
simplified_tb = try_simplify_traceback(tb) # type: ignore[arg-type]
if simplified_tb is not None:
tb = simplified_tb
e.__cause__ = None
return "".join(traceback.format_exception(type(e), e, tb))
return traceback.format_exc()
# We only set SPARK_TRACEBACK_WITH_LOCALS=1 for now. This equivalent to a
# check for the existence of the environment variable.
capture_locals = bool(os.environ.get("SPARK_TRACEBACK_WITH_LOCALS", False))
te = traceback.TracebackException(
type(e), e, tb, compact=True, capture_locals=capture_locals
)
return "".join(te.format())

try:
exc_info = format_exception()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4768,6 +4768,17 @@ object SQLConf {
// show full stacktrace in tests but hide in production by default.
.createWithDefault(!Utils.isTesting)

val PYSPARK_TRACEBACK_WITH_LOCALS =
buildConf("spark.sql.execution.pyspark.udf.tracebackWithLocals.enabled")
.doc(
"When true, include the local variables in the traceback from Python UDFs. " +
"Note that this config will print the value of every local variable in the call stack, " +
"including sensitive data like passwords or API keys. Please use this config with caution.")
.version("4.3.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(false)

val PYSPARK_ARROW_VALIDATE_SCHEMA =
buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled")
.doc(
Expand Down Expand Up @@ -8387,6 +8398,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)

def pysparkTracebackWithLocals: Boolean = getConf(PYSPARK_TRACEBACK_WITH_LOCALS)

def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA)

def pandasGroupedMapAssignColumnsByName: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
require(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ArrowPythonUDTFRunner(

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
require(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class CoGroupedArrowPythonRunner(

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals

private val maxRecordsPerBatch: Int = {
val v = SQLConf.get.arrowMaxRecordsPerBatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) extends Logging {
val killWorkerOnFlushFailure: Boolean = SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ abstract class BasePythonUDFRunner(

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ApplyInPandasWithStatePythonRunner(

override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = sqlConf.pysparkTracebackWithLocals

override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
override val tracebackWithLocals: Boolean = SQLConf.get.pysparkTracebackWithLocals
}
}

Expand Down