Source code for docarray.array.storage.weaviate.backend

import copy
import uuid
from dataclasses import dataclass, field, asdict
from typing import (
    Iterable,
    Dict,
    Optional,
    TYPE_CHECKING,
    Union,
    Tuple,
    List,
)

import numpy as np
import weaviate

from docarray import Document
from docarray.helper import dataclass_from_dict, filter_dict, _safe_cast_int
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
from docarray.array.storage.registry import _REGISTRY

if TYPE_CHECKING:  # pragma: no cover
    from docarray.typing import ArrayType, DocumentArraySourceType


[docs]@dataclass class WeaviateConfig: """This class stores the config variables to initialize connection to the Weaviate server""" host: Optional[str] = field(default='localhost') port: Optional[int] = field(default=8080) protocol: Optional[str] = field(default='http') name: Optional[str] = None list_like: bool = True serialize_config: Dict = field(default_factory=dict) n_dim: Optional[int] = None # deprecated, not used anymore since weaviate 1.10 # vectorIndexConfig parameters ef: Optional[int] = None ef_construction: Optional[int] = None timeout_config: Optional[Tuple[int, int]] = field(default=(10, 60)) max_connections: Optional[int] = None dynamic_ef_min: Optional[int] = None dynamic_ef_max: Optional[int] = None dynamic_ef_factor: Optional[int] = None vector_cache_max_objects: Optional[int] = None flat_search_cutoff: Optional[int] = None cleanup_interval_seconds: Optional[int] = None skip: Optional[bool] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None distance: Optional[str] = None # weaviate python client parameters batch_size: Optional[int] = field(default=50) dynamic_batching: Optional[bool] = field(default=False) root_id: bool = True def __post_init__(self): if isinstance(self.timeout_config, list): self.timeout_config = tuple(self.timeout_config)
_banned_classname_chars = [ '[', ' ', '"', '*', '\\', '<', '|', ',', '>', '/', '?', ']', '@', '.', ] def _sanitize_class_name(name): new_name = name for char in _banned_classname_chars: new_name = new_name.replace(char, '') return new_name
[docs]class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" TYPE_MAP = { 'str': TypeMap(type='string', converter=str), 'float': TypeMap(type='number', converter=float), 'int': TypeMap(type='int', converter=_safe_cast_int), } def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, config: Optional[Union[WeaviateConfig, Dict]] = None, **kwargs, ): """Initialize weaviate storage. :param docs: the list of documents to initialize to :param config: the config object used to ininitialize connection to weaviate server :param kwargs: extra keyword arguments :raises ValueError: only one of name or docs can be used for initialization, raise an error if both are provided """ config = copy.deepcopy(config) if not config: config = WeaviateConfig() elif isinstance(config, dict): config = dataclass_from_dict(WeaviateConfig, config) self._serialize_config = config.serialize_config if config.name and config.name != config.name.capitalize(): raise ValueError( 'Weaviate class name has to be capitalized. ' 'Please capitalize when declaring the name field in config.' ) self._client = weaviate.Client( f'{config.protocol}://{config.host}:{config.port}', timeout_config=config.timeout_config, ) self._config = config self._config.columns = self._normalize_columns(self._config.columns) self._schemas = self._load_or_create_weaviate_schema() self._list_like = config.list_like _REGISTRY[self.__class__.__name__][self._class_name].append(self) super()._init_storage(_docs, **kwargs) # To align with Sqlite behavior; if `docs` is not `None` and table name # is provided, :class:`DocumentArraySqlite` will clear the existing # table and load the given `docs` if _docs is None: return elif isinstance(_docs, Iterable): self.clear() self.extend(_docs) else: self.clear() if isinstance(_docs, Document): self.append(_docs) def _ensure_unique_config( self, config_root: dict, config_subindex: dict, config_joined: dict, subindex_name: str, ) -> dict: if 'name' not in config_subindex: unique_name = _sanitize_class_name( config_joined['name'] + 'subindex' + subindex_name ) config_joined['name'] = unique_name return config_joined def _get_weaviate_class_name(self) -> str: """Generate the class/schema name using the ``uuid1`` module with some formatting to tailor to weaviate class name convention :return: string representing the name of weaviate class/schema name of this :class:`DocumentArrayWeaviate` object """ return f'Class{uuid.uuid4().hex}' def _get_schema_by_name(self, cls_name: str) -> Dict: """Return the schema dictionary object with the class name Content of the all dictionaries by this method are the same except the name of the weaviate's ``class`` :param cls_name: the name of the schema/class in weaviate :return: the schema dictionary """ # TODO: ideally we should only use one schema. this will allow us to deal with # consistency better hnsw_config = { 'ef': self._config.ef, 'efConstruction': self._config.ef_construction, 'maxConnections': self._config.max_connections, 'dynamicEfMin': self._config.dynamic_ef_min, 'dynamicEfMax': self._config.dynamic_ef_max, 'dynamicEfFactor': self._config.dynamic_ef_factor, 'vectorCacheMaxObjects': self._config.vector_cache_max_objects, 'flatSearchCutoff': self._config.flat_search_cutoff, 'cleanupIntervalSeconds': self._config.cleanup_interval_seconds, 'skip': self._config.skip, 'distance': self._config.distance, } base_classes = { 'classes': [ { 'class': cls_name, "vectorizer": "none", 'vectorIndexConfig': {'skip': False, **filter_dict(hnsw_config)}, 'properties': [ { 'dataType': ['blob'], 'name': '_serialized', 'indexInverted': False, }, ], }, { 'class': cls_name + 'Meta', "vectorizer": "none", 'vectorIndexConfig': {'skip': True}, 'properties': [ { 'dataType': ['string[]'], 'name': '_offset2ids', 'indexInverted': False, }, ], }, ] } for col, coltype in self._config.columns.items(): new_property = { 'dataType': [self._map_type(coltype)], 'name': col, 'indexInverted': True, } base_classes['classes'][0]['properties'].append(new_property) return base_classes def _load_or_create_weaviate_schema(self): """Create a new weaviate schema for this :class:`DocumentArrayWeaviate` object if not present in weaviate or if ``self._config.name`` is None. If ``self._config.name`` is provided and not None and schema with the specified name exists in weaviate, then load the object with the given ``self._config.name`` :return: the schemas of this :class`DocumentArrayWeaviate` object and its meta """ if not self._config.name: name_candidate = self._get_weaviate_class_name() doc_schemas = self._get_schema_by_name(name_candidate) while self._client.schema.contains(doc_schemas): name_candidate = self._get_weaviate_class_name() doc_schemas = self._get_schema_by_name(name_candidate) self._client.schema.create(doc_schemas) self._config.name = name_candidate return doc_schemas doc_schemas = self._get_schema_by_name(self._config.name) if self._client.schema.contains(doc_schemas): return doc_schemas self._client.schema.create(doc_schemas) return doc_schemas def _update_offset2ids_meta(self): """Update the offset2ids in weaviate the the current local version""" if self._offset2ids_wid is not None and self._client.data_object.exists( self._offset2ids_wid ): self._client.data_object.update( data_object={'_offset2ids': self._offset2ids.ids}, class_name=self._meta_name, uuid=self._offset2ids_wid, ) else: self._offset2ids_wid = str(uuid.uuid1()) self._client.data_object.create( data_object={'_offset2ids': self._offset2ids.ids}, class_name=self._meta_name, uuid=self._offset2ids_wid, ) def _get_offset2ids_meta(self) -> Tuple[List, str]: """Return the offset2ids stored in weaviate along with the name of the schema/class in weaviate that stores meta information of this object :return: a tuple with first element as a list of offset2ids and second element being name of weaviate class/schema of the meta object :raises ValueError: error is raised if meta class name is not defined """ if not self._meta_name: raise ValueError('meta object is not defined') resp = ( self._client.query.get(self._meta_name, ['_offset2ids', '_additional {id}']) .do() .get('data', {}) .get('Get', {}) .get(self._meta_name, []) ) if not resp: return [], None elif len(resp) == 1: return resp[0]['_offset2ids'], resp[0]['_additional']['id'] else: raise ValueError('received multiple meta copies which is invalid') @property def name(self): """An alias to _class_name that returns the id/name of the class in the weaviate of this :class:`DocumentArrayWeaviate` :return: name of weaviate class/schema of this :class:`DocumentArrayWeaviate` """ return self._class_name @property def _class_name(self): """Return the name of the class in weaviate of this :class:`DocumentArrayWeaviate :return: name of weaviate class/schema of this :class:`DocumentArrayWeaviate` """ if not self._schemas: return None return self._schemas['classes'][0]['class'] @property def _meta_name(self): """Return the name of the class in weaviate that stores the meta information of this :class:`DocumentArrayWeaviate` :return: name of weaviate class/schema of class that stores the meta information """ # TODO: remove this after we combine the meta info to the DocumentArray class if not self._schemas: return None return self._schemas['classes'][1]['class'] @property def _class_schema(self) -> Optional[Dict]: """Return the schema dictionary of this :class:`DocumentArrayWeaviate`'s weaviate schema :return: the dictionary representing this weaviate schema """ if not self._schemas: return None return self._schemas['classes'][0] @property def _meta_schema(self): """Return the schema dictionary of this weaviate schema that stores this object's meta :return: the dictionary representing a meta object's weaviate schema """ if not self._schemas and len(self._schemas) < 2: return None return self._schemas['classes'][1] def _doc2weaviate_create_payload(self, value: 'Document'): """Return the payload to store :class:`Document` into weaviate :param value: document to create a payload for :return: the payload dictionary """ extra_columns = { col: self._map_column(value.tags.get(col), col_type) for col, col_type in self._config.columns.items() } return dict( data_object={ '_serialized': value.to_base64(**self._serialize_config), **extra_columns, }, class_name=self._class_name, uuid=self._map_id(value.id), vector=self._map_embedding(value.embedding), ) @staticmethod def _map_id(doc_id: str): # if doc_id is a random ID in hex format, just translate back to UUID str # otherwise, create UUID5 from doc_id try: return str(uuid.UUID(hex=doc_id)) except ValueError: return str(uuid.uuid5(uuid.NAMESPACE_URL, doc_id)) def _map_embedding(self, embedding: 'ArrayType'): if embedding is not None: from docarray.math.ndarray import to_numpy_array embedding = to_numpy_array(embedding) if embedding.ndim > 1: embedding = np.asarray(embedding).squeeze() # Weaviate expects vector to have dim 2 at least # or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector # hence we cast it to list of a single element if len(embedding) == 1: embedding = [embedding[0]] else: embedding = None return embedding def __getstate__(self): d = dict(self.__dict__) del d['_client'] return d def __setstate__(self, state): self.__dict__ = state self._client = weaviate.Client( f'{state["_config"].protocol}://{state["_config"].host}:{state["_config"].port}' )