From 9b849e4a4e8d40059744ccc353b70bc310405759 Mon Sep 17 00:00:00 2001 From: marklescroart Date: Tue, 28 Jan 2025 14:56:01 -0800 Subject: [PATCH] Updates to make_subplot_image_animation to add possible overlays, background color, and to enforce tight layout --- plot_utils/utils.py | 58 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/plot_utils/utils.py b/plot_utils/utils.py index f895f5f..f20da83 100644 --- a/plot_utils/utils.py +++ b/plot_utils/utils.py @@ -1254,14 +1254,20 @@ def animate(i): frames=images.shape[0], interval=1/fps * 1000, blit=True) return anim -def make_subplot_image_animation(image_stacks, data=None, n_rows=None, n_cols=None, figsize=(5,5), fps=30, extent=None, cmap=None, - yticks=None, xticks=None, ylabel=None, xlabel=None, **kwargs): +def make_subplot_image_animation(image_stacks, overlay=None, overlay_kwargs=None, n_rows=None, n_cols=None, figsize=(5,5), fps=30, extent=None, cmap=None, + yticks=None, xticks=None, ylabel=None, xlabel=None, + background=None, + **kwargs): """interval appears to be in ms Parameters ---------- - images : array - array of (time, vdim, hdim, color) + image_stacks : array-like + list or array of arrays, one for each subplot. Each array in the list is + (time, vdim, hdim, color) + overlay : array-like + list or array of arrays, for overlay data (e.g. image masks), same length + and array sizes (except color dimension) as image_stacks figsize : tuple, optional size of figure. Determines aspect ratio of movie. fps : int, optional @@ -1278,7 +1284,9 @@ def make_subplot_image_animation(image_stacks, data=None, n_rows=None, n_cols=No Y axis label xlabel : None, optional X axis label - + background : None, optional + If set, color of axis and figure background (both) + Returns ------- TYPE @@ -1286,8 +1294,9 @@ def make_subplot_image_animation(image_stacks, data=None, n_rows=None, n_cols=No """ if n_rows is None: n_rows, n_cols = find_squarish_dimensions(len(image_stacks)) - if data is None: - data = [None] * len(image_stacks) + # if data is None: + # data = [None] * len(image_stacks) + n_frames = np.max([len(x) for x in image_stacks]) # First set up the figure, the axis, and the plot element we want to animate fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize) @@ -1295,11 +1304,21 @@ def make_subplot_image_animation(image_stacks, data=None, n_rows=None, n_cols=No axs = np.array([axs]) # Show image & prettify im_h = [] + if overlay is not None: + im_o_h = [] + if overlay_kwargs is None: + overlay_kwargs = {} + for j, (ims, ax) in enumerate(zip(image_stacks, axs.flatten())): tmp = ims[0] - h = ax.imshow(tmp, extent=extent, cmap=cmap, **kwargs) + h = ax.imshow(tmp, extent=extent, cmap=cmap, **kwargs) im_h.append(h) imsz = tmp.shape + if overlay is not None: + otmp = overlays[0] + im_o = ax.imshow(otmp, extent=extent, **overlay_kwargs) + im_o_h.append(im_o) + osz = otmp.shape if yticks is not None: ax.set_yticks(yticks) if xticks is not None: @@ -1308,23 +1327,42 @@ def make_subplot_image_animation(image_stacks, data=None, n_rows=None, n_cols=No ax.set_ylabel(ylabel) if xlabel is not None: ax.set_xlabel(xlabel) + if background is not None: + ax.patch.set_color(background) if (xticks is None) and (yticks is None) and (xlabel is None) and (ylabel is None): # ax.set_position([0, 0, 1, 1]) ax.axis('off') + # Color + if background is not None: + fig.patch.set_color(background) + # Tight layout + plt.tight_layout() # Hide figure, we don't care plt.close(fig.number) # initialization function: plot the background of each frame def init(): for h in im_h: h.set_array(np.zeros(imsz)) - return im_h + if overlay is not None: + for oh in im_o_h: + oh.set_array(np.zeros(osz)) + return (im_h, im_o_h) + else: + return im_h # animation function. This is called sequentially def animate(i): for j, h in enumerate(im_h): if i >= len(image_stacks[j]): continue h.set_array(image_stacks[j][i]) - return im_h + if overlay is not None: + for j, ho in enumerate(im_o_h): + if i >= len(overlay[j]): + continue + ho.set_array(overlay[j][i]) + return (im_h, im_o_h) + else: + return im_h # call the animator. blit=True means only re-draw the parts that have changed. anim = animation.FuncAnimation(fig, animate, init_func=init, frames=n_frames, interval=1/fps * 1000, blit=True)