import copy
import uuid
from dataclasses import dataclass, field
import warnings
from typing import (
Dict,
Optional,
TYPE_CHECKING,
Union,
List,
Iterable,
Any,
Tuple,
Mapping,
)
import numpy as np
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
from docarray import Document
from docarray.helper import dataclass_from_dict, _safe_cast_int
if TYPE_CHECKING: # pragma: no cover
from docarray.typing import (
DocumentArraySourceType,
)
from docarray.typing import DocumentArraySourceType, ArrayType
[docs]@dataclass
class ElasticConfig:
n_dim: int # dims in elastic
distance: str = 'cosine' # similarity in elastic
hosts: Union[
str, List[Union[str, Mapping[str, Union[str, int]]]], None
] = 'http://localhost:9200'
index_name: Optional[str] = None
list_like: bool = True
es_config: Dict[str, Any] = field(default_factory=dict)
index_text: bool = False
tag_indices: List[str] = field(default_factory=list)
batch_size: int = 64
ef_construction: Optional[int] = None
m: Optional[int] = None
columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None
root_id: bool = True
_banned_indexname_chars = ['[', ' ', '"', '*', '\\', '<', '|', ',', '>', '/', '?', ']']
def _sanitize_index_name(name):
new_name = name
for char in _banned_indexname_chars:
new_name = new_name.replace(char, '')
return new_name
[docs]class BackendMixin(BaseBackendMixin):
"""Provide necessary functions to enable this storage backend."""
TYPE_MAP = {
'str': TypeMap(type='text', converter=str),
'float': TypeMap(type='float', converter=float),
'int': TypeMap(type='integer', converter=_safe_cast_int),
'double': TypeMap(type='double', converter=float),
'long': TypeMap(type='long', converter=_safe_cast_int),
'bool': TypeMap(type='boolean', converter=bool),
}
def _init_storage(
self,
_docs: Optional['DocumentArraySourceType'] = None,
config: Optional[Union[ElasticConfig, Dict]] = None,
**kwargs,
):
config = copy.deepcopy(config)
if not config:
raise ValueError('Empty config is not allowed for Elastic storage')
elif isinstance(config, dict):
config = dataclass_from_dict(ElasticConfig, config)
if config.index_name is None:
id = uuid.uuid4().hex
config.index_name = 'index_name__' + id
self._index_name_offset2id = 'offset2id__' + config.index_name
self._config = config
self._config.columns = self._normalize_columns(self._config.columns)
self.n_dim = self._config.n_dim
self._list_like = self._config.list_like
self._client = self._build_client()
self._build_index()
self._build_offset2id_index()
# Note super()._init_storage() calls _load_offset2ids which calls _get_offset2ids_meta
super()._init_storage(**kwargs)
if _docs is None:
return
elif isinstance(_docs, Iterable):
self.extend(_docs)
else:
if isinstance(_docs, Document):
self.append(_docs)
def _ensure_unique_config(
self,
config_root: dict,
config_subindex: dict,
config_joined: dict,
subindex_name: str,
) -> dict:
if 'index_name' not in config_subindex:
unique_index_name = _sanitize_index_name(
config_joined['index_name'] + '_subindex_' + subindex_name
)
config_joined['index_name'] = unique_index_name
return config_joined
def _build_offset2id_index(self):
if self._list_like and not self._client.indices.exists(
index=self._index_name_offset2id
):
self._client.indices.create(index=self._index_name_offset2id, ignore=[404])
def _build_schema_from_elastic_config(self, elastic_config):
da_schema = {
'mappings': {
'dynamic': 'true',
'_source': {'enabled': 'true'},
'properties': {
'embedding': {
'type': 'dense_vector',
'dims': elastic_config.n_dim,
'index': 'true',
'similarity': elastic_config.distance,
},
'text': {'type': 'text', 'index': elastic_config.index_text},
},
}
}
if elastic_config.tag_indices:
for index in elastic_config.tag_indices:
da_schema['mappings']['properties'][index] = {
'type': 'text',
'index': True,
}
for col, coltype in self._config.columns.items():
da_schema['mappings']['properties'][col] = {
'type': self._map_type(coltype),
'index': True,
}
if self._config.m or self._config.ef_construction:
index_options = {
'type': 'hnsw',
'm': self._config.m or 16,
'ef_construction': self._config.ef_construction or 100,
}
da_schema['mappings']['properties']['embedding'][
'index_options'
] = index_options
return da_schema
def _build_client(self):
client = Elasticsearch(
hosts=self._config.hosts,
**self._config.es_config,
)
return client
def _build_index(self):
schema = self._build_schema_from_elastic_config(self._config)
if not self._client.indices.exists(index=self._config.index_name):
self._client.indices.create(
index=self._config.index_name, mappings=schema['mappings']
)
self._client.indices.refresh(index=self._config.index_name)
def _send_requests(self, request, **kwargs) -> List[Dict]:
"""Send bulk request to Elastic and gather the successful info"""
# for backward compatibility
if 'chunk_size' not in kwargs:
kwargs['chunk_size'] = self._config.batch_size
accumulated_info = []
for success, info in parallel_bulk(
self._client,
request,
raise_on_error=False,
raise_on_exception=False,
**kwargs,
):
if not success:
warnings.warn(str(info))
else:
accumulated_info.append(info)
return accumulated_info
def _refresh(self, index_name):
self._client.indices.refresh(index=index_name)
def _doc_id_exists(self, doc_id):
return self._client.exists(index=self._config.index_name, id=doc_id)
def _update_offset2ids_meta(self):
"""Update the offset2ids in elastic"""
if self._client.indices.exists(index=self._index_name_offset2id):
requests = [
{
'_op_type': 'index',
'_id': offset_, # note offset goes here because it's what we want to get by
'_index': self._index_name_offset2id,
'blob': f'{id_}',
} # id here
for offset_, id_ in enumerate(self._offset2ids.ids)
]
self._send_requests(requests)
self._client.indices.refresh(index=self._index_name_offset2id)
# Clean trailing unused offsets
offset_count = self._client.count(index=self._index_name_offset2id)
unused_offsets = range(len(self._offset2ids.ids), offset_count['count'])
if len(unused_offsets) > 0:
requests = [
{
'_op_type': 'delete',
'_id': offset_, # note offset goes here because it's what we want to get by
'_index': self._index_name_offset2id,
}
for offset_ in unused_offsets
]
self._send_requests(requests)
self._client.indices.refresh(index=self._index_name_offset2id)
def _get_offset2ids_meta(self) -> List:
"""Return the offset2ids stored in elastic
:return: a list containing ids
:raises ValueError: error is raised if index _client is not found or no offsets are found
"""
if not self._client:
raise ValueError('Elastic client does not exist')
n_docs = self._client.count(index=self._index_name_offset2id)["count"]
if n_docs != 0:
offsets = [x for x in range(n_docs)]
resp = self._client.mget(index=self._index_name_offset2id, ids=offsets)
ids = [x['_source']['blob'] for x in resp['docs']]
return ids
else:
return []
def _map_embedding(self, embedding: 'ArrayType') -> List[float]:
from docarray.math.helper import EPSILON
if embedding is None:
embedding = np.zeros(self.n_dim) + EPSILON
else:
from docarray.math.ndarray import to_numpy_array
embedding = to_numpy_array(embedding)
if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()
if np.all(embedding == 0):
embedding = embedding + EPSILON
return embedding # .tolist()
def __getstate__(self):
d = dict(self.__dict__)
del d['_client']
return d
def __setstate__(self, state):
self.__dict__ = state
self._client = self._build_client()