Skip to content
2 changes: 1 addition & 1 deletion aeon/base/_base_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _preprocess_series(self, X, axis, store_metadata):
self.metadata_ = meta
return self._convert_X(X, axis)

def _check_X(self, X, axis):
def _check_X(self, X, axis: int = 0):
"""Check input X is valid.

Check if the input data is a compatible type, and that this estimator is
Expand Down
83 changes: 45 additions & 38 deletions aeon/forecasting/_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ class ETSForecaster(BaseForecaster):

def __init__(
self,
error_type=ADDITIVE,
trend_type=NONE,
seasonality_type=NONE,
seasonal_period=1,
alpha=0.1,
beta=0.01,
gamma=0.01,
phi=0.99,
horizon=1,
error_type: int = ADDITIVE,
trend_type: int = NONE,
seasonality_type: int = NONE,
seasonal_period: int = 1,
alpha: float = 0.1,
beta: float = 0.01,
gamma: float = 0.01,
phi: float = 0.99,
horizon: int = 1,
):
self.error_type = error_type
self.trend_type = trend_type
Expand Down Expand Up @@ -190,14 +190,14 @@ def _predict(self, y=None, exog=None):
@njit(nogil=NOGIL, cache=CACHE)
def _fit_numba(
data,
error_type,
trend_type,
seasonality_type,
seasonal_period,
alpha,
beta,
gamma,
phi,
error_type: int,
trend_type: int,
seasonality_type: int,
seasonal_period: int,
alpha: float,
beta: float,
gamma: float,
phi: float,
):
n_timepoints = len(data)
level, trend, seasonality = _initialise(
Expand Down Expand Up @@ -236,15 +236,15 @@ def _fit_numba(


def _predict_numba(
trend_type,
seasonality_type,
level,
trend,
seasonality,
phi,
horizon,
n_timepoints,
seasonal_period,
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
phi: float,
horizon: int,
n_timepoints: int,
seasonal_period: int,
):
# Generate forecasts based on the final values of level, trend, and seasonals
if phi == 1: # No damping case
Expand All @@ -264,7 +264,7 @@ def _predict_numba(


@njit(nogil=NOGIL, cache=CACHE)
def _initialise(trend_type, seasonality_type, seasonal_period, data):
def _initialise(trend_type: int, seasonality_type: int, seasonal_period: int, data):
"""
Initialize level, trend, and seasonality values for the ETS model.

Expand Down Expand Up @@ -307,17 +307,17 @@ def _initialise(trend_type, seasonality_type, seasonal_period, data):

@njit(nogil=NOGIL, cache=CACHE)
def _update_states(
error_type,
trend_type,
seasonality_type,
level,
trend,
seasonality,
error_type: int,
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
data_item: int,
alpha,
beta,
gamma,
phi,
alpha: float,
beta: float,
gamma: float,
phi: float,
):
"""
Update level, trend, and seasonality components.
Expand Down Expand Up @@ -374,7 +374,14 @@ def _update_states(


@njit(nogil=NOGIL, cache=CACHE)
def _predict_value(trend_type, seasonality_type, level, trend, seasonality, phi):
def _predict_value(
trend_type: int,
seasonality_type: int,
level: float,
trend: float,
seasonality: float,
phi: float,
):
"""

Generate various useful values, including the next fitted value.
Expand Down
4 changes: 2 additions & 2 deletions aeon/forecasting/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RegressionForecaster(BaseForecaster):
with sklearn regressors.
"""

def __init__(self, window, horizon=1, regressor=None):
def __init__(self, window: int, horizon: int = 1, regressor=None):
self.window = window
self.regressor = regressor
super().__init__(horizon=horizon, axis=1)
Expand Down Expand Up @@ -123,7 +123,7 @@ def _forecast(self, y, exog=None):
return self.predict()

@classmethod
def _get_test_params(cls, parameter_set="default"):
def _get_test_params(cls, parameter_set: str = "default"):
"""Return testing parameter settings for the estimator.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion aeon/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseForecaster(BaseSeriesEstimator):
"y_inner_type": "np.ndarray",
}

def __init__(self, horizon, axis):
def __init__(self, horizon: int, axis: int):
self.horizon = horizon
self.meta_ = None # Meta data related to y on the last fit
super().__init__(axis)
Expand Down