Skip to content

Commit 1f32e61

Browse files
committed
fix parallelization of TransformPlugins
1 parent e54b351 commit 1f32e61

File tree

5 files changed

+131
-126
lines changed

5 files changed

+131
-126
lines changed

countess/core/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def run(self, ddbc):
8585
assert isinstance(self.plugin, DuckdbPlugin)
8686
if self.is_dirty:
8787
sources = {pn.name: pn.run(ddbc) for pn in self.parent_nodes}
88-
88+
self.plugin.prepare_multi(ddbc, sources)
8989
result = self.plugin.execute_multi(ddbc, sources)
9090
if result is not None:
9191
try:
@@ -189,7 +189,9 @@ def run(self):
189189
start_time = time.time()
190190
for node in self.traverse_nodes():
191191
node.load_config()
192-
result = node.plugin.execute_multi(self.ddbc, {pn.name: pn.result for pn in node.parent_nodes})
192+
sources = {pn.name: pn.result for pn in node.parent_nodes}
193+
node.plugin.prepare_multi(self.ddbc, sources)
194+
result = node.plugin.execute_multi(self.ddbc, sources)
193195
if result:
194196
node.result = duckdb_source_to_view(self.ddbc, result)
195197
else:

countess/core/plugins.py

Lines changed: 71 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,23 @@
1414
* Call Plugin.prerun() to generate new output
1515
"""
1616

17+
import decimal
1718
import glob
1819
import hashlib
1920
import importlib
2021
import importlib.metadata
2122
import logging
2223
import multiprocessing
24+
from multiprocessing.pool import ThreadPool
2325
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
2426

2527
import duckdb
28+
import psutil
2629
import pyarrow
2730
from duckdb import DuckDBPyConnection, DuckDBPyRelation
2831

2932
from countess.core.parameters import BaseParam, FileArrayParam, FileParam, HasSubParametersMixin, MultiParam
30-
from countess.utils.duckdb import duckdb_concatenate, duckdb_escape_identifier, duckdb_source_to_view
33+
from countess.utils.duckdb import duckdb_combine, duckdb_concatenate, duckdb_escape_identifier, duckdb_source_to_view
3134

3235
PRERUN_ROW_LIMIT: int = 100000
3336

@@ -112,29 +115,26 @@ class DuckdbPlugin(BasePlugin):
112115
# XXX expand this, or find in library somewhere
113116
ALLOWED_TYPES = {"INTEGER", "VARCHAR", "FLOAT", "DOUBLE", "DECIMAL"}
114117

118+
def prepare_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> None:
119+
pass
120+
115121
def execute_multi(
116122
self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]
117123
) -> Optional[DuckDBPyRelation]:
118124
raise NotImplementedError(f"{self.__class__}.execute_multi")
119125

120126

121127
class DuckdbSimplePlugin(DuckdbPlugin):
122-
def execute_multi(
123-
self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]
124-
) -> Optional[DuckDBPyRelation]:
125-
tables = list(sources.values())
126-
if len(sources) > 1:
127-
source = duckdb_source_to_view(ddbc, duckdb_concatenate(tables))
128-
elif len(sources) == 1:
129-
source = tables[0]
130-
else:
131-
source = None
132-
133-
logger.debug("DuckdbSimplePlugin execute_multi %s", source.alias)
128+
def prepare_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> None:
129+
self.prepare(ddbc, duckdb_combine(ddbc, list(sources.values())))
134130

131+
def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None:
135132
self.set_column_choices([] if source is None else source.columns)
136133

137-
return self.execute(ddbc, source)
134+
def execute_multi(
135+
self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]
136+
) -> Optional[DuckDBPyRelation]:
137+
return self.execute(ddbc, duckdb_combine(ddbc, list(sources.values())))
138138

139139
def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
140140
raise NotImplementedError(f"{self.__class__}.execute")
@@ -248,94 +248,72 @@ def filter(self, *_) -> bool:
248248
raise NotImplementedError(f"{self.__class__}.transform")
249249

250250

251-
class DuckdbTransformPlugin(DuckdbSimplePlugin):
252-
def dropped_columns(self) -> set[str]:
253-
return set()
254-
255-
def output_columns(self) -> dict[str, str]:
256-
"""Return a dictionary of `column name` -> `dbtype`
257-
which will be used to construct the user-defined
258-
function. The columns returned by transform() must
259-
match the columns declared here."""
251+
def _python_type_to_arrow_dtype(ttype: type) -> pyarrow.DataType:
252+
if ttype in (float, decimal.Decimal):
253+
return pyarrow.float64()
254+
elif ttype is int:
255+
return pyarrow.int64()
256+
elif ttype is bool:
257+
return pyarrow.bool8()
258+
else:
259+
return pyarrow.string()
260260

261-
raise NotImplementedError(f"{self.__class__}.output_columns")
262261

263-
def execute(self, ddbc, source):
262+
class DuckdbTransformPlugin(DuckdbSimplePlugin):
263+
def __init__(self, *args, **kwargs):
264+
super().__init__(*args, **kwargs)
265+
self.view_name = f"v_{id(self)}"
266+
267+
def get_reader(self, source):
268+
return source.to_arrow_table().to_reader(max_chunksize=2048)
269+
270+
def remove_fields(self, field_names: list[str]) -> list[str]:
271+
return []
272+
273+
def add_fields(self) -> Mapping[str, type]:
274+
return {}
275+
276+
def fix_schema(self, schema: pyarrow.Schema) -> pyarrow.Schema:
277+
logger.debug("DuckdbTransformPlugin.fix_schema in %s", schema.to_string())
278+
for field_name in self.remove_fields(schema.names):
279+
if field_name in schema.names:
280+
schema = schema.remove(schema.get_field_index(field_name))
281+
for field_name, ttype in self.add_fields().items():
282+
if field_name and ttype is not None:
283+
schema = schema.append(pyarrow.field(field_name, _python_type_to_arrow_dtype(ttype)))
284+
return schema
285+
286+
def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation:
264287
"""Perform a query which calls `self.transform` for every row."""
265288

266-
# if you happen to have an output column with the same name as an
267-
# input column this drops it, as well as any columns being explicitly
268-
# dropped.
269-
drop_columns_set = set(list(self.output_columns().keys()) + list(self.dropped_columns()))
289+
reader = self.get_reader(source)
290+
ddbc.register(self.view_name, pyarrow.Table.from_batches(self.transform_batch(batch) for batch in reader))
291+
return ddbc.view(self.view_name)
270292

271-
# Make up an arbitrary unique name for our temporary function
272-
function_name = f"f_{id(self)}"
273-
274-
# Output type has to be completely defined, with types and all
275-
output_type = (
276-
"STRUCT("
277-
+ ",".join(
278-
f"{duckdb_escape_identifier(k)} {str(v).upper()}"
279-
for k, v in self.output_columns().items()
280-
if k is not None and v is not None
281-
)
282-
+ ")"
293+
def transform_batch(self, batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch:
294+
schema = self.fix_schema(batch.schema)
295+
return pyarrow.RecordBatch.from_pylist(
296+
[t for t in (self.transform(row) for row in batch.to_pylist()) if t is not None], schema=schema
283297
)
284298

285-
# source columns which aren't being dropped get copied into the projection
286-
# in their original order, followed by the generated output columns.
287-
keep_columns = " ".join(f"{duckdb_escape_identifier(k)}," for k in source.columns if k not in drop_columns_set)
288-
289-
logger.debug("DuckDbTransformPlugin.query function_name %s", function_name)
290-
logger.debug("DuckDbTransformPlugin.query output_type %s", output_type)
291-
logger.debug("DuckDbTransformPlugin.query keep_columns %s", keep_columns)
292-
293-
# if the function already exists, remove it
294-
try:
295-
ddbc.remove_function(function_name)
296-
logger.debug("DuckDbTransformPlugin.query removed function %s", function_name)
297-
except duckdb.InvalidInputException as exc:
298-
if not str(exc).startswith("Invalid Input Error: No function by the name of '"):
299-
# some other error
300-
logger.debug("DuckDbTransformPlugin.query can't remove function %s: %s", function_name, exc)
301-
302-
# XXX it'd be nice to have an an arrow version of this
303-
# to allow easy parallelization, but see:
304-
# https://github.com/duckdb/duckdb/issues/15626
305-
# Appears to be fixed in 1.1.4.dev4815
306-
307-
ddbc.create_function(
308-
name=function_name,
309-
function=self.transform_arrow,
310-
type="arrow",
311-
return_type=output_type,
312-
null_handling="special",
313-
side_effects=False,
314-
)
315-
316-
# the "SELECT func(_row) FROM {table} _row" bit passes
317-
# a whole row to the function, sadly there's no way
318-
# to express this in a `.project()`.
319-
320-
sql_command = f"SELECT {keep_columns} UNNEST({function_name}(_row)) FROM {source.alias} _row"
321-
logger.debug("DuckDbTransformPlugin.query sql_command %s", sql_command)
322-
323-
self.prepare(source)
324-
325-
return duckdb_source_to_view(ddbc, ddbc.sql(sql_command))
326-
327-
def prepare(self, source: DuckDBPyRelation):
328-
"""Called before the transform functions are run, to prepare anything
329-
which needs preparation ..."""
330-
pass
331-
332-
def transform_arrow(self, data: pyarrow.array) -> pyarrow.array:
333-
logger.debug("DuckDbTransformPlugin.transform_arrow %d", len(data))
334-
pool = multiprocessing.Pool(processes=4)
335-
return pyarrow.array(pool.imap_unordered(self.transform, data.to_pylist()))
336-
337299
def transform(self, data: dict[str, Any]) -> Union[dict[str, Any], Tuple[Any], None]:
338300
"""This will be called for each row. Return a tuple with the same
339301
value types as (or a dictionary with the same keys and value types as)
340302
those nominated by `self.output_columns`, or None to return all NULLs."""
341303
raise NotImplementedError(f"{self.__class__}.transform")
304+
305+
306+
class DuckdbThreadedTransformPlugin(DuckdbTransformPlugin):
307+
def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation:
308+
with multiprocessing.pool.ThreadPool(processes=psutil.cpu_count()) as pool:
309+
reader = self.get_reader(source)
310+
ddbc.register(self.view_name, pyarrow.Table.from_batches(pool.imap_unordered(self.transform_batch, reader)))
311+
return ddbc.view(self.view_name)
312+
313+
314+
class DuckdbParallelTransformPlugin(DuckdbTransformPlugin):
315+
def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation:
316+
with multiprocessing.Pool(processes=psutil.cpu_count()) as pool:
317+
reader = self.get_reader(source)
318+
ddbc.register(self.view_name, pyarrow.Table.from_batches(pool.imap_unordered(self.transform_batch, reader)))
319+
return ddbc.view(self.view_name)

countess/plugins/score.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from math import log
23
from typing import Any, Optional, Union
34

@@ -13,7 +14,9 @@
1314
ColumnGroupOrNoneChoiceParam,
1415
StringParam,
1516
)
16-
from countess.core.plugins import DuckdbTransformPlugin
17+
from countess.core.plugins import DuckdbParallelTransformPlugin
18+
19+
logger = logging.getLogger(__name__)
1720

1821

1922
def float_or_none(s: Any) -> Optional[float]:
@@ -37,7 +40,7 @@ def score(xs: list[float], ys: list[float]) -> Optional[tuple[float, float]]:
3740
return None
3841

3942

40-
class ScoringPlugin(DuckdbTransformPlugin):
43+
class ScoringPlugin(DuckdbParallelTransformPlugin):
4144
name = "Scoring"
4245
description = "Score variants using counts or frequencies"
4346
version = VERSION
@@ -61,7 +64,9 @@ def output_columns(self) -> dict[str, str]:
6164
else:
6265
return {self.output.value: "DOUBLE"}
6366

64-
def prepare(self, source):
67+
def prepare(self, ddbc, source):
68+
logger.debug("ScoringPlugin.prepare")
69+
super().prepare(ddbc, source)
6570
yaxis_prefix = self.columns.get_column_prefix()
6671
suffix_set = {k.removeprefix(yaxis_prefix) for k in source.columns if k.startswith(yaxis_prefix)}
6772

@@ -70,9 +75,24 @@ def prepare(self, source):
7075
suffix_set.update([k.removeprefix(xaxis_prefix) for k in source.columns if k.startswith(xaxis_prefix)])
7176

7277
self.suffixes = sorted(suffix_set)
78+
logger.debug("ScoringPlugin.prepare suffixes %s", self.suffixes)
79+
80+
def add_fields(self):
81+
return {self.output.value: float, self.variance.value: float}
82+
83+
def remove_fields(self, field_names: list[str]):
84+
if self.drop_input:
85+
return [
86+
name
87+
for name in field_names
88+
if name.startswith(self.columns.get_column_prefix())
89+
or (self.xaxis.is_not_none() and name.startswith(self.xaxis.get_column_prefix()))
90+
]
91+
else:
92+
return []
7393

7494
def transform(self, data: dict[str, Any]) -> Optional[dict[str, Any]]:
75-
assert self.suffixes
95+
assert self.suffixes is not None
7696

7797
if self.xaxis.is_not_none():
7898
xaxis_prefix = self.xaxis.get_column_prefix()
@@ -99,9 +119,9 @@ def transform(self, data: dict[str, Any]) -> Optional[dict[str, Any]]:
99119

100120
try:
101121
s, v = score(x_values, y_values)
122+
data[self.output.value] = s
102123
if self.variance:
103-
return {self.output.value: s, self.variance.value: v}
104-
else:
105-
return {self.output.value: s}
124+
data[self.variance.value] = v
125+
return data
106126
except TypeError:
107127
return None

0 commit comments

Comments
 (0)