Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions pytorch_forecasting/data/_encoders_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import warnings

import numpy as np
import pandas as pd


class D1CategoricalEncoder:
"""
Categorical Label Encoder for the v2 D1 Layer.
Scans specified categorical columns and maps string/text values to integers.
"""

def __init__(
self,
columns: str | list[str] = None,
handle_unknown: str = "assign_new",
):
"""
Args:
columns: List of column names to encode. If None, it will encode
all object/category columns.
handle_unknown: How to handle unseen categories during transform.
'assign_new' gives them a new integer (like 0).
"""
self.columns = [columns] if isinstance(columns, str) else columns
self.handle_unknown = handle_unknown

self.mapping_: dict[str, dict[str, int]] = {}
self.inverse_mapping_: dict[str, dict[int, str]] = {}
self._is_fitted = False
self._warned_cols = set()

def fit(self, df: pd.DataFrame):
"""Learns the vocabulary from the dataframe."""
if self.columns is None:
self.columns = df.select_dtypes(

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.14)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.11)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.13)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.14)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.11)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.12)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (macos-latest, 3.12)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / no-softdeps (ubuntu-latest, 3.13)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (macos-latest, 3.11)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (macos-latest, 3.14)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (macos-latest, 3.12)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (macos-latest, 3.13)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (ubuntu-latest, 3.14)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (ubuntu-latest, 3.11)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (ubuntu-latest, 3.12)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.

Check warning on line 36 in pytorch_forecasting/data/_encoders_v2.py

View workflow job for this annotation

GitHub Actions / Run pytest (ubuntu-latest, 3.13)

For backward compatibility, 'str' dtypes are included by select_dtypes when 'object' dtype is specified. This behavior is deprecated and will be removed in a future version. Explicitly pass 'str' to `include` to select them, or to `exclude` to remove them and silence this warning. See https://pandas.pydata.org/docs/user_guide/migration-3-strings.html#string-migration-select-dtypes for details on how to write code that works with pandas 2 and 3.
include=["object", "category"]
).columns.tolist()

for col in self.columns:
if col not in df.columns:
raise ValueError(f"Column '{col}' not found in dataframe.")

series = df[col].fillna("NaN_CATEGORY")

_, uniques = pd.factorize(series, sort=True)

self.mapping_[col] = {val: idx + 1 for idx, val in enumerate(uniques)}

self.inverse_mapping_[col] = {
idx: val for val, idx in self.mapping_[col].items()
}

self._is_fitted = True
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Applies the integer translation to the dataframe."""
if not self._is_fitted:
raise RuntimeError("You must call fit() before transform().")

df_encoded = df.copy()

for col in self.columns:
if col not in df_encoded.columns:
continue

series = df_encoded[col].fillna("NaN_CATEGORY")

encoded_col = series.map(self.mapping_[col])

if encoded_col.isna().any():
if self.handle_unknown == "assign_new":
encoded_col = encoded_col.fillna(0)
if col not in self._warned_cols:
warnings.warn(
f"Unseen categories found in column '{col}'. "
"Assigned to index 0."
)
self._warned_cols.add(col)
else:
raise ValueError(
f"Unseen categories found in column '{col}' "
"and handle_unknown!='assign_new'"
)

df_encoded[col] = encoded_col.astype(int)

return df_encoded

def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Translates the integers back into original text/categories."""
if not self._is_fitted:
raise RuntimeError("You must call fit() before inverse_transform().")

df_decoded = df.copy()

for col in self.columns:
if col not in df_decoded.columns:
continue

df_decoded[col] = df_decoded[col].map(self.inverse_mapping_[col])
df_decoded[col] = df_decoded[col].replace("NaN_CATEGORY", np.nan)

return df_decoded
12 changes: 12 additions & 0 deletions pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch.utils.data import Dataset

from pytorch_forecasting.data._encoders_v2 import D1CategoricalEncoder
from pytorch_forecasting.utils._coerce import _coerce_to_list

#######################################################################################
Expand Down Expand Up @@ -125,6 +126,17 @@
self._unknown = _coerce_to_list(unknown)
self._static = _coerce_to_list(static)

self.categorical_encoder = None
if self._cat:
self.categorical_encoder = D1CategoricalEncoder(columns=self._cat)

self.categorical_encoder.fit(self.data)

self.data = self.categorical_encoder.transform(self.data)

if self.data_future is not None:
self.data_future = self.categorical_encoder.transform(self.data_future)

self.feature_cols = [
col
for col in data.columns
Expand All @@ -136,7 +148,7 @@
if isinstance(self._group, (list, tuple)) and len(self._group) == 1
else self._group
)
self._groups = self.data.groupby(group_arg).groups

Check warning on line 151 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 151 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 151 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 151 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
self._group_ids = list(self._groups.keys())
else:
self._groups = {"_single_group": self.data.index}
Expand Down
91 changes: 91 additions & 0 deletions tests/test_data/test_encoders_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import pandas as pd
import pytest

from pytorch_forecasting.data._encoders_v2 import D1CategoricalEncoder


@pytest.fixture
def sample_data():
"""Provides a fresh, abstract dataframe for each test."""
return pd.DataFrame(
{
"cat1": ["a", "b", "c", "a", np.nan],
"cat2": ["x", "y", "z", "y", "x"],
"num1": [1.1, 2.2, 3.3, 4.4, 5.5],
}
)


def test_encoder_fit_transform(sample_data):
"""Validates encoding of categorical columns and preservation of numeric columns."""
encoder = D1CategoricalEncoder(columns=["cat1", "cat2"])
encoded_df = encoder.fit(sample_data).transform(sample_data)

assert pd.api.types.is_integer_dtype(encoded_df["cat1"])
assert pd.api.types.is_integer_dtype(encoded_df["cat2"])

assert encoded_df["num1"].equals(sample_data["num1"])

assert not encoded_df["cat1"].isna().any()


def test_encoder_inverse_transform(sample_data):
"""Ensures inverse transformation restores original values including NaNs."""
encoder = D1CategoricalEncoder(columns=["cat1"])
encoded_df = encoder.fit(sample_data).transform(sample_data)

decoded_df = encoder.inverse_transform(encoded_df)

pd.testing.assert_series_equal(decoded_df["cat1"], sample_data["cat1"])


def test_unseen_variables_warning(sample_data):
"""Checks unseen category handling and warning behavior."""
encoder = D1CategoricalEncoder(columns=["cat1"], handle_unknown="assign_new")
encoder.fit(sample_data)

new_data = pd.DataFrame({"cat1": ["q", "a"]})

with pytest.warns(UserWarning, match="Unseen categories found in column 'cat1'"):
encoded_new = encoder.transform(new_data)

assert encoded_new.loc[0, "cat1"] == 0


def test_only_categorical_columns_selected(sample_data):
"""Ensures only categorical columns are encoded when columns=None."""
encoder = D1CategoricalEncoder()
encoder.fit(sample_data)

assert "num1" not in encoder.mapping_


def test_unfitted_errors(sample_data):
"""Ensures the encoder blocks transforms before fitting."""
encoder = D1CategoricalEncoder(columns=["cat1"])

with pytest.raises(RuntimeError, match="You must call fit"):
encoder.transform(sample_data)

with pytest.raises(RuntimeError, match="You must call fit"):
encoder.inverse_transform(sample_data)


def test_missing_column_error(sample_data):
"""Ensures fitting fails gracefully if asked to encode a non-existent column."""
encoder = D1CategoricalEncoder(columns=["phantom_column"])

with pytest.raises(ValueError, match="not found in dataframe"):
encoder.fit(sample_data)


def test_invalid_handle_unknown_strategy(sample_data):
"""Ensures unseen variables crash safely if the strategy isn't 'assign_new'."""
encoder = D1CategoricalEncoder(columns=["cat1"], handle_unknown="strict")
encoder.fit(sample_data)

new_data = pd.DataFrame({"cat1": ["unseen_string"]})

with pytest.raises(ValueError, match="handle_unknown!='assign_new'"):
encoder.transform(new_data)
Loading