Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
# we are not testing pyspark on Windows here because it is very slow
run: uv pip install -e ".[tests, core, extra, dask, modin, sqlframe]" --system
run: uv pip install -e ".[tests, core, extra, dask, modin]" --system
- name: show-deps
run: uv pip freeze
- name: Run pytest
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
run: uv pip install -e ".[tests, core, extra, modin, dask, sqlframe]" --system
run: uv pip install -e ".[tests, core, extra, modin, dask]" --system
- name: install pyspark
run: uv pip install -e ".[pyspark]" --system
# PySpark is not yet available on Python3.12+
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/typing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
mypy:
strategy:
matrix:
python-version: ["3.11"]
python-version: ["3.12"]
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -32,7 +32,7 @@ jobs:
# TODO: add more dependencies/backends incrementally
run: |
source .venv/bin/activate
uv pip install -e ".[tests, typing, core, pyspark, sqlframe]"
uv pip install -e ".[typing, core, pyspark]"
- name: show-deps
run: |
source .venv/bin/activate
Expand Down
20 changes: 20 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ then their tests will run too.

We can't currently test in CI against cuDF, but you can test it manually in Kaggle using GPUs. Please follow this [Kaggle notebook](https://www.kaggle.com/code/marcogorelli/testing-cudf-in-narwhals) to run the tests.

### Static typing

We run both `mypy` and `pyright` in CI. To run them locally, make sure to install

```terminal
uv pip install -U -e ".[typing]"
```

You can then run
- `mypy narwhals tests`
- `pyright narwhals tests`

to verify type completeness / correctness.

Note that:
- In `_pandas_like`, we type all native objects as if they are pandas ones, though
in reality this package is shared between pandas, Modin, and cuDF.
- In `_spark_like`, we type all native objects as if they are SQLFrame ones, though
in reality this package is shared between SQLFrame and PySpark.

### 8. Writing the doc(strings)

If you are adding a new feature or changing an existing one, you should also update the documentation and the docstrings
Expand Down
126 changes: 51 additions & 75 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import warnings
from importlib import import_module
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._spark_like.utils import evaluate_exprs
from narwhals._spark_like.utils import import_functions
from narwhals._spark_like.utils import import_native_dtypes
from narwhals._spark_like.utils import import_window
from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantDataFrame
Expand All @@ -26,11 +27,10 @@
from types import ModuleType

import pyarrow as pa
from pyspark.sql import Column
from pyspark.sql import DataFrame
from pyspark.sql import Window
from pyspark.sql.session import SparkSession
from sqlframe.base.dataframe import BaseDataFrame as _SQLFrameDataFrame
from sqlframe.base.column import Column
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.session import _BaseSession
from sqlframe.base.window import Window
from typing_extensions import Self
from typing_extensions import TypeAlias

Expand All @@ -40,8 +40,8 @@
from narwhals.dtypes import DType
from narwhals.utils import Version

SQLFrameDataFrame: TypeAlias = _SQLFrameDataFrame[Any, Any, Any, Any, Any]
_NativeDataFrame: TypeAlias = "DataFrame | SQLFrameDataFrame"
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
SQLFrameSession = _BaseSession[Any, Any, Any, Any, Any, Any, Any]

Incomplete: TypeAlias = Any # pragma: no cover
"""Marker for working code that fails type checking."""
Expand All @@ -50,15 +50,15 @@
class SparkLikeLazyFrame(CompliantLazyFrame):
def __init__(
self: Self,
native_dataframe: _NativeDataFrame,
native_dataframe: SQLFrameDataFrame,
*,
backend_version: tuple[int, ...],
version: Version,
implementation: Implementation,
# Unused, just for compatibility. We only validate when collecting.
validate_column_names: bool = False,
) -> None:
self._native_frame = native_dataframe
self._native_frame: SQLFrameDataFrame = native_dataframe
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an annotation to try and trigger an [arg-type] from mypy.

Still not sure why pyright is the only one to detect (d5e899a)

self._backend_version = backend_version
self._implementation = implementation
self._version = version
Expand All @@ -68,58 +68,38 @@ def __init__(
@property
def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802
if TYPE_CHECKING:
from pyspark.sql import functions
from sqlframe.base import functions

return functions
if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

return import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.functions"
)

from pyspark.sql import functions

return functions
else:
return import_functions(self._implementation)

@property
def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from pyspark.sql import types
from sqlframe.base import types

return types

if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

return import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.types"
)

from pyspark.sql import types

return types
else:
return import_native_dtypes(self._implementation)

@property
def _Window(self: Self) -> type[Window]: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

_window = import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
)
return _window.Window

from pyspark.sql import Window
if TYPE_CHECKING:
from sqlframe.base.window import Window

return Window
return Window
else:
return import_window(self._implementation)

@property
def _session(self: Self) -> SparkSession:
def _session(self: Self) -> SQLFrameSession:
if TYPE_CHECKING:
return self._native_frame.session
if self._implementation is Implementation.SQLFRAME:
return cast("SQLFrameDataFrame", self._native_frame).session
return self._native_frame.session

return cast("DataFrame", self._native_frame).sparkSession
return self._native_frame.sparkSession

def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
return self._implementation.to_native_namespace()
Expand All @@ -144,7 +124,7 @@ def _change_version(self: Self, version: Version) -> Self:
implementation=self._implementation,
)

def _from_native_frame(self: Self, df: DataFrame) -> Self:
def _from_native_frame(self: Self, df: SQLFrameDataFrame) -> Self:
return self.__class__(
df,
backend_version=self._backend_version,
Expand All @@ -158,7 +138,7 @@ def _collect_to_arrow(self) -> pa.Table:
):
import pyarrow as pa # ignore-banned-import

native_frame = cast("DataFrame", self._native_frame)
native_frame = self._native_frame
try:
return pa.Table.from_batches(native_frame._collect_as_arrow())
except ValueError as exc:
Expand All @@ -174,13 +154,12 @@ def _collect_to_arrow(self) -> pa.Table:
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001
native_spark_dtype = native_frame.schema[key].dataType
native_spark_dtype = native_frame.schema[key].dataType # type: ignore[index]
# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
if not isinstance(
native_spark_dtype, self._native_dtypes.NullType
):
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
if not isinstance(native_spark_dtype, null_type):
warnings.warn(
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
stacklevel=find_stacklevel(),
Expand All @@ -192,9 +171,7 @@ def _collect_to_arrow(self) -> pa.Table:
else: # pragma: no cover
raise
else:
# NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309
to_arrow: Incomplete = self._native_frame.toArrow
return to_arrow()
return self._native_frame.toArrow()

def _iter_columns(self) -> Iterator[Column]:
for col in self.columns:
Expand Down Expand Up @@ -250,7 +227,7 @@ def collect(
raise ValueError(msg) # pragma: no cover

def simple_select(self: Self, *column_names: str) -> Self:
return self._from_native_frame(self._native_frame.select(*column_names)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.select(*column_names))

def aggregate(
self: Self,
Expand All @@ -259,7 +236,7 @@ def aggregate(
new_columns = evaluate_exprs(self, *exprs)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
return self._from_native_frame(self._native_frame.agg(*new_columns_list)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))

def select(
self: Self,
Expand All @@ -274,17 +251,17 @@ def select(
return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
return self._from_native_frame(self._native_frame.select(*new_columns_list)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns))) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns)))

def filter(self: Self, predicate: SparkLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
condition = predicate._call(self)[0]
spark_df = self._native_frame.where(condition) # pyright: ignore[reportArgumentType]
return self._from_native_frame(spark_df) # pyright: ignore[reportArgumentType]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)

@property
def schema(self: Self) -> dict[str, DType]:
Expand All @@ -293,8 +270,7 @@ def schema(self: Self) -> dict[str, DType]:
field.name: native_to_narwhals_dtype(
dtype=field.dataType,
version=self._version,
# NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662)
spark_types=self._native_dtypes, # pyright: ignore[reportArgumentType]
spark_types=self._native_dtypes,
)
for field in self._native_frame.schema
}
Expand All @@ -307,10 +283,10 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))

def head(self: Self, n: int) -> Self:
return self._from_native_frame(self._native_frame.limit(num=n)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.limit(num=n))

def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
Expand Down Expand Up @@ -340,18 +316,18 @@ def sort(
)

sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.sort(*sort_cols))

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
return self._from_native_frame(self._native_frame.dropna(subset=subset)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.dropna(subset=subset))

def rename(self: Self, mapping: dict[str, str]) -> Self:
rename_mapping = {
colname: mapping.get(colname, colname) for colname in self.columns
}
return self._from_native_frame(
self._native_frame.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()] # pyright: ignore[reportArgumentType]
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)
)

Expand All @@ -365,7 +341,7 @@ def unique(
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)
check_column_exists(self.columns, subset)
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) # pyright: ignore[reportArgumentType]
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))

def join(
self: Self,
Expand Down Expand Up @@ -409,7 +385,7 @@ def join(
]
)
return self._from_native_frame(
self_native.join(other_native, on=left_on, how=how).select(col_order) # pyright: ignore[reportArgumentType]
self_native.join(other_native, on=left_on, how=how).select(col_order)
)

def explode(self: Self, columns: list[str]) -> Self:
Expand Down Expand Up @@ -445,7 +421,7 @@ def explode(self: Self, columns: list[str]) -> Self:
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
), # pyright: ignore[reportArgumentType]
)
)
elif self._implementation.is_sqlframe():
# Not every sqlframe dialect supports `explode_outer` function
Expand All @@ -466,14 +442,14 @@ def null_condition(col_name: str) -> Column:
for col_name in column_names
]
).union(
native_frame.filter(null_condition(columns[0])).select( # pyright: ignore[reportArgumentType]
native_frame.filter(null_condition(columns[0])).select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
) # pyright: ignore[reportArgumentType]
)
),
)
else: # pragma: no cover
Expand Down Expand Up @@ -508,4 +484,4 @@ def unpivot(
)
if index is None:
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
return self._from_native_frame(unpivoted_native_frame) # pyright: ignore[reportArgumentType]
return self._from_native_frame(unpivoted_native_frame)
Loading
Loading