import json
import os
import pathlib
import random
import sys
import uuid
import warnings
from os.path import expanduser
from typing import Any, Dict, Optional, Sequence, Tuple, Union, TYPE_CHECKING
from collections import Counter
if TYPE_CHECKING: # pragma: no cover
from docarray import Document, DocumentArray
__resources_path__ = os.path.join(
os.path.dirname(
sys.modules.get('docarray').__file__ if 'docarray' in sys.modules else __file__
),
'resources',
)
__cache_path__ = f'{expanduser("~")}/.cache/{__package__}'
[docs]def typename(obj):
"""
Get the typename of object.
:param obj: Target object.
:return: Typename of the obj.
"""
if not isinstance(obj, type):
obj = obj.__class__
try:
return f'{obj.__module__}.{obj.__name__}'
except AttributeError:
return str(obj)
[docs]def deprecate_by(new_fn, removed_at: str):
"""A helper function to label deprecated function
Usage: old_fn_name = deprecate_by(new_fn)
:param new_fn: the new function
:param removed_at: removed at which version
:return: a wrapped function with old function name
"""
def _f(*args, **kwargs):
import inspect
old_fn_name = inspect.stack()[1][4][0].strip().split("=")[0].strip()
warnings.warn(
f'`{old_fn_name}` is renamed to `.{new_fn.__name__}()` with the same usage, please use the latter instead. The old function will be removed in {removed_at}.',
FutureWarning,
)
return new_fn(*args, **kwargs)
return _f
[docs]def dunder_get(_dict: Any, key: str) -> Any:
"""Returns value for a specified dunderkey
A "dunderkey" is just a fieldname that may or may not contain
double underscores (dunderscores!) for referencing nested keys in
a dict. eg::
>>> data = {'a': {'b': 1}}
>>> dunder_get(data, 'a__b')
key 'b' can be referrenced as 'a__b'
:param _dict: (dict, list, struct or object) which we want to index into
:param key: (str) that represents a first level or nested key in the dict
:return: (mixed) value corresponding to the key
"""
if not _dict:
return None
try:
part1, part2 = key.split('__', 1)
except ValueError:
part1, part2 = key, ''
try:
part1 = int(part1) # parse int parameter
except ValueError:
pass
if isinstance(part1, int):
result = _dict[part1]
elif isinstance(_dict, dict):
if part1 in _dict:
result = _dict[part1]
else:
result = None
elif isinstance(_dict, Sequence):
result = _dict[part1]
else:
result = getattr(_dict, part1)
return dunder_get(result, part2) if part2 else result
[docs]def random_identity(use_uuid1: bool = False) -> str:
"""
Generate random UUID.
..note::
A MAC address or time-based ordering (UUID1) can afford increased database performance, since it's less work
to sort numbers closer-together than those distributed randomly (UUID4) (see here).
A second related issue, is that using UUID1 can be useful in debugging, even if origin data is lost or not
explicitly stored.
:param use_uuid1: use UUID1 instead of UUID4. This is the default Document ID generator.
:return: A random UUID.
"""
return random_uuid(use_uuid1).hex
[docs]def random_uuid(use_uuid1: bool = False) -> uuid.UUID:
"""
Get a random UUID.
:param use_uuid1: Use UUID1 if True, else use UUID4.
:return: A random UUID.
"""
return uuid.uuid1() if use_uuid1 else uuid.uuid4()
[docs]def download_mermaid_url(mermaid_url, output) -> None:
"""
Download the jpg image from mermaid_url.
:param mermaid_url: The URL of the image.
:param output: A filename specifying the name of the image to be created, the suffix svg/jpg determines the file type of the output image.
"""
from urllib.request import Request, urlopen
try:
req = Request(mermaid_url, headers={'User-Agent': 'Mozilla/5.0'})
with open(output, 'wb') as fp:
fp.write(urlopen(req).read())
except:
raise RuntimeError('Invalid or too-complicated graph')
[docs]def get_full_version() -> Dict:
"""
Get the version of libraries used in Jina and environment variables.
:return: Version information and environment variables
"""
import google.protobuf, platform
from docarray import __version__
from google.protobuf.internal import api_implementation
from uuid import getnode
return {
'docarray': __version__,
'protobuf': google.protobuf.__version__,
'proto-backend': api_implementation.Type(),
'python': platform.python_version(),
'platform': platform.system(),
'platform-release': platform.release(),
'platform-version': platform.version(),
'architecture': platform.machine(),
'processor': platform.processor(),
'uid': getnode(),
'session-id': str(random_uuid(use_uuid1=True)),
'ci-vendor': get_ci_vendor(),
}
[docs]def get_ci_vendor() -> str:
with open(os.path.join(__resources_path__, 'ci-vendors.json')) as fp:
all_cis = json.load(fp)
for c in all_cis:
if isinstance(c['env'], str) and c['env'] in os.environ:
return c['constant']
elif isinstance(c['env'], dict):
for k, v in c['env'].items():
if os.environ.get(k, None) == v:
return c['constant']
elif isinstance(c['env'], list):
for k in c['env']:
if k in os.environ:
return c['constant']
return 'unset'
assigned_ports = set()
unassigned_ports = []
DEFAULT_MIN_PORT = 49153
MAX_PORT = 65535
[docs]def reset_ports():
def _get_unassigned_ports():
# if we are running out of ports, lower default minimum port
if MAX_PORT - DEFAULT_MIN_PORT - len(assigned_ports) < 100:
min_port = int(os.environ.get('JINA_RANDOM_PORT_MIN', '16384'))
else:
min_port = int(
os.environ.get('JINA_RANDOM_PORT_MIN', str(DEFAULT_MIN_PORT))
)
max_port = int(os.environ.get('JINA_RANDOM_PORT_MAX', str(MAX_PORT)))
return set(range(min_port, max_port + 1)) - set(assigned_ports)
unassigned_ports.clear()
assigned_ports.clear()
unassigned_ports.extend(_get_unassigned_ports())
random.shuffle(unassigned_ports)
[docs]def random_port() -> Optional[int]:
"""
Get a random available port number.
:return: A random port.
"""
def _random_port():
import socket
def _check_bind(port):
with socket.socket() as s:
try:
s.bind(('', port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return port
except OSError:
return None
_port = None
if len(unassigned_ports) == 0:
reset_ports()
for idx, _port in enumerate(unassigned_ports):
if _check_bind(_port) is not None:
break
else:
raise OSError(
f'can not find an available port in {len(unassigned_ports)} unassigned ports, assigned already {len(assigned_ports)} ports'
)
int_port = int(_port)
unassigned_ports.pop(idx)
assigned_ports.add(int_port)
return int_port
try:
return _random_port()
except OSError:
assigned_ports.clear()
unassigned_ports.clear()
return _random_port()
[docs]class cached_property:
"""The decorator to cache property of a class."""
def __init__(self, func):
"""
Create the :class:`cached_property`.
:param func: Cached function.
"""
self.func = func
def __get__(self, obj, cls):
cached_value = obj.__dict__.get(f'CACHED_{self.func.__name__}', None)
if cached_value is not None:
return cached_value
value = obj.__dict__[f'CACHED_{self.func.__name__}'] = self.func(obj)
return value
def __delete__(self, obj):
cached_value = obj.__dict__.get(f'CACHED_{self.func.__name__}', None)
if cached_value is not None:
if hasattr(cached_value, 'close'):
cached_value.close()
del obj.__dict__[f'CACHED_{self.func.__name__}']
[docs]def compress_bytes(data: bytes, algorithm: Optional[str] = None) -> bytes:
if algorithm == 'lz4':
import lz4.frame
data = lz4.frame.compress(data)
elif algorithm == 'bz2':
import bz2
data = bz2.compress(data)
elif algorithm == 'lzma':
import lzma
data = lzma.compress(data)
elif algorithm == 'zlib':
import zlib
data = zlib.compress(data)
elif algorithm == 'gzip':
import gzip
data = gzip.compress(data)
return data
[docs]def decompress_bytes(data: bytes, algorithm: Optional[str] = None) -> bytes:
if algorithm == 'lz4':
import lz4.frame
data = lz4.frame.decompress(data)
elif algorithm == 'bz2':
import bz2
data = bz2.decompress(data)
elif algorithm == 'lzma':
import lzma
data = lzma.decompress(data)
elif algorithm == 'zlib':
import zlib
data = zlib.decompress(data)
elif algorithm == 'gzip':
import gzip
data = gzip.decompress(data)
return data
[docs]def get_compress_ctx(algorithm: Optional[str] = None, mode: str = 'wb'):
if algorithm == 'lz4':
import lz4.frame
compress_ctx = lambda x: lz4.frame.LZ4FrameFile(x, mode)
elif algorithm == 'gzip':
import gzip
compress_ctx = lambda x: gzip.GzipFile(fileobj=x, mode=mode)
elif algorithm == 'bz2':
import bz2
compress_ctx = lambda x: bz2.BZ2File(x, mode)
elif algorithm == 'lzma':
import lzma
compress_ctx = lambda x: lzma.LZMAFile(x, mode)
else:
compress_ctx = None
return compress_ctx
[docs]def dataclass_from_dict(klass, dikt):
try:
fieldtypes = klass.__annotations__
return klass(**{f: dataclass_from_dict(fieldtypes[f], dikt[f]) for f in dikt})
except AttributeError:
if isinstance(dikt, (tuple, list)):
return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
return dikt
[docs]def protocol_and_compress_from_file_path(
file_path: str,
default_protocol: Optional[str] = None,
default_compress: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""Extract protocol and compression algorithm from a string, use defaults if not found.
:param file_path: path of a file.
:param default_protocol: default serialization protocol used in case not found.
:param default_compress: default compression method used in case not found.
Examples:
>>> protocol_and_compress_from_file_path('./docarray_fashion_mnist.protobuf.gzip')
('protobuf', 'gzip')
>>> protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.protobuf')
('protobuf', None)
>>> protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.gzip')
(None, gzip)
"""
ALLOWED_PROTOCOLS = {'pickle', 'protobuf', 'protobuf-array', 'pickle-array'}
ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'}
protocol = default_protocol
compress = default_compress
file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes]
for extension in file_extensions:
if extension in ALLOWED_PROTOCOLS:
protocol = extension
elif extension in ALLOWED_COMPRESSIONS:
compress = extension
return protocol, compress
[docs]def add_protocol_and_compress_to_file_path(
file_path: str, protocol: Optional[str] = None, compress: Optional[str] = None
) -> str:
"""Creates a new file path with the protocol and compression methods as extensions.
:param file_path: path of a file.
:param protocol: chosen protocol.
:param compress: compression algorithm.
Examples:
>>> add_protocol_and_compress_to_file_path('docarray_fashion_mnist.bin')
'docarray_fashion_mnist.bin'
>>> add_protocol_and_compress_to_file_path('docarray_fashion_mnist', 'protobuf', 'gzip')
'docarray_fashion_mnist.protobuf.gzip'
"""
file_path_extended = file_path
if protocol:
file_path_extended += '.' + protocol
if compress:
file_path_extended += '.' + compress
return file_path_extended
[docs]def filter_dict(d: Dict) -> Dict:
"""Removes `None` values from dict `d`.
:param d: input dict
:return: filtered dict
"""
return dict(filter(lambda item: item[1] is not None, d.items()))
def _safe_cast_int(value: Union[str, int, float]) -> int:
"""Safely cast string and float to an integer
It mainly avoids silently rounding down the float value
:param value: value to be cast
:return: cast integer
"""
if isinstance(value, float) and not value.is_integer():
raise ValueError(f"Can't safely cast {value} to an int")
return int(value)
def _get_array_info(da: 'DocumentArray'):
all_attrs = da._get_attributes('non_empty_fields')
# remove underscore attribute
all_attrs = [tuple(vv for vv in v if not vv.startswith('_')) for v in all_attrs]
attr_counter = Counter(all_attrs)
all_attrs_names = set(v for k in all_attrs for v in k)
_nested_in = []
if 'chunks' in all_attrs_names:
_nested_in.append('chunks')
if 'matches' in all_attrs_names:
_nested_in.append('matches')
is_homo = len(attr_counter) == 1
_nested_items = []
if not is_homo:
for n_attributes, n_docs in attr_counter.most_common():
if n_docs == 1:
_doc_text = f'{n_docs} Document has'
else:
_doc_text = f'{n_docs} Documents have'
if len(n_attributes) == 1:
_text = f'{_doc_text} one attribute'
elif len(n_attributes) == 0:
_text = f'{_doc_text} no attribute'
else:
_text = f'{_doc_text} attributes'
_nested_items.append(
dict(name=_text, value=str(n_attributes), description='')
)
return is_homo, _nested_in, _nested_items, attr_counter, all_attrs_names
[docs]def check_root_id(da: 'DocumentArray', value: Union['Document', Sequence['Document']]):
from docarray import Document
from docarray.array.memory import DocumentArrayInMemory
if not (
isinstance(value, Document)
or (isinstance(value, Sequence) and isinstance(value[0], Document))
):
return
if isinstance(value, Document):
value = [value]
if isinstance(da, DocumentArrayInMemory):
if not all([getattr(doc, 'parent_id', None) for doc in value]):
warnings.warn(
"Not all documents have parent_id set. This may cause unexpected behavior.",
UserWarning,
)
elif da._config.root_id and not all(
[doc.tags.get('_root_id_', None) for doc in value]
):
warnings.warn(
"root_id is enabled but not all documents have _root_id_ set. This may cause unexpected behavior.",
UserWarning,
)
[docs]def login(interactive: Optional[bool] = None, force: bool = False, **kwargs):
"""Login to Jina AI Cloud account.
:param interactive: If set to true, login will support notebook environments, otherwise the enviroment will be inferred.
:param force: If set to true, overwrite token and re-login.
"""
import hubble
hubble.login(interactive=interactive, force=force)
[docs]def logout():
"""Log out of Hubble account."""
import hubble
hubble.logout()