Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from narwhals.typing import CompliantLazyFrame


class ArrowDataFrame(CompliantDataFrame, CompliantLazyFrame):
class ArrowDataFrame(CompliantDataFrame["ArrowSeries"], CompliantLazyFrame):
# --- not in the spec ---
def __init__(
self: Self,
Expand Down Expand Up @@ -354,24 +354,24 @@ def simple_select(self, *column_names: str) -> Self:
self._native_frame.select(list(column_names)), validate_column_names=False
)

def aggregate(self: Self, *exprs: ArrowExpr) -> Self:
def aggregate(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need changing?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli
pyright is fine with Self IIRC - but mypy still needs all the extra inline annotations that we have on main

def aggregate(self: Self, *exprs: ArrowExpr) -> Self:
return self.select(*exprs)
def select(self: Self, *exprs: ArrowExpr) -> Self:
new_series: Sequence[ArrowSeries] = evaluate_into_exprs(self, *exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(
self._native_frame.__class__.from_arrays([]), validate_column_names=False
)
names = [s.name for s in new_series]
new_series = align_series_full_broadcast(*new_series)
df = pa.Table.from_arrays([s._native_series for s in new_series], names=names)
return self._from_native_frame(df, validate_column_names=True)
def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
native_frame = self._native_frame
new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs)

I had a tough time with that as well in (#2055) - where I've tried moving it into a method:

class ReuseDataFrame(CompliantDataFrame, Protocol[ReuseSeriesT]):
def _evaluate_into_expr(
self: ReuseDataFrameT, expr: ReuseExpr[ReuseDataFrameT, ReuseSeriesT], /
) -> Sequence[ReuseSeriesT]:
_, aliases = evaluate_output_names_and_aliases(expr, self, [])
result = expr(self)
if list(aliases) != [s.name for s in result]:
msg = f"Safety assertion failed, expected {aliases}, got {result}"
raise AssertionError(msg)
return result
def _evaluate_into_exprs(
self, *exprs: ReuseExpr[Self, ReuseSeriesT]
) -> Sequence[ReuseSeriesT]:
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs))
# NOTE: `mypy` is **very** fragile here in what is permitted
# DON'T CHANGE UNLESS IT SOLVES ANOTHER ISSUE
def _maybe_evaluate_expr(
self, expr: ReuseExpr[ReuseDataFrame[ReuseSeriesT], ReuseSeriesT] | T, /
) -> ReuseSeriesT | T:
if is_reuse_expr(expr):
result: Sequence[ReuseSeriesT] = expr(self)
if len(result) > 1:
msg = (
"Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) "
"are not supported in this context"
)
raise ValueError(msg)
return result[0]
return expr
# NOTE: DON'T CHANGE THIS EITHER
def is_reuse_expr(
obj: ReuseExpr[Any, ReuseSeriesT] | Any,
) -> TypeIs[ReuseExpr[Any, ReuseSeriesT]]:
return hasattr(obj, "__narwhals_expr__")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the log from (c11dc95)

narwhals/_arrow/dataframe.py:350: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
            new_series: Sequence[ArrowSeries] = evaluate_into_exprs(self, ...
                                                ^~~~~~~~~~~~~~~~~~~~~~~~~~...
narwhals/_arrow/dataframe.py:363: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
            new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *ex...
                                             ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~...
narwhals/_arrow/dataframe.py:548: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
                mask_native = evaluate_into_exprs(self, predicate)[0]._nat...
                              ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
narwhals/_pandas_like/dataframe.py:395: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
            new_series: list[PandasLikeSeries] = evaluate_into_exprs(self,...
                                                 ^~~~~~~~~~~~~~~~~~~~~~~~~...
narwhals/_pandas_like/dataframe.py:445: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
                mask = evaluate_into_exprs(self, predicate)[0]
                       ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
narwhals/_pandas_like/dataframe.py:454: error: Cannot infer type argument 1 of
"evaluate_into_exprs"  [misc]
            new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self...
                                                  ^~~~~~~~~~~~~~~~~~~~~~~~...

return self.select(*exprs)

def select(self: Self, *exprs: ArrowExpr) -> Self:
new_series: Sequence[ArrowSeries] = evaluate_into_exprs(self, *exprs)
def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
new_series = evaluate_into_exprs(self, *exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(
self._native_frame.__class__.from_arrays([]), validate_column_names=False
)
names = [s.name for s in new_series]
new_series = align_series_full_broadcast(*new_series)
df = pa.Table.from_arrays([s._native_series for s in new_series], names=names)
reshaped = align_series_full_broadcast(*new_series)
df = pa.Table.from_arrays([s._native_series for s in reshaped], names=names)
return self._from_native_frame(df, validate_column_names=True)

def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
native_frame = self._native_frame
new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs)
new_columns = evaluate_into_exprs(self, *exprs)

length = len(self)
columns = self.columns
Expand Down Expand Up @@ -469,7 +469,7 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
self._native_frame.drop(to_drop), validate_column_names=False
)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
def drop_nulls(self: ArrowDataFrame, subset: list[str] | None) -> ArrowDataFrame:
if subset is None:
return self._from_native_frame(
self._native_frame.drop_null(), validate_column_names=False
Expand Down Expand Up @@ -551,7 +551,9 @@ def with_row_index(self: Self, name: str) -> Self:
df.append_column(name, row_indices).select([name, *cols])
)

def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self:
def filter(
self: ArrowDataFrame, predicate: ArrowExpr | list[bool | None]
) -> ArrowDataFrame:
if isinstance(predicate, list):
mask_native: Mask | ArrowChunkedArray = predicate
else:
Expand Down Expand Up @@ -627,7 +629,7 @@ def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
if backend is Implementation.PYARROW or backend is None:
from narwhals._arrow.dataframe import ArrowDataFrame

Expand Down Expand Up @@ -743,12 +745,12 @@ def is_unique(self: Self) -> ArrowSeries:
)

def unique(
self: Self,
self: ArrowDataFrame,
subset: list[str] | None,
*,
keep: Literal["any", "first", "last", "none"],
maintain_order: bool | None = None,
) -> Self:
) -> ArrowDataFrame:
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
import numpy as np # ignore-banned-import
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
import pandas as pd

result = self._native_frame.compute(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def collect(
self: Self,
backend: ModuleType | Implementation | str | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
if backend is None or backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import

Expand Down
30 changes: 15 additions & 15 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from narwhals.expr import Expr
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantExpr
from narwhals.typing import CompliantFrameT_contra
from narwhals.typing import CompliantFrameT
from narwhals.typing import CompliantLazyFrame
from narwhals.typing import CompliantNamespace
from narwhals.typing import CompliantSeries
Expand All @@ -52,8 +52,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]:


def evaluate_into_expr(
df: CompliantFrameT_contra,
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
df: CompliantFrameT,
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
) -> Sequence[CompliantSeriesT_co]:
"""Return list of raw columns.

Expand All @@ -73,9 +73,9 @@ def evaluate_into_expr(


def evaluate_into_exprs(
df: CompliantFrameT_contra,
df: CompliantFrameT,
/,
*exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
*exprs: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
) -> list[CompliantSeriesT_co]:
"""Evaluate each expr into Series."""
return [
Expand All @@ -87,13 +87,13 @@ def evaluate_into_exprs(

@overload
def maybe_evaluate_expr(
df: CompliantFrameT_contra,
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
df: CompliantFrameT,
expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co],
) -> CompliantSeriesT_co: ...


@overload
def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ...
def maybe_evaluate_expr(df: CompliantDataFrame[Any], expr: T) -> T: ...


def maybe_evaluate_expr(
Expand Down Expand Up @@ -155,7 +155,7 @@ def reuse_series_implementation(
"""
plx = expr.__narwhals_namespace__()

def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]:
def func(df: CompliantDataFrame[Any]) -> Sequence[CompliantSeries]:
_kwargs = {
**(call_kwargs or {}),
**{
Expand Down Expand Up @@ -258,15 +258,15 @@ def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool:


def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT_contra, Any],
) -> Callable[[CompliantFrameT_contra], Sequence[str]]:
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> Callable[[CompliantFrameT], Sequence[str]]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
if not is_compliant_expr(exprs[0]): # pragma: no cover
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
raise AssertionError(msg)

def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]:
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]

return evaluate_output_names
Expand All @@ -287,11 +287,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:


def extract_compliant(
plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co],
plx: CompliantNamespace[CompliantFrameT, CompliantSeriesT_co],
other: Any,
*,
str_as_lit: bool,
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object:
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | object:
if is_expr(other):
return other._to_compliant_expr(plx)
if isinstance(other, str) and not str_as_lit:
Expand All @@ -306,7 +306,7 @@ def extract_compliant(

def evaluate_output_names_and_aliases(
expr: CompliantExpr[Any, Any],
df: CompliantDataFrame | CompliantLazyFrame,
df: CompliantDataFrame[Any] | CompliantLazyFrame,
exclude: Sequence[str],
) -> tuple[Sequence[str], Sequence[str]]:
output_names = expr._evaluate_output_names(df)
Expand Down
26 changes: 17 additions & 9 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)


class PandasLikeDataFrame(CompliantDataFrame, CompliantLazyFrame):
class PandasLikeDataFrame(CompliantDataFrame["PandasLikeSeries"], CompliantLazyFrame):
# --- not in the spec ---
def __init__(
self: Self,
Expand Down Expand Up @@ -396,11 +396,13 @@ def simple_select(self: Self, *column_names: str) -> Self:
validate_column_names=False,
)

def aggregate(self: Self, *exprs: PandasLikeExpr) -> Self:
def aggregate(
self: PandasLikeDataFrame, *exprs: PandasLikeExpr
) -> PandasLikeDataFrame:
return self.select(*exprs)

def select(self: Self, *exprs: PandasLikeExpr) -> Self:
new_series: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs)
def select(self: PandasLikeDataFrame, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
new_series = evaluate_into_exprs(self, *exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(
Expand All @@ -414,7 +416,9 @@ def select(self: Self, *exprs: PandasLikeExpr) -> Self:
)
return self._from_native_frame(df, validate_column_names=True)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
def drop_nulls(
self: PandasLikeDataFrame, subset: list[str] | None
) -> PandasLikeDataFrame:
if subset is None:
return self._from_native_frame(
self._native_frame.dropna(axis=0), validate_column_names=False
Expand Down Expand Up @@ -445,7 +449,9 @@ def with_row_index(self: Self, name: str) -> Self:
def row(self: Self, row: int) -> tuple[Any, ...]:
return tuple(x for x in self._native_frame.iloc[row])

def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self:
def filter(
self: PandasLikeDataFrame, predicate: PandasLikeExpr | list[bool]
) -> PandasLikeDataFrame:
if isinstance(predicate, list):
mask_native: pd.Series[Any] | list[bool] = predicate
else:
Expand All @@ -457,9 +463,11 @@ def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self:
self._native_frame.loc[mask_native], validate_column_names=False
)

def with_columns(self: Self, *exprs: PandasLikeExpr) -> Self:
def with_columns(
self: PandasLikeDataFrame, *exprs: PandasLikeExpr
) -> PandasLikeDataFrame:
index = self._native_frame.index
new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs)
new_columns = evaluate_into_exprs(self, *exprs)
if not new_columns and len(self) == 0:
return self

Expand Down Expand Up @@ -528,7 +536,7 @@ def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
if backend is None:
return PandasLikeDataFrame(
self._native_frame,
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
try:
result = self._native_frame.collect(**kwargs)
except Exception as e: # noqa: BLE001
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def dtype(self: Self) -> DType:
self._native_series.dtype, self._version, self._backend_version
)

def alias(self, name: str) -> Self:
return self._from_native_object(self._native_series.alias(name))

@overload
def __getitem__(self: Self, item: int) -> Any: ...

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def collect(
self: Self,
backend: ModuleType | Implementation | str | None,
**kwargs: Any,
) -> CompliantDataFrame:
) -> CompliantDataFrame[Any]:
if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import

Expand Down
34 changes: 17 additions & 17 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generic
from typing import Literal
from typing import Protocol
from typing import Sequence
Expand Down Expand Up @@ -52,7 +51,12 @@ def __narwhals_series__(self) -> CompliantSeries: ...
def alias(self, name: str) -> Self: ...


class CompliantDataFrame(Protocol):
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
)


class CompliantDataFrame(Protocol[CompliantSeriesT_co]):
def __narwhals_dataframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
def simple_select(
Expand All @@ -64,6 +68,7 @@ def aggregate(self, *exprs: Any) -> Self:

@property
def columns(self) -> Sequence[str]: ...
def get_column(self, name: str) -> CompliantSeriesT_co: ...
Comment on lines 69 to +71
Copy link
Copy Markdown
Member Author

@dangotbanned dangotbanned Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this, the only part that should be new in (#2064) is adding .schema for this and CompliantLazyFrame.

I can add that here, split into another PR, or just include in (#2064)?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli I chose get_column just to provide a method that uses CompliantSeriesT_co.

The eventual goal will be (#2104 (comment))



class CompliantLazyFrame(Protocol):
Expand All @@ -80,30 +85,25 @@ def aggregate(self, *exprs: Any) -> Self:
def columns(self) -> Sequence[str]: ...


CompliantFrameT_contra = TypeVar(
"CompliantFrameT_contra",
bound="CompliantDataFrame | CompliantLazyFrame",
contravariant=True,
)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
CompliantFrameT = TypeVar(
"CompliantFrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame"
)


class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
_evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]]
_evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]]
_alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None
_depth: int
_function_name: str

def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ...
def __call__(self, df: CompliantFrameT) -> Sequence[CompliantSeriesT_co]: ...
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(
self,
) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ...
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesT_co]: ...
def is_null(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def cast(self, dtype: DType) -> Self: ...
Expand All @@ -125,21 +125,21 @@ def broadcast(
) -> Self: ...


class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesT_co]):
def col(
self, *column_names: str
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...
def lit(
self, value: Any, dtype: DType | None
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ...


class SupportsNativeNamespace(Protocol):
def __native_namespace__(self) -> ModuleType: ...


IntoCompliantExpr: TypeAlias = (
"CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co"
"CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | CompliantSeriesT_co"
)

IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"]
Expand Down
Loading
Loading