Source code for docarray.array.storage.milvus.getsetdel

from typing import Iterable, Dict, TYPE_CHECKING

import numpy as np

from docarray import DocumentArray
from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin
from docarray.array.storage.base.helper import Offset2ID
from docarray.array.storage.milvus.backend import (
    _always_true_expr,
    _ids_to_milvus_expr,
    _batch_list,
)

if TYPE_CHECKING:
    from docarray import Document, DocumentArray


[docs]class GetSetDelMixin(BaseGetSetDelMixin): def _get_doc_by_id(self, _id: str) -> 'Document': # to be implemented return self._get_docs_by_ids([_id])[0] def _del_doc_by_id(self, _id: str): # to be implemented self._del_docs_by_ids([_id]) def _set_doc_by_id(self, _id: str, value: 'Document', **kwargs): # to be implemented self._set_docs_by_ids([_id], [value], None, **kwargs) def _load_offset2ids(self): if self._list_like: collection = self._offset2id_collection kwargs = self._update_kwargs_from_config('consistency_level', **dict()) with self.loaded_collection(collection): res = collection.query( expr=_always_true_expr('document_id'), output_fields=['offset', 'document_id'], **kwargs, ) sorted_res = sorted(res, key=lambda k: int(k['offset'])) self._offset2ids = Offset2ID([r['document_id'] for r in sorted_res]) else: self._offset2ids = Offset2ID([], list_like=self._list_like) def _save_offset2ids(self): if self._list_like: # delete old entries self._clear_offset2ids_milvus() # insert current entries ids = self._offset2ids.ids if not ids: return offsets = [str(i) for i in range(len(ids))] dummy_vectors = [np.zeros(1) for _ in range(len(ids))] collection = self._offset2id_collection collection.insert([offsets, ids, dummy_vectors]) def _get_docs_by_ids(self, ids: 'Iterable[str]', **kwargs) -> 'DocumentArray': if not ids: return DocumentArray() ids = list(ids) kwargs = self._update_kwargs_from_config('consistency_level', **kwargs) kwargs = self._update_kwargs_from_config('batch_size', **kwargs) with self.loaded_collection(): docs = DocumentArray() for id_batch in _batch_list(ids, kwargs['batch_size']): res = self._collection.query( expr=f'document_id in {_ids_to_milvus_expr(id_batch)}', output_fields=['serialized'], **kwargs, ) if not res: raise KeyError(f'No documents found for ids {ids}') docs.extend(self._docs_from_query_response(res)) # sort output docs according to input id sorting return DocumentArray([docs[d] for d in ids]) def _del_docs_by_ids(self, ids: 'Iterable[str]', **kwargs) -> 'DocumentArray': kwargs = self._update_kwargs_from_config('consistency_level', **kwargs) kwargs = self._update_kwargs_from_config('batch_size', **kwargs) for id_batch in _batch_list(list(ids), kwargs['batch_size']): self._collection.delete( expr=f'document_id in {_ids_to_milvus_expr(id_batch)}', **kwargs ) def _set_docs_by_ids( self, ids, docs: 'Iterable[Document]', mismatch_ids: 'Dict', **kwargs ): kwargs = self._update_kwargs_from_config('consistency_level', **kwargs) kwargs = self._update_kwargs_from_config('batch_size', **kwargs) # delete old entries for id_batch in _batch_list(list(ids), kwargs['batch_size']): self._collection.delete( expr=f'document_id in {_ids_to_milvus_expr(id_batch)}', **kwargs, ) for docs_batch in _batch_list(list(docs), kwargs['batch_size']): # insert new entries payload = self._docs_to_milvus_payload(docs_batch) self._collection.insert(payload, **kwargs) def _clear_storage(self): self._collection.drop() self._create_or_reuse_collection() self._clear_offset2ids_milvus() def _clear_offset2ids_milvus(self): self._offset2id_collection.drop() self._create_or_reuse_offset2id_collection()