diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index c99b47d2..c8d68089 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -2,8 +2,9 @@ import contextlib import sys +import warnings from collections import OrderedDict -from collections.abc import Callable +from collections.abc import Callable, Sequence from copy import deepcopy from pathlib import Path from typing import Any, Literal, cast @@ -938,6 +939,15 @@ def show( show, ) + if fig is not None and not isinstance(ax, Sequence): + warnings.warn( + "`fig` is being deprecated as an argument to `PlotAccessor.show` in spatialdata-plot. " + "To use a custom figure, create axes from it and pass them via `ax` instead: " + "`ax = fig.add_subplot(111)`.", + DeprecationWarning, + stacklevel=2, + ) + sdata = self._copy() # Evaluate execution tree for plotting @@ -1242,6 +1252,8 @@ def _draw_colorbar( raise IndexError("The number of titles must match the number of coordinate systems.") from e ax.set_title(t) ax.set_aspect("equal") + if fig_params.frameon is False: + ax.axis("off") extent = get_extent( sdata, diff --git a/tests/_images/Show_frameon_false_multi_panel.png b/tests/_images/Show_frameon_false_multi_panel.png new file mode 100644 index 00000000..6e05a648 Binary files /dev/null and b/tests/_images/Show_frameon_false_multi_panel.png differ diff --git a/tests/_images/Show_frameon_false_single_panel.png b/tests/_images/Show_frameon_false_single_panel.png new file mode 100644 index 00000000..746daab2 Binary files /dev/null and b/tests/_images/Show_frameon_false_single_panel.png differ diff --git a/tests/_images/Show_no_decorations.png b/tests/_images/Show_no_decorations.png new file mode 100644 index 00000000..826e8ce1 Binary files /dev/null and b/tests/_images/Show_no_decorations.png differ diff --git a/tests/pl/test_show.py b/tests/pl/test_show.py index a4033d2b..fec5b49a 100644 --- a/tests/pl/test_show.py +++ b/tests/pl/test_show.py @@ -1,9 +1,13 @@ +import warnings from unittest.mock import patch import matplotlib import matplotlib.pyplot as plt +import pytest import scanpy as sc +from matplotlib.figure import Figure from spatialdata import SpatialData +from spatialdata.transformations import Identity, set_transformation import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta @@ -25,6 +29,19 @@ class TestShow(PlotTester, metaclass=PlotTesterMeta): def test_plot_pad_extent_adds_padding(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_image").pl.show(pad_extent=100) + def test_plot_frameon_false_single_panel(self, sdata_blobs: SpatialData): + """Visual test: frameon=False hides axes decorations on a single panel (regression for #204).""" + sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=False) + + def test_plot_frameon_false_multi_panel(self, sdata_blobs: SpatialData): + """Visual test: frameon=False hides axes decorations on all panels (regression for #204).""" + set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs") + sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=False, title="") + + def test_plot_no_decorations(self, sdata_blobs: SpatialData): + """Visual test: frameon=False + title='' produces just the plot content (regression for #204).""" + sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=False, title="", colorbar=False) + def test_no_plt_show_when_ax_provided(self, sdata_blobs: SpatialData): """plt.show() must not be called when the user supplies ax= (regression for #362).""" _, ax = plt.subplots() @@ -40,3 +57,56 @@ def test_plt_show_when_ax_provided_and_show_true(self, sdata_blobs: SpatialData) sdata_blobs.pl.render_images(element="blobs_image").pl.show(ax=ax, show=True) mock_show.assert_called_once() plt.close("all") + + def test_frameon_false_hides_axes_decorations(self, sdata_blobs: SpatialData): + """frameon=False should turn off axes decorations (regression for #204).""" + ax = sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=False, return_ax=True, show=False) + assert not ax.axison + plt.close("all") + + def test_frameon_none_keeps_axes_decorations(self, sdata_blobs: SpatialData): + """Default frameon=None should keep axes decorations visible.""" + ax = sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=None, return_ax=True, show=False) + assert ax.axison + plt.close("all") + + def test_title_empty_string_suppresses_title(self, sdata_blobs: SpatialData): + """title='' should suppress the default coordinate system title (regression for #204).""" + ax = sdata_blobs.pl.render_images(element="blobs_image").pl.show(title="", return_ax=True, show=False) + assert ax.get_title() == "" + plt.close("all") + + +def test_fig_parameter_emits_deprecation_warning(sdata_blobs: SpatialData): + """Passing fig= should emit a DeprecationWarning (regression for #204).""" + fig = Figure() + with pytest.warns(DeprecationWarning, match="`fig` is being deprecated"): + sdata_blobs.pl.render_images(element="blobs_image").pl.show(fig=fig, show=False) + plt.close("all") + + +def test_fig_parameter_default_no_warning(sdata_blobs: SpatialData): + """Not passing fig= should not emit a deprecation warning.""" + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + sdata_blobs.pl.render_images(element="blobs_image").pl.show(show=False) + plt.close("all") + + +def test_fig_parameter_no_warning_with_ax_list(sdata_blobs: SpatialData): + """Passing fig= with a list of axes should not warn (fig is still required there).""" + set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs") + fig, axs = plt.subplots(1, 2) + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + sdata_blobs.pl.render_images(element="blobs_image").pl.show(fig=fig, ax=list(axs), show=False) + plt.close("all") + + +def test_frameon_false_multi_panel(sdata_blobs: SpatialData): + """frameon=False should apply to all panels in a multi-panel plot (regression for #204).""" + set_transformation(sdata_blobs["blobs_image"], Identity(), "second_cs") + axs = sdata_blobs.pl.render_images(element="blobs_image").pl.show(frameon=False, return_ax=True, show=False) + for ax in axs: + assert not ax.axison + plt.close("all")