|
14 | 14 | * Call Plugin.prerun() to generate new output |
15 | 15 | """ |
16 | 16 |
|
| 17 | +import decimal |
17 | 18 | import glob |
18 | 19 | import hashlib |
19 | 20 | import importlib |
20 | 21 | import importlib.metadata |
21 | 22 | import logging |
22 | 23 | import multiprocessing |
| 24 | +from multiprocessing.pool import ThreadPool |
23 | 25 | from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union |
24 | 26 |
|
25 | 27 | import duckdb |
| 28 | +import psutil |
26 | 29 | import pyarrow |
27 | 30 | from duckdb import DuckDBPyConnection, DuckDBPyRelation |
28 | 31 |
|
29 | 32 | from countess.core.parameters import BaseParam, FileArrayParam, FileParam, HasSubParametersMixin, MultiParam |
30 | | -from countess.utils.duckdb import duckdb_concatenate, duckdb_escape_identifier, duckdb_source_to_view |
| 33 | +from countess.utils.duckdb import duckdb_combine, duckdb_concatenate, duckdb_escape_identifier, duckdb_source_to_view |
31 | 34 |
|
32 | 35 | PRERUN_ROW_LIMIT: int = 100000 |
33 | 36 |
|
@@ -112,29 +115,26 @@ class DuckdbPlugin(BasePlugin): |
112 | 115 | # XXX expand this, or find in library somewhere |
113 | 116 | ALLOWED_TYPES = {"INTEGER", "VARCHAR", "FLOAT", "DOUBLE", "DECIMAL"} |
114 | 117 |
|
| 118 | + def prepare_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> None: |
| 119 | + pass |
| 120 | + |
115 | 121 | def execute_multi( |
116 | 122 | self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation] |
117 | 123 | ) -> Optional[DuckDBPyRelation]: |
118 | 124 | raise NotImplementedError(f"{self.__class__}.execute_multi") |
119 | 125 |
|
120 | 126 |
|
121 | 127 | class DuckdbSimplePlugin(DuckdbPlugin): |
122 | | - def execute_multi( |
123 | | - self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation] |
124 | | - ) -> Optional[DuckDBPyRelation]: |
125 | | - tables = list(sources.values()) |
126 | | - if len(sources) > 1: |
127 | | - source = duckdb_source_to_view(ddbc, duckdb_concatenate(tables)) |
128 | | - elif len(sources) == 1: |
129 | | - source = tables[0] |
130 | | - else: |
131 | | - source = None |
132 | | - |
133 | | - logger.debug("DuckdbSimplePlugin execute_multi %s", source.alias) |
| 128 | + def prepare_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> None: |
| 129 | + self.prepare(ddbc, duckdb_combine(ddbc, list(sources.values()))) |
134 | 130 |
|
| 131 | + def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None: |
135 | 132 | self.set_column_choices([] if source is None else source.columns) |
136 | 133 |
|
137 | | - return self.execute(ddbc, source) |
| 134 | + def execute_multi( |
| 135 | + self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation] |
| 136 | + ) -> Optional[DuckDBPyRelation]: |
| 137 | + return self.execute(ddbc, duckdb_combine(ddbc, list(sources.values()))) |
138 | 138 |
|
139 | 139 | def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> Optional[DuckDBPyRelation]: |
140 | 140 | raise NotImplementedError(f"{self.__class__}.execute") |
@@ -248,94 +248,72 @@ def filter(self, *_) -> bool: |
248 | 248 | raise NotImplementedError(f"{self.__class__}.transform") |
249 | 249 |
|
250 | 250 |
|
251 | | -class DuckdbTransformPlugin(DuckdbSimplePlugin): |
252 | | - def dropped_columns(self) -> set[str]: |
253 | | - return set() |
254 | | - |
255 | | - def output_columns(self) -> dict[str, str]: |
256 | | - """Return a dictionary of `column name` -> `dbtype` |
257 | | - which will be used to construct the user-defined |
258 | | - function. The columns returned by transform() must |
259 | | - match the columns declared here.""" |
| 251 | +def _python_type_to_arrow_dtype(ttype: type) -> pyarrow.DataType: |
| 252 | + if ttype in (float, decimal.Decimal): |
| 253 | + return pyarrow.float64() |
| 254 | + elif ttype is int: |
| 255 | + return pyarrow.int64() |
| 256 | + elif ttype is bool: |
| 257 | + return pyarrow.bool8() |
| 258 | + else: |
| 259 | + return pyarrow.string() |
260 | 260 |
|
261 | | - raise NotImplementedError(f"{self.__class__}.output_columns") |
262 | 261 |
|
263 | | - def execute(self, ddbc, source): |
| 262 | +class DuckdbTransformPlugin(DuckdbSimplePlugin): |
| 263 | + def __init__(self, *args, **kwargs): |
| 264 | + super().__init__(*args, **kwargs) |
| 265 | + self.view_name = f"v_{id(self)}" |
| 266 | + |
| 267 | + def get_reader(self, source): |
| 268 | + return source.to_arrow_table().to_reader(max_chunksize=2048) |
| 269 | + |
| 270 | + def remove_fields(self, field_names: list[str]) -> list[str]: |
| 271 | + return [] |
| 272 | + |
| 273 | + def add_fields(self) -> Mapping[str, type]: |
| 274 | + return {} |
| 275 | + |
| 276 | + def fix_schema(self, schema: pyarrow.Schema) -> pyarrow.Schema: |
| 277 | + logger.debug("DuckdbTransformPlugin.fix_schema in %s", schema.to_string()) |
| 278 | + for field_name in self.remove_fields(schema.names): |
| 279 | + if field_name in schema.names: |
| 280 | + schema = schema.remove(schema.get_field_index(field_name)) |
| 281 | + for field_name, ttype in self.add_fields().items(): |
| 282 | + if field_name and ttype is not None: |
| 283 | + schema = schema.append(pyarrow.field(field_name, _python_type_to_arrow_dtype(ttype))) |
| 284 | + return schema |
| 285 | + |
| 286 | + def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation: |
264 | 287 | """Perform a query which calls `self.transform` for every row.""" |
265 | 288 |
|
266 | | - # if you happen to have an output column with the same name as an |
267 | | - # input column this drops it, as well as any columns being explicitly |
268 | | - # dropped. |
269 | | - drop_columns_set = set(list(self.output_columns().keys()) + list(self.dropped_columns())) |
| 289 | + reader = self.get_reader(source) |
| 290 | + ddbc.register(self.view_name, pyarrow.Table.from_batches(self.transform_batch(batch) for batch in reader)) |
| 291 | + return ddbc.view(self.view_name) |
270 | 292 |
|
271 | | - # Make up an arbitrary unique name for our temporary function |
272 | | - function_name = f"f_{id(self)}" |
273 | | - |
274 | | - # Output type has to be completely defined, with types and all |
275 | | - output_type = ( |
276 | | - "STRUCT(" |
277 | | - + ",".join( |
278 | | - f"{duckdb_escape_identifier(k)} {str(v).upper()}" |
279 | | - for k, v in self.output_columns().items() |
280 | | - if k is not None and v is not None |
281 | | - ) |
282 | | - + ")" |
| 293 | + def transform_batch(self, batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: |
| 294 | + schema = self.fix_schema(batch.schema) |
| 295 | + return pyarrow.RecordBatch.from_pylist( |
| 296 | + [t for t in (self.transform(row) for row in batch.to_pylist()) if t is not None], schema=schema |
283 | 297 | ) |
284 | 298 |
|
285 | | - # source columns which aren't being dropped get copied into the projection |
286 | | - # in their original order, followed by the generated output columns. |
287 | | - keep_columns = " ".join(f"{duckdb_escape_identifier(k)}," for k in source.columns if k not in drop_columns_set) |
288 | | - |
289 | | - logger.debug("DuckDbTransformPlugin.query function_name %s", function_name) |
290 | | - logger.debug("DuckDbTransformPlugin.query output_type %s", output_type) |
291 | | - logger.debug("DuckDbTransformPlugin.query keep_columns %s", keep_columns) |
292 | | - |
293 | | - # if the function already exists, remove it |
294 | | - try: |
295 | | - ddbc.remove_function(function_name) |
296 | | - logger.debug("DuckDbTransformPlugin.query removed function %s", function_name) |
297 | | - except duckdb.InvalidInputException as exc: |
298 | | - if not str(exc).startswith("Invalid Input Error: No function by the name of '"): |
299 | | - # some other error |
300 | | - logger.debug("DuckDbTransformPlugin.query can't remove function %s: %s", function_name, exc) |
301 | | - |
302 | | - # XXX it'd be nice to have an an arrow version of this |
303 | | - # to allow easy parallelization, but see: |
304 | | - # https://github.com/duckdb/duckdb/issues/15626 |
305 | | - # Appears to be fixed in 1.1.4.dev4815 |
306 | | - |
307 | | - ddbc.create_function( |
308 | | - name=function_name, |
309 | | - function=self.transform_arrow, |
310 | | - type="arrow", |
311 | | - return_type=output_type, |
312 | | - null_handling="special", |
313 | | - side_effects=False, |
314 | | - ) |
315 | | - |
316 | | - # the "SELECT func(_row) FROM {table} _row" bit passes |
317 | | - # a whole row to the function, sadly there's no way |
318 | | - # to express this in a `.project()`. |
319 | | - |
320 | | - sql_command = f"SELECT {keep_columns} UNNEST({function_name}(_row)) FROM {source.alias} _row" |
321 | | - logger.debug("DuckDbTransformPlugin.query sql_command %s", sql_command) |
322 | | - |
323 | | - self.prepare(source) |
324 | | - |
325 | | - return duckdb_source_to_view(ddbc, ddbc.sql(sql_command)) |
326 | | - |
327 | | - def prepare(self, source: DuckDBPyRelation): |
328 | | - """Called before the transform functions are run, to prepare anything |
329 | | - which needs preparation ...""" |
330 | | - pass |
331 | | - |
332 | | - def transform_arrow(self, data: pyarrow.array) -> pyarrow.array: |
333 | | - logger.debug("DuckDbTransformPlugin.transform_arrow %d", len(data)) |
334 | | - pool = multiprocessing.Pool(processes=4) |
335 | | - return pyarrow.array(pool.imap_unordered(self.transform, data.to_pylist())) |
336 | | - |
337 | 299 | def transform(self, data: dict[str, Any]) -> Union[dict[str, Any], Tuple[Any], None]: |
338 | 300 | """This will be called for each row. Return a tuple with the same |
339 | 301 | value types as (or a dictionary with the same keys and value types as) |
340 | 302 | those nominated by `self.output_columns`, or None to return all NULLs.""" |
341 | 303 | raise NotImplementedError(f"{self.__class__}.transform") |
| 304 | + |
| 305 | + |
| 306 | +class DuckdbThreadedTransformPlugin(DuckdbTransformPlugin): |
| 307 | + def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation: |
| 308 | + with multiprocessing.pool.ThreadPool(processes=psutil.cpu_count()) as pool: |
| 309 | + reader = self.get_reader(source) |
| 310 | + ddbc.register(self.view_name, pyarrow.Table.from_batches(pool.imap_unordered(self.transform_batch, reader))) |
| 311 | + return ddbc.view(self.view_name) |
| 312 | + |
| 313 | + |
| 314 | +class DuckdbParallelTransformPlugin(DuckdbTransformPlugin): |
| 315 | + def execute(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> DuckDBPyRelation: |
| 316 | + with multiprocessing.Pool(processes=psutil.cpu_count()) as pool: |
| 317 | + reader = self.get_reader(source) |
| 318 | + ddbc.register(self.view_name, pyarrow.Table.from_batches(pool.imap_unordered(self.transform_batch, reader))) |
| 319 | + return ddbc.view(self.view_name) |
0 commit comments