Skip to content

Commit 4717460

Browse files
committed
Move datatype_presets to databackend logic
1 parent 409bc0c commit 4717460

File tree

11 files changed

+174
-79
lines changed

11 files changed

+174
-79
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
- Remove the dependencies property from `Component`
2727
- Fix the bug where the parent of dependent_tables was incorrect when deleting a component.
2828
- Add `outputs` parameter to `@trigger` to show outputs location and link streaming tasks
29+
- Remove `json_native` and `datatype_presets`
2930

3031
### New features
3132

plugins/sql/superduper_sql/data_backend.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from superduper.backends.base.data_backend import BaseDataBackend
1616
from superduper.base import exceptions
1717
from superduper.base.artifacts import FileSystemArtifactStore
18+
from superduper.base.datatype import Array, NativeVector
1819
from superduper.base.query import Query, QueryPart
1920
from superduper.base.schema import Schema
2021

@@ -91,6 +92,13 @@ def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None):
9192
# Get a connection to initialize
9293
self.reconnect()
9394

95+
@property
96+
def vector_impl(self):
97+
"""Get the vector implementation based on the URI."""
98+
if self.uri.startswith("snowflake"):
99+
return NativeVector
100+
return Array
101+
94102
def reconnect(self):
95103
"""Reconnect to the database client."""
96104
with self.connection_manager.get_connection() as conn:
@@ -138,7 +146,7 @@ def create_table_and_schema(self, identifier: str, schema: Schema, primary_id: s
138146
:param mapping: The mapping of the schema.
139147
"""
140148
with self.connection_manager.get_connection() as conn:
141-
mapping = convert_schema_to_fields(schema)
149+
mapping = convert_schema_to_fields(schema, self.json_native)
142150
if primary_id not in mapping:
143151
mapping[primary_id] = "string"
144152
try:
@@ -341,6 +349,13 @@ def __init__(self, uri, plugin, flavour=None):
341349
self._create_sqlalchemy_engine()
342350
self.sm = sessionmaker(bind=self.alchemy_engine)
343351

352+
@property
353+
def json_native(self):
354+
"""Check if the database supports JSON natively."""
355+
if self.uri.startswith("postgres"):
356+
return True
357+
return False
358+
344359
def update(self, table, condition, key, value):
345360
"""Update data in the database."""
346361
with self.sm() as session:

plugins/sql/superduper_sql/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from ibis.expr.datatypes import dtype
2-
from superduper import CFG
32
from superduper.base.datatype import (
43
ID,
54
Array,
65
BaseDataType,
76
FieldType,
87
FileItem,
9-
Vector,
108
)
119
from superduper.base.schema import Schema
1210

@@ -23,7 +21,7 @@ def _convert_field_type_to_ibis_type(field_type: FieldType):
2321
return dtype(ibis_type)
2422

2523

26-
def convert_schema_to_fields(schema: Schema):
24+
def convert_schema_to_fields(schema: Schema, json_native: bool) -> dict:
2725
"""Return the raw fields.
2826
2927
Get a dictionary of fields as keys and datatypes as values.
@@ -39,11 +37,9 @@ def convert_schema_to_fields(schema: Schema):
3937
else:
4038
assert isinstance(schema.fields[k], BaseDataType)
4139

42-
if not CFG.json_native and schema.fields[k].dtype == "json":
40+
if not json_native and schema.fields[k].dtype == "json":
4341
fields[k] = dtype("str")
44-
elif isinstance(v, Array):
45-
fields[k] = dtype("str")
46-
elif isinstance(v, Vector) and isinstance(v.datatype_impl, Array):
42+
elif isinstance(schema.fields[k], Array):
4743
fields[k] = dtype("str")
4844
else:
4945
fields[k] = dtype(schema.fields[k].dtype)

superduper/backends/base/data_backend.py

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import functools
2-
import json
32
import hashlib
3+
import json
44
import typing as t
55
import uuid
66
from abc import ABC, abstractmethod
77

88
from superduper import CFG, logging
99
from superduper.base import exceptions
10-
from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES
11-
from superduper.base.datatype import JSON, BaseDataType
10+
from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES, KEY_PATH
11+
from superduper.base.datatype import JSON, BaseDataType, NativeVector, Vector
1212
from superduper.base.document import Document
1313
from superduper.base.query import Query
1414

@@ -31,6 +31,7 @@ class BaseDataBackend(ABC):
3131

3232
batched: bool = False
3333
id_field: str = 'id'
34+
vector_impl: t.Type = NativeVector
3435

3536
# TODO plugin not required
3637
# TODO flavour required?
@@ -160,6 +161,64 @@ def insert(self, table: str, documents: t.Sequence[t.Dict]) -> t.List[str]:
160161
:param documents: The documents to insert.
161162
"""
162163

164+
def do_replace(self, table: str, condition: t.Dict, r: t.Dict):
165+
"""Replace data in the database.
166+
167+
This method is a wrapper around the `replace` method to ensure
168+
that the datatype is set to `None` by default.
169+
170+
:param table: The table to insert into.
171+
:param condition: The condition to update.
172+
:param r: The document to replace.
173+
"""
174+
schema = self.get_schema(self.db[table])
175+
176+
if isinstance(r, Document) and not schema.trivial:
177+
schema = self.get_schema(self.db[table])
178+
r = Document(r).encode(schema=schema, db=self.db)
179+
if r.get(KEY_BLOBS) or r.get(KEY_FILES):
180+
self.db.artifact_store.save_artifact(r)
181+
182+
try:
183+
r.pop(KEY_BUILDS)
184+
except KeyError:
185+
pass
186+
try:
187+
r.pop(KEY_BLOBS)
188+
except KeyError:
189+
pass
190+
try:
191+
r.pop(KEY_FILES)
192+
except KeyError:
193+
pass
194+
try:
195+
r.pop(KEY_PATH)
196+
except KeyError:
197+
pass
198+
199+
vector_datatypes = {
200+
k: self.vector_impl(dtype=v.dtype, shape=v.shape)
201+
for k, v in schema.fields.items()
202+
if isinstance(v, Vector)
203+
}
204+
if vector_datatypes:
205+
for k in vector_datatypes:
206+
if k in r:
207+
r[k] = vector_datatypes[k].encode_data(r[k], None)
208+
209+
if not self.json_native:
210+
json_fields = [
211+
k
212+
for k in schema.fields
213+
if getattr(schema.fields[k], 'dtype', None) == 'json'
214+
]
215+
for k in json_fields:
216+
if k in r:
217+
r[k] = json.dumps(r[k])
218+
219+
self.replace(table, condition=condition, r=r)
220+
return
221+
163222
@abstractmethod
164223
def replace(self, table: str, condition: t.Dict, r: t.Dict) -> t.List[str]:
165224
"""Replace data.
@@ -169,7 +228,14 @@ def replace(self, table: str, condition: t.Dict, r: t.Dict) -> t.List[str]:
169228
:param r: The document to replace.
170229
"""
171230

172-
def _update(self, table: str, condition: t.Dict, key: str, value: t.Any, datatype: BaseDataType | None = None):
231+
def do_update(
232+
self,
233+
table: str,
234+
condition: t.Dict,
235+
key: str,
236+
value: t.Any,
237+
datatype: BaseDataType | None = None,
238+
):
173239
"""Update data in the database.
174240
175241
This method is a wrapper around the `update` method to ensure
@@ -179,11 +245,17 @@ def _update(self, table: str, condition: t.Dict, key: str, value: t.Any, datatyp
179245
:param condition: The condition to update.
180246
:param key: The key to update.
181247
:param value: The value to update.
248+
:param datatype: The datatype to use for encoding the value.
182249
"""
183250
if datatype is not None:
184251
value = datatype.encode_data(value, None)
185252
if datatype.dtype == 'json' and not self.json_native:
186253
value = json.dumps(value)
254+
elif isinstance(datatype, Vector):
255+
vector_datatype = self.vector_impl(
256+
dtype=datatype.dtype, shape=datatype.shape
257+
)
258+
value = vector_datatype.encode_data(value, None)
187259
self.update(table, condition, key, value)
188260

189261
@abstractmethod
@@ -268,12 +340,30 @@ def _wrap_results(self, query: Query, result, schema, raw: bool = False):
268340
r['_source'] = str(r['_source'])
269341

270342
if not self.json_native:
271-
json_fields = [k for k in schema.fields if getattr(schema.fields[k], 'dtype', None) == 'json']
343+
json_fields = [
344+
k
345+
for k in schema.fields
346+
if getattr(schema.fields[k], 'dtype', None) == 'json'
347+
]
272348
for r in result:
273349
for k in json_fields:
274350
if k in r and isinstance(r[k], str):
275351
r[k] = json.loads(r[k])
276352

353+
vector_datatypes = {
354+
k: self.vector_impl(dtype=v.dtype, shape=v.shape)
355+
for k, v in schema.fields.items()
356+
if isinstance(v, Vector)
357+
}
358+
359+
if vector_datatypes:
360+
for r in result:
361+
for k in vector_datatypes:
362+
if k in r and r[k] is not None:
363+
r[k] = vector_datatypes[k].decode_data(
364+
r[k], builds={}, db=self.db
365+
)
366+
277367
if raw:
278368
return result
279369

@@ -311,8 +401,13 @@ def get_schema(self, query) -> 'Schema':
311401

312402
return base_schema
313403

314-
def _do_insert(self, table, documents, raw: bool = False):
404+
def do_insert(self, table, documents, raw: bool = False):
405+
"""Insert data into the database.
315406
407+
:param table: The table to insert into.
408+
:param documents: The documents to insert.
409+
:param raw: If ``True``, insert raw documents.
410+
"""
316411
schema = self.get_schema(self.db[table])
317412

318413
if not raw and not schema.trivial:
@@ -342,8 +437,23 @@ def _do_insert(self, table, documents, raw: bool = False):
342437
pass
343438
documents[i] = r
344439

440+
vector_datatypes = {
441+
k: self.vector_impl(dtype=v.dtype, shape=v.shape)
442+
for k, v in schema.fields.items()
443+
if isinstance(v, Vector)
444+
}
445+
if vector_datatypes:
446+
for r in documents:
447+
for k in vector_datatypes:
448+
if k in r:
449+
r[k] = vector_datatypes[k].encode_data(r[k], None)
450+
345451
if not self.json_native:
346-
json_fields = [k for k in schema.fields if getattr(schema.fields[k], 'dtype', None) == 'json']
452+
json_fields = [
453+
k
454+
for k in schema.fields
455+
if getattr(schema.fields[k], 'dtype', None) == 'json'
456+
]
347457
for r in documents:
348458
for k in json_fields:
349459
if k in r:
@@ -490,6 +600,8 @@ class KeyedDatabackend(BaseDataBackend):
490600
:param flavour: Flavour of the databackend.
491601
"""
492602

603+
json_native: bool = True
604+
493605
@abstractmethod
494606
def __getitem__(self, key: t.Tuple[str, str, str]) -> t.Dict:
495607
pass

superduper/base/config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,6 @@ class Downloads(BaseConfig):
118118
timeout: t.Optional[int] = None
119119

120120

121-
@dc.dataclass
122-
class DataTypePresets(BaseConfig):
123-
"""Paths of default types of data.
124-
125-
Overrides DataBackend.datatype_presets.
126-
127-
:param vector: BaseDataType to encode vectors.
128-
"""
129-
130-
vector: str | None = None
131-
132-
133121
@dc.dataclass
134122
class Config(BaseConfig):
135123
"""The data class containing all configurable superduper values.
@@ -148,7 +136,6 @@ class Config(BaseConfig):
148136
:param logging_type: The type of logging to use
149137
:param log_hostname: Whether to include the hostname in the logs
150138
:param force_apply: Whether to force apply the configuration
151-
:param datatype_presets: Presets to be applied for default types of data
152139
:param log_colorize: Whether to colorize the logs
153140
:param bytes_encoding: (Deprecated)
154141
:param output_prefix: The prefix for the output table and output field key
@@ -178,8 +165,6 @@ class Config(BaseConfig):
178165

179166
force_apply: bool = False
180167

181-
datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets)
182-
183168
output_prefix: str = "_outputs__"
184169
vector_search_kwargs: t.Dict = dc.field(default_factory=dict)
185170

superduper/base/datatype.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -432,26 +432,13 @@ class Vector(BaseVector):
432432
:param shape: Shape of array.
433433
"""
434434

435-
@cached_property
436-
def datatype_impl(self):
437-
type_ = CFG.datatype_presets.vector
438-
if type_ is None:
439-
return NativeVector(shape=self.shape, dtype=self.dtype)
440-
441-
module = '.'.join(type_.split('.')[:-1])
442-
cls = type_.split('.')[-1]
443-
datatype = getattr(import_module(module), cls)
444-
if inspect.isclass(datatype):
445-
datatype = datatype(dtype=self.dtype, shape=self.shape)
446-
return datatype
447-
448435
def encode_data(self, item, context):
449436
"""Encode the given item into a bytes-like object or reference.
450437
451438
:param item: The object/instance to encode.
452439
:param context: A context object containing caches.
453440
"""
454-
return self.datatype_impl.encode_data(item, context)
441+
return item
455442

456443
def decode_data(self, item, builds, db):
457444
"""Decode the item from `bytes`.
@@ -460,7 +447,7 @@ def decode_data(self, item, builds, db):
460447
:param builds: The build cache.
461448
:param db: The Datalayer.
462449
"""
463-
return self.datatype_impl.decode_data(item, builds, db)
450+
return item
464451

465452

466453
class JSON(BaseDataType):

superduper/base/event.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,12 @@ def execute(
472472
self.component, uuid=self.data['uuid'], info=self.data
473473
)
474474
except Exception as e:
475-
db.metadata.set_component_status(
475+
db.metadata.set_component_failed(
476476
component=self.component,
477477
uuid=self.data['uuid'],
478-
details_update={
479-
'phase': STATUS_FAILED,
480-
'reason': f'Failed to update: {str(e)}',
481-
'message': format_exc(),
482-
},
478+
reason=f'Failed to update: {str(e)}',
479+
message=str(format_exc()),
480+
context=self.context,
483481
)
484482
raise e
485483

superduper/base/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def init(self):
459459
).encode()
460460
r['version'] = 0
461461

462-
self.db.databackend.insert('Table', [r])
462+
self.db.databackend.do_insert('Table', [r], raw=True)
463463

464464
r = self.get_component('Table', 'Table')
465465

0 commit comments

Comments
 (0)