Source code for docarray.document.mixins.plot

import copy
from typing import Optional

import numpy as np

from docarray.document.mixins.mesh import MeshEnum, PointCloudEnum


[docs]class PlotMixin: """Provide helper functions for :class:`Document` to plot and visualize itself.""" def _ipython_display_(self): """Displays the object in IPython as a side effect""" self.summary() def __rich_console__(self, console, options): yield f":page_facing_up: [b]Document[/b]: [cyan]{self.id}[cyan]" from rich.table import Table from rich import box from rich.text import Text my_table = Table( 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True ) for f in self.non_empty_fields: if f.startswith('_'): continue elif f in ('text', 'blob', 'uri') and len(getattr(self, f)) > 100: v = getattr(self, f) my_table.add_row(f, Text(str(v)[:100] + f'... [dim](length: {len(v)})')) elif f in ('embedding', 'tensor'): from docarray.math.ndarray import to_numpy_array v = to_numpy_array(getattr(self, f)) if v.squeeze().ndim == 1 and len(v) < 1000: from docarray.document.mixins.rich_embedding import ( ColorBoxEmbedding, ) v = ColorBoxEmbedding(v.squeeze()) else: v = f'{type(getattr(self, f))} in shape {v.shape}, dtype: {v.dtype}' my_table.add_row(f, v) elif f not in ('id', 'chunks', 'matches'): my_table.add_row(f, Text(str(getattr(self, f)))) if my_table.rows: yield my_table
[docs] def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" from rich import print print(self._plot_recursion())
def _plot_recursion(self, tree=None): if tree is None: from rich.tree import Tree tree = Tree(self) else: tree = tree.add(self) for a in ('matches', 'chunks'): if getattr(self, a): if a == 'chunks': _icon = ':diamond_with_a_dot:' else: _icon = ':large_orange_diamond:' _match_tree = tree.add(f'{_icon} [b]{a.capitalize()}[/b]') for d in getattr(self, a): d._plot_recursion(_match_tree) return tree
[docs] def display(self, from_: Optional[str] = None): """ Plot image data from :attr:`.uri` or from :attr:`.tensor` if :attr:`.uri` is empty . :param from_: an optional string to decide if a document should display using either the uri or the tensor field. """ if self._is_3d_point_cloud(): self.display_point_cloud_tensor() elif self._is_3d_rgbd(): self.display_rgbd_tensor() elif self._is_3d_vertices_and_faces(): self.display_vertices_and_faces() else: if not from_: if self.uri: from_ = 'uri' elif self.tensor is not None: from_ = 'tensor' else: self.summary() if from_ == 'uri': self.display_uri() elif from_ == 'tensor': self.display_tensor() else: self.summary()
def _is_3d_point_cloud(self): """ Tells if Document stores a 3D object saved as point cloud tensor. :return: bool. """ if ( self.tensor is not None and self.tensor.ndim == 2 and self.tensor.shape[-1] == 3 ): return True else: return False def _is_3d_rgbd(self): """ Tells if Document stores a 3D object saved as RGB-D image tensor. :return: bool. """ if ( self.tensor is not None and self.tensor.ndim == 3 and self.tensor.shape[-1] == 4 ): return True else: return False def _is_3d_vertices_and_faces(self): """ Tells if Document stores a 3D object saved as vertices and faces. :return: bool. """ if self.chunks is not None: name_tags = [c.tags['name'] for c in self.chunks] if ( MeshEnum.VERTICES.value in name_tags and MeshEnum.FACES.value in name_tags ): return True else: return False
[docs] def display_tensor(self) -> None: """Plot image data from :attr:`.tensor`""" if self.tensor is None: raise ValueError( 'Impossible to display with tensor when the tensor is None' ) from IPython.display import display try: import PIL.Image p = PIL.Image.fromarray(self.tensor) if p.mode != 'RGB': raise display(p) except: import matplotlib.pyplot as plt plt.matshow(self.tensor)
[docs] def display_vertices_and_faces(self): """Plot mesh consisting of vertices and faces.""" from IPython.display import display if self.uri: # mesh from uri mesh = self._load_mesh() display(mesh.show()) else: # mesh from chunks import trimesh vertices = [ c.tensor for c in self.chunks if c.tags['name'] == MeshEnum.VERTICES.value ][-1] faces = [ c.tensor for c in self.chunks if c.tags['name'] == MeshEnum.FACES.value ][-1] mesh = trimesh.Trimesh(vertices=vertices, faces=faces) display(mesh.show())
[docs] def display_point_cloud_tensor(self) -> None: """Plot interactive point cloud from :attr:`.tensor`""" import trimesh from IPython.display import display from hubble.utils.notebook import is_notebook colors = np.tile(np.array([0, 0, 0]), (len(self.tensor), 1)) for chunk in self.chunks: if ( 'name' in chunk.tags.keys() and chunk.tags['name'] == PointCloudEnum.COLORS.value and chunk.tensor.shape[-1] in [3, 4] ): colors = chunk.tensor pc = trimesh.points.PointCloud( vertices=self.tensor, colors=colors, ) if is_notebook(): s = trimesh.Scene(geometry=pc) display(s.show()) else: display(pc.show())
[docs] def display_rgbd_tensor(self) -> None: """Plot an RGB-D image and a corresponding depth image from :attr:`.tensor`""" import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable rgb_img = self.tensor[:, :, :3] depth_img = self.tensor[:, :, -1] depth_img = depth_img / (np.max(depth_img) + 1e-08) * 255 depth_img = depth_img.astype(np.uint8) f, ax = plt.subplots(1, 2, figsize=(16, 6)) ax[0].imshow(rgb_img, interpolation='None') ax[0].set_title('RGB image\n', fontsize=16) im2 = ax[1].imshow(self.tensor[:, :, -1], cmap='gray') cax = make_axes_locatable(ax[1]).append_axes('right', size='5%', pad=0.05) f.colorbar(im2, cax=cax, orientation='vertical', label='Depth') ax[1].imshow(depth_img, cmap='gray') ax[1].set_title('Depth image\n', fontsize=16) plt.show()
[docs] def display_uri(self): """Plot image data from :attr:`.uri`""" if not self.uri: raise ValueError('Impossible to display with uri when the uri is None') from IPython.display import Image, display if self.mime_type.startswith('audio') or self.uri.startswith('data:audio/'): uri = _convert_display_uri(self.uri, self.mime_type) _html5_audio_player(uri) elif self.mime_type.startswith('video') or self.uri.startswith('data:video/'): uri = _convert_display_uri(self.uri, self.mime_type) _html5_video_player(uri) elif self.uri.startswith('data:image/'): _html5_image(self.uri) else: display(Image(self.uri))
[docs] def plot_matches_sprites( self, top_k: int = 10, channel_axis: int = -1, inv_normalize: bool = False, skip_empty: bool = False, canvas_size: int = 1920, min_size: int = 100, output: Optional[str] = None, ): """Generate a sprite image for the query and its matching images in this Document object. An image sprite is a collection of images put into a single image. Query image is on the left followed by matching images. The Document object should contain matches. :param top_k: the number of top matching documents to show in the sprite. :param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis :param inv_normalize: If set to True, inverse the normalization of a float32 image :attr:`.tensor` into a uint8 image :attr:`.tensor` inplace. :param skip_empty: skip matches which has no .uri or .tensor. :param canvas_size: the width of the canvas :param min_size: the minimum size of the image :param output: Optional path to store the visualization. If not given, show in UI """ if not self or not self.matches: raise ValueError(f'{self!r} is empty or has no matches') if not self.uri and self.tensor is None: raise ValueError( f'Document has neither `uri` nor `tensor`, cannot be plotted' ) if top_k <= 0: raise ValueError(f'`limit` must be larger than 0, receiving `{top_k}`') import matplotlib.pyplot as plt img_per_row = top_k + 2 if top_k > len(self.matches): img_per_row = len(self.matches) + 2 img_size = int((canvas_size - 50) / img_per_row) if img_size < min_size: # image is too small, recompute the image size and canvas size img_size = min_size canvas_size = img_per_row * img_size + 50 try: _d = copy.deepcopy(self) if _d.content_type != 'tensor': _d.load_uri_to_image_tensor() # the channel axis is -1 if inv_normalize: # inverse normalise to uint8 and set the channel axis to -1 _d.set_image_tensor_inv_normalization(channel_axis) _d.set_image_tensor_channel_axis(channel_axis, -1) # Maintain the aspect ratio keeping the width fixed h, w, _ = _d.tensor.shape img_h, img_w = int(h * (img_size / float(w))), img_size sprite_img = np.ones([img_h + 20, canvas_size, 3], dtype='uint8') _d.set_image_tensor_shape(shape=(img_h, img_w)) sprite_img[10 : img_h + 10, 10 : 10 + img_w] = _d.tensor pos = canvas_size // img_per_row for col_id, d in enumerate(self.matches, start=2): if not d.uri and d.tensor is None: if skip_empty: continue else: raise ValueError( f'Document match has neither `uri` nor `tensor`, cannot be plotted' ) _d = copy.deepcopy(d) if _d.content_type != 'tensor': _d.load_uri_to_image_tensor() if inv_normalize: _d.set_image_tensor_inv_normalization(channel_axis=channel_axis) _d.set_image_tensor_channel_axis( channel_axis, -1 ).set_image_tensor_shape(shape=(img_h, img_w)) # paste it on the main canvas sprite_img[ 10 : img_h + 10, (col_id * pos) : ((col_id * pos) + img_w), ] = _d.tensor col_id += 1 if col_id >= img_per_row: break except Exception as ex: raise ValueError('Bad image tensor. Try different `channel_axis`') from ex from PIL import Image im = Image.fromarray(sprite_img) if output: with open(output, 'wb') as fp: im.save(fp) else: plt.figure(figsize=(img_per_row, 2)) plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im, interpolation="none") plt.show()
def _convert_display_uri(uri, mime_type): import urllib from docarray.document.mixins.helper import _to_datauri, _uri_to_blob scheme = urllib.parse.urlparse(uri).scheme if scheme not in ['data', 'http', 'https']: blob = _uri_to_blob(uri) return _to_datauri(mime_type, blob) return uri def _html5_image(uri): from IPython.display import display from IPython.core.display import HTML # noqa src = f''' <body> <image src="{uri}" height="200px"> </body> ''' display(HTML(src)) # noqa def _html5_video_player(uri): from IPython.display import display from IPython.core.display import HTML # noqa src = f''' <body> <video width="320" height="240" autoplay muted controls> <source src="{uri}"> Your browser does not support the video tag. </video> </body> ''' display(HTML(src)) # noqa def _html5_audio_player(uri): from IPython.display import display from IPython.core.display import HTML # noqa src = f''' <body> <audio controls="controls" style="width:320px" > <source src="{uri}"/> Your browser does not support the audio element. </audio> </body> ''' display(HTML(src)) # noqa