diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index e35a330974..2cb84475d0 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -71,7 +71,7 @@ from narwhals.typing import CompliantLazyFrame -class ArrowDataFrame(CompliantDataFrame, CompliantLazyFrame): +class ArrowDataFrame(CompliantDataFrame["ArrowSeries"], CompliantLazyFrame): # --- not in the spec --- def __init__( self: Self, @@ -354,24 +354,24 @@ def simple_select(self, *column_names: str) -> Self: self._native_frame.select(list(column_names)), validate_column_names=False ) - def aggregate(self: Self, *exprs: ArrowExpr) -> Self: + def aggregate(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: return self.select(*exprs) - def select(self: Self, *exprs: ArrowExpr) -> Self: - new_series: Sequence[ArrowSeries] = evaluate_into_exprs(self, *exprs) + def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: + new_series = evaluate_into_exprs(self, *exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame( self._native_frame.__class__.from_arrays([]), validate_column_names=False ) names = [s.name for s in new_series] - new_series = align_series_full_broadcast(*new_series) - df = pa.Table.from_arrays([s._native_series for s in new_series], names=names) + reshaped = align_series_full_broadcast(*new_series) + df = pa.Table.from_arrays([s._native_series for s in reshaped], names=names) return self._from_native_frame(df, validate_column_names=True) - def with_columns(self: Self, *exprs: ArrowExpr) -> Self: + def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: native_frame = self._native_frame - new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs) + new_columns = evaluate_into_exprs(self, *exprs) length = len(self) columns = self.columns @@ -469,7 +469,7 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 self._native_frame.drop(to_drop), validate_column_names=False ) - def drop_nulls(self: Self, subset: list[str] | None) -> Self: + def drop_nulls(self: ArrowDataFrame, subset: list[str] | None) -> ArrowDataFrame: if subset is None: return self._from_native_frame( self._native_frame.drop_null(), validate_column_names=False @@ -551,7 +551,9 @@ def with_row_index(self: Self, name: str) -> Self: df.append_column(name, row_indices).select([name, *cols]) ) - def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self: + def filter( + self: ArrowDataFrame, predicate: ArrowExpr | list[bool | None] + ) -> ArrowDataFrame: if isinstance(predicate, list): mask_native: Mask | ArrowChunkedArray = predicate else: @@ -627,7 +629,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: if backend is Implementation.PYARROW or backend is None: from narwhals._arrow.dataframe import ArrowDataFrame @@ -743,12 +745,12 @@ def is_unique(self: Self) -> ArrowSeries: ) def unique( - self: Self, + self: ArrowDataFrame, subset: list[str] | None, *, keep: Literal["any", "first", "last", "none"], maintain_order: bool | None = None, - ) -> Self: + ) -> ArrowDataFrame: # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. import numpy as np # ignore-banned-import diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 2409d7316a..e77d05d742 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -89,7 +89,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: import pandas as pd result = self._native_frame.compute(**kwargs) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index e194edd4f3..dc74eae824 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -89,7 +89,7 @@ def collect( self: Self, backend: ModuleType | Implementation | str | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: if backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index c3323a5f2e..11819495e6 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -30,7 +30,7 @@ from narwhals.expr import Expr from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr - from narwhals.typing import CompliantFrameT_contra + from narwhals.typing import CompliantFrameT from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantNamespace from narwhals.typing import CompliantSeries @@ -52,8 +52,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]: def evaluate_into_expr( - df: CompliantFrameT_contra, - expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + df: CompliantFrameT, + expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> Sequence[CompliantSeriesT_co]: """Return list of raw columns. @@ -73,9 +73,9 @@ def evaluate_into_expr( def evaluate_into_exprs( - df: CompliantFrameT_contra, + df: CompliantFrameT, /, - *exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + *exprs: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> list[CompliantSeriesT_co]: """Evaluate each expr into Series.""" return [ @@ -87,13 +87,13 @@ def evaluate_into_exprs( @overload def maybe_evaluate_expr( - df: CompliantFrameT_contra, - expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + df: CompliantFrameT, + expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> CompliantSeriesT_co: ... @overload -def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ... +def maybe_evaluate_expr(df: CompliantDataFrame[Any], expr: T) -> T: ... def maybe_evaluate_expr( @@ -155,7 +155,7 @@ def reuse_series_implementation( """ plx = expr.__narwhals_namespace__() - def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: + def func(df: CompliantDataFrame[Any]) -> Sequence[CompliantSeries]: _kwargs = { **(call_kwargs or {}), **{ @@ -258,15 +258,15 @@ def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool: def combine_evaluate_output_names( - *exprs: CompliantExpr[CompliantFrameT_contra, Any], -) -> Callable[[CompliantFrameT_contra], Sequence[str]]: + *exprs: CompliantExpr[CompliantFrameT, Any], +) -> Callable[[CompliantFrameT], Sequence[str]]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. if not is_compliant_expr(exprs[0]): # pragma: no cover msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." raise AssertionError(msg) - def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]: + def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] return evaluate_output_names @@ -287,11 +287,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co], + plx: CompliantNamespace[CompliantFrameT, CompliantSeriesT_co], other: Any, *, str_as_lit: bool, -) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object: +) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: @@ -306,7 +306,7 @@ def extract_compliant( def evaluate_output_names_and_aliases( expr: CompliantExpr[Any, Any], - df: CompliantDataFrame | CompliantLazyFrame, + df: CompliantDataFrame[Any] | CompliantLazyFrame, exclude: Sequence[str], ) -> tuple[Sequence[str], Sequence[str]]: output_names = expr._evaluate_output_names(df) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 7142372982..f2a336146e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -83,7 +83,7 @@ ) -class PandasLikeDataFrame(CompliantDataFrame, CompliantLazyFrame): +class PandasLikeDataFrame(CompliantDataFrame["PandasLikeSeries"], CompliantLazyFrame): # --- not in the spec --- def __init__( self: Self, @@ -396,11 +396,13 @@ def simple_select(self: Self, *column_names: str) -> Self: validate_column_names=False, ) - def aggregate(self: Self, *exprs: PandasLikeExpr) -> Self: + def aggregate( + self: PandasLikeDataFrame, *exprs: PandasLikeExpr + ) -> PandasLikeDataFrame: return self.select(*exprs) - def select(self: Self, *exprs: PandasLikeExpr) -> Self: - new_series: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs) + def select(self: PandasLikeDataFrame, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: + new_series = evaluate_into_exprs(self, *exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame( @@ -414,7 +416,9 @@ def select(self: Self, *exprs: PandasLikeExpr) -> Self: ) return self._from_native_frame(df, validate_column_names=True) - def drop_nulls(self: Self, subset: list[str] | None) -> Self: + def drop_nulls( + self: PandasLikeDataFrame, subset: list[str] | None + ) -> PandasLikeDataFrame: if subset is None: return self._from_native_frame( self._native_frame.dropna(axis=0), validate_column_names=False @@ -445,7 +449,9 @@ def with_row_index(self: Self, name: str) -> Self: def row(self: Self, row: int) -> tuple[Any, ...]: return tuple(x for x in self._native_frame.iloc[row]) - def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self: + def filter( + self: PandasLikeDataFrame, predicate: PandasLikeExpr | list[bool] + ) -> PandasLikeDataFrame: if isinstance(predicate, list): mask_native: pd.Series[Any] | list[bool] = predicate else: @@ -457,9 +463,11 @@ def filter(self: Self, predicate: PandasLikeExpr | list[bool]) -> Self: self._native_frame.loc[mask_native], validate_column_names=False ) - def with_columns(self: Self, *exprs: PandasLikeExpr) -> Self: + def with_columns( + self: PandasLikeDataFrame, *exprs: PandasLikeExpr + ) -> PandasLikeDataFrame: index = self._native_frame.index - new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self, *exprs) + new_columns = evaluate_into_exprs(self, *exprs) if not new_columns and len(self) == 0: return self @@ -528,7 +536,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: if backend is None: return PandasLikeDataFrame( self._native_frame, diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 30a7590b4c..bc25184e11 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -474,7 +474,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: try: result = self._native_frame.collect(**kwargs) except Exception as e: # noqa: BLE001 diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 1f91fbdbad..73133a3bea 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -115,6 +115,9 @@ def dtype(self: Self) -> DType: self._native_series.dtype, self._version, self._backend_version ) + def alias(self, name: str) -> Self: + return self._from_native_object(self._native_series.alias(name)) + @overload def __getitem__(self: Self, item: int) -> Any: ... diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index be9c100606..d4a792e678 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -203,7 +203,7 @@ def collect( self: Self, backend: ModuleType | Implementation | str | None, **kwargs: Any, - ) -> CompliantDataFrame: + ) -> CompliantDataFrame[Any]: if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import diff --git a/narwhals/typing.py b/narwhals/typing.py index 9c868694a7..ceb4677220 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Generic from typing import Literal from typing import Protocol from typing import Sequence @@ -52,7 +51,12 @@ def __narwhals_series__(self) -> CompliantSeries: ... def alias(self, name: str) -> Self: ... -class CompliantDataFrame(Protocol): +CompliantSeriesT_co = TypeVar( + "CompliantSeriesT_co", bound=CompliantSeries, covariant=True +) + + +class CompliantDataFrame(Protocol[CompliantSeriesT_co]): def __narwhals_dataframe__(self) -> Self: ... def __narwhals_namespace__(self) -> Any: ... def simple_select( @@ -64,6 +68,7 @@ def aggregate(self, *exprs: Any) -> Self: @property def columns(self) -> Sequence[str]: ... + def get_column(self, name: str) -> CompliantSeriesT_co: ... class CompliantLazyFrame(Protocol): @@ -80,30 +85,25 @@ def aggregate(self, *exprs: Any) -> Self: def columns(self) -> Sequence[str]: ... -CompliantFrameT_contra = TypeVar( - "CompliantFrameT_contra", - bound="CompliantDataFrame | CompliantLazyFrame", - contravariant=True, -) -CompliantSeriesT_co = TypeVar( - "CompliantSeriesT_co", bound=CompliantSeries, covariant=True +CompliantFrameT = TypeVar( + "CompliantFrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame" ) -class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): +class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]): _implementation: Implementation _backend_version: tuple[int, ...] _version: Version - _evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]] + _evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]] _alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None _depth: int _function_name: str - def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ... + def __call__(self, df: CompliantFrameT) -> Sequence[CompliantSeriesT_co]: ... def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, - ) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantNamespace[CompliantFrameT, CompliantSeriesT_co]: ... def is_null(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: DType) -> Self: ... @@ -125,13 +125,13 @@ def broadcast( ) -> Self: ... -class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): +class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesT_co]): def col( self, *column_names: str - ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ... def lit( self, value: Any, dtype: DType | None - ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ... class SupportsNativeNamespace(Protocol): @@ -139,7 +139,7 @@ def __native_namespace__(self) -> ModuleType: ... IntoCompliantExpr: TypeAlias = ( - "CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co" + "CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | CompliantSeriesT_co" ) IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"] diff --git a/narwhals/utils.py b/narwhals/utils.py index 9d286398c7..6cf39d7ba4 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -54,7 +54,7 @@ from narwhals.series import Series from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr - from narwhals.typing import CompliantFrameT_contra + from narwhals.typing import CompliantFrameT from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantSeries from narwhals.typing import CompliantSeriesT_co @@ -1357,7 +1357,9 @@ def _hasattr_static(obj: Any, attr: str) -> bool: return getattr_static(obj, attr, sentinel) is not sentinel -def is_compliant_dataframe(obj: Any) -> TypeIs[CompliantDataFrame]: +def is_compliant_dataframe( + obj: CompliantDataFrame[CompliantSeriesT_co] | Any, +) -> TypeIs[CompliantDataFrame[CompliantSeriesT_co]]: return _hasattr_static(obj, "__narwhals_dataframe__") @@ -1370,8 +1372,8 @@ def is_compliant_series(obj: Any) -> TypeIs[CompliantSeries]: def is_compliant_expr( - obj: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | Any, -) -> TypeIs[CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]]: + obj: CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | Any, +) -> TypeIs[CompliantExpr[CompliantFrameT, CompliantSeriesT_co]]: return hasattr(obj, "__narwhals_expr__")