diff --git a/doc/python/imshow.md b/doc/python/imshow.md index bec3b83de9..18e14ba95c 100644 --- a/doc/python/imshow.md +++ b/doc/python/imshow.md @@ -74,6 +74,22 @@ fig = px.imshow(img, binary_format="jpeg", binary_compression_level=0) fig.show() ``` +Image data is encoded as a lossless base64-encoded WebP string by default. +The example below uses a lossy WebP instead by passing a setting to the underlying image library backend. + +```python +import plotly.express as px +from skimage import data +img = data.astronaut() +fig = px.imshow( + img, + # Pillow backend parameters are documented here: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp + # Available parameters depend on the `binary_format` + binary_backend_kwargs={"lossless": False} +) +fig.show() +``` + ### Display single-channel 2D data as a heatmap For a 2D image, `px.imshow` uses a colorscale to map scalar data to colors. The default colorscale is the one of the active template (see [the tutorial on templates](/python/templates/)). diff --git a/packages/python/plotly/_plotly_utils/data_utils.py b/packages/python/plotly/_plotly_utils/data_utils.py index 5fb05b0311..f24ace5aae 100644 --- a/packages/python/plotly/_plotly_utils/data_utils.py +++ b/packages/python/plotly/_plotly_utils/data_utils.py @@ -10,7 +10,9 @@ pil_imported = False -def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"): +def image_array_to_data_uri( + img, backend="pil", compression=4, ext="webp", backend_kwargs=None +): """Converts a numpy array of uint8 into a base64 png or jpg string. Parameters @@ -22,8 +24,10 @@ def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"): otherwise pypng. compression: int, between 0 and 9 compression level to be passed to the backend - ext: str, 'png' or 'jpg' + ext: str, 'webp', 'png', or 'jpg' compression format used to generate b64 string + backend_kwargs : dict or None + keyword arguments to be passed to the backend """ # PIL and pypng error messages are quite obscure so we catch invalid compression values if compression < 0 or compression > 9: @@ -41,15 +45,26 @@ def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"): if backend == "auto": backend = "pil" if pil_imported else "pypng" if ext != "png" and backend != "pil": - raise ValueError("jpg binary strings are only available with PIL backend") + raise ValueError( + "webp and jpg binary strings are only available with PIL backend" + ) + + if backend_kwargs is None: + backend_kwargs = {} if backend == "pypng": + backend_kwargs.setdefault("compression", compression) + ndim = img.ndim sh = img.shape if ndim == 3: img = img.reshape((sh[0], sh[1] * sh[2])) w = Writer( - sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression + sh[1], + sh[0], + greyscale=(ndim == 2), + alpha=alpha, + **backend_kwargs, ) img_png = from_array(img, mode=mode) prefix = "data:image/png;base64," @@ -57,19 +72,32 @@ def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"): w.write(stream, img_png.rows) base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8") else: # pil + if ext == "png": + backend_kwargs.setdefault("compress_level", compression) + + if ext == "webp": + backend_kwargs.setdefault("lossless", True) + if not pil_imported: raise ImportError( "pillow needs to be installed to use `backend='pil'. Please" "install pillow or use `backend='pypng'." ) pil_img = Image.fromarray(img) - if ext == "jpg" or ext == "jpeg": + if ext == "webp": + prefix = "data:image/webp;base64," + ext = "webp" + elif ext == "jpg" or ext == "jpeg": prefix = "data:image/jpeg;base64," ext = "jpeg" else: prefix = "data:image/png;base64," ext = "png" with BytesIO() as stream: - pil_img.save(stream, format=ext, compress_level=compression) + pil_img.save( + stream, + format=ext, + **backend_kwargs, + ) base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8") return base64_string diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index de0e22284b..00d8ee1e29 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -78,7 +78,8 @@ def imshow( binary_string=None, binary_backend="auto", binary_compression_level=4, - binary_format="png", + binary_format="webp", + binary_backend_kwargs=None, text_auto=False, ) -> go.Figure: """ @@ -204,10 +205,15 @@ def imshow( test `len(fig.data[0].source)` and to time the execution of `imshow` to tune the level of compression. 0 means no compression (not recommended). - binary_format: str, 'png' (default) or 'jpg' - compression format used to generate b64 string. 'png' is recommended - since it uses lossless compression, but 'jpg' (lossy) compression can - result if smaller binary strings for natural images. + binary_format: str, 'webp' (default), 'png', or 'jpg' + compression format used to generate b64 string. 'webp' is recommended + since it supports both lossless and lossy compression with better quality + then 'png' or 'jpg' of similar sizes, but 'jpg' or 'png' can be used for + environments that do not support 'webp'. + + binary_backend_kwargs : dict or None + keyword arguments for the image backend. For Pillow, these are passed to `Image.save`. + For 'pypng', these are passed to `Writer.__init__` text_auto: bool or str (default `False`) If `True` or a string, single-channel `img` values will be displayed as text. @@ -502,6 +508,7 @@ def imshow( backend=binary_backend, compression=binary_compression_level, ext=binary_format, + backend_kwargs=binary_backend_kwargs, ) for index_tup in itertools.product(*iterables) ] diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py index c2e863c846..952f112878 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py @@ -16,7 +16,9 @@ def decode_image_string(image_string): """ Converts image string to numpy array. """ - if "png" in image_string[:22]: + if "webp" in image_string[:23]: + return np.asarray(Image.open(BytesIO(base64.b64decode(image_string[23:])))) + elif "png" in image_string[:22]: return np.asarray(Image.open(BytesIO(base64.b64decode(image_string[22:])))) elif "jpeg" in image_string[:23]: return np.asarray(Image.open(BytesIO(base64.b64decode(image_string[23:])))) @@ -62,7 +64,7 @@ def test_automatic_zmax_from_dtype(): @pytest.mark.parametrize("binary_string", [False, True]) -@pytest.mark.parametrize("binary_format", ["png", "jpg"]) +@pytest.mark.parametrize("binary_format", ["webp", "png", "jpg"]) def test_origin(binary_string, binary_format): for i, img in enumerate([img_rgb, img_gray]): fig = px.imshow( @@ -76,7 +78,9 @@ def test_origin(binary_string, binary_format): # The equality below does not hold for jpeg compression since it's lossy assert np.all(img[::-1] == decode_image_string(fig.data[0].source)) if binary_string: - if binary_format == "jpg": + if binary_format == "webp": + assert fig.data[0].source[:15] == "data:image/webp" + elif binary_format == "jpg": assert fig.data[0].source[:15] == "data:image/jpeg" else: assert fig.data[0].source[:14] == "data:image/png" @@ -324,36 +328,41 @@ def test_imshow_dataframe(): def test_imshow_source_dtype_zmax(dtype, contrast_rescaling): img = np.arange(100, dtype=dtype).reshape((10, 10)) fig = px.imshow(img, binary_string=True, contrast_rescaling=contrast_rescaling) + + decoded = decode_image_string(fig.data[0].source)[:, :, 0] if contrast_rescaling == "minmax": assert ( np.max( np.abs( rescale_intensity(img, in_range="image", out_range=np.uint8) - - decode_image_string(fig.data[0].source) + - decoded ) ) < 1 ) else: if dtype in [np.uint8, np.float32, np.float64]: - assert np.all(img == decode_image_string(fig.data[0].source)) + assert np.all(img == decoded) else: - assert ( - np.abs( - np.max(decode_image_string(fig.data[0].source)) - - 255 * img.max() / np.iinfo(dtype).max - ) - < 1 - ) + assert np.abs(np.max(decoded) - 255 * img.max() / np.iinfo(dtype).max) < 1 @pytest.mark.parametrize("backend", ["auto", "pypng", "pil"]) def test_imshow_backend(backend): - fig = px.imshow(img_rgb, binary_backend=backend) + fig = px.imshow(img_rgb, binary_backend=backend, binary_format="png") decoded_img = decode_image_string(fig.data[0].source) assert np.all(decoded_img == img_rgb) +@pytest.mark.parametrize("lossless", [True, False]) +def test_imshow_backend_kwargs(lossless): + fig = px.imshow(img_rgb, binary_backend_kwargs={"lossless": lossless}) + decoded_img = decode_image_string(fig.data[0].source) + + if lossless: + assert np.all(decoded_img == img_rgb) + + @pytest.mark.parametrize("level", [0, 3, 6, 9]) def test_imshow_compression(level): _, grid_img = np.mgrid[0:10, 0:100] @@ -361,6 +370,7 @@ def test_imshow_compression(level): fig = px.imshow( grid_img, binary_string=True, + binary_format="png", binary_compression_level=level, contrast_rescaling="infer", )