From ae65a6e4da6e8f18936a131507048954b825de03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Morten=20Gr=C3=B8ftehauge?= Date: Thu, 7 Mar 2024 11:15:50 +0100 Subject: [PATCH] dtypes support when reading csv files into Polars --- .../kedro_datasets/polars/csv_dataset.py | 22 ++++++++++++++++ .../polars/eager_polars_dataset.py | 26 +++++++++++++++++++ .../polars/lazy_polars_dataset.py | 26 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/kedro-datasets/kedro_datasets/polars/csv_dataset.py b/kedro-datasets/kedro_datasets/polars/csv_dataset.py index 5c9a99433..098c13bc8 100644 --- a/kedro-datasets/kedro_datasets/polars/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/csv_dataset.py @@ -2,6 +2,7 @@ filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file. """ import logging +from collections.abc import Mapping, Sequence from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -144,6 +145,27 @@ def __init__( # noqa: PLR0913 self._save_args.pop("storage_options", None) self._load_args.pop("storage_options", None) + if "dtypes" in self._load_args: + if isinstance(self._load_args["dtypes"], Sequence) and not isinstance( + self._load_args["dtypes"], str + ): + self._load_args["dtypes"] = [ + getattr(pl, dtype.split(".")[-1]) + for dtype in self._load_args["dtypes"] + ] + elif isinstance(self._load_args["dtypes"], Mapping): + self._load_args["dtypes"] = { + key: getattr(pl, dtype.split(".")[-1]) + for key, dtype in self._load_args["dtypes"].items() + } + elif self._load_args["dtypes"] is None: + pass + else: + examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}" + raise ValueError( + f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}" + ) + def _describe(self) -> dict[str, Any]: return { "filepath": self._filepath, diff --git a/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py index b7f617fa8..ab13d841d 100644 --- a/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py @@ -2,6 +2,7 @@ filesystem (e.g.: local, S3, GCS). It uses polars to handle the type of read/write target. """ +from collections.abc import Mapping, Sequence from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -141,6 +142,31 @@ def __init__( # noqa: PLR0913 self._fs_open_args_load = _fs_open_args_load self._fs_open_args_save = _fs_open_args_save + if "dtypes" in self._load_args and self._file_format in ("csv"): + if isinstance(self._load_args["dtypes"], Sequence) and not isinstance( + self._load_args["dtypes"], str + ): + self._load_args["dtypes"] = [ + getattr(pl, dtype.split(".")[-1]) + for dtype in self._load_args["dtypes"] + ] + elif isinstance(self._load_args["dtypes"], Mapping): + self._load_args["dtypes"] = { + key: getattr(pl, dtype.split(".")[-1]) + for key, dtype in self._load_args["dtypes"].items() + } + elif self._load_args["dtypes"] is None: + pass + else: + examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}" + raise ValueError( + f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}" + ) + elif "dtypes" in self._load_args and self._file_format not in ("csv"): + raise ValueError( + f"Invalid argument 'dtypes' for file_format '{self._file_format}'" + ) + def _load(self) -> pl.DataFrame: load_path = get_filepath_str(self._get_load_path(), self._protocol) load_method = getattr(pl, f"read_{self._file_format}", None) diff --git a/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py b/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py index fd6248ef6..ff3ddfa5a 100644 --- a/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py @@ -3,6 +3,7 @@ type of read/write target. """ import logging +from collections.abc import Mapping, Sequence from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -179,6 +180,31 @@ def __init__( # noqa: PLR0913 self._save_args.pop("storage_options", None) self._load_args.pop("storage_options", None) + if "dtypes" in self._load_args and self._file_format in ("csv"): + if isinstance(self._load_args["dtypes"], Sequence) and not isinstance( + self._load_args["dtypes"], str + ): + self._load_args["dtypes"] = [ + getattr(pl, dtype.split(".")[-1]) + for dtype in self._load_args["dtypes"] + ] + elif isinstance(self._load_args["dtypes"], Mapping): + self._load_args["dtypes"] = { + key: getattr(pl, dtype.split(".")[-1]) + for key, dtype in self._load_args["dtypes"].items() + } + elif self._load_args["dtypes"] is None: + pass + else: + examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}" + raise ValueError( + f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}" + ) + elif "dtypes" in self._load_args and self._file_format not in ("csv"): + raise ValueError( + f"Invalid argument 'dtypes' for file_format '{self._file_format}'" + ) + def _describe(self) -> dict[str, Any]: return { "filepath": self._filepath,