11import functools
2- import json
32import hashlib
3+ import json
44import typing as t
55import uuid
66from abc import ABC , abstractmethod
77
88from superduper import CFG , logging
99from superduper .base import exceptions
10- from superduper .base .constant import KEY_BLOBS , KEY_BUILDS , KEY_FILES
11- from superduper .base .datatype import JSON , BaseDataType
10+ from superduper .base .constant import KEY_BLOBS , KEY_BUILDS , KEY_FILES , KEY_PATH
11+ from superduper .base .datatype import JSON , BaseDataType , NativeVector , Vector
1212from superduper .base .document import Document
1313from superduper .base .query import Query
1414
@@ -31,6 +31,7 @@ class BaseDataBackend(ABC):
3131
3232 batched : bool = False
3333 id_field : str = 'id'
34+ vector_impl : t .Type = NativeVector
3435
3536 # TODO plugin not required
3637 # TODO flavour required?
@@ -160,6 +161,64 @@ def insert(self, table: str, documents: t.Sequence[t.Dict]) -> t.List[str]:
160161 :param documents: The documents to insert.
161162 """
162163
164+ def do_replace (self , table : str , condition : t .Dict , r : t .Dict ):
165+ """Replace data in the database.
166+
167+ This method is a wrapper around the `replace` method to ensure
168+ that the datatype is set to `None` by default.
169+
170+ :param table: The table to insert into.
171+ :param condition: The condition to update.
172+ :param r: The document to replace.
173+ """
174+ schema = self .get_schema (self .db [table ])
175+
176+ if isinstance (r , Document ) and not schema .trivial :
177+ schema = self .get_schema (self .db [table ])
178+ r = Document (r ).encode (schema = schema , db = self .db )
179+ if r .get (KEY_BLOBS ) or r .get (KEY_FILES ):
180+ self .db .artifact_store .save_artifact (r )
181+
182+ try :
183+ r .pop (KEY_BUILDS )
184+ except KeyError :
185+ pass
186+ try :
187+ r .pop (KEY_BLOBS )
188+ except KeyError :
189+ pass
190+ try :
191+ r .pop (KEY_FILES )
192+ except KeyError :
193+ pass
194+ try :
195+ r .pop (KEY_PATH )
196+ except KeyError :
197+ pass
198+
199+ vector_datatypes = {
200+ k : self .vector_impl (dtype = v .dtype , shape = v .shape )
201+ for k , v in schema .fields .items ()
202+ if isinstance (v , Vector )
203+ }
204+ if vector_datatypes :
205+ for k in vector_datatypes :
206+ if k in r :
207+ r [k ] = vector_datatypes [k ].encode_data (r [k ], None )
208+
209+ if not self .json_native :
210+ json_fields = [
211+ k
212+ for k in schema .fields
213+ if getattr (schema .fields [k ], 'dtype' , None ) == 'json'
214+ ]
215+ for k in json_fields :
216+ if k in r :
217+ r [k ] = json .dumps (r [k ])
218+
219+ self .replace (table , condition = condition , r = r )
220+ return
221+
163222 @abstractmethod
164223 def replace (self , table : str , condition : t .Dict , r : t .Dict ) -> t .List [str ]:
165224 """Replace data.
@@ -169,7 +228,14 @@ def replace(self, table: str, condition: t.Dict, r: t.Dict) -> t.List[str]:
169228 :param r: The document to replace.
170229 """
171230
172- def _update (self , table : str , condition : t .Dict , key : str , value : t .Any , datatype : BaseDataType | None = None ):
231+ def do_update (
232+ self ,
233+ table : str ,
234+ condition : t .Dict ,
235+ key : str ,
236+ value : t .Any ,
237+ datatype : BaseDataType | None = None ,
238+ ):
173239 """Update data in the database.
174240
175241 This method is a wrapper around the `update` method to ensure
@@ -179,11 +245,17 @@ def _update(self, table: str, condition: t.Dict, key: str, value: t.Any, datatyp
179245 :param condition: The condition to update.
180246 :param key: The key to update.
181247 :param value: The value to update.
248+ :param datatype: The datatype to use for encoding the value.
182249 """
183250 if datatype is not None :
184251 value = datatype .encode_data (value , None )
185252 if datatype .dtype == 'json' and not self .json_native :
186253 value = json .dumps (value )
254+ elif isinstance (datatype , Vector ):
255+ vector_datatype = self .vector_impl (
256+ dtype = datatype .dtype , shape = datatype .shape
257+ )
258+ value = vector_datatype .encode_data (value , None )
187259 self .update (table , condition , key , value )
188260
189261 @abstractmethod
@@ -268,12 +340,30 @@ def _wrap_results(self, query: Query, result, schema, raw: bool = False):
268340 r ['_source' ] = str (r ['_source' ])
269341
270342 if not self .json_native :
271- json_fields = [k for k in schema .fields if getattr (schema .fields [k ], 'dtype' , None ) == 'json' ]
343+ json_fields = [
344+ k
345+ for k in schema .fields
346+ if getattr (schema .fields [k ], 'dtype' , None ) == 'json'
347+ ]
272348 for r in result :
273349 for k in json_fields :
274350 if k in r and isinstance (r [k ], str ):
275351 r [k ] = json .loads (r [k ])
276352
353+ vector_datatypes = {
354+ k : self .vector_impl (dtype = v .dtype , shape = v .shape )
355+ for k , v in schema .fields .items ()
356+ if isinstance (v , Vector )
357+ }
358+
359+ if vector_datatypes :
360+ for r in result :
361+ for k in vector_datatypes :
362+ if k in r and r [k ] is not None :
363+ r [k ] = vector_datatypes [k ].decode_data (
364+ r [k ], builds = {}, db = self .db
365+ )
366+
277367 if raw :
278368 return result
279369
@@ -311,8 +401,13 @@ def get_schema(self, query) -> 'Schema':
311401
312402 return base_schema
313403
314- def _do_insert (self , table , documents , raw : bool = False ):
404+ def do_insert (self , table , documents , raw : bool = False ):
405+ """Insert data into the database.
315406
407+ :param table: The table to insert into.
408+ :param documents: The documents to insert.
409+ :param raw: If ``True``, insert raw documents.
410+ """
316411 schema = self .get_schema (self .db [table ])
317412
318413 if not raw and not schema .trivial :
@@ -342,8 +437,23 @@ def _do_insert(self, table, documents, raw: bool = False):
342437 pass
343438 documents [i ] = r
344439
440+ vector_datatypes = {
441+ k : self .vector_impl (dtype = v .dtype , shape = v .shape )
442+ for k , v in schema .fields .items ()
443+ if isinstance (v , Vector )
444+ }
445+ if vector_datatypes :
446+ for r in documents :
447+ for k in vector_datatypes :
448+ if k in r :
449+ r [k ] = vector_datatypes [k ].encode_data (r [k ], None )
450+
345451 if not self .json_native :
346- json_fields = [k for k in schema .fields if getattr (schema .fields [k ], 'dtype' , None ) == 'json' ]
452+ json_fields = [
453+ k
454+ for k in schema .fields
455+ if getattr (schema .fields [k ], 'dtype' , None ) == 'json'
456+ ]
347457 for r in documents :
348458 for k in json_fields :
349459 if k in r :
@@ -490,6 +600,8 @@ class KeyedDatabackend(BaseDataBackend):
490600 :param flavour: Flavour of the databackend.
491601 """
492602
603+ json_native : bool = True
604+
493605 @abstractmethod
494606 def __getitem__ (self , key : t .Tuple [str , str , str ]) -> t .Dict :
495607 pass
0 commit comments