Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions kedro-datasets/kedro_datasets/polars/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file.
"""
import logging
from collections.abc import Mapping, Sequence
from copy import deepcopy
from io import BytesIO
from pathlib import PurePosixPath
Expand Down Expand Up @@ -144,6 +145,27 @@ def __init__( # noqa: PLR0913
self._save_args.pop("storage_options", None)
self._load_args.pop("storage_options", None)

if "dtypes" in self._load_args:
if isinstance(self._load_args["dtypes"], Sequence) and not isinstance(
self._load_args["dtypes"], str
):
self._load_args["dtypes"] = [
getattr(pl, dtype.split(".")[-1])
for dtype in self._load_args["dtypes"]
]
elif isinstance(self._load_args["dtypes"], Mapping):
self._load_args["dtypes"] = {
key: getattr(pl, dtype.split(".")[-1])
for key, dtype in self._load_args["dtypes"].items()
}
elif self._load_args["dtypes"] is None:
pass
else:
examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}"
raise ValueError(
f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}"
)

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
Expand Down
26 changes: 26 additions & 0 deletions kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
filesystem (e.g.: local, S3, GCS). It uses polars to handle the
type of read/write target.
"""
from collections.abc import Mapping, Sequence
from copy import deepcopy
from io import BytesIO
from pathlib import PurePosixPath
Expand Down Expand Up @@ -141,6 +142,31 @@ def __init__( # noqa: PLR0913
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save

if "dtypes" in self._load_args and self._file_format in ("csv"):
if isinstance(self._load_args["dtypes"], Sequence) and not isinstance(
self._load_args["dtypes"], str
):
self._load_args["dtypes"] = [
getattr(pl, dtype.split(".")[-1])
for dtype in self._load_args["dtypes"]
]
elif isinstance(self._load_args["dtypes"], Mapping):
self._load_args["dtypes"] = {
key: getattr(pl, dtype.split(".")[-1])
for key, dtype in self._load_args["dtypes"].items()
}
elif self._load_args["dtypes"] is None:
pass
else:
examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}"
raise ValueError(
f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}"
)
elif "dtypes" in self._load_args and self._file_format not in ("csv"):
raise ValueError(
f"Invalid argument 'dtypes' for file_format '{self._file_format}'"
)

def _load(self) -> pl.DataFrame:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
load_method = getattr(pl, f"read_{self._file_format}", None)
Expand Down
26 changes: 26 additions & 0 deletions kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
type of read/write target.
"""
import logging
from collections.abc import Mapping, Sequence
from copy import deepcopy
from io import BytesIO
from pathlib import PurePosixPath
Expand Down Expand Up @@ -179,6 +180,31 @@ def __init__( # noqa: PLR0913
self._save_args.pop("storage_options", None)
self._load_args.pop("storage_options", None)

if "dtypes" in self._load_args and self._file_format in ("csv"):
if isinstance(self._load_args["dtypes"], Sequence) and not isinstance(
self._load_args["dtypes"], str
):
self._load_args["dtypes"] = [
getattr(pl, dtype.split(".")[-1])
for dtype in self._load_args["dtypes"]
]
elif isinstance(self._load_args["dtypes"], Mapping):
self._load_args["dtypes"] = {
key: getattr(pl, dtype.split(".")[-1])
for key, dtype in self._load_args["dtypes"].items()
}
elif self._load_args["dtypes"] is None:
pass
else:
examples = "\nValid examples: None, [pl.Utf8, pl.Int64], {str_col: pl.Utf8, int_col: pl.Int64}"
raise ValueError(
f"Invalid type for 'dtypes' in load_args: {type(self._load_args['dtypes'])}. {examples}"
)
elif "dtypes" in self._load_args and self._file_format not in ("csv"):
raise ValueError(
f"Invalid argument 'dtypes' for file_format '{self._file_format}'"
)

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
Expand Down