Skip to content

Commit bbd669f

Browse files
committed
Move datatype_presets to databackend logic
1 parent 8c32e3b commit bbd669f

File tree

3 files changed

+28
-31
lines changed

3 files changed

+28
-31
lines changed

superduper/backends/base/data_backend.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from superduper import CFG, logging
99
from superduper.base import exceptions
1010
from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES
11-
from superduper.base.datatype import JSON, BaseDataType
11+
from superduper.base.datatype import JSON, BaseDataType, NativeVector, Vector
1212
from superduper.base.document import Document
1313
from superduper.base.query import Query
1414

@@ -31,6 +31,7 @@ class BaseDataBackend(ABC):
3131

3232
batched: bool = False
3333
id_field: str = 'id'
34+
vector_impl: t.Type = NativeVector
3435

3536
# TODO plugin not required
3637
# TODO flavour required?
@@ -184,6 +185,9 @@ def _update(self, table: str, condition: t.Dict, key: str, value: t.Any, datatyp
184185
value = datatype.encode_data(value, None)
185186
if datatype.dtype == 'json' and not self.json_native:
186187
value = json.dumps(value)
188+
elif isinstance(datatype, Vector):
189+
vector_datatype = self.vector_impl(dtype=datatype.dtype, shape=datatype.shape)
190+
value = vector_datatype.encode_data(value, None)
187191
self.update(table, condition, key, value)
188192

189193
@abstractmethod
@@ -274,6 +278,17 @@ def _wrap_results(self, query: Query, result, schema, raw: bool = False):
274278
if k in r and isinstance(r[k], str):
275279
r[k] = json.loads(r[k])
276280

281+
vector_datatypes = {
282+
k: self.vector_impl(dtype=v.dtype, shape=v.shape)
283+
for k, v in schema.fields.items() if isinstance(v, Vector)
284+
}
285+
286+
if vector_datatypes:
287+
for r in result:
288+
for k in vector_datatypes:
289+
if k in r and not r[k] is None:
290+
r[k] = vector_datatypes[k].decode_data(r[k], builds={}, db=self.db)
291+
277292
if raw:
278293
return result
279294

@@ -342,6 +357,16 @@ def _do_insert(self, table, documents, raw: bool = False):
342357
pass
343358
documents[i] = r
344359

360+
vector_datatypes = {
361+
k: self.vector_impl(dtype=v.dtype, shape=v.shape)
362+
for k, v in schema.fields.items() if isinstance(v, Vector)
363+
}
364+
if vector_datatypes:
365+
for r in documents:
366+
for k in vector_datatypes:
367+
if k in r:
368+
r[k] = vector_datatypes[k].encode_data(r[k], None)
369+
345370
if not self.json_native:
346371
json_fields = [k for k in schema.fields if getattr(schema.fields[k], 'dtype', None) == 'json']
347372
for r in documents:

superduper/base/config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,6 @@ class Downloads(BaseConfig):
118118
timeout: t.Optional[int] = None
119119

120120

121-
@dc.dataclass
122-
class DataTypePresets(BaseConfig):
123-
"""Paths of default types of data.
124-
125-
Overrides DataBackend.datatype_presets.
126-
127-
:param vector: BaseDataType to encode vectors.
128-
"""
129-
130-
vector: str | None = None
131-
132-
133121
@dc.dataclass
134122
class Config(BaseConfig):
135123
"""The data class containing all configurable superduper values.
@@ -148,7 +136,6 @@ class Config(BaseConfig):
148136
:param logging_type: The type of logging to use
149137
:param log_hostname: Whether to include the hostname in the logs
150138
:param force_apply: Whether to force apply the configuration
151-
:param datatype_presets: Presets to be applied for default types of data
152139
:param log_colorize: Whether to colorize the logs
153140
:param bytes_encoding: (Deprecated)
154141
:param output_prefix: The prefix for the output table and output field key
@@ -178,8 +165,6 @@ class Config(BaseConfig):
178165

179166
force_apply: bool = False
180167

181-
datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets)
182-
183168
output_prefix: str = "_outputs__"
184169
vector_search_kwargs: t.Dict = dc.field(default_factory=dict)
185170

superduper/base/datatype.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -432,26 +432,13 @@ class Vector(BaseVector):
432432
:param shape: Shape of array.
433433
"""
434434

435-
@cached_property
436-
def datatype_impl(self):
437-
type_ = CFG.datatype_presets.vector
438-
if type_ is None:
439-
return NativeVector(shape=self.shape, dtype=self.dtype)
440-
441-
module = '.'.join(type_.split('.')[:-1])
442-
cls = type_.split('.')[-1]
443-
datatype = getattr(import_module(module), cls)
444-
if inspect.isclass(datatype):
445-
datatype = datatype(dtype=self.dtype, shape=self.shape)
446-
return datatype
447-
448435
def encode_data(self, item, context):
449436
"""Encode the given item into a bytes-like object or reference.
450437
451438
:param item: The object/instance to encode.
452439
:param context: A context object containing caches.
453440
"""
454-
return self.datatype_impl.encode_data(item, context)
441+
return item
455442

456443
def decode_data(self, item, builds, db):
457444
"""Decode the item from `bytes`.
@@ -460,7 +447,7 @@ def decode_data(self, item, builds, db):
460447
:param builds: The build cache.
461448
:param db: The Datalayer.
462449
"""
463-
return self.datatype_impl.decode_data(item, builds, db)
450+
return item
464451

465452

466453
class JSON(BaseDataType):

0 commit comments

Comments
 (0)