Source code for docarray.array.storage.qdrant.getsetdel

from abc import abstractmethod
from typing import Iterable, Iterator

from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models.models import (
    PointIdsList,
    PointStruct,
    VectorParams,
)

from docarray import Document
from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin
from docarray.array.storage.base.helper import Offset2ID


[docs]class GetSetDelMixin(BaseGetSetDelMixin): @property @abstractmethod def client(self) -> QdrantClient: raise NotImplementedError() @property @abstractmethod def serialization_config(self) -> dict: raise NotImplementedError() @property @abstractmethod def n_dim(self) -> int: raise NotImplementedError() @property @abstractmethod def collection_name(self) -> str: raise NotImplementedError() @property @abstractmethod def scroll_batch_size(self) -> int: raise NotImplementedError() def _upload_batch(self, docs: Iterable['Document']): batch = [] for doc in docs: batch.append(self._document_to_qdrant(doc)) if len(batch) > self.scroll_batch_size: self.client.upsert( collection_name=self.collection_name, points=batch, wait=True, ) batch = [] if len(batch) > 0: self.client.upsert( collection_name=self.collection_name, wait=True, points=batch, ) def _qdrant_to_document(self, qdrant_record: dict) -> 'Document': return Document.from_base64( qdrant_record['_serialized'], **self.serialization_config ) def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct': extra_columns = { col: doc.tags.get(col) for col, _ in self._config.columns.items() } return PointStruct( id=self._map_id(doc.id), payload=dict( _serialized=doc.to_base64(**self.serialization_config), **extra_columns ), vector=self._map_embedding(doc.embedding), ) def _get_doc_by_id(self, _id: str) -> 'Document': try: resp = self.client.retrieve( collection_name=self.collection_name, ids=[self._map_id(_id)] ) if len(resp) == 0: raise KeyError(_id) return self._qdrant_to_document(resp[0].payload) except UnexpectedResponse as response_error: if response_error.status_code in [404, 400]: raise KeyError(_id) def _del_doc_by_id(self, _id: str): self.client.delete( collection_name=self.collection_name, points_selector=PointIdsList(points=[self._map_id(_id)]), wait=True, ) def _set_doc_by_id(self, _id: str, value: 'Document'): if _id != value.id: self._del_doc_by_id(_id) self.client.upsert( collection_name=self.collection_name, wait=True, points=[self._document_to_qdrant(value)], )
[docs] def scan(self) -> Iterator['Document']: offset = None while True: response, next_page = self.client.scroll( collection_name=self.collection_name, offset=offset, limit=self.scroll_batch_size, with_payload=['_serialized'], with_vectors=False, ) for point in response: yield self._qdrant_to_document(point.payload) if next_page: offset = next_page else: break
def _load_offset2ids(self): if self._list_like: ids = self._get_offset2ids_meta() self._offset2ids = Offset2ID(ids, list_like=self._list_like) else: self._offset2ids = Offset2ID([], list_like=self._list_like) def _save_offset2ids(self): if self._list_like: self._update_offset2ids_meta() def _clear_storage(self): self.client.recreate_collection( self.collection_name, vectors_config=VectorParams( size=self.n_dim, distance=self.distance, ), )