Skip to content

Commit bf68b09

Browse files
committed
Minor improvements like single db connection, removal of default creds & validation for mutually exclusive fields
Signed-off-by: Aniket Paluskar <apaluska@redhat.com>
1 parent 8a3d2db commit bf68b09

File tree

1 file changed

+30
-14
lines changed
  • sdk/python/feast/infra/offline_stores/contrib/oracle_offline_store

1 file changed

+30
-14
lines changed

sdk/python/feast/infra/offline_stores/contrib/oracle_offline_store/oracle.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
import pyarrow
88
from ibis.expr.types import Table
9-
from pydantic import StrictInt, StrictStr
9+
from pydantic import StrictInt, StrictStr, model_validator
1010

1111
from feast.data_source import DataSource
1212
from feast.feature_logging import LoggingConfig, LoggingSource
@@ -61,29 +61,31 @@ def _read_oracle_table(con, data_source: DataSource) -> Table:
6161
return con.table(data_source.table_ref)
6262

6363

64-
def _build_data_source_reader(config: RepoConfig):
64+
def _build_data_source_reader(config: RepoConfig, con=None):
6565
"""Build a reader that returns Oracle-backend ibis tables.
6666
6767
Used by ``pull_latest`` and ``pull_all`` where all operations happen on a
6868
single backend (Oracle) and no cross-backend joins are needed.
6969
"""
70-
con = get_ibis_connection(config)
70+
if con is None:
71+
con = get_ibis_connection(config)
7172

7273
def _read_data_source(data_source: DataSource, repo_path: str = "") -> Table:
7374
return _read_oracle_table(con, data_source)
7475

7576
return _read_data_source
7677

7778

78-
def _build_data_source_reader_for_retrieval(config: RepoConfig):
79+
def _build_data_source_reader_for_retrieval(config: RepoConfig, con=None):
7980
"""Build a reader that materializes Oracle data into an in-memory table.
8081
8182
Used by ``get_historical_features`` which joins feature tables with an
8283
in-memory entity table (``ibis.memtable``). Both sides must be on the
8384
same backend for computed columns like ``entity_row_id`` to survive the
8485
join — converting to memtable ensures this.
8586
"""
86-
con = get_ibis_connection(config)
87+
if con is None:
88+
con = get_ibis_connection(config)
8789

8890
def _read_data_source(data_source: DataSource, repo_path: str = "") -> Table:
8991
table = _read_oracle_table(con, data_source)
@@ -92,9 +94,10 @@ def _read_data_source(data_source: DataSource, repo_path: str = "") -> Table:
9294
return _read_data_source
9395

9496

95-
def _build_data_source_writer(config: RepoConfig):
97+
def _build_data_source_writer(config: RepoConfig, con=None):
9698
"""Build a function that writes data to an Oracle table via ibis."""
97-
con = get_ibis_connection(config)
99+
if con is None:
100+
con = get_ibis_connection(config)
98101

99102
def _write_data_source(
100103
table: Table,
@@ -125,10 +128,10 @@ class OracleOfflineStoreConfig(FeastConfigBaseModel):
125128
type: Literal["oracle"] = "oracle"
126129
"""Offline store type selector"""
127130

128-
user: StrictStr = "system"
131+
user: StrictStr
129132
"""Oracle database user"""
130133

131-
password: StrictStr = "oracle123"
134+
password: StrictStr
132135
"""Oracle database password"""
133136

134137
host: StrictStr = "localhost"
@@ -149,6 +152,18 @@ class OracleOfflineStoreConfig(FeastConfigBaseModel):
149152
dsn: Optional[StrictStr] = None
150153
"""Oracle DSN string (mutually exclusive with service_name and sid)"""
151154

155+
@model_validator(mode="after")
156+
def _validate_connection_params(self):
157+
exclusive = [
158+
f for f in ("service_name", "sid", "dsn") if getattr(self, f) is not None
159+
]
160+
if len(exclusive) > 1:
161+
raise ValueError(
162+
f"Only one of 'service_name', 'sid', or 'dsn' may be set, "
163+
f"but got: {', '.join(exclusive)}"
164+
)
165+
return self
166+
152167

153168
class OracleOfflineStore(OfflineStore):
154169
@staticmethod
@@ -186,6 +201,9 @@ def get_historical_features(
186201
full_feature_names: bool = False,
187202
**kwargs,
188203
) -> RetrievalJob:
204+
# Single connection reused across the entire call.
205+
con = get_ibis_connection(config)
206+
189207
# Handle non-entity retrieval mode (start_date/end_date only)
190208
if entity_df is None:
191209
start_date: Optional[datetime] = kwargs.get("start_date")
@@ -212,7 +230,6 @@ def get_historical_features(
212230
start_date = start_date.replace(tzinfo=timezone.utc)
213231

214232
# Build a synthetic entity_df from the feature source data
215-
con = get_ibis_connection(config)
216233
all_entities: set = set()
217234
for fv in feature_views:
218235
all_entities.update(e.name for e in fv.entity_columns)
@@ -234,8 +251,7 @@ def get_historical_features(
234251
entity_df = pd.concat(entity_dfs, ignore_index=True).drop_duplicates()
235252

236253
# If entity_df is a SQL string, execute it to get a DataFrame
237-
if type(entity_df) == str:
238-
con = get_ibis_connection(config)
254+
if isinstance(entity_df, str):
239255
entity_df = con.sql(entity_df).execute()
240256

241257
# Use the retrieval reader which materializes Oracle data into
@@ -249,8 +265,8 @@ def get_historical_features(
249265
registry=registry,
250266
project=project,
251267
full_feature_names=full_feature_names,
252-
data_source_reader=_build_data_source_reader_for_retrieval(config),
253-
data_source_writer=_build_data_source_writer(config),
268+
data_source_reader=_build_data_source_reader_for_retrieval(config, con=con),
269+
data_source_writer=_build_data_source_writer(config, con=con),
254270
)
255271

256272
@staticmethod

0 commit comments

Comments
 (0)