Skip to content

Commit fc2b9f3

Browse files
committed
Fix primary_id and add test cases for in-memory metadata store.
1 parent 876979d commit fc2b9f3

File tree

11 files changed

+70
-67
lines changed

11 files changed

+70
-67
lines changed

.github/workflows/ci_code.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ jobs:
7474
run: |
7575
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/default.yaml
7676
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/sql.yaml
77+
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/inmemory.yaml
7778
7879
- name: Usecase Testing
7980
run: |
8081
make usecase_testing SUPERDUPER_CONFIG=test/configs/default.yaml
8182
make usecase_testing SUPERDUPER_CONFIG=test/configs/sql.yaml
83+
make usecase_testing SUPERDUPER_CONFIG=test/configs/inmemory.yaml

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6262
- Fix the primary_id of ArtifactRelations
6363
- Fix the snowflake and KeyedDatabackend
6464
- Fix a bug where Ibis throws Ibis.TableError but the framework waits for MetadataNoExists
65+
- Fix primary_id and add test cases for in-memory metadata store.
6566

6667
## [0.5.0](https://github.com/superduper-io/superduper/compare/0.5.0...0.4.0]) (2024-Nov-02)
6768

plugins/snowflake/superduper_snowflake/data_backend.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def missing_outputs(self, query, predict_id):
329329
:param query: The query to get the missing outputs of.
330330
:param predict_id: The identifier of the output destination.
331331
"""
332-
pid = self.primary_id(query)
332+
pid = self.primary_id(query.table)
333333
df = map_superduper_query_to_snowpark_query(self.session, query, pid)
334334
output_df = self.session.table(f'"{CFG.output_prefix + predict_id}"')
335335
columns = output_df.columns
@@ -348,18 +348,6 @@ def missing_outputs(self, query, predict_id):
348348
.tolist()
349349
)
350350

351-
def primary_id(self, query: Query) -> str:
352-
"""Get the primary id of a query.
353-
354-
:param query: The query to get the primary id of.
355-
"""
356-
return (
357-
self.get_table(query.table)
358-
.schema[0]
359-
.name.removeprefix('"')
360-
.removesuffix('"')
361-
)
362-
363351
def _build_schema(self, query: Query):
364352
"""Build the schema of a query.
365353
@@ -375,7 +363,7 @@ def select(self, query: Query, primary_id: str | None = None) -> t.List[t.Dict]:
375363
q = map_superduper_query_to_snowpark_query(
376364
self.session,
377365
query,
378-
primary_id or self.primary_id(query),
366+
primary_id or self.primary_id(query.table),
379367
)
380368
start = time.time()
381369
logging.info(f"Executing query: {query}")

plugins/sql/superduper_sql/data_backend.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def delete(self, table, condition):
224224

225225
def insert(self, table, documents):
226226
"""Insert data into the database."""
227-
primary_id = self.primary_id(self.db[table])
227+
primary_id = self.primary_id(table)
228228
for r in documents:
229229
if primary_id not in r:
230230
r[primary_id] = str(uuid.uuid4())
@@ -244,19 +244,13 @@ def insert(self, table, documents):
244244
def missing_outputs(self, query, predict_id: str) -> t.List[str]:
245245
"""Get missing outputs from the database."""
246246
with self.connection_manager.get_connection() as conn:
247-
pid = self.primary_id(query)
247+
pid = self.primary_id(query.table)
248248
query = self._build_native_query(conn, query)
249249
output_table = conn.table(f"{CFG.output_prefix}{predict_id}")
250250
q = query.anti_join(output_table, output_table["_source"] == query[pid])
251251
rows = q.execute().to_dict(orient="records")
252252
return [r[pid] for r in rows]
253253

254-
def primary_id(self, query):
255-
"""Get the primary ID of the query."""
256-
return self.db.metadata.get_component(
257-
component="Table", identifier=query.table
258-
)["primary_id"]
259-
260254
def select(self, query):
261255
"""Select data from the database."""
262256
with self.connection_manager.get_connection() as conn:
@@ -279,7 +273,7 @@ def _build_native_query(self, conn, query):
279273
args = []
280274
for a in part.args:
281275
if isinstance(a, Query) and str(a).endswith(".primary_id"):
282-
args.append(self.primary_id(query))
276+
args.append(self.primary_id(query.table))
283277
elif isinstance(a, Query):
284278
args.append(self._build_native_query(conn, a))
285279
else:
@@ -288,7 +282,7 @@ def _build_native_query(self, conn, query):
288282
kwargs = {}
289283
for k, v in part.kwargs.items():
290284
if isinstance(a, Query) and str(a).endswith(".primary_id"):
291-
args.append(self.primary_id(query))
285+
args.append(self.primary_id(query.table))
292286
elif isinstance(v, Query):
293287
kwargs[k] = self._build_native_query(conn, v)
294288
else:
@@ -311,7 +305,7 @@ def _build_native_query(self, conn, query):
311305

312306
elif isinstance(part, QueryPart) and part.name == "outputs":
313307
if pid is None:
314-
pid = self.primary_id(query)
308+
pid = self.primary_id(query.table)
315309

316310
original_q = q
317311
for predict_id in part.args:
@@ -323,7 +317,7 @@ def _build_native_query(self, conn, query):
323317
elif isinstance(part, str):
324318
if part == "primary_id":
325319
if pid is None:
326-
pid = self.primary_id(query)
320+
pid = self.primary_id(query.table)
327321
part = pid
328322
q = q[part]
329323
else:

superduper/backends/base/data_backend.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ def missing_outputs(self, query: Query, predict_id: str) -> t.List[str]:
177177
:param predict_id: The predict id.
178178
"""
179179

180-
@abstractmethod
181-
def primary_id(self, query: Query) -> str:
182-
"""Get the primary id of a query.
180+
def primary_id(self, table: str) -> str:
181+
"""Get the primary id of a table.
183182
184-
:param query: The query to get the primary id of.
183+
:param table: The table to get the primary id of.
185184
"""
185+
return self.db.metadata.get_primary_id(table)
186186

187187
@abstractmethod
188188
def select(self, query: Query) -> t.List[t.Dict]:
@@ -222,7 +222,7 @@ def get(self, query: Query, raw: bool = False):
222222
return None
223223

224224
def _wrap_results(self, query: Query, result, schema, raw: bool = False):
225-
pid = self.primary_id(query)
225+
pid = self.primary_id(query.table)
226226
for r in result:
227227
if pid in r:
228228
r[pid] = str(r[pid])
@@ -326,7 +326,7 @@ def pre_like(self, query: Query, **kwargs):
326326

327327
results = new.execute(**kwargs)
328328

329-
pid = self.primary_id(query)
329+
pid = self.primary_id(query.table)
330330
for r in results:
331331
r['score'] = lookup[r[pid]]
332332

@@ -356,7 +356,7 @@ def post_like(self, query: Query, **kwargs):
356356

357357
results = prepare_query.filter(t.primary_id.isin(ids)).execute(**kwargs)
358358

359-
pid = self.primary_id(query)
359+
pid = self.primary_id(query.table)
360360

361361
for r in results:
362362
r['score'] = lookup[r[pid]]
@@ -484,7 +484,7 @@ def delete(self, table, condition):
484484
r_table = self._get_with_component_identifier('Table', table)
485485

486486
if not r_table['is_component']:
487-
pid = r_table['primary_id']
487+
pid = self.primary_id(table)
488488
if pid in condition:
489489
docs = self.get_many(table, condition[pid])
490490
else:
@@ -571,7 +571,7 @@ def replace(self, table, condition, r):
571571
r_table = self._get_with_component_identifier('Table', table)
572572

573573
if not r_table['is_component']:
574-
pid = r_table['primary_id']
574+
pid = self.primary_id(table)
575575
docs = self.get_many(table, condition[pid])
576576
docs = self._do_filter(docs, condition)
577577
for s in docs:
@@ -603,7 +603,7 @@ def update(self, table, condition, key, value):
603603
r_table = self._get_with_component_identifier('Table', table)
604604

605605
if not r_table['is_component']:
606-
pid = r_table['primary_id']
606+
pid = self.primary_id(table)
607607
docs = self.get_many(table, condition[pid])
608608
docs = self._do_filter(docs, condition)
609609
for s in docs:
@@ -695,20 +695,6 @@ def _get_with_component_identifier_version(
695695
def __delitem__(self, key: t.Tuple[str, str, str]):
696696
pass
697697

698-
def primary_id(self, query):
699-
"""Get the primary id of a query.
700-
701-
:param query: The query to get the primary id of.
702-
"""
703-
r = max(
704-
self.get_many('Table', query.table, '*'),
705-
key=lambda x: x['version'],
706-
default=None,
707-
)
708-
if r is None:
709-
raise exceptions.NotFound("Table", query.table)
710-
return r['primary_id']
711-
712698
def insert(self, table, documents):
713699
"""Insert data into the database.
714700
@@ -717,7 +703,7 @@ def insert(self, table, documents):
717703
"""
718704
ids = []
719705
try:
720-
pid = self.primary_id(self.db[table])
706+
pid = self.primary_id(table)
721707
except exceptions.NotFound:
722708
pid = None
723709

@@ -726,7 +712,7 @@ def insert(self, table, documents):
726712
self[table, r['identifier'], r['uuid']] = r
727713
ids.append(r['uuid'])
728714
elif pid:
729-
pid = self.primary_id(self.db[table])
715+
pid = self.primary_id(table)
730716
for r in documents:
731717
if pid not in r:
732718
r[pid] = self.random_id()
@@ -790,7 +776,7 @@ def do_test(r):
790776
is_component = max(tables, key=lambda x: x['version'])['is_component']
791777

792778
if not is_component:
793-
pid = self.primary_id(query)
779+
pid = self.primary_id(query.table)
794780
if pid in filter_kwargs:
795781
keys = self.keys(query.table, filter_kwargs[pid]['value'])
796782
del filter_kwargs[pid]

superduper/base/build.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,20 @@ def build_datalayer(
100100

101101
artifact_store = _build_artifact_store()
102102

103-
metadata = _build_databackend(cfg.metadata_store or cfg.data_backend)
104-
105103
backend = getattr(load_plugin(cfg.cluster_engine), 'Cluster')
106104

107105
cluster = backend.build(cfg, compute=compute)
108106

109-
metadata = Datalayer(
110-
databackend=metadata,
111-
cluster=None,
112-
artifact_store=artifact_store,
113-
metadata=None,
114-
)
107+
if cfg.metadata_store:
108+
metadata = _build_databackend(cfg.metadata_store)
109+
metadata = Datalayer(
110+
databackend=metadata,
111+
cluster=cluster,
112+
artifact_store=artifact_store,
113+
metadata=None,
114+
)
115+
else:
116+
metadata = None
115117

116118
datalayer = Datalayer(
117119
databackend=databackend_obj,

superduper/base/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class Config(BaseConfig):
162162
data_backend: str = "mongodb://localhost:27017/test_db"
163163

164164
artifact_store: str = 'filesystem://./artifact_store'
165-
metadata_store: str = 'inmemory://'
165+
metadata_store: str = ''
166166

167167
cache: str | None = None
168168
vector_search_engine: str = 'local'

superduper/base/datalayer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,13 @@ def __init__(
6767

6868
self._cfg = s.CFG
6969
self.startup_cache: t.Dict[str, t.Any] = {}
70+
self._component_cache: t.Dict[t.Tuple[str, str], Component] = {}
7071

7172
if metadata:
7273
self.metadata = MetaDataStore(metadata, parent_db=self) # type: ignore[arg-type]
73-
self.metadata.init()
7474
else:
7575
self.metadata = MetaDataStore(self, parent_db=self)
76-
77-
self._component_cache: t.Dict[t.Tuple[str, str], Component] = {}
76+
self.metadata.init()
7877

7978
logging.info("Data Layer built")
8079

superduper/base/metadata.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,15 @@ class ParentChildAssociations(Base):
218218
class ArtifactRelations(Base):
219219
"""Artifact relations table.
220220
221+
:param relation_id: relation identifier
221222
:param component: component type
222223
:param identifier: identifier of component
223224
:param uuid: UUID of component version
224225
:param artifact_id: UUID of component version
225226
"""
226227

228+
primary_id: t.ClassVar[str] = 'relation'
229+
relation_id: str
227230
component: str
228231
identifier: str
229232
uuid: str
@@ -250,6 +253,12 @@ def __init__(self, db: 'Datalayer', parent_db: 'Datalayer'):
250253
self.db = db
251254
self.parent_db = parent_db
252255
self._schema_cache: t.Dict[str, Schema] = {}
256+
self.primary_ids = {
257+
"Table": "uuid",
258+
"ParentChildAssociations": "uuid",
259+
"ArtifactRelations": "relation_id",
260+
"Job": "job_id",
261+
}
253262

254263
def __getitem__(self, item: str):
255264
return self.db[item]
@@ -286,6 +295,21 @@ def init(self):
286295
self.create(ArtifactRelations)
287296
self.create(Job)
288297

298+
def get_primary_id(self, table: str):
299+
"""Get the primary id of a table.
300+
301+
:param table: table name.
302+
"""
303+
pid = self.primary_ids.get(table)
304+
305+
if pid is None:
306+
pid = self.get_component(component="Table", identifier=table, version=0)[
307+
"primary_id"
308+
]
309+
self.primary_ids[table] = pid
310+
311+
return pid
312+
289313
def create_table_and_schema(
290314
self,
291315
identifier: str,
@@ -870,12 +894,12 @@ def get_component(
870894
identifier=identifier,
871895
)
872896

873-
metadata = self.db['Table'].get(identifier=component)
874897
r = self.db[component].get(identifier=identifier, version=version, raw=True)
875898

876899
if r is None:
877900
raise exceptions.NotFound(component, identifier)
878901

902+
metadata = self.db['Table'].get(identifier=component)
879903
r['_path'] = metadata['path']
880904

881905
return r

superduper/base/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def execute(self, raw: bool = False):
971971
if self.table in db.metadata.db.databackend.list_tables():
972972
db = db.metadata.db
973973
if self.parts and self.parts[0] == 'primary_id':
974-
return db.databackend.primary_id(self)
974+
return db.databackend.primary_id(self.table)
975975
results = db.databackend.execute(self, raw=raw)
976976
return results
977977

0 commit comments

Comments
 (0)