Source code for spacec.plotting._segmentation

import random

import matplotlib.pyplot as plt
import numpy as np
import skimage.io
from deepcell.utils.plot_utils import create_rgb_image, make_outline_overlay
from skimage.measure import regionprops_table

from .._shared.segmentation import combine_channels, format_CODEX


# plot membrane channel selected for segmentation
[docs] def segmentation_ch( file_name, # image for segmentation channel_file, # all channels used for staining output_dir, savefig=False, # new output_fname="", # new extra_seg_ch_list=None, # channels used for membrane segmentation nuclei_channel="DAPI", input_format="Multichannel", # CODEX or Phenocycler --> This depends on the machine you are using and the resulting file format (see documentation above) ): """ Plot the channel selected for segmentation. Parameters ---------- file_name : str The path to the image file for segmentation. channel_file : str The path to the file containing all channels used for staining. output_dir : str The directory to save the output in. extra_seg_ch_list : list, optional The channels used for membrane segmentation, by default None. nuclei_channel : str, optional The channel used for nuclei, by default "DAPI". input_format : str, optional The input_format used (either "CODEX", "Multichannel" or channels), by default "Multichannel". Returns ------- None """ if input_format != "Channels": # Load the image img = skimage.io.imread(file_name) # Read channels and store as list with open(channel_file, "r") as f: channel_names = f.read().splitlines() # Function reads channels and stores them as a dictionary (storing as a dictionary allows to select specific channels by name) image_dict = format_CODEX( image=img, channel_names=channel_names, # file with list of channel names (see channelnames.txt) input_format=input_format, ) else: image_dict = format_CODEX( image=file_name, channel_names=None, # file with list of channel names (see channelnames.txt) input_format=input_format, ) image_dict = combine_channels( image_dict, extra_seg_ch_list, new_channel_name="segmentation_channel" ) fig, ax = plt.subplots(1, 2, figsize=(15, 15)) ax[0].imshow(image_dict[nuclei_channel]) ax[1].imshow(image_dict["segmentation_channel"]) ax[0].set_title("nuclei") ax[1].set_title("membrane") # save or plot figure if savefig: plt.savefig( output_dir + output_fname + ".pdf", format="pdf", dpi=300, transparent=True, bbox_inches="tight", ) else: plt.show()
[docs] def show_masks( seg_output, nucleus_channel, additional_channels=None, show_subsample=True, n=2, # need to be at least 2 tilesize=100, idx=0, rand_seed=1, ): """ Visualize the segmentation results of an image. Parameters ---------- seg_output : dict The output from the segmentation process. It should contain 'image_dict' and 'masks'. nucleus_channel : str The name of the nucleus channel in the image_dict. additional_channels : list of str, optional The names of additional channels to be combined with the nucleus channel for visualization. show_subsample : bool, optional Whether to show a subsample of the image. Default is True. n : int, optional The number of subsamples to show. Default is 2. tilesize : int, optional The size of the tiles for subsampling. Default is 100. idx : int, optional The index for displaying. Default is 0. rand_seed : int, optional The seed for the random number generator. Default is 1. Returns ------- overlay_data : ndarray The overlay of the segmentation results on the RGB images. rgb_images : ndarray The RGB images. Raises ------ ValueError If the image size is smaller than the tile size or if there are not enough tiles to display. """ image_dict = seg_output["image_dict"] masks = seg_output["masks"] # Create a combined image stack # Assumes nuclei_image and membrane_image are numpy arrays of the same shape if len(masks.shape) == 2: masks = np.expand_dims(masks, axis=0) masks = np.expand_dims(masks, axis=0) masks = np.moveaxis(masks, 0, -1) if additional_channels != None: image_dict = combine_channels( image_dict, additional_channels, new_channel_name="segmentation_channel" ) nuclei_image = image_dict[nucleus_channel] add_chan_image = image_dict["segmentation_channel"] combined_image = np.stack([nuclei_image, add_chan_image], axis=-1) # Add an extra dimension to make it compatible with Mesmer's input requirements # Changes shape from (height, width, channels) to (1, height, width, channels) combined_image = np.expand_dims(combined_image, axis=0) # create rgb overlay of image data for visualization rgb_images = create_rgb_image(combined_image, channel_colors=["green", "blue"]) else: nuclei_image = image_dict[nucleus_channel] combined_image = np.stack([nuclei_image], axis=-1) # Add an extra dimension to make it compatible with Mesmer's input requirements # Changes shape from (height, width, channels) to (1, height, width, channels) combined_image = np.expand_dims(combined_image, axis=0) # create rgb overlay of image data for visualization rgb_images = create_rgb_image(combined_image, channel_colors=["blue"]) # create overlay of segmentation results overlay_data = make_outline_overlay(rgb_data=rgb_images, predictions=masks) # select index for displaying # plot the data fig, ax = plt.subplots(1, 2, figsize=(15, 15)) ax[0].imshow(rgb_images[idx, ...]) ax[1].imshow(overlay_data[idx, ...]) ax[0].set_title("Raw data") ax[1].set_title("Predictions") plt.show() random.seed(rand_seed) if show_subsample: overlay_data = np.squeeze(overlay_data, axis=0) rgb_images = np.squeeze(rgb_images, axis=0) # Ensure the sizes are compatible for tile calculation if overlay_data.shape[0] < tilesize or overlay_data.shape[1] < tilesize: print("Image size is smaller than the tile size. Cannot display tiles.") else: # Calculate the number of tiles in x and y directions y_tiles, x_tiles = ( overlay_data.shape[0] // tilesize, overlay_data.shape[1] // tilesize, ) # Check if either x_tiles or y_tiles is zero if x_tiles == 0 or y_tiles == 0: print("Not enough tiles to display.") else: # Split images into tiles overlay_tiles = [] grayscale_tiles = [] for i in range(x_tiles): for j in range(y_tiles): x_start, y_start = i * tilesize, j * tilesize overlay_tile = overlay_data[ y_start : y_start + tilesize, x_start : x_start + tilesize ] image_tile = rgb_images[ y_start : y_start + tilesize, x_start : x_start + tilesize ] overlay_tiles.append(overlay_tile) grayscale_tiles.append(image_tile) # Randomly select n tiles random_indices = random.sample(range(len(overlay_tiles)), n) # Plot the tiles fig, axs = plt.subplots(n, 2, figsize=(10, 5 * n)) for i, idx in enumerate(random_indices): axs[i, 0].imshow(grayscale_tiles[idx]) axs[i, 0].axis("off") axs[i, 1].imshow(overlay_tiles[idx]) axs[i, 1].axis("off") axs[i, 0].set_title("Raw data") axs[i, 1].set_title("Predictions") plt.tight_layout() plt.show() return overlay_data, rgb_images