from collections.abc import Iterable
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams
from matplotlib.axes import Axes
from matplotlib.colorbar import Colorbar
from matplotlib.colors import LogNorm, SymLogNorm
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
from simpple.model import Model
[docs]
def plot_mosaic(
cube: np.ndarray | list[np.ndarray],
nrows: int,
ncols: int,
titles: list[str] | None = None,
ytitles: list[str] | None = None,
xtitles: list[str] | None = None,
colorbar: bool = True,
imshow_kwargs: dict | None = None,
) -> tuple[Figure, Axes]:
"""
Plot an image cube in a mosaic display.
:param cube: Image cube with shape ``(Nimg, Ny, Nx)``
:param nrows: Number of rows in the mosaic
:param cols: Number of rows in the mosaic
:param titles: List of titles to use
:param xtitles: List of titles for each column.
:param ytitles: List of titles for each row.
:param colorbar: Show a colorbar for each panel if ``True``.
:param imshow_kwargs: Keyword arguments passed to imshow for all panels.
:return: Figure and Axes used to create the plot.
"""
if nrows / ncols > 4.0:
height_factor = 1.5
elif nrows / ncols >= 1.5:
height_factor = nrows / ncols
else:
height_factor = 1.0
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(4 * ncols / height_factor, 4 * nrows),
dpi=100,
squeeze=False,
sharex=True,
sharey=True,
)
assert axes.size >= len(cube), "There are less axes in grid than datasets in cube"
for i, img in enumerate(cube):
ax = axes[np.unravel_index(i, axes.shape)]
# Symlog avoids having blank pixels when slightly < 0
default_kwargs = {"norm": "symlog"}
if imshow_kwargs is not None:
imshow_kwargs = default_kwargs | imshow_kwargs
else:
imshow_kwargs = default_kwargs
im = ax.imshow(img, **imshow_kwargs, origin="lower")
# Ticks clutter the figure and we don't really need them here
ax.set_xticks([])
ax.set_yticks([])
if titles is not None:
ax.set_title(titles[i])
if colorbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cb = fig.colorbar(im, cax=cax)
cb.formatter = lambda x, _: f"{x:.2f}"
if ytitles is not None:
assert len(ytitles) == nrows
for i, ax in enumerate(axes[:, 0]):
ax.set_ylabel(ytitles[i], size=rcParams["axes.titlesize"])
if xtitles is not None:
assert len(xtitles) == ncols
for i, ax in enumerate(axes[0]):
ax.set_title(xtitles[i])
for i in range(len(cube), axes.size):
ax = axes[np.unravel_index(i, axes.shape)]
fig.delaxes(ax)
return fig, axes
[docs]
def plot_image(
img: np.ndarray,
scale: str = "log",
ax: Optional[Axes] = None,
fig: Optional[Figure] = None,
color_label: Optional[str] = None,
return_colorbar: bool = False,
**kwargs,
) -> tuple[Figure, Axes] | tuple[Figure, Axes, Colorbar]:
"""Plot an image with a given scaling and axes
Simple wrapper to handle proper scaling, colorbar and axis names
:param img: Image array
:param scale: Scale to apply. If a "norm" kwarg is present, it will override this setting. Symlog uses a treshold of 1e-3 of the max.
:param ax: Axis. Current axis (`plt.gca()`) is used if None.
:param fig: Figure. Current figure (`plt.gcf()`) is used if None.
:return: Figure and axis
"""
if fig is None:
fig = plt.gcf()
if ax is None:
ax = plt.gca()
if "norm" in kwargs:
norm = kwargs.pop("norm")
elif scale == "log":
norm = LogNorm(vmin=kwargs.pop("vmin", None), vmax=kwargs.pop("vmax", None))
elif scale == "symlog":
norm = SymLogNorm(
linthresh=np.nanmax(img) * 1e-3,
vmin=kwargs.pop("vmin", None),
vmax=kwargs.pop("vmax", None),
)
else:
norm = None
im = ax.imshow(img, norm=norm, origin="lower", **kwargs)
ax.set_xlabel("X [pix]")
ax.set_ylabel("Y [pix]")
cb = fig.colorbar(im, ax=ax, label=color_label)
if return_colorbar:
return fig, ax, cb
return fig, ax
[docs]
def plot_with_diff(
img1: np.ndarray,
img2: np.ndarray,
scale: Union[str, list] = "log",
fig: Optional[Figure] = None,
axes: Optional[Axes] = None,
) -> tuple[Figure, Axes]:
"""Simple 3-panel plot with two images and their difference
If fig or axes is passed, the current figure will be
:param img1: First image
:param img2: Second image
:param scale: Scale to use. Log will automatically use symlog for the diff. Pass a list to override per-panel.
:param fig: Figure to use. 3-panel figure created if None.
:param axes: Axes to use. Indices 0-2 will be used. Make sure they exist.
:return: Tuple with figure and axes
"""
img_res = img1 - img2
if scale == "log":
scales = ["log", "log", "symlog"]
elif isinstance(scale, str) or not isinstance(scale, Iterable):
scales = [scale] * 3
else:
scales = scale
if fig is not None:
if axes is None:
axes = []
for i in range(3):
axes.append(fig.add_subplot(1, 3, i + 1))
elif axes is not None:
fig = axes[0].get_figure()
else:
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(16, 4))
plot_image(img1, ax=axes[0], scale=scales[0], fig=fig)
plot_image(img2, ax=axes[1], scale=scales[1], fig=fig)
plot_image(img_res, ax=axes[2], scale=scales[2], fig=fig)
return fig, axes
def plot_flat_samples(
img: np.ndarray,
err: np.ndarray,
samples: np.ndarray,
) -> tuple[Figure, Axes]:
fig, axs = plt.subplots(2, 1, gridspec_kw={"height_ratios": (3, 1)})
x_img = np.arange(img.size)
axd, axr = axs
res_samples = samples - img
axd.errorbar(
x_img,
img.ravel(),
yerr=err.ravel(),
fmt="k.",
capsize=2,
label="Image data",
mfc="w",
)
for i, ypred in enumerate(samples):
axd.plot(
x_img,
ypred.ravel(),
".",
color="C0",
alpha=0.1,
label="Posterior samples" if i == 0 else None,
)
axd.set_ylabel("Count rate [DN/s]")
axr.fill_between(
x_img,
-err.ravel(),
err.ravel(),
alpha=0.5,
color="k",
zorder=10000,
label="Uncertainty envelope",
)
axr.axhline(0.0, linestyle="--", color="k")
axr.set_ylabel("Samples - Data [DN/s]")
for res in res_samples:
axr.plot(x_img, res.ravel(), ".", color="C0", alpha=0.1)
axr.set_xlabel("Pixel")
fig.legend()
return fig, axs