Source code for spacec.tools._segmentation

import gc
import os
import pathlib
import shutil
import sys
import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import skimage
import skimage.io
import tensorflow as tf
import tifffile
from cellpose import models  # Renamed import
from cellpose import io
from deepcell.applications import Mesmer
from deepcell.utils.plot_utils import create_rgb_image, make_outline_overlay
from IPython.display import clear_output
from scipy.ndimage import gaussian_filter, label
from skimage.measure import regionprops_table
from skimage.segmentation import relabel_sequential
from tensorflow.keras.models import load_model
from tqdm import tqdm
from tqdm.notebook import (
    tqdm as notebook_tqdm,  # Use notebook version for better display
)

from .._shared.segmentation import (
    combine_channels,
    create_multichannel_tiff,
    format_CODEX,
)


def load_image_dictionary(file_name, channel_file, input_format, nuclei_channel):
    """
    Loads images and channel names based on the specified format using tifffile.

    Memory Consideration: Loading large 'Multichannel' or 'CODEX' files directly
    with tifffile.imread can consume significant memory. For very large files,
    consider using memory mapping (e.g., tifffile.imread(..., aszarr=True)) if needed
    (requires downstream code modification to handle Zarr arrays). 'Channels' format reads
    individual files, which can also be memory-intensive if there are many large files.

    Parameters:
        file_name (str or Path): Path to image file (Multichannel/CODEX) or directory (Channels).
        channel_file (str or Path): Path to file containing channel names (one per line). Used for Multichannel/CODEX.
        input_format (str): 'Multichannel', 'Channels', or 'CODEX'.
        nuclei_channel (str): Name of the nuclei channel (must exist in the loaded channels).

    Returns:
        tuple: (img_ref, image_dict, channel_names_list)
               img_ref (ndarray): Reference image (typically nuclei channel or the raw loaded stack).
               image_dict (dict): Dictionary mapping channel names to 2D NumPy arrays.
               channel_names_list (list): List of loaded channel names.
               Returns (None, None, None) on error.
    """
    print(f"--- Loading Image Data (Format: {input_format}) ---")
    img_ref = None
    image_dict = None
    channel_names_list = None

    try:
        if input_format == "Channels":
            # file_name is the directory path
            # format_CODEX now uses tifffile internally for 'Channels'
            image_dict, channel_names_list = format_CODEX(
                image=file_name,
                input_format=input_format,
            )
            if image_dict is None:
                raise ValueError("Image formatting failed.")  # Check format_CODEX error

            if nuclei_channel not in image_dict:
                raise ValueError(
                    f"Specified nuclei_channel '{nuclei_channel}' not found in loaded channels: {list(image_dict.keys())}"
                )
            img_ref = image_dict[nuclei_channel]  # Use nuclei as reference

        elif input_format in ["Multichannel", "CODEX"]:
            # file_name is the image file path
            if not os.path.exists(file_name):
                raise FileNotFoundError(f"Image file not found: {file_name}")
            if not os.path.exists(channel_file):
                raise FileNotFoundError(f"Channel file not found: {channel_file}")

            print(f"Loading image: {file_name}")
            # Memory intensive step for large files - use tifffile
            # Consider tifffile.imread(file_name, aszarr=True) for memory mapping if needed later
            loaded_img = skimage.io.imread(file_name)
            img_ref = loaded_img  # Store raw loaded image as reference
            print(f"Loaded image shape: {loaded_img.shape}")

            with open(channel_file, "r") as f:
                channel_names_from_file = f.read().splitlines()
            print(
                f"Loaded {len(channel_names_from_file)} channel names from: {channel_file}"
            )

            # Determine CODEX parameters if needed (example, needs actual logic)
            number_cycles = None
            images_per_cycle = None
            if input_format == "CODEX":
                # Placeholder: Infer or require these parameters
                if loaded_img.ndim != 4:
                    raise ValueError(
                        f"CODEX input image must be 4D. Got shape: {loaded_img.shape}"
                    )
                # Assuming CODEX format is (cycles, H, W, images_per_cycle)
                # Adjust if your CODEX format is different (e.g., (cycles, images_per_cycle, H, W))
                number_cycles = loaded_img.shape[0]
                images_per_cycle = loaded_img.shape[3]
                print(
                    f"Inferred CODEX params: Cycles={number_cycles}, Channels/Cycle={images_per_cycle}"
                )

            image_dict, channel_names_list = format_CODEX(
                image=loaded_img,
                channel_names=channel_names_from_file,
                number_cycles=number_cycles,
                images_per_cycle=images_per_cycle,
                input_format=input_format,
            )
            if image_dict is None:
                raise ValueError("Image formatting failed.")  # Check format_CODEX error

            if nuclei_channel not in image_dict:
                raise ValueError(
                    f"Specified nuclei_channel '{nuclei_channel}' not found in loaded channels: {list(image_dict.keys())}"
                )

        else:
            raise ValueError(f"Invalid input_format: {input_format}")

        print(f"Image dictionary created with {len(image_dict)} channels.")
        return img_ref, image_dict, channel_names_list

    except (FileNotFoundError, ValueError, MemoryError) as e:
        print(f"Error loading image dictionary: {e}")
        return None, None, None
    except Exception as e:
        print(f"An unexpected error occurred during image loading: {e}")
        return None, None, None


# --------------------------------------------------------------------------
# Preprocessing Utilities
# --------------------------------------------------------------------------


def setup_gpu(use_gpu=True, set_memory_growth=True):
    """Configures TensorFlow GPU memory growth to avoid allocating all memory at once."""
    if use_gpu:
        gpus = tf.config.list_physical_devices("GPU")
        if gpus:
            try:
                for gpu in gpus:
                    if set_memory_growth == True:
                        tf.config.experimental.set_memory_growth(
                            gpu, True
                        )  # As your model requires more memory during execution, TensorFlow will gradually increase the allocation. Rather than just allocating all memory!
                        print(f"GPU(s) available: {len(gpus)}. Memory growth enabled.")
                    else:
                        print(f"GPU(s) available: {len(gpus)}. Memory growth not set.")
            except RuntimeError as e:
                print(
                    f"Warning: Could not set memory growth (may already be initialized): {e}"
                )
        else:
            print("No GPU detected by TensorFlow.")
            use_gpu = False  # Ensure CPU is used if no GPU found
    else:
        print("GPU usage explicitly disabled.")
        # Optionally force CPU only
        # tf.config.set_visible_devices([], 'GPU')
    return use_gpu  # Return status in case it changed


def prepare_segmentation_dict(image_dict, nuclei_channel, membrane_channel_list):
    """
    Prepares a dictionary containing only the channels needed for segmentation.
    Combines membrane channels if provided.

    Parameters:
        image_dict (dict): Full dictionary of images.
        nuclei_channel (str): Name of the nuclei channel.
        membrane_channel_list (list or None): List of membrane channel names to combine.

    Returns:
        tuple: (segmentation_dict, combined_membrane_channel_name)
               segmentation_dict (dict): Dictionary with 'nuclei_channel' and optionally 'segmentation_channel'.
               combined_membrane_channel_name (str or None): Name of the combined channel, or None.
    """
    if nuclei_channel not in image_dict:
        raise ValueError(
            f"Nuclei channel '{nuclei_channel}' not found in image dictionary."
        )

    segmentation_dict = {}
    # Make a copy to avoid modifying the original image_dict values if resizing/etc happens later
    segmentation_dict[nuclei_channel] = image_dict[nuclei_channel].copy()
    combined_membrane_channel_name = None

    if membrane_channel_list:
        combined_membrane_channel_name = "segmentation_channel"
        # Combine channels (creates a new array in the dict)
        # Pass a copy of image_dict to combine_channels to avoid modifying the original one
        segmentation_dict_temp = combine_channels(
            image_dict.copy(), membrane_channel_list, combined_membrane_channel_name
        )
        if combined_membrane_channel_name in segmentation_dict_temp:
            # Add the newly created combined channel to our segmentation_dict
            segmentation_dict[combined_membrane_channel_name] = segmentation_dict_temp[
                combined_membrane_channel_name
            ]
        else:
            print(
                f"Warning: Failed to create '{combined_membrane_channel_name}'. Proceeding without it."
            )
            combined_membrane_channel_name = None  # Reset if creation failed
        del segmentation_dict_temp  # Clean up temporary dict

    print(
        f"Segmentation dictionary prepared with channels: {list(segmentation_dict.keys())}"
    )
    return segmentation_dict, combined_membrane_channel_name


def resize_segmentation_images(seg_dict, resize_factor):
    """
    Resizes images within the segmentation dictionary using area interpolation.

    Memory Consideration: Creates new resized arrays, temporarily increasing memory usage.

    Parameters:
        seg_dict (dict): Dictionary containing images to resize.
        resize_factor (float): Factor by which to resize (e.g., 0.5 for half size).

    Returns:
        dict: Dictionary with resized images. Returns original dict if resize_factor is 1.
    """
    if resize_factor == 1:
        print("Resize factor is 1, skipping image resizing.")
        return seg_dict
    if resize_factor <= 0:
        raise ValueError("Resize factor must be positive.")

    print(f"Resizing segmentation images by factor: {resize_factor}")
    resized_seg_dict = {}
    for ch, im in seg_dict.items():
        if im is None:
            continue
        original_height, original_width = im.shape[:2]
        new_height = int(original_height * resize_factor)
        new_width = int(original_width * resize_factor)

        if new_height == 0 or new_width == 0:
            print(
                f"Warning: Resize factor {resize_factor} results in zero dimension for channel '{ch}'. Skipping resize."
            )
            resized_seg_dict[ch] = im.copy()  # Keep original
            continue

        # Use INTER_AREA for downsampling, INTER_LINEAR for upsampling (general purpose)
        interpolation = cv2.INTER_AREA if resize_factor < 1 else cv2.INTER_LINEAR
        resized_im = cv2.resize(
            im, (new_width, new_height), interpolation=interpolation
        )
        resized_seg_dict[ch] = resized_im
        print(f"  Resized channel '{ch}' from {im.shape} to {resized_im.shape}")
        # del im # Original image in seg_dict is not deleted here, caller should manage seg_dict
    gc.collect()
    return resized_seg_dict


def resize_mask(mask, target_shape_or_ref_img):
    """
    Resizes a segmentation mask to a target shape using nearest neighbor interpolation.

    Parameters:
        mask (ndarray): The segmentation mask to resize.
        target_shape_or_ref_img (tuple or ndarray): Target (height, width) or a reference image
                                                     from which to get the target shape.

    Returns:
        ndarray: The resized mask.
    """
    if mask is None:
        return None

    if isinstance(target_shape_or_ref_img, np.ndarray):
        target_height, target_width = target_shape_or_ref_img.shape[:2]
    elif (
        isinstance(target_shape_or_ref_img, tuple) and len(target_shape_or_ref_img) >= 2
    ):
        target_height, target_width = target_shape_or_ref_img[:2]
    else:
        raise ValueError(
            "target_shape_or_ref_img must be a NumPy array or a (height, width) tuple."
        )

    current_height, current_width = mask.shape[:2]

    if (current_height, current_width) == (target_height, target_width):
        return mask  # No resize needed

    # print(f"Resizing mask from {(current_height, current_width)} to {(target_height, target_width)}")
    # Use INTER_NEAREST for masks to preserve integer labels
    resized_mask = cv2.resize(
        mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
    )
    return resized_mask


# --------------------------------------------------------------------------
# Tiling Utilities
# --------------------------------------------------------------------------


def generate_tiles(image_shape, tile_size, tile_overlap):
    """
    Generate tile coordinates (y_start, y_end, x_start, x_end) with overlap.
    Ensures full image coverage and that overlap does not exceed tile_overlap.
    Tiles at the edges will be made smaller if necessary to fit the image dimensions.

    Parameters:
        image_shape (tuple): (height, width) of the full image.
        tile_size (int): Desired size of the square tile.
        tile_overlap (int): Maximum desired overlap between adjacent tiles.

    Returns:
        list: List of tuples, each defining a tile's coordinates (y_start, y_end, x_start, x_end).
    """
    height, width = image_shape
    tiles = []

    if height <= 0 or width <= 0:
        return tiles  # No tiles for empty image

    # Adjust tile size if image is smaller
    actual_tile_h = min(tile_size, height)
    actual_tile_w = min(tile_size, width)

    # Adjust overlap: must be less than tile size and non-negative
    # Ensure overlap is not greater than the dimension itself minus 1
    actual_overlap_h = max(0, min(tile_overlap, actual_tile_h - 1))
    actual_overlap_w = max(0, min(tile_overlap, actual_tile_w - 1))

    # Calculate step size based on adjusted tile size and overlap
    step_y = actual_tile_h - actual_overlap_h
    step_x = actual_tile_w - actual_overlap_w

    # Ensure step is at least 1 to prevent infinite loops
    step_y = max(1, step_y)
    step_x = max(1, step_x)

    # --- Generate Start Positions ---
    y_starts = []
    y = 0
    while y < height:
        y_starts.append(y)
        # If a tile starting at y reaches or passes the image height, stop.
        if y + actual_tile_h >= height:
            break
        y += step_y
        # Safety break if step is somehow non-positive (shouldn't happen)
        if y <= y_starts[-1]:
            print(
                f"Warning: Step calculation resulted in non-positive step_y ({step_y}). Breaking y loop."
            )
            break

    # If the last start position doesn't allow coverage to the edge, add a final start position.
    # This final start position aligns the tile's end with the image's end.
    if y_starts and y_starts[-1] + actual_tile_h < height:
        final_y_start = height - actual_tile_h
        # Add only if it's different from the last start position
        if final_y_start > y_starts[-1]:
            y_starts.append(final_y_start)
    elif not y_starts and height > 0:  # Handle case where height <= tile_size
        y_starts = [0]

    # Similar logic for x start positions
    x_starts = []
    x = 0
    while x < width:
        x_starts.append(x)
        if x + actual_tile_w >= width:
            break
        x += step_x
        if x <= x_starts[-1]:
            print(
                f"Warning: Step calculation resulted in non-positive step_x ({step_x}). Breaking x loop."
            )
            break

    if x_starts and x_starts[-1] + actual_tile_w < width:
        final_x_start = width - actual_tile_w
        if final_x_start > x_starts[-1]:
            x_starts.append(final_x_start)
    elif not x_starts and width > 0:  # Handle case where width <= tile_size
        x_starts = [0]

    # Ensure start lists are unique (can happen if added start coincides with a step)
    # Sorting is not strictly necessary but keeps order predictable
    y_starts = sorted(list(set(y_starts)))
    x_starts = sorted(list(set(x_starts)))

    # --- Create Tile Coordinates ---
    for y_start in y_starts:
        for x_start in x_starts:
            # Calculate end coordinates, ensuring they don't exceed image boundaries
            y_end = min(y_start + actual_tile_h, height)
            x_end = min(x_start + actual_tile_w, width)

            # Ensure tile has valid positive dimensions before adding
            if y_end > y_start and x_end > x_start:
                tiles.append((y_start, y_end, x_start, x_end))

    # Final check for uniqueness, although the generation logic should minimize duplicates
    unique_tiles = list(dict.fromkeys(tiles))

    # print(f"Generated {len(unique_tiles)} unique tiles for shape {image_shape} (Tile: {tile_size}, Max Overlap: {tile_overlap})")
    return unique_tiles


def display_tile_progress(
    tiles_info, completed_tiles_indices, image_shape, current_tile_index=None
):
    """
    Display an ASCII grid showing tile processing progress in Jupyter/IPython.
    Updates in-place.

    Parameters:
        tiles_info (list): List of tile coordinate tuples.
        completed_tiles_indices (set): Set of indices of completed tiles.
        image_shape (tuple): (height, width) of the full image.
        current_tile_index (int, optional): Index of the tile currently being processed.
    """
    try:
        clear_output(wait=True)  # Clears the output cell in Jupyter/IPython

        height, width = image_shape
        if not tiles_info:
            print("\n--- Tile Processing Progress ---")
            print("No tiles to process.")
            return

        # Determine grid dimensions based on unique start coordinates
        y_starts = sorted(list(set([y for y, _, _, _ in tiles_info])))
        x_starts = sorted(list(set([x for _, _, x, _ in tiles_info])))
        grid_height = len(y_starts)
        grid_width = len(x_starts)

        if grid_height == 0 or grid_width == 0:
            print("\n--- Tile Processing Progress ---")
            print("Could not determine grid dimensions.")
            return

        # Create a mapping from (y_start, x_start) to tile index for grid population
        # This assumes generate_tiles produces tiles in a somewhat grid-like order
        tile_coord_to_index = {}
        temp_grid_map = {}
        y_start_map = {y: i for i, y in enumerate(y_starts)}
        x_start_map = {x: i for i, x in enumerate(x_starts)}

        for idx, (y, _, x, _) in enumerate(tiles_info):
            if (y, x) not in tile_coord_to_index:
                tile_coord_to_index[(y, x)] = idx
            # Map grid position to tile index
            y_grid_idx = y_start_map.get(y)
            x_grid_idx = x_start_map.get(x)
            if y_grid_idx is not None and x_grid_idx is not None:
                temp_grid_map[(y_grid_idx, x_grid_idx)] = idx

        grid = [["□"] * grid_width for _ in range(grid_height)]  # Initialize grid

        for r in range(grid_height):
            for c in range(grid_width):
                tile_idx = temp_grid_map.get((r, c))
                if tile_idx is not None:
                    if tile_idx == current_tile_index:
                        grid[r][c] = "P"  # Processing
                    elif tile_idx in completed_tiles_indices:
                        grid[r][c] = "✓"  # Completed
                    # else: remains '□' (Pending)

        total_tiles = len(tiles_info)
        completed_count = len(completed_tiles_indices)
        progress_percent = (
            (completed_count / total_tiles * 100) if total_tiles > 0 else 0
        )

        print("\n--- Tile Processing Progress ---")
        print(
            f"Image Size: {width}x{height} | Grid: {grid_width}x{grid_height} | Total Tiles: {total_tiles}"
        )
        print(f"Completed: {completed_count}/{total_tiles} ({progress_percent:.1f}%)")
        if current_tile_index is not None:
            print(f"Processing Tile: {current_tile_index + 1}")

        # Print the grid
        if grid_width > 0 and grid_height > 0:
            # Limit grid display size for very large grids
            max_display_width = 80
            max_display_height = 40
            display_grid_width = min(grid_width, max_display_width)
            display_grid_height = min(grid_height, max_display_height)

            print("┌" + "─" * (display_grid_width * 2 - 1) + "┐")
            for r in range(display_grid_height):
                row_str = " ".join(grid[r][:display_grid_width])
                if grid_width > max_display_width:
                    row_str += " ..."  # Indicate truncation
                print("│" + row_str + "│")
            if grid_height > max_display_height:
                print("." * (display_grid_width * 2 + 1))  # Indicate truncation
            print("└" + "─" * (display_grid_width * 2 - 1) + "┘")
            print("Legend: P = Processing, ✓ = Completed, □ = Pending\n")
        else:
            print("Grid display skipped (invalid dimensions).")

    except Exception as e:
        # Avoid crashing the main process if display fails
        print(f"Warning: Failed to display tile progress: {e}")


# --------------------------------------------------------------------------
# Segmentation Models
# --------------------------------------------------------------------------


def cellpose_segmentation(
    image_dict,
    output_dir,  # Note: Currently only used if save_mask_as_png=True
    membrane_channel_name=None,  # Name of membrane channel in image_dict (e.g., 'segmentation_channel')
    cytoplasm_channel_name=None,  # Name of cytoplasm channel in image_dict
    nucleus_channel_name=None,  # Name of nucleus channel in image_dict
    use_gpu=True,
    model="cyto3",  # Default Cellpose model
    custom_model=False,  # Set to True if 'model' is a path to a custom model file
    diameter=None,  # Cell diameter estimate (recommended)
    save_mask_as_png=False,  # Save Cellpose overlay PNG
):
    """
    Perform cell segmentation using Cellpose. Handles channel selection for Cellpose input.

    Parameters:
        image_dict (dict): Dict with images needed for segmentation (e.g., nuclei, membrane/cyto).
        output_dir (str or Path): Base output directory (used for saving PNG).
        membrane_channel_name (str, optional): Key in image_dict for membrane channel.
        cytoplasm_channel_name (str, optional): Key in image_dict for cytoplasm channel.
        nucleus_channel_name (str): Key in image_dict for nucleus channel.
        use_gpu (bool): Whether to use GPU.
        model (str): Cellpose model name (e.g., 'cyto3', 'nuclei') or path to custom model.
        custom_model (bool): True if 'model' is a path.
        diameter (float, optional): Estimated cell diameter.
        save_mask_as_png (bool): If True, saves Cellpose's diagnostic PNG.

    Returns:
        tuple: (masks, flows, styles) from cellpose.eval() or (None, None, None) on error.
               'masks' is the 2D labeled mask array.
    """
    if not nucleus_channel_name or nucleus_channel_name not in image_dict:
        print(
            f"Error: Nucleus channel '{nucleus_channel_name}' not provided or not found in image_dict."
        )
        return None, None, None

    # Determine Cellpose channels argument based on provided channel names
    # Cellpose channel mapping: 0=gray, 1=red(membrane), 2=green(cyto), 3=blue(nucleus)
    channels_arg = [0, 0]  # Default to grayscale (nucleus only)
    input_image = None

    # Prepare the input image (can be 2D grayscale or 3D RGB-like)
    nucleus_img = image_dict[nucleus_channel_name]
    membrane_img = image_dict.get(membrane_channel_name)
    cytoplasm_img = image_dict.get(cytoplasm_channel_name)

    if cytoplasm_img is not None:
        print("Using Cytoplasm (Green=2) and Nucleus (Blue=3) channels for Cellpose.")
        channels_arg = [2, 3]
        # Create a 3D array [H, W, C] -> R=0, G=cyto, B=nuc
        input_image = np.stack(
            [np.zeros_like(nucleus_img), cytoplasm_img, nucleus_img], axis=-1
        )
    elif membrane_img is not None:
        print("Using Membrane (Red=1) and Nucleus (Blue=3) channels for Cellpose.")
        channels_arg = [1, 3]
        # Create a 3D array [H, W, C] -> R=memb, G=0, B=nuc
        input_image = np.stack(
            [membrane_img, np.zeros_like(nucleus_img), nucleus_img], axis=-1
        )
    else:
        print("Using only Nucleus channel (Grayscale=0) for Cellpose.")
        channels_arg = [0, 0]  # Grayscale mode
        input_image = nucleus_img  # Use the 2D nucleus image directly

    if input_image is None:
        print("Error: Failed to prepare input image for Cellpose.")
        return None, None, None

    # Run CellPose core function
    try:
        masks, flows, styles = run_cellpose(
            image=input_image,
            output_dir=output_dir,
            use_gpu=use_gpu,
            model=model,
            custom_model=custom_model,
            diameter=diameter,
            channels=channels_arg,
            save_mask_as_png=save_mask_as_png,
        )
        return masks, flows, styles
    except Exception as e:
        print(f"Error during Cellpose segmentation run: {e}")
        return None, None, None


def run_cellpose(
    image,
    output_dir,
    use_gpu=True,
    model="cyto3",
    custom_model=False,
    diameter=None,
    channels=[0, 0],
    save_mask_as_png=False,
):
    """
    Internal helper to initialize and run the Cellpose model evaluation.

    Parameters: (See cellpose_segmentation docstring)

    Returns:
        tuple: (masks, flows, styles) from model.eval().
    """
    print(
        f"Running Cellpose: model='{model}', custom={custom_model}, diameter={diameter}, channels={channels}, gpu={use_gpu}"
    )

    # Initialize model
    model_obj = None
    try:
        if custom_model:
            if not os.path.exists(model):
                raise FileNotFoundError(f"Custom Cellpose model not found at: {model}")
            # Use CellposeModel for custom models
            model_obj = cellpose_models.CellposeModel(
                pretrained_model=model, gpu=use_gpu
            )
            print(f"Loaded custom Cellpose model from: {model}")
        else:
            # Use Cellpose or CellposeModel based on model type (nuclei often needs Cellpose)
            # Check if model is 'nuclei' or similar that might require the base class
            if model in ["nuclei"]:  # Add other models if needed
                model_obj = cellpose_models.Cellpose(model_type=model, gpu=use_gpu)
            else:  # Default to CellposeModel for 'cyto', 'cyto2', 'cyto3' etc.
                model_obj = cellpose_models.CellposeModel(model_type=model, gpu=use_gpu)
            print(f"Initialized Cellpose model: {model}")
    except Exception as e:
        print(f"Error initializing Cellpose model '{model}': {e}")
        raise  # Re-raise error

    # Run evaluation
    masks, flows, styles, diams = None, None, None, None  # Initialize
    try:
        # The eval signature might differ slightly based on Cellpose version and class used
        if isinstance(model_obj, cellpose_models.Cellpose):
            masks, flows, styles, diams = model_obj.eval(
                image, diameter=diameter, channels=channels, do_3D=False
            )
            print(
                f"Cellpose segmentation complete. Found {np.max(masks)} objects. Est. diameter: {diams:.2f}"
            )
        elif isinstance(model_obj, cellpose_models.CellposeModel):
            # CellposeModel.eval might not return diameter directly in all versions
            eval_output = model_obj.eval(
                image, diameter=diameter, channels=channels, do_3D=False
            )
            if len(eval_output) == 4:  # Older versions might return diams
                masks, flows, styles, diams = eval_output
                print(
                    f"Cellpose segmentation complete. Found {np.max(masks)} objects. Est. diameter: {diams:.2f}"
                )
            elif (
                len(eval_output) == 3
            ):  # Newer versions might return only masks, flows, styles
                masks, flows, styles = eval_output
                # Try to get diameter from the object if possible (might not always be set post-eval)
                diams = getattr(
                    model_obj, "diam_labels", diameter
                )  # Use provided diameter if not found
                print(f"Cellpose segmentation complete. Found {np.max(masks)} objects.")
            else:
                raise TypeError(
                    f"Unexpected output from CellposeModel.eval: {eval_output}"
                )

        else:
            raise TypeError("Unsupported Cellpose model object type.")

    except Exception as e:
        print(f"Error during Cellpose model.eval: {e}")
        raise  # Re-raise error

    # Save output PNG if requested
    if save_mask_as_png:
        try:
            from cellpose import io as cellpose_io  # Use cellpose io module

            output_dir_path = pathlib.Path(output_dir)
            output_dir_path.mkdir(parents=True, exist_ok=True)
            # Construct filename
            base_fname = (
                f"cellpose_seg_{pathlib.Path(model).stem if custom_model else model}"
            )
            filename = output_dir_path / f"{base_fname}_seg.png"

            # Use cellpose_io.save_masks or masks_flows_to_seg depending on needs
            # masks_flows_to_seg creates the standard diagnostic image
            # Need diameter value (estimated or provided)
            effective_diameter = diams if diams is not None else diameter
            if effective_diameter is None:
                print("Warning: Diameter not available for saving PNG. Using default.")
                effective_diameter = 30.0  # Default for saving function if unknown

            # Ensure image is suitable for saving (e.g., rescale if needed, handle multi-channel input)
            # masks_flows_to_seg expects an image that can be displayed (e.g., uint8 or float scaled 0-1)
            # It might handle the input 'image' directly, but let's prepare a displayable version
            from cellpose import utils

            img_display = image  # Use the input image directly first
            if image.ndim == 3 and image.shape[-1] == 3:  # RGB-like input
                # Cellpose plotting utils often work with channel-first or specific channel indices
                # Let's try to reconstruct a displayable image based on channels_arg
                if channels == [2, 3]:  # Cyto(G), Nuc(B)
                    img_display = image[
                        :, :, [2, 1, 0]
                    ]  # BGR for display? Or use utils.format_image
                elif channels == [1, 3]:  # Memb(R), Nuc(B)
                    img_display = image[:, :, [2, 0, 1]]  # BGR for display?
                # Or just use the first channel if grayscale
            elif image.ndim == 2:
                img_display = image

            # Normalize image for display if it's not uint8
            if img_display.dtype != np.uint8:
                img_display = utils.normalize99(img_display) * 255
                img_display = img_display.astype(np.uint8)

            cellpose_io.masks_flows_to_seg(
                img_display, masks, flows, effective_diameter, filename, channels
            )
            print(f"Saved Cellpose segmentation overlay to: {filename}")

        except ImportError:
            print(
                "Warning: Could not import cellpose.io or cellpose.utils. Skipping saving PNG."
            )
        except Exception as e:
            print(f"Warning: Error saving Cellpose PNG: {e}")

    return masks, flows, styles


def load_mesmer_model(model_dir):
    """
    Loads the Mesmer model from a specified directory. Downloads if not found.

    Parameters:
        model_dir (str or Path): Directory where 'Mesmer_model/MultiplexSegmentation'
                                 is located or will be downloaded to.

    Returns:
        tensorflow.keras.Model: The loaded Mesmer model, or None on error.
    """
    model_dir_path = pathlib.Path(model_dir)
    mesmer_subdir = "Mesmer_model"
    model_name = "MultiplexSegmentation"
    full_model_path = model_dir_path / mesmer_subdir / model_name

    if not full_model_path.exists():
        print(f"Mesmer model not found at {full_model_path}. Attempting download...")
        try:
            (model_dir_path / mesmer_subdir).mkdir(parents=True, exist_ok=True)

            # Download URL and target file path
            url = "https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/MultiplexSegmentation-9.tar.gz"
            tar_path = model_dir_path / mesmer_subdir / "MultiplexSegmentation.tar.gz"
            extract_target_dir = model_dir_path / mesmer_subdir

            # Download
            print(f"Downloading Mesmer model from {url}...")
            response = requests.get(url, stream=True)
            response.raise_for_status()  # Check for download errors
            with open(tar_path, "wb") as f:
                for chunk in tqdm(
                    response.iter_content(chunk_size=8192 * 16),
                    desc="Downloading Mesmer",
                ):  # Larger chunk size + tqdm
                    f.write(chunk)
            print(f"Downloaded model archive to {tar_path}")

            # Unpack
            print(f"Unpacking {tar_path}...")
            shutil.unpack_archive(tar_path, extract_target_dir)
            print(f"Unpacked model to {extract_target_dir}")

            # Check if the expected model directory exists after unpacking
            if not full_model_path.exists():
                raise FileNotFoundError(
                    f"Model directory '{model_name}' not found in {extract_target_dir} after unpacking."
                )

            # Clean up downloaded archive
            os.remove(tar_path)
            print(f"Removed downloaded archive {tar_path}")

        except requests.exceptions.RequestException as e:
            print(f"Error downloading Mesmer model: {e}")
            return None
        except (shutil.ReadError, FileNotFoundError, Exception) as e:
            print(f"Error setting up Mesmer model: {e}")
            if "tar_path" in locals() and os.path.exists(tar_path):
                os.remove(tar_path)  # Clean up failed download
            return None
    else:
        print(f"Found existing Mesmer model at: {full_model_path}")

    # Load the model
    print("Loading Mesmer model...")
    try:
        # Use tf.keras.models.load_model (already imported)
        mesmer_pretrained_model = load_model(str(full_model_path), compile=False)
        print("Mesmer model loaded successfully.")
        return mesmer_pretrained_model
    except Exception as e:
        print(f"Error loading Mesmer model from {full_model_path}: {e}")
        return None


def mesmer_segmentation(
    nuclei_image,
    membrane_image,  # Can be None for nuclear-only segmentation
    image_mpp=0.5,  # Microns per pixel - important for Mesmer performance
    plot_predictions=False,
    compartment="whole-cell",  # 'whole-cell' or 'nuclear'
    model_path="./models",  # Base directory for Mesmer model download/load
):
    """
    Perform segmentation using the DeepCell Mesmer model.

    Parameters:
        nuclei_image (ndarray): 2D NumPy array for nuclei.
        membrane_image (ndarray or None): 2D NumPy array for membrane/cytoplasm, or None for nuclear segmentation.
        image_mpp (float): Microns per pixel.
        plot_predictions (bool): Whether to plot segmentation overlay.
        compartment (str): 'whole-cell' or 'nuclear'.
        model_path (str or Path): Directory for Mesmer model.

    Returns:
        ndarray: 2D integer-labeled segmentation mask, or None on error.
    """
    print(f"Running Mesmer segmentation: compartment='{compartment}', mpp={image_mpp}")

    # Load Mesmer model
    mesmer_pretrained_model = load_mesmer_model(model_path)
    if mesmer_pretrained_model is None:
        return None  # Error loading model

    # Initialize Mesmer application
    try:
        app = Mesmer(model=mesmer_pretrained_model)
    except Exception as e:
        print(f"Error initializing Mesmer application: {e}")
        return None

    # Prepare input image stack for Mesmer: (batch, height, width, channels)
    # Channels: [Nuclear, Membrane/Cytoplasm]
    if nuclei_image.ndim != 2:
        print(f"Error: Nuclei image must be 2D, but got shape {nuclei_image.shape}")
        return None

    if compartment == "whole-cell":
        if membrane_image is None:
            print(
                "Warning: compartment is 'whole-cell' but membrane_image is None. Performing nuclear segmentation instead."
            )
            compartment = "nuclear"  # Switch to nuclear
            membrane_channel = np.zeros_like(nuclei_image)  # Dummy channel
        elif membrane_image.shape != nuclei_image.shape:
            print(
                f"Error: Nuclei ({nuclei_image.shape}) and membrane ({membrane_image.shape}) images must have the same shape."
            )
            return None
        else:
            membrane_channel = membrane_image
    elif compartment == "nuclear":
        membrane_channel = np.zeros_like(nuclei_image)  # Dummy channel if nuclear only
    else:
        print(
            f"Error: Invalid compartment: {compartment}. Choose 'whole-cell' or 'nuclear'."
        )
        return None

    # Stack channels and add batch dimension
    try:
        # Normalize images (Mesmer often expects float inputs, check docs if needed)
        # Example normalization (adjust based on expected input range):
        # nuclei_norm = (nuclei_image - np.min(nuclei_image)) / (np.max(nuclei_image) - np.min(nuclei_image) + 1e-6)
        # membrane_norm = (membrane_channel - np.min(membrane_channel)) / (np.max(membrane_channel) - np.min(membrane_channel) + 1e-6)
        # combined_image = np.stack([nuclei_norm, membrane_norm], axis=-1)

        # Using images as is for now, assuming Mesmer handles scaling or prefers raw values
        combined_image = np.stack([nuclei_image, membrane_channel], axis=-1)
        combined_image_batch = np.expand_dims(combined_image, axis=0).astype(
            np.float32
        )  # Mesmer expects float32
        # print(f\"Prepared Mesmer input batch with shape: {combined_image_batch.shape}\")
    except Exception as e:
        print(f"Error preparing Mesmer input stack: {e}")
        return None

    # Run the Mesmer model prediction
    print("Predicting with Mesmer...")
    try:
        segmented_batch = app.predict(
            combined_image_batch, image_mpp=image_mpp, compartment=compartment
        )
        # print(f\"Mesmer prediction output batch shape: {segmented_batch.shape}\")
    except Exception as e:
        print(f"Error during Mesmer prediction: {e}")
        return None

    # Extract the single mask from the batch: (1, H, W, 1) -> (H, W)
    if (
        segmented_batch is None
        or segmented_batch.shape[0] != 1
        or segmented_batch.shape[-1] != 1
    ):
        print(
            f"Warning: Unexpected Mesmer output shape {segmented_batch.shape if segmented_batch is not None else 'None'}. Cannot extract mask."
        )
        return None
    segmented_mask = np.squeeze(segmented_batch).astype(np.int32)
    print(
        f"Extracted Mesmer mask with shape: {segmented_mask.shape}, max label: {np.max(segmented_mask)}"
    )

    # Plotting (optional)
    if plot_predictions:
        try:
            from deepcell.utils.plot_utils import create_rgb_image, make_outline_overlay

            print("Plotting Mesmer predictions...")
            # Use the original (non-batch) combined image for plotting colors
            channel_colors = [
                "blue",
                "green",
            ]  # Nuc=Blue, Memb=Green (adjust as preferred)
            rgb_images = create_rgb_image(
                np.expand_dims(combined_image, axis=0), channel_colors=channel_colors
            )
            overlay_data = make_outline_overlay(
                rgb_data=rgb_images, predictions=segmented_batch
            )

            fig, ax = plt.subplots(1, 2, figsize=(15, 7))
            ax[0].imshow(rgb_images[0, ...])
            ax[0].set_title(
                f"Input (Nuc: {channel_colors[0]}, Memb: {channel_colors[1]})"
            )
            ax[0].axis("off")
            ax[1].imshow(overlay_data[0, ...])
            ax[1].set_title(f"Mesmer {compartment} Predictions")
            ax[1].axis("off")
            plt.tight_layout()
            plt.show()
        except ImportError:
            print(
                "Warning: deepcell.utils.plot_utils not found. Cannot plot Mesmer predictions."
            )
        except Exception as e:
            print(f"Warning: Error during Mesmer plotting: {e}")

    # Clean up
    del (
        combined_image,
        combined_image_batch,
        segmented_batch,
        app,
        mesmer_pretrained_model,
    )
    gc.collect()

    return segmented_mask


# --------------------------------------------------------------------------
# Stitching and Postprocessing
# --------------------------------------------------------------------------


def stitch_masks(tiles_info, tile_masks, full_shape, tile_overlap=32, sigma=128):
    """Stitch multiple segmentation masks from overlapping tiles with confidence-based blending.

    This function combines segmentation masks from multiple tiles into a single cohesive mask.
    It handles overlapping regions using a confidence-based approach and resolves conflicts
    between object labels across tile boundaries.

    Parameters
    ----------
    tiles_info : list of tuple
        List of tile coordinates, each tuple containing (y_start, y_end, x_start, x_end)
        defining the position of each tile in the full image.
    tile_masks : list of ndarray
        List of 2D integer-labeled segmentation masks, one for each tile.
        Must match the length of tiles_info. Each mask should have dimensions
        matching its corresponding tile coordinates.
    full_shape : tuple
        Shape (height, width) of the final stitched mask.
    tile_overlap : int, optional
        Overlap between adjacent tiles in pixels, by default 32.
        Used to create smooth transitions between tiles.
    sigma : float, optional
        Standard deviation for Gaussian smoothing of confidence maps,
        by default 128. Larger values create smoother transitions.

    Returns
    -------
    ndarray
        2D integer-labeled segmentation mask of shape full_shape.
        Background is 0, objects are labeled with consecutive positive integers.
        Returns zeros if no valid tiles are found.

    Notes
    -----
    Algorithm Steps:
    1. Validates input and initializes output arrays
    2. First pass: Prepares tiles and calculates confidence maps
        - Offsets labels to avoid conflicts
        - Creates confidence maps for overlap regions
    3. Second pass: Merges tiles with conflict resolution
        - Resolves overlapping objects using majority voting
        - Updates pixels based on confidence values
    4. Final relabeling to ensure consecutive labels

    Memory Optimization:
    - Uses caching for confidence maps of similar tile shapes
    - Processes tiles sequentially to limit memory usage
    - Cleans up intermediate arrays explicitly

    Performance Considerations:
    - Gaussian smoothing of confidence maps can be a bottleneck
    - Label conflict resolution scales with number of overlapping objects
    - Memory usage scales with full image size and tile overlap

    See Also
    --------
    scipy.ndimage.gaussian_filter : Used for confidence map smoothing
    skimage.morphology.label : Similar functionality for connected component labeling
    """

    print(
        f"Stitching {len(tiles_info)} masks with overlap={tile_overlap}, sigma={sigma}..."
    )
    start_time = time.time()  # Track stitching time

    # --- Input Validation ---
    if not tiles_info or not tile_masks or len(tiles_info) != len(tile_masks):
        print("Warning: Invalid input tiles_info or tile_masks. Returning empty mask.")
        return np.zeros(full_shape, dtype=np.int32)

    # Ensure tile_overlap is valid
    tile_overlap = max(0, int(tile_overlap))

    # --- Pre-allocation ---
    # Use float64 for accumulation to avoid potential overflow with large labels, then cast later if needed
    # Using int32 directly might be slightly faster if max_label doesn't exceed 2^31, but safer with int64 intermediate.
    # Let's stick to int32 as label counts rarely exceed this, and it matches skimage output.
    full_mask = np.zeros(full_shape, dtype=np.int32)
    # Confidence map stores the confidence of the label at each pixel
    confidence_map = np.zeros(full_shape, dtype=np.float32)

    # --- Confidence Map Cache ---
    confidence_cache = {}

    # --- First Pass: Prepare Tiles (Offset Labels and Get Confidence) ---
    print("  Preparing tiles (label offset and confidence)...")
    processed_tiles = []
    max_label = 0
    skipped_tiles = 0

    for i in range(len(tiles_info)):
        coords = tiles_info[i]
        tile_mask = tile_masks[i]

        # Skip if tile_mask is None or empty
        if tile_mask is None or tile_mask.size == 0 or not np.any(tile_mask):
            skipped_tiles += 1
            continue

        y_start, y_end, x_start, x_end = coords
        h, w = tile_mask.shape[:2]  # Use shape from actual mask

        # Ensure mask dimensions match coordinate dimensions
        if h != (y_end - y_start) or w != (x_end - x_start):
            print(
                f"Warning: Tile {i} mask shape {tile_mask.shape} mismatch with coords {coords}. Skipping."
            )
            skipped_tiles += 1
            continue

        # Copy mask to avoid modifying originals and offset labels
        tile_mask = tile_mask.copy()
        valid_mask_pixels = tile_mask > 0
        if max_label > 0 and np.any(valid_mask_pixels):
            tile_mask[valid_mask_pixels] += max_label

        current_max = tile_mask.max()
        max_label = max(max_label, current_max)

        # --- Calculate or Retrieve Confidence Map ---
        # Cache key based on actual tile shape
        cache_key = (
            h,
            w,
            tile_overlap,
            sigma,
            y_start > 0,
            y_end < full_shape[0],
            x_start > 0,
            x_end < full_shape[1],
        )
        if cache_key not in confidence_cache:
            # Create a confidence map: 1.0 in center, ramps down towards edges that overlap
            conf_local = np.ones((h, w), dtype=np.float32)
            overlap_y = min(
                tile_overlap, h // 2
            )  # Ensure overlap doesn't exceed half the tile dim
            overlap_x = min(tile_overlap, w // 2)

            # Create ramps (only if overlap > 0)
            if overlap_y > 0:
                ramp_y = np.linspace(0.0, 1.0, overlap_y, dtype=np.float32)
                if y_start > 0:  # Top edge needs ramp up
                    conf_local[:overlap_y, :] *= ramp_y[:, np.newaxis]
                if y_end < full_shape[0]:  # Bottom edge needs ramp down
                    conf_local[h - overlap_y :, :] *= ramp_y[::-1][:, np.newaxis]

            if overlap_x > 0:
                ramp_x = np.linspace(0.0, 1.0, overlap_x, dtype=np.float32)
                if x_start > 0:  # Left edge needs ramp up
                    conf_local[:, :overlap_x] *= ramp_x[np.newaxis, :]
                if x_end < full_shape[1]:  # Right edge needs ramp down
                    conf_local[:, w - overlap_x :] *= ramp_x[::-1][np.newaxis, :]

            # Apply Gaussian smoothing (potential bottleneck)
            # sigma/4 is used based on the original code's heuristic
            smooth_sigma = sigma / 4.0
            if smooth_sigma > 0:
                conf_local = gaussian_filter(
                    conf_local, sigma=smooth_sigma, mode="constant", cval=0.0
                )

            # Normalize confidence map (prevents issues if smoothing pushes max slightly > 1)
            max_conf = conf_local.max()
            if max_conf > 0:
                conf_local /= max_conf
            else:
                conf_local[:] = 0  # Avoid NaN if max is 0

            # Clip values to ensure they are within [0, 1] after filtering/normalization
            np.clip(conf_local, 0.0, 1.0, out=conf_local)

            confidence_cache[cache_key] = conf_local
        # --- End Confidence Map Calculation ---

        processed_tiles.append((coords, tile_mask, confidence_cache[cache_key]))

    if skipped_tiles > 0:
        print(f"  Skipped {skipped_tiles} empty or invalid tiles.")
    if not processed_tiles:
        print(
            "Warning: No valid tiles found to process after first pass. Returning empty mask."
        )
        return np.zeros(full_shape, dtype=np.int32)
    print(
        f"  Processed {len(processed_tiles)} tiles in first pass. Max label offset: {max_label}"
    )

    # --- Second Pass: Merge Tiles with Conflict Resolution ---
    print("  Merging tiles...")
    for coords, tile_mask, confidence in processed_tiles:
        y_start, y_end, x_start, x_end = coords

        # Get views into the full mask and confidence map
        region_mask = full_mask[y_start:y_end, x_start:x_end]
        region_conf = confidence_map[y_start:y_end, x_start:x_end]

        # --- Conflict Resolution ---
        # Identify pixels in the current tile that have labels (value > 0)
        tile_pixels_with_labels = tile_mask > 0

        # Find where these pixels overlap with existing labels in the full_mask region
        conflicting_pixels = tile_pixels_with_labels & (region_mask > 0)

        if np.any(conflicting_pixels):
            # Get labels from the tile and the existing region at conflicting pixels
            tile_labels_at_conflict = tile_mask[conflicting_pixels]
            region_labels_at_conflict = region_mask[conflicting_pixels]

            # Iterate through unique *tile* labels involved in conflicts
            unique_tile_labels_in_conflict = np.unique(tile_labels_at_conflict)

            for tile_L in unique_tile_labels_in_conflict:
                # Find where *this specific tile label* causes conflicts
                current_conflict_mask = conflicting_pixels & (tile_mask == tile_L)
                if not np.any(current_conflict_mask):
                    continue

                # Get the existing region labels that conflict with this tile_L
                overlapping_region_labels = region_mask[current_conflict_mask]

                # Find the most common *region* label overlapping with this *tile* label
                # Using bincount can be faster than unique if max label isn't excessively large
                try:
                    counts = np.bincount(overlapping_region_labels)
                    if counts.size > 0:
                        most_common_region_label = np.argmax(counts)
                        max_count = counts[most_common_region_label]

                        # If a region label is significantly present (>50% overlap),
                        # merge the current tile label (tile_L) into that region label.
                        # np.count_nonzero is faster than sum() for boolean arrays
                        if (
                            max_count > 0
                            and (max_count / np.count_nonzero(current_conflict_mask))
                            >= 0.5
                        ):
                            # Update the tile_mask *in place* for subsequent steps
                            tile_mask[tile_mask == tile_L] = most_common_region_label
                            # No need to break, continue checking other tile labels in conflict
                except (IndexError, ValueError) as e:
                    # Handle potential issues with bincount if labels are negative or too large
                    print(
                        f"Warning: Error during bincount conflict resolution for tile label {tile_L}: {e}. Falling back to slower unique."
                    )
                    unique_overlaps, counts = np.unique(
                        overlapping_region_labels, return_counts=True
                    )
                    if counts.size > 0:
                        max_idx = np.argmax(counts)
                        if (
                            counts[max_idx] / np.count_nonzero(current_conflict_mask)
                        ) >= 0.5:
                            tile_mask[tile_mask == tile_L] = unique_overlaps[max_idx]

        # --- Apply Update ---
        # Find pixels where the current tile has a label AND its confidence is higher
        # than the existing confidence in the region.
        update_mask = (tile_mask > 0) & (confidence > region_conf)

        # Apply the update using boolean indexing (generally efficient in NumPy)
        region_mask[update_mask] = tile_mask[update_mask]
        region_conf[update_mask] = confidence[update_mask]

    # Clean up intermediate list and cache
    del processed_tiles, confidence_cache
    gc.collect()

    # --- Final Relabeling ---
    print("  Relabeling final mask...")
    # Check if there are any labels before relabeling
    max_final_label = full_mask.max()
    if max_final_label > 0:
        # relabel_sequential ensures consecutive labels starting from 1
        full_mask, _, _ = relabel_sequential(full_mask)
        print(f"  Relabeling complete. Final max label: {full_mask.max()}")
    else:
        print("  Skipping relabeling as the final mask is empty.")

    end_time = time.time()
    print(f"Stitching finished in {end_time - start_time:.2f} seconds.")

    return full_mask


def remove_border_objects(mask):
    """
    Remove labeled objects in a mask that directly touch its borders.
    Assumes background is labeled as 0.
    """
    # Get unique labels along the four borders
    border_labels = set(np.unique(mask[0, :])).union(np.unique(mask[-1, :]))
    border_labels = border_labels.union(np.unique(mask[:, 0])).union(
        np.unique(mask[:, -1])
    )
    border_labels.discard(0)  # Keep background intact
    for lbl in border_labels:
        mask[mask == lbl] = 0
    return mask


# --------------------------------------------------------------------------
# Feature Extraction
# --------------------------------------------------------------------------


def extract_features(
    image_dict,
    segmentation_masks,
    channels_to_quantify,
    output_file,
    size_cutoff=0,
    # Tiling parameters for intensity calculation
    use_tiling_for_intensity=True,
    tile_size=2048,
    tile_overlap=128,
    memory_limit_gb=4,  # Approx. memory limit per channel before tiling intensity calc.
):
    """Extract morphological and intensity features from segmented images with memory optimization.

    This function performs feature extraction in multiple stages:
    1. Calculates morphological features using regionprops
    2. Filters objects based on size
    3. Creates a filtered mask
    4. Calculates mean intensities (with optional tiling)
    5. Combines and saves the features

    Parameters
    ----------
    image_dict : dict
        Dictionary mapping channel names to 2D image arrays. Images should be
        single-channel arrays with matching dimensions.
    segmentation_masks : ndarray
        2D integer-labeled segmentation mask. Background should be 0,
        objects labeled with consecutive positive integers.
    channels_to_quantify : list
        List of channel names in image_dict for intensity quantification.
        These names must exist as keys in image_dict.
    output_file : str or Path
        Path where the output CSV file will be saved.
    size_cutoff : int, optional
        Minimum object area in pixels to include in analysis, by default 0.
        Objects smaller than this are filtered out based on regionprops area.
    use_tiling_for_intensity : bool, optional
        Whether to enable tiled processing for intensity calculations,
        by default True. Recommended for large images.
    tile_size : int, optional
        Size of tiles in pixels for intensity calculation when tiling is used,
        by default 2048.
    tile_overlap : int, optional
        Overlap between adjacent tiles in pixels, by default 128.
        Prevents edge artifacts in tiled processing.
    memory_limit_gb : float, optional
        Memory threshold in GB per channel that triggers tiled processing,
        by default 4. Adjust based on available system memory.

    Returns
    -------
    pandas.DataFrame or None
        DataFrame containing extracted features if successful, with columns:
        - Morphological: 'label', 'y', 'x', 'area', 'eccentricity', 'perimeter',
          'convex_area', 'axis_major_length', 'axis_minor_length'
        - Intensity: mean intensity for each channel in channels_to_quantify
        Returns None if processing fails or no objects are found.

    Notes
    -----
    Memory Optimization:
    - Morphological features are calculated on the full mask at once
    - Intensity calculations can be tiled for large images
    - Intermediate results are explicitly cleaned up
    - Uses float64 for intensity calculations to maintain precision

    Performance Considerations:
    - Large masks may cause memory issues during morphological calculation
    - Tiling adds overhead but reduces peak memory usage
    - Consider reducing tile_size if memory errors occur
    - GPU memory is not used directly but temp arrays may impact GPU memory

    File Handling:
    - Creates output directory if it doesn't exist
    - Saves empty CSV if no objects are found
    - Attempts to save partial results (morphology only) on failure

    See Also
    --------
    skimage.measure.regionprops_table : Used for morphological feature extraction
    numpy.bincount : Used for efficient intensity calculation
    """
    print("--- Starting Feature Extraction ---")
    output_file = pathlib.Path(output_file)  # Ensure Path object

    if (
        segmentation_masks is None
        or np.prod(segmentation_masks.shape) == 0
        or np.max(segmentation_masks) == 0
    ):
        print(
            "Error: Segmentation mask is empty, None, or contains no labeled objects."
        )
        output_file.parent.mkdir(parents=True, exist_ok=True)
        pd.DataFrame().to_csv(output_file, index=False)
        print(f"Created empty features file: {output_file}")
        return None

    segmentation_masks = segmentation_masks.squeeze().astype(np.int32)
    img_h, img_w = segmentation_masks.shape

    # Ensure size_cutoff is non-negative
    size_cutoff = max(0, size_cutoff)

    # --- 1. Calculate Morphological Features (on full mask before size filtering) ---
    print("Calculating morphological features...")
    props_df = None
    try:
        # Run regionprops on the original mask
        if np.max(segmentation_masks) > 0:
            props = regionprops_table(
                segmentation_masks,  # Use the original mask
                properties=(
                    "label",
                    "centroid",
                    "area",
                    "eccentricity",
                    "perimeter",
                    "convex_area",
                    "axis_major_length",
                    "axis_minor_length",
                ),
            )
            props_df = pd.DataFrame(props)
            props_df.rename(
                columns={"centroid-0": "y", "centroid-1": "x"}, inplace=True
            )
            print(f"Calculated initial morphology for {len(props_df)} objects.")

            # --- 2. Filter Small Objects based on regionprops area ---
            print(f"Filtering objects with area < {size_cutoff} pixels...")
            props_df = props_df[
                props_df["area"] >= size_cutoff
            ].copy()  # Filter based on area
            props_df.set_index("label", inplace=True)  # Set index after filtering

            if props_df.empty:
                print("No objects remaining after size filtering.")
                # Create empty file
                output_file.parent.mkdir(parents=True, exist_ok=True)
                pd.DataFrame().to_csv(output_file, index=False)
                print(f"Created empty features file: {output_file}")
                return None
            print(f"Found {len(props_df)} objects after size filtering.")
        else:
            print("No objects found in the initial mask for morphology calculation.")
            # Create empty file
            output_file.parent.mkdir(parents=True, exist_ok=True)
            pd.DataFrame().to_csv(output_file, index=False)
            print(f"Created empty features file: {output_file}")
            return None

    except MemoryError as e:
        print(f"MemoryError calculating morphological features on full mask: {e}")
        print(
            "Consider using libraries designed for out-of-core morphology if this persists."
        )
        return None
    except Exception as e:
        print(f"Error calculating morphological features or filtering: {e}")
        return None

    # --- 3. Create Filtered Mask (filterimg) ---
    # This mask contains only the objects that passed the size filter.
    print("Creating filtered mask for intensity calculation...")
    filterimg = None  # Initialize
    max_label = 0  # Initialize
    final_nucleus_ids = props_df.index.to_numpy(
        dtype=np.int32
    )  # Get labels that passed filtering

    try:
        # Create a mapping array to zero out small labels based on props_df index
        original_max_label = np.max(segmentation_masks)
        filter_map = np.zeros(original_max_label + 1, dtype=segmentation_masks.dtype)
        # Only keep labels that are in the filtered props_df index
        valid_labels_in_mask = final_nucleus_ids[
            final_nucleus_ids <= original_max_label
        ]
        filter_map[valid_labels_in_mask] = valid_labels_in_mask  # Keep valid labels

        filterimg = filter_map[
            segmentation_masks
        ]  # Apply mapping to create filtered mask

        if len(final_nucleus_ids) > 0:
            # Use the max label from the *filtered* set for bincount minlength
            max_label = int(np.max(final_nucleus_ids))
        else:
            max_label = (
                0  # Should not happen if props_df wasn't empty, but safety check
            )

    except MemoryError:
        print(
            "MemoryError creating the filtered mask. Image/Mask might be too large for this step."
        )
        del props_df
        gc.collect()
        return None
    except Exception as e:
        print(f"Error creating filtered mask: {e}")
        del props_df
        gc.collect()
        return None

    # --- 4. Calculate Mean Intensities (Potentially Tiled, using filterimg) ---
    print("Calculating mean intensities...")
    mean_intensity_data = {}

    # Determine if tiling is needed for intensity calculation
    tiling_needed = False
    if channels_to_quantify:
        first_channel_name = channels_to_quantify[0]
        if first_channel_name in image_dict:
            dtype_size = image_dict[first_channel_name].dtype.itemsize
            estimated_gb_per_channel = (img_h * img_w * dtype_size) / (1024**3)
            tiling_needed = use_tiling_for_intensity and (
                estimated_gb_per_channel > memory_limit_gb
            )
        else:
            print(
                f"Warning: First channel '{first_channel_name}' for quantification not found in image_dict."
            )
            channels_to_quantify = []  # Avoid processing if first channel missing

    tiles_info = None
    if tiling_needed:
        print(
            f"Memory estimate ({estimated_gb_per_channel:.2f} GB) exceeds limit ({memory_limit_gb} GB). Using tiling for intensity."
        )
        tiles_info = generate_tiles(filterimg.shape, tile_size, tile_overlap)
        print(f"Generated {len(tiles_info)} tiles for intensity calculation.")
    elif channels_to_quantify:
        print("Processing intensities on full images.")
    else:
        print("Skipping intensity calculation (no valid channels specified).")

    # Process each channel using the 'filterimg'
    for chan in tqdm(channels_to_quantify, desc="Processing channels"):
        if chan not in image_dict:
            print(
                f"Warning: Channel '{chan}' not found in image_dict. Filling with NaN."
            )
            # Use props_df index length for consistency
            mean_intensity_data[chan] = np.full(len(props_df), np.nan)
            continue

        chan_data = None  # Ensure variable exists for finally block
        try:
            chan_data = image_dict[chan]
            if chan_data.shape != filterimg.shape:
                print(
                    f"Warning: Shape mismatch for channel '{chan}' ({chan_data.shape}) vs mask ({filterimg.shape}). Resizing channel."
                )
                chan_data = cv2.resize(
                    chan_data, (img_w, img_h), interpolation=cv2.INTER_LINEAR
                )
                if chan_data.shape != filterimg.shape:
                    raise ValueError(
                        f"Channel resize failed for '{chan}'. Expected {filterimg.shape}, got {chan_data.shape}."
                    )

            # Initialize sums and counts for this channel, ensure large enough for max_label from filtered set
            channel_sums = np.zeros(max_label + 1, dtype=np.float64)
            channel_counts = np.zeros(max_label + 1, dtype=np.int64)

            if tiling_needed and tiles_info:
                # --- Tiled Intensity Calculation ---
                for y_start, y_end, x_start, x_end in tiles_info:
                    try:
                        # Use filterimg for masking
                        tile_mask_view = filterimg[y_start:y_end, x_start:x_end]
                        tile_chan_view = chan_data[y_start:y_end, x_start:x_end]

                        if tile_mask_view.size == 0 or tile_chan_view.size == 0:
                            continue

                        # Use minlength=max_label + 1 based on filtered labels
                        tile_sums = np.bincount(
                            tile_mask_view.ravel(),
                            weights=tile_chan_view.ravel(),
                            minlength=max_label + 1,
                        )
                        tile_counts = np.bincount(
                            tile_mask_view.ravel(), minlength=max_label + 1
                        )

                        channel_sums += tile_sums
                        channel_counts += tile_counts
                        del tile_mask_view, tile_chan_view, tile_sums, tile_counts
                    except IndexError:
                        print(
                            f"Warning: Tile coordinates caused IndexError during intensity calculation for channel '{chan}'. Skipping tile."
                        )
                        continue
                    except Exception as tile_e:
                        print(
                            f"Warning: Error processing tile for channel '{chan}': {tile_e}. Skipping tile."
                        )
                        continue
            else:
                # --- Full Image Intensity Calculation ---
                if filterimg.size > 0 and chan_data.size > 0:
                    # Use filterimg here
                    channel_sums = np.bincount(
                        filterimg.ravel(),
                        weights=chan_data.ravel(),
                        minlength=max_label + 1,
                    )
                    channel_counts = np.bincount(
                        filterimg.ravel(), minlength=max_label + 1
                    )
                else:
                    print(
                        f"Warning: Empty filtered mask or channel data for '{chan}'. Skipping full image calculation."
                    )

            # Calculate mean intensity only for the labels present in props_df (final_nucleus_ids)
            # Ensure indices are valid before accessing sums/counts
            valid_indices_mask = (final_nucleus_ids >= 0) & (
                final_nucleus_ids <= max_label
            )
            valid_final_ids = final_nucleus_ids[valid_indices_mask]

            if len(valid_final_ids) == 0:
                print(
                    f"Warning: No valid labels found for channel '{chan}' after index check."
                )
                mean_intensity_data[chan] = np.full(
                    len(props_df), np.nan
                )  # Match props_df length
                continue

            # Get sums/counts only for the valid labels that passed size filtering
            counts_for_final_labels = channel_counts[valid_final_ids]
            sums_for_final_labels = channel_sums[valid_final_ids]

            # Initialize result array matching the length of props_df (filtered objects)
            mean_channel_intensity = np.full(len(props_df), np.nan, dtype=np.float64)

            # Calculate mean where counts > 0
            valid_counts_mask_local = counts_for_final_labels > 0
            # Ensure we only calculate for labels that actually had counts
            mean_values = np.full(
                np.sum(valid_counts_mask_local), np.nan
            )  # Initialize output for division
            np.divide(
                sums_for_final_labels[valid_counts_mask_local],
                counts_for_final_labels[valid_counts_mask_local],
                out=mean_values,
                where=valid_counts_mask_local,
            )  # Condition for division

            # Place calculated means into the correct positions in the result array
            # We need to map the results back based on the original index positions in final_nucleus_ids
            # Create a boolean mask aligned with final_nucleus_ids
            full_valid_mask = np.zeros(len(final_nucleus_ids), dtype=bool)
            full_valid_mask[valid_indices_mask] = valid_counts_mask_local

            mean_channel_intensity[full_valid_mask] = mean_values
            mean_intensity_data[chan] = mean_channel_intensity

        except MemoryError as e:
            print(f"\\nMemoryError processing channel '{chan}': {e}")
            mean_intensity_data[chan] = np.full(len(props_df), np.nan)
        except ValueError as e:
            print(f"\\nValueError processing channel '{chan}': {e}")
            mean_intensity_data[chan] = np.full(len(props_df), np.nan)
        except Exception as e:
            print(f"\\nError processing channel '{chan}': {e}")
            mean_intensity_data[chan] = np.full(len(props_df), np.nan)
        finally:
            if chan_data is not None:
                del chan_data
            gc.collect()

    # --- 5. Combine Features and Save ---
    print("Combining morphology and intensity features...")
    try:
        # filterimg no longer needed
        if filterimg is not None:
            del filterimg
        gc.collect()

        # Create DataFrame from intensity data using the filtered props_df index
        # Ensure the index matches props_df's index (which is 'label')
        mean_df = pd.DataFrame(mean_intensity_data, index=props_df.index)
        # mean_df.index.name = 'label' # Index already named 'label' from props_df

        # Join morphological features (already filtered) with mean intensities
        markers_df = props_df.join(
            mean_df, how="left"
        )  # props_df is already filtered by size

        if markers_df.isnull().values.any():
            nan_cols = markers_df.columns[markers_df.isnull().any()].tolist()
            print(
                f"Warning: Found NaN values in final features. Columns affected: {nan_cols}"
            )

        # Reorder columns (optional) - Use the globally defined morpho_cols
        present_morpho_cols = [col for col in morpho_cols if col in markers_df.columns]
        present_intensity_cols = [
            col for col in channels_to_quantify if col in markers_df.columns
        ]
        final_cols = present_morpho_cols + present_intensity_cols
        markers_df = markers_df.reindex(columns=final_cols)

        markers_df.reset_index(inplace=True)  # Make 'label' a column

        output_file.parent.mkdir(parents=True, exist_ok=True)
        markers_df.to_csv(output_file, index=False)
        print(
            f"Successfully saved features for {len(markers_df)} objects to {output_file}"
        )

    except Exception as e:
        print(f"Error combining features or saving CSV {output_file}: {e}")
        if props_df is not None:
            try:
                morpho_output_file = (
                    output_file.parent / f"{output_file.stem}_morphology_only.csv"
                )
                # Save the filtered props_df
                props_df.reset_index().to_csv(morpho_output_file, index=False)
                print(
                    f"Saved morphology-only features (after size filter) to {morpho_output_file}"
                )
            except Exception as save_e:
                print(f"Could not save morphology-only features: {save_e}")
        return None
    finally:
        if "props_df" in locals() and props_df is not None:
            del props_df
        if "mean_df" in locals() and mean_df is not None:
            del mean_df
        gc.collect()

    print("--- Feature Extraction Complete ---")
    return markers_df


# Make sure morpho_cols is defined (it was at the end of the first cell)
morpho_cols = [
    "y",
    "x",
    "area",
    "eccentricity",
    "perimeter",
    "convex_area",
    "axis_major_length",
    "axis_minor_length",
]


# --------------------------------------------------------------------------
# Main Segmentation Function
# --------------------------------------------------------------------------


def _perform_segmentation(
    seg_dict,  # Dict with channels needed for this tile/image
    seg_method,
    output_dir,  # For potential saving inside sub-functions
    nuclei_channel_name,
    membrane_channel_name,  # Name of combined membrane channel or None
    cytoplasm_channel_name,  # Name of cytoplasm channel or None (Cellpose only)
    compartment,  # 'whole-cell' or 'nuclear' (Mesmer)
    plot_predictions,  # For Mesmer plotting
    model_path,  # For Mesmer model download/load
    use_gpu,  # For Cellpose/TF
    model,  # Cellpose model name/path
    custom_model,  # Cellpose flag
    diameter,  # Cellpose diameter
    save_mask_as_png,  # Cellpose flag
    image_mpp=0.5,  # For Mesmer
):
    """Internal helper function to perform cell segmentation using either Mesmer or Cellpose.

    Parameters
    ----------
    seg_dict : dict
        Dictionary containing channel images needed for segmentation
    seg_method : str
        Segmentation method to use ('mesmer' or 'cellpose')
    output_dir : str or Path
        Directory for saving output files from segmentation
    nuclei_channel_name : str
        Name of the nuclei channel in seg_dict
    membrane_channel_name : str or None
        Name of the combined membrane channel in seg_dict, or None
    cytoplasm_channel_name : str or None
        Name of cytoplasm channel in seg_dict (Cellpose only), or None
    compartment : str
        Segmentation compartment for Mesmer ('whole-cell' or 'nuclear')
    plot_predictions : bool
        Whether to plot Mesmer predictions
    model_path : str or Path
        Path for Mesmer model download/loading
    use_gpu : bool
        Whether to use GPU acceleration
    model : str
        Cellpose model name or path to custom model
    custom_model : bool
        Whether model parameter is a path to custom Cellpose model
    diameter : float or None
        Expected cell diameter in pixels for Cellpose
    save_mask_as_png : bool
        Whether to save Cellpose overlay as PNG
    image_mpp : float, optional
        Microns per pixel for Mesmer, by default 0.5

    Returns
    -------
    ndarray or None
        2D integer-labeled segmentation mask if successful, None otherwise

    Notes
    -----
    The function handles both Mesmer and Cellpose segmentation methods:
    - For Mesmer: Requires nuclei image, optional membrane image
    - For Cellpose: Requires nuclei image, optional membrane/cytoplasm images

    The output mask is guaranteed to be 2D and of type int32 if successful.
    GPU memory is cleared after segmentation if GPU was used.
    """

    mask = None
    try:
        if seg_method == "mesmer":
            membrane_img = (
                seg_dict.get(membrane_channel_name) if membrane_channel_name else None
            )
            # Ensure nuclei image exists
            if nuclei_channel_name not in seg_dict:
                print(
                    f"Error: Nuclei channel '{nuclei_channel_name}' missing in seg_dict for Mesmer."
                )
                return None
            mask = mesmer_segmentation(
                nuclei_image=seg_dict[nuclei_channel_name],
                membrane_image=membrane_img,
                image_mpp=image_mpp,
                plot_predictions=plot_predictions,  # Plotting handled per-tile if called in loop
                compartment=compartment,
                model_path=model_path,
            )

        elif seg_method == "cellpose":
            # Note: cellpose_segmentation expects the *name* of the channels in seg_dict
            # It also needs the actual image data within seg_dict
            mask, _, _ = cellpose_segmentation(
                image_dict=seg_dict,  # Pass the dict containing tile images
                output_dir=output_dir,
                membrane_channel_name=membrane_channel_name,  # Pass name
                cytoplasm_channel_name=cytoplasm_channel_name,  # Pass name if available
                nucleus_channel_name=nuclei_channel_name,  # Pass name
                use_gpu=use_gpu,
                model=model,
                custom_model=custom_model,
                diameter=diameter,
                save_mask_as_png=save_mask_as_png,
            )
        else:
            print(f"Error: Unsupported segmentation method: {seg_method}")
            return None

        if mask is None:
            print(f"Warning: {seg_method} returned None.")
            return None

        # Ensure mask is 2D and integer type
        if mask.ndim > 2:
            mask = np.squeeze(mask)
        if mask.ndim != 2:
            print(
                f"Warning: Segmentation produced unexpected mask dimension {mask.ndim}. Expected 2D."
            )
            return None
        mask = mask.astype(np.int32)
        return mask

    except Exception as e:
        print(f"Error during _perform_segmentation ({seg_method}): {e}")
        # Print traceback for debugging
        import traceback

        traceback.print_exc()
        return None
    finally:
        # Clean up GPU memory if TF/Cellpose was used
        if use_gpu and (seg_method == "mesmer" or seg_method == "cellpose"):
            try:
                tf.keras.backend.clear_session()
                gc.collect()
            except Exception as clear_e:
                print(f"Warning: Error clearing TF session: {clear_e}")


[docs] def cell_segmentation( file_name, # Path to image file or directory channel_file, # Path to channel names file (if not input_format=="Channels") output_dir, # Base directory for outputs output_fname="", # Basename for output files seg_method="mesmer", # 'mesmer' or 'cellpose' nuclei_channel="DAPI", # Name of the nucleus channel input_format="Multichannel", # 'Multichannel', 'Channels', 'CODEX' membrane_channel_list=None, # List of channel names for membrane/whole-cell seg cytoplasm_channel_list=None, # List of channel names for cytoplasm (Cellpose only) size_cutoff=0, # Min object size for feature extraction compartment="whole-cell", # Mesmer: 'whole-cell' or 'nuclear'. Cellpose: Ignored. plot_predictions=False, # Plot Mesmer predictions model="cyto3", # Cellpose model name or path use_gpu=True, # Use GPU if available diameter=None, # Cellpose cell diameter estimate save_mask_as_png=False, # Save Cellpose overlay PNGs model_path="./models", # Path for Mesmer model download/load resize_factor=1, # Factor to resize images before segmentation custom_model=False, # True if 'model' is a path to a custom Cellpose model differentiate_nucleus_cytoplasm=False, # Perform separate Nuc/Whole seg tile_size=4096, # Tile size for segmentation tile_overlap=128, # Overlap between segmentation tiles tiling_threshold=5000, # Use tiling if H and W exceed this threshold image_mpp=0.5, # Microns per pixel (primarily for Mesmer) stitch_sigma=64, # Sigma for Gaussian blending during stitching remove_tile_border_objects=True, # Remove objects touching tile borders before stitching # Feature extraction parameters embedded feature_tile_size=4096, feature_tile_overlap=128, feature_memory_limit_gb=8, set_memory_growth=True, ): """Perform cell segmentation using Mesmer or Cellpose with optional tiling and feature extraction. This function implements a complete segmentation pipeline including image loading, preprocessing, segmentation, mask stitching, and feature extraction. It handles large images through tiling and provides memory-optimized processing. Parameters ---------- file_name : str or Path Path to input image file or directory (Multichannel = multichannel TIFF, Channels = single-channel TIFFs in a directory, CODEX = CODEX format with channels, cycles, y, x) channel_file : str or Path Path to channel names file (ignored if input_format=="Channels") output_dir : str or Path Base directory for output files output_fname : str, optional Basename for output files, by default auto-generated seg_method : {'mesmer', 'cellpose'}, optional Segmentation algorithm to use, by default 'mesmer' nuclei_channel : str, optional Name of the nuclei channel, by default 'DAPI' input_format : {'Multichannel', 'Channels', 'CODEX'}, optional Format of input data, by default 'Multichannel' membrane_channel_list : list of str, optional Channel names for membrane/whole-cell segmentation cytoplasm_channel_list : list of str, optional Channel names for cytoplasm (Cellpose only) size_cutoff : int, optional Minimum object size in pixels for feature extraction compartment : {'whole-cell', 'nuclear'}, optional Segmentation compartment for Mesmer (ignored by Cellpose) plot_predictions : bool, optional Whether to plot Mesmer predictions model : str, optional Model name or path for Cellpose use_gpu : bool, optional Whether to use GPU acceleration diameter : float, optional Expected cell diameter for Cellpose in pixels (setting a value is recommended to speed up segmentation significantly - if you are unsure you can measure the average cell diameter in ImageJ) save_mask_as_png : bool, optional Save Cellpose overlay as PNG model_path : str or Path, optional Path for Mesmer model download/load resize_factor : float, optional Factor to resize images before segmentation custom_model : bool, optional Whether 'model' is a path to custom Cellpose model differentiate_nucleus_cytoplasm : bool, optional Perform separate nuclear and whole-cell segmentation tile_size : int, optional Size of tiles for segmentation in pixels tile_overlap : int, optional Overlap between adjacent tiles in pixels tiling_threshold : int, optional Image size threshold to enable tiling image_mpp : float, optional Microns per pixel (for Mesmer) stitch_sigma : float, optional Sigma for Gaussian blending during stitching remove_tile_border_objects : bool, optional Remove objects touching tile borders feature_tile_size : int, optional Tile size for feature extraction feature_tile_overlap : int, optional Overlap for feature extraction tiles feature_memory_limit_gb : float, optional Memory limit per channel for feature extraction set_memory_growth : bool, optional Enable TensorFlow memory growth Returns ------- dict or None Dictionary containing: - 'img_ref': Reference image - 'image_dict': Channel images - 'masks': Primary segmentation mask - 'masks_nuclei': Nuclear mask (if differentiated) - 'masks_cytoplasm': Cytoplasm mask (if differentiated) - 'features': DataFrame of extracted features - 'features_nuclei/cytoplasm/whole_cell': Region-specific features - 'features_combined': Combined features from all regions Returns None on critical error Notes ----- Memory optimization strategies: - Tiling for large image segmentation - Memory-efficient feature extraction - Optional GPU memory growth - Cleanup of intermediate arrays The pipeline includes: 1. Image loading and preprocessing 2. Segmentation (tiled or full image) 3. Mask post-processing and stitching 4. Feature extraction and combination """ results = {} output_dir = pathlib.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # --- 1. Initial Setup --- print("--- Initializing Segmentation Pipeline ---") use_gpu = setup_gpu(use_gpu, set_memory_growth) # Setup GPU and update status if not output_fname: base = ( pathlib.Path(file_name).stem if input_format != "Channels" else pathlib.Path(file_name).name ) output_fname = f"{base}_{seg_method}" print(f"Output basename: {output_fname}") print(f"Segmentation method: {seg_method}") print(f"Differentiate Nucleus/Cytoplasm: {differentiate_nucleus_cytoplasm}") # --- 2. Load Images --- # Memory bottleneck potential here for very large files img_ref, image_dict, channel_names = load_image_dictionary( file_name, channel_file, input_format, nuclei_channel ) if image_dict is None: return None # Error during loading results["img_ref"] = img_ref # Store reference (could be large) results["image_dict"] = image_dict # Holds all channel arrays (potentially large) original_shape = image_dict[nuclei_channel].shape # Shape before any resizing # --- 3. Prepare Images for Segmentation --- print("\n--- Preparing Segmentation Inputs ---") # Creates a dict with only nuclei and optionally combined membrane channel segmentation_dict_prep, combined_membrane_channel_name = prepare_segmentation_dict( image_dict, nuclei_channel, membrane_channel_list ) # Check if cytoplasm channel exists for Cellpose (if needed later) # Assuming cytoplasm_channel_list contains only one name for simplicity here cytoplasm_channel_name = ( cytoplasm_channel_list[0] if cytoplasm_channel_list else None ) if cytoplasm_channel_name and cytoplasm_channel_name not in image_dict: print( f"Warning: Specified cytoplasm channel '{cytoplasm_channel_name}' not found." ) cytoplasm_channel_name = None # Disable if not found # Resize if necessary (creates new, potentially large, arrays) segmentation_dict_resized = resize_segmentation_images( segmentation_dict_prep, resize_factor ) del segmentation_dict_prep gc.collect() # Clean up intermediate dict full_shape_seg = segmentation_dict_resized[ nuclei_channel ].shape # Shape used for segmentation print(f"Image shape for segmentation: {full_shape_seg}") # --- 4. Determine Tiling for Segmentation --- use_tiling_seg = ( tile_size is not None and full_shape_seg[0] > tiling_threshold and full_shape_seg[1] > tiling_threshold ) if use_tiling_seg: print( f"\n--- Tiling Enabled for Segmentation (Threshold: {tiling_threshold}px, Tile: {tile_size}, Overlap: {tile_overlap}) ---" ) tiles_info = generate_tiles(full_shape_seg, tile_size, tile_overlap) else: print("\n--- Processing Full Image for Segmentation ---") tiles_info = None # --- 5. Perform Segmentation --- final_masks = {} # Store final, full-sized masks try: if differentiate_nucleus_cytoplasm: print("\n--- Segmentation Mode: Differentiate Nucleus/Cytoplasm ---") if not combined_membrane_channel_name: print("Error: Membrane channels must be provided for differentiation.") return None # Define tasks: Nuclear and Whole-Cell segmentation_tasks = { "nuclei": { "compartment": "nuclear", "membrane": None, "cytoplasm": None, }, "whole_cell": { "compartment": "whole-cell", "membrane": combined_membrane_channel_name, "cytoplasm": cytoplasm_channel_name, }, } task_results_resized = {} # Store masks at segmentation resolution for task_name, task_params in segmentation_tasks.items(): print(f"\n--- Running {task_name.upper()} Segmentation ---") task_tile_masks = [] completed_tiles_indices = set() if use_tiling_seg: print(f"Processing {len(tiles_info)} tiles for {task_name}...") for idx, (y_start, y_end, x_start, x_end) in enumerate(tiles_info): display_tile_progress( tiles_info, completed_tiles_indices, full_shape_seg, current_tile_index=idx, ) # Crop tile from the *resized* segmentation dict tile_seg_dict = { ch: im[y_start:y_end, x_start:x_end] for ch, im in segmentation_dict_resized.items() } tile_mask = _perform_segmentation( seg_dict=tile_seg_dict, seg_method=seg_method, output_dir=output_dir / f"tile_{idx}_{task_name}", nuclei_channel_name=nuclei_channel, membrane_channel_name=task_params["membrane"], cytoplasm_channel_name=task_params[ "cytoplasm" ], # Pass cyto name compartment=task_params["compartment"], plot_predictions=plot_predictions, model_path=model_path, use_gpu=use_gpu, model=model, custom_model=custom_model, diameter=diameter, save_mask_as_png=save_mask_as_png, image_mpp=image_mpp, ) if tile_mask is None: print( f"Warning: Tile {idx} segmentation failed for {task_name}. Skipping tile." ) # Append None or an empty mask? Append None for stitching robustness. task_tile_masks.append(None) else: # Optional: Resize mask back to tile's input size if needed (shouldn't be necessary if seg returns correct size) # tile_mask_resized = resize_mask(tile_mask, tile_seg_dict[nuclei_channel].shape) if remove_tile_border_objects: tile_mask = remove_border_objects(tile_mask) task_tile_masks.append(tile_mask) completed_tiles_indices.add(idx) del tile_seg_dict, tile_mask gc.collect() # Clean up tile data # time.sleep(0.01) # Small delay for display update display_tile_progress( tiles_info, completed_tiles_indices, full_shape_seg ) # Show final progress print(f"\nStitching {task_name} masks...") stitched_mask = stitch_masks( tiles_info, task_tile_masks, full_shape_seg, tile_overlap, sigma=stitch_sigma, ) task_results_resized[task_name] = stitched_mask del task_tile_masks, stitched_mask gc.collect() else: # No tiling for segmentation print(f"Processing full image for {task_name}...") full_mask = _perform_segmentation( seg_dict=segmentation_dict_resized, # Use the resized dict seg_method=seg_method, output_dir=output_dir / f"full_{task_name}", nuclei_channel_name=nuclei_channel, membrane_channel_name=task_params["membrane"], cytoplasm_channel_name=task_params["cytoplasm"], compartment=task_params["compartment"], plot_predictions=plot_predictions, model_path=model_path, use_gpu=use_gpu, model=model, custom_model=custom_model, diameter=diameter, save_mask_as_png=save_mask_as_png, image_mpp=image_mpp, ) if full_mask is None: raise RuntimeError( f"Full image segmentation failed for {task_name}." ) task_results_resized[task_name] = full_mask del full_mask gc.collect() # Resize final masks back to original image size print("Resizing final masks to original image shape...") final_masks["masks_nuclei"] = resize_mask( task_results_resized["nuclei"], original_shape ) final_masks["masks"] = resize_mask( task_results_resized["whole_cell"], original_shape ) # 'masks' holds whole-cell # Calculate Cytoplasm Mask print("Calculating cytoplasm masks...") if ( final_masks["masks_nuclei"] is not None and final_masks["masks"] is not None ): binary_masks_nuclei = final_masks["masks_nuclei"] > 0 binary_masks_whole_cell = final_masks["masks"] > 0 binary_masks_cytoplasm = binary_masks_whole_cell & ( ~binary_masks_nuclei ) # Relabel cytoplasm mask masks_cytoplasm_labeled, num_labels = label(binary_masks_cytoplasm) final_masks["masks_cytoplasm"] = masks_cytoplasm_labeled.astype( np.int32 ) print(f"Created cytoplasm mask with {num_labels} labeled objects.") del ( binary_masks_nuclei, binary_masks_whole_cell, binary_masks_cytoplasm, masks_cytoplasm_labeled, ) else: print( "Warning: Could not calculate cytoplasm mask due to missing nuclei or whole-cell mask." ) final_masks["masks_cytoplasm"] = None del task_results_resized gc.collect() else: # Standard (non-differentiated) segmentation print("\n--- Segmentation Mode: Standard ---") current_membrane_channel = combined_membrane_channel_name current_compartment = compartment if current_membrane_channel else "nuclear" if not current_membrane_channel: print( "Performing nuclear-only segmentation (no membrane channels provided)." ) task_tile_masks = [] completed_tiles_indices = set() primary_mask_resized = None # Mask at segmentation resolution if use_tiling_seg: print(f"Processing {len(tiles_info)} tiles...") for idx, (y_start, y_end, x_start, x_end) in enumerate(tiles_info): display_tile_progress( tiles_info, completed_tiles_indices, full_shape_seg, current_tile_index=idx, ) tile_seg_dict = { ch: im[y_start:y_end, x_start:x_end] for ch, im in segmentation_dict_resized.items() } tile_mask = _perform_segmentation( seg_dict=tile_seg_dict, seg_method=seg_method, output_dir=output_dir / f"tile_{idx}", nuclei_channel_name=nuclei_channel, membrane_channel_name=current_membrane_channel, cytoplasm_channel_name=cytoplasm_channel_name, compartment=current_compartment, plot_predictions=plot_predictions, model_path=model_path, use_gpu=use_gpu, model=model, custom_model=custom_model, diameter=diameter, save_mask_as_png=save_mask_as_png, image_mpp=image_mpp, ) if tile_mask is None: print( f"Warning: Tile {idx} segmentation failed. Skipping tile." ) task_tile_masks.append(None) else: if remove_tile_border_objects: tile_mask = remove_border_objects(tile_mask) task_tile_masks.append(tile_mask) completed_tiles_indices.add(idx) del tile_seg_dict, tile_mask gc.collect() # time.sleep(0.01) display_tile_progress( tiles_info, completed_tiles_indices, full_shape_seg ) print("\nStitching masks...") primary_mask_resized = stitch_masks( tiles_info, task_tile_masks, full_shape_seg, tile_overlap, sigma=stitch_sigma, ) del task_tile_masks gc.collect() else: # No tiling for segmentation print("Processing full image...") primary_mask_resized = _perform_segmentation( seg_dict=segmentation_dict_resized, seg_method=seg_method, output_dir=output_dir / "full", nuclei_channel_name=nuclei_channel, membrane_channel_name=current_membrane_channel, cytoplasm_channel_name=cytoplasm_channel_name, compartment=current_compartment, plot_predictions=plot_predictions, model_path=model_path, use_gpu=use_gpu, model=model, custom_model=custom_model, diameter=diameter, save_mask_as_png=save_mask_as_png, image_mpp=image_mpp, ) if primary_mask_resized is None: raise RuntimeError("Full image segmentation failed.") # Resize final mask back to original image size print("Resizing final mask to original image shape...") final_masks["masks"] = resize_mask(primary_mask_resized, original_shape) del primary_mask_resized gc.collect() except Exception as e: print(f"An error occurred during the segmentation stage: {e}") # Clean up potentially large intermediate data del segmentation_dict_resized if "task_results_resized" in locals(): del task_results_resized if "task_tile_masks" in locals(): del task_tile_masks gc.collect() return None # Critical error # Clean up resized segmentation dictionary del segmentation_dict_resized gc.collect() # --- 6. Feature Extraction --- print("\n--- Extracting Features ---") # Use the *original* image_dict (full resolution) and *final, original-size* masks try: if differentiate_nucleus_cytoplasm: if all( k in final_masks and final_masks[k] is not None for k in ["masks_nuclei", "masks_cytoplasm", "masks"] ): print("Quantifying features for Nuclei, Cytoplasm, and Whole Cell...") features = {} for region, mask_key in [ ("nuclei", "masks_nuclei"), ("cytoplasm", "masks_cytoplasm"), ("whole_cell", "masks"), ]: # 'masks' is whole_cell here print(f" Quantifying {region}...") output_file = output_dir / f"{output_fname}_{region}_features.csv" features[region] = extract_features( image_dict=image_dict, segmentation_masks=final_masks[mask_key], channels_to_quantify=channel_names, output_file=output_file, size_cutoff=size_cutoff, use_tiling_for_intensity=True, # Enable intensity tiling tile_size=feature_tile_size, tile_overlap=feature_tile_overlap, memory_limit_gb=feature_memory_limit_gb, ) if features[region] is not None: print(f" Saved {region} features to {output_file}") results[f"features_{region}"] = features[region] else: print(f" Feature extraction failed for {region}.") # Combine features if all parts were successful if all( f in features and features[f] is not None for f in ["nuclei", "cytoplasm", "whole_cell"] ): print("Combining features...") try: # Use whole_cell features as the base for metadata and labels base_features = features["whole_cell"].copy() base_features.set_index( "label", inplace=True ) # Ensure label is index # Prepare intensity columns with suffixes nuc_int = ( features["nuclei"] .set_index("label") .drop(columns=morpho_cols, errors="ignore") ) cyto_int = ( features["cytoplasm"] .set_index("label") .drop(columns=morpho_cols, errors="ignore") ) # whole_int = base_features.drop(columns=morpho_cols, errors='ignore') # Already have whole cell intensities nuc_int.columns = [f"{col}_nuc" for col in nuc_int.columns] cyto_int.columns = [f"{col}_cyto" for col in cyto_int.columns] # whole_int.columns = [f"{col}_whole" for col in whole_int.columns] # Rename base intensities # Join based on label index combined = base_features.join([nuc_int, cyto_int], how="left") combined.reset_index(inplace=True) # Make label a column again # Save combined features combined_output_file = ( output_dir / f"{output_fname}_combined_features.csv" ) combined.to_csv(combined_output_file, index=False) print(f"Saved combined features to {combined_output_file}") results["features_combined"] = combined except Exception as e: print(f"Warning: Failed to combine features: {e}") else: print("Skipping feature combination due to missing feature sets.") else: print( "Warning: Missing required masks for differentiated feature extraction." ) else: # Standard segmentation if "masks" in final_masks and final_masks["masks"] is not None: print("Quantifying features for segmented objects...") output_file = output_dir / f"{output_fname}_features.csv" features_df = extract_features( image_dict=image_dict, segmentation_masks=final_masks["masks"], channels_to_quantify=channel_names, output_file=output_file, size_cutoff=size_cutoff, use_tiling_for_intensity=True, # Enable intensity tiling tile_size=feature_tile_size, tile_overlap=feature_tile_overlap, memory_limit_gb=feature_memory_limit_gb, ) if features_df is not None: print(f"Saved features to {output_file}") results["features"] = features_df else: print("Feature extraction failed.") else: print("Warning: No final mask found for feature extraction.") except Exception as e: print(f"An error occurred during feature extraction: {e}") # Continue to return masks even if features fail # --- 7. Final Return --- print("\n--- Segmentation Pipeline Complete ---") results.update( final_masks ) # Add final masks ('masks', 'masks_nuclei', 'masks_cytoplasm') # Clean up large image_dict before returning? Optional. # del results['image_dict']; gc.collect() return results