Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci_code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ jobs:
run: |
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/default.yaml
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/sql.yaml
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/inmemory.yaml

- name: Usecase Testing
run: |
make usecase_testing SUPERDUPER_CONFIG=test/configs/default.yaml
make usecase_testing SUPERDUPER_CONFIG=test/configs/sql.yaml
make usecase_testing SUPERDUPER_CONFIG=test/configs/inmemory.yaml
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix the primary_id of ArtifactRelations
- Fix the snowflake and KeyedDatabackend
- Fix a bug where Ibis throws Ibis.TableError but the framework waits for MetadataNoExists
- Fix primary_id and add test cases for in-memory metadata store.

## [0.5.0](https://github.com/superduper-io/superduper/compare/0.5.0...0.4.0]) (2024-Nov-02)

Expand Down
16 changes: 2 additions & 14 deletions plugins/snowflake/superduper_snowflake/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def missing_outputs(self, query, predict_id):
:param query: The query to get the missing outputs of.
:param predict_id: The identifier of the output destination.
"""
pid = self.primary_id(query)
pid = self.primary_id(query.table)
df = map_superduper_query_to_snowpark_query(self.session, query, pid)
output_df = self.session.table(f'"{CFG.output_prefix + predict_id}"')
columns = output_df.columns
Expand All @@ -348,18 +348,6 @@ def missing_outputs(self, query, predict_id):
.tolist()
)

def primary_id(self, query: Query) -> str:
"""Get the primary id of a query.

:param query: The query to get the primary id of.
"""
return (
self.get_table(query.table)
.schema[0]
.name.removeprefix('"')
.removesuffix('"')
)

def _build_schema(self, query: Query):
"""Build the schema of a query.

Expand All @@ -375,7 +363,7 @@ def select(self, query: Query, primary_id: str | None = None) -> t.List[t.Dict]:
q = map_superduper_query_to_snowpark_query(
self.session,
query,
primary_id or self.primary_id(query),
primary_id or self.primary_id(query.table),
)
start = time.time()
logging.info(f"Executing query: {query}")
Expand Down
18 changes: 6 additions & 12 deletions plugins/sql/superduper_sql/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def delete(self, table, condition):

def insert(self, table, documents):
"""Insert data into the database."""
primary_id = self.primary_id(self.db[table])
primary_id = self.primary_id(table)
for r in documents:
if primary_id not in r:
r[primary_id] = str(uuid.uuid4())
Expand All @@ -244,19 +244,13 @@ def insert(self, table, documents):
def missing_outputs(self, query, predict_id: str) -> t.List[str]:
"""Get missing outputs from the database."""
with self.connection_manager.get_connection() as conn:
pid = self.primary_id(query)
pid = self.primary_id(query.table)
query = self._build_native_query(conn, query)
output_table = conn.table(f"{CFG.output_prefix}{predict_id}")
q = query.anti_join(output_table, output_table["_source"] == query[pid])
rows = q.execute().to_dict(orient="records")
return [r[pid] for r in rows]

def primary_id(self, query):
"""Get the primary ID of the query."""
return self.db.metadata.get_component(
component="Table", identifier=query.table
)["primary_id"]

def select(self, query):
"""Select data from the database."""
with self.connection_manager.get_connection() as conn:
Expand All @@ -279,7 +273,7 @@ def _build_native_query(self, conn, query):
args = []
for a in part.args:
if isinstance(a, Query) and str(a).endswith(".primary_id"):
args.append(self.primary_id(query))
args.append(self.primary_id(query.table))
elif isinstance(a, Query):
args.append(self._build_native_query(conn, a))
else:
Expand All @@ -288,7 +282,7 @@ def _build_native_query(self, conn, query):
kwargs = {}
for k, v in part.kwargs.items():
if isinstance(a, Query) and str(a).endswith(".primary_id"):
args.append(self.primary_id(query))
args.append(self.primary_id(query.table))
elif isinstance(v, Query):
kwargs[k] = self._build_native_query(conn, v)
else:
Expand All @@ -311,7 +305,7 @@ def _build_native_query(self, conn, query):

elif isinstance(part, QueryPart) and part.name == "outputs":
if pid is None:
pid = self.primary_id(query)
pid = self.primary_id(query.table)

original_q = q
for predict_id in part.args:
Expand All @@ -323,7 +317,7 @@ def _build_native_query(self, conn, query):
elif isinstance(part, str):
if part == "primary_id":
if pid is None:
pid = self.primary_id(query)
pid = self.primary_id(query.table)
part = pid
q = q[part]
else:
Expand Down
40 changes: 13 additions & 27 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def missing_outputs(self, query: Query, predict_id: str) -> t.List[str]:
:param predict_id: The predict id.
"""

@abstractmethod
def primary_id(self, query: Query) -> str:
"""Get the primary id of a query.
def primary_id(self, table: str) -> str:
"""Get the primary id of a table.

:param query: The query to get the primary id of.
:param table: The table to get the primary id of.
"""
return self.db.metadata.get_primary_id(table)

@abstractmethod
def select(self, query: Query) -> t.List[t.Dict]:
Expand Down Expand Up @@ -222,7 +222,7 @@ def get(self, query: Query, raw: bool = False):
return None

def _wrap_results(self, query: Query, result, schema, raw: bool = False):
pid = self.primary_id(query)
pid = self.primary_id(query.table)
for r in result:
if pid in r:
r[pid] = str(r[pid])
Expand Down Expand Up @@ -326,7 +326,7 @@ def pre_like(self, query: Query, **kwargs):

results = new.execute(**kwargs)

pid = self.primary_id(query)
pid = self.primary_id(query.table)
for r in results:
r['score'] = lookup[r[pid]]

Expand Down Expand Up @@ -356,7 +356,7 @@ def post_like(self, query: Query, **kwargs):

results = prepare_query.filter(t.primary_id.isin(ids)).execute(**kwargs)

pid = self.primary_id(query)
pid = self.primary_id(query.table)

for r in results:
r['score'] = lookup[r[pid]]
Expand Down Expand Up @@ -484,7 +484,7 @@ def delete(self, table, condition):
r_table = self._get_with_component_identifier('Table', table)

if not r_table['is_component']:
pid = r_table['primary_id']
pid = self.primary_id(table)
if pid in condition:
docs = self.get_many(table, condition[pid])
else:
Expand Down Expand Up @@ -571,7 +571,7 @@ def replace(self, table, condition, r):
r_table = self._get_with_component_identifier('Table', table)

if not r_table['is_component']:
pid = r_table['primary_id']
pid = self.primary_id(table)
docs = self.get_many(table, condition[pid])
docs = self._do_filter(docs, condition)
for s in docs:
Expand Down Expand Up @@ -603,7 +603,7 @@ def update(self, table, condition, key, value):
r_table = self._get_with_component_identifier('Table', table)

if not r_table['is_component']:
pid = r_table['primary_id']
pid = self.primary_id(table)
docs = self.get_many(table, condition[pid])
docs = self._do_filter(docs, condition)
for s in docs:
Expand Down Expand Up @@ -695,20 +695,6 @@ def _get_with_component_identifier_version(
def __delitem__(self, key: t.Tuple[str, str, str]):
pass

def primary_id(self, query):
"""Get the primary id of a query.

:param query: The query to get the primary id of.
"""
r = max(
self.get_many('Table', query.table, '*'),
key=lambda x: x['version'],
default=None,
)
if r is None:
raise exceptions.NotFound("Table", query.table)
return r['primary_id']

def insert(self, table, documents):
"""Insert data into the database.

Expand All @@ -717,7 +703,7 @@ def insert(self, table, documents):
"""
ids = []
try:
pid = self.primary_id(self.db[table])
pid = self.primary_id(table)
except exceptions.NotFound:
pid = None

Expand All @@ -726,7 +712,7 @@ def insert(self, table, documents):
self[table, r['identifier'], r['uuid']] = r
ids.append(r['uuid'])
elif pid:
pid = self.primary_id(self.db[table])
pid = self.primary_id(table)
for r in documents:
if pid not in r:
r[pid] = self.random_id()
Expand Down Expand Up @@ -790,7 +776,7 @@ def do_test(r):
is_component = max(tables, key=lambda x: x['version'])['is_component']

if not is_component:
pid = self.primary_id(query)
pid = self.primary_id(query.table)
if pid in filter_kwargs:
keys = self.keys(query.table, filter_kwargs[pid]['value'])
del filter_kwargs[pid]
Expand Down
18 changes: 10 additions & 8 deletions superduper/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,20 @@ def build_datalayer(

artifact_store = _build_artifact_store()

metadata = _build_databackend(cfg.metadata_store or cfg.data_backend)

backend = getattr(load_plugin(cfg.cluster_engine), 'Cluster')

cluster = backend.build(cfg, compute=compute)

metadata = Datalayer(
databackend=metadata,
cluster=None,
artifact_store=artifact_store,
metadata=None,
)
if cfg.metadata_store:
metadata = _build_databackend(cfg.metadata_store)
metadata = Datalayer(
databackend=metadata,
cluster=cluster,
artifact_store=artifact_store,
metadata=None,
)
else:
metadata = None

datalayer = Datalayer(
databackend=databackend_obj,
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class Config(BaseConfig):
data_backend: str = "mongodb://localhost:27017/test_db"

artifact_store: str = 'filesystem://./artifact_store'
metadata_store: str = 'inmemory://'
metadata_store: str = ''

cache: str | None = None
vector_search_engine: str = 'local'
Expand Down
5 changes: 2 additions & 3 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ def __init__(

self._cfg = s.CFG
self.startup_cache: t.Dict[str, t.Any] = {}
self._component_cache: t.Dict[t.Tuple[str, str], Component] = {}

if metadata:
self.metadata = MetaDataStore(metadata, parent_db=self) # type: ignore[arg-type]
self.metadata.init()
else:
self.metadata = MetaDataStore(self, parent_db=self)

self._component_cache: t.Dict[t.Tuple[str, str], Component] = {}
self.metadata.init()

logging.info("Data Layer built")

Expand Down
26 changes: 25 additions & 1 deletion superduper/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,15 @@ class ParentChildAssociations(Base):
class ArtifactRelations(Base):
"""Artifact relations table.

:param relation_id: relation identifier
:param component: component type
:param identifier: identifier of component
:param uuid: UUID of component version
:param artifact_id: UUID of component version
"""

primary_id: t.ClassVar[str] = 'relation'
relation_id: str
component: str
identifier: str
uuid: str
Expand All @@ -250,6 +253,12 @@ def __init__(self, db: 'Datalayer', parent_db: 'Datalayer'):
self.db = db
self.parent_db = parent_db
self._schema_cache: t.Dict[str, Schema] = {}
self.primary_ids = {
"Table": "uuid",
"ParentChildAssociations": "uuid",
"ArtifactRelations": "relation_id",
"Job": "job_id",
}

def __getitem__(self, item: str):
return self.db[item]
Expand Down Expand Up @@ -286,6 +295,21 @@ def init(self):
self.create(ArtifactRelations)
self.create(Job)

def get_primary_id(self, table: str):
"""Get the primary id of a table.

:param table: table name.
"""
pid = self.primary_ids.get(table)

if pid is None:
pid = self.get_component(component="Table", identifier=table, version=0)[
"primary_id"
]
self.primary_ids[table] = pid

return pid

def create_table_and_schema(
self,
identifier: str,
Expand Down Expand Up @@ -870,12 +894,12 @@ def get_component(
identifier=identifier,
)

metadata = self.db['Table'].get(identifier=component)
r = self.db[component].get(identifier=identifier, version=version, raw=True)

if r is None:
raise exceptions.NotFound(component, identifier)

metadata = self.db['Table'].get(identifier=component)
r['_path'] = metadata['path']

return r
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def execute(self, raw: bool = False):
if self.table in db.metadata.db.databackend.list_tables():
db = db.metadata.db
if self.parts and self.parts[0] == 'primary_id':
return db.databackend.primary_id(self)
return db.databackend.primary_id(self.table)
results = db.databackend.execute(self, raw=raw)
return results

Expand Down
7 changes: 7 additions & 0 deletions test/configs/inmemory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
data_backend: 'sqlite://'
metadata_store: "inmemory://"
auto_schema: false
force_apply: true
json_native: false
datatype_presets:
vector: superduper.base.datatype.Array
Loading