Source code for spacec.tools._general

# load required packages
from __future__ import annotations

# Filter out specific FutureWarnings from anndata and tm
import warnings

warnings.filterwarnings("ignore", category=FutureWarning, module="anndata.utils")
import logging

logging.disable(
    logging.CRITICAL
)  # disables all logging messages at and below CRITICAL level

import os
import platform
import subprocess
import sys
import tempfile
import zipfile

import requests

# TODO: Remove this!
if platform.system() == "Windows":
    vipsbin = r"c:\vips-dev-8.15\bin\vips-dev-8.15\bin"
    vips_file_path = os.path.join(vipsbin, "vips.exe")

    # Check if VIPS is installed
    if not os.path.exists(vips_file_path):
        # VIPS is not installed, download and extract it
        url = "https://github.com/libvips/build-win64-mxe/releases/download/v8.15.2/vips-dev-w64-all-8.15.2.zip"
        zip_file_path = "vips-dev-w64-all-8.15.2.zip"
        response = requests.get(url, stream=True)

        if response.status_code == 200:
            with open(zip_file_path, "wb") as f:
                f.write(response.raw.read())

            # Extract the zip file
            with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
                zip_ref.extractall(vipsbin)
        else:
            print("Error downloading the file.")

        # Install pyvips
        subprocess.check_call([sys.executable, "-m", "pip", "install", "pyvips"])

    # Add vipsbin to the DLL search path or PATH environment variable
    add_dll_dir = getattr(os, "add_dll_directory", None)
    os.environ["PATH"] = os.pathsep.join((vipsbin, os.environ["PATH"]))


import argparse
import json
import pathlib
import pickle
import re
import time
from builtins import range
from multiprocessing import Pool
from typing import TYPE_CHECKING

import geopandas as gpd
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import panel as pn
import scipy.stats as st
import skimage.filters.rank
import tissuumaps.jupyter as tj
import torch
from concave_hull import concave_hull_indexes
from joblib import Parallel, delayed
from matplotlib.patches import Patch
from pyFlowSOM import map_data_to_nodes, som
from scipy import stats
from scipy.spatial import Delaunay
from scipy.spatial.distance import cdist
from shapely.geometry import MultiPolygon
from skimage.io import imsave
from skimage.segmentation import find_boundaries
from sklearn.calibration import CalibratedClassifierCV
from sklearn.cluster import HDBSCAN, MiniBatchKMeans
from sklearn.cross_decomposition import CCA
from sklearn.metrics import classification_report, f1_score, pairwise_distances
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.svm import SVC, LinearSVC
from tqdm import tqdm
from yellowbrick.cluster import KElbowVisualizer

if TYPE_CHECKING:
    from anndata import AnnData

from ..helperfunctions._general import *
from ..plotting._general import catplot

try:
    import cupy as cp
    import rapids_singlecell as rsc
    from cupyx.scipy.sparse import csc_matrix as csc_matrix_gpu
    from cupyx.scipy.sparse import csr_matrix as csr_matrix_gpu
    from cupyx.scipy.sparse import isspmatrix_csc as isspmatrix_csc_gpu
    from cupyx.scipy.sparse import isspmatrix_csr as isspmatrix_csr_gpu
    from scanpy.get import _get_obs_rep, _set_obs_rep
    from scipy.sparse import isspmatrix_csc as isspmatrix_csc_cpu
    from scipy.sparse import isspmatrix_csr as isspmatrix_csr_cpu
except ImportError:
    pass

# Tools
############################################################


def tl_calculate_neigh_combs(w, l, n_num, threshold=0.85, per_keep_thres=0.85):
    """
    Calculate neighborhood combinations based on a threshold.

    Parameters
    ----------
    w : DataFrame
        DataFrame containing the data.
    l : list
        List of column names to be used.
    n_num : int
        Number of neighborhoods or k chosen for the neighborhoods.
    threshold : float, optional
        Threshold for neighborhood combinations, by default 0.85.
    per_keep_thres : float, optional
        Percent to keep threshold or percent of neighborhoods that fall above a certain threshold, by default 0.85.

    Returns
    -------
    tuple
        A tuple containing:
        - simps: Series of neighborhood combinations.
        - simp_freqs: Series of frequency counts of the combinations.
        - simp_sums: Series of cumulative sums of the frequency counts.

    """
    w.loc[:, l]

    # need to normalize by number of neighborhoods or k chosen for the neighborhoods
    xm = w.loc[:, l].values / n_num

    # Get the neighborhood combinations based on the threshold
    simps = hf_get_thresh_simps(xm, threshold)
    simp_freqs = simps.value_counts(normalize=True)
    simp_sums = np.cumsum(simp_freqs)

    # See the percent to keep threshold or percent of neigbhorhoods that fall above a certain threshold
    test_sums_thres = simp_sums[simp_sums < per_keep_thres]
    test_len = len(test_sums_thres)
    per_values_above = simp_sums[test_len] - simp_sums[test_len - 1]
    print(test_len, per_values_above)

    w["combination"] = [tuple(l[a] for a in s) for s in simps]
    w["combination_num"] = [tuple(a for a in s) for s in simps]

    # this shows what proportion (y) of the total cells are assigned to the top x combinations
    plt.figure(figsize=(7, 3))
    plt.plot(simp_sums.values)
    plt.title(
        "proportion (y) of the total cells are assigned to the top x combinations"
    )
    plt.show()

    # this shows what proportion (y) of the total cells are assigned to the top x combinations
    plt.figure(figsize=(7, 3))
    plt.plot(test_sums_thres.values)
    plt.title(
        "proportion (y) of the total cells are assigned to the top x combinations - thresholded"
    )
    plt.show()
    # plt.xticks(range(0,350,35),range(0,350,35),rotation = 90,fontsize = 10)

    return (simps, simp_freqs, simp_sums)


def tl_build_graph_CN_comb_map(simp_freqs, thresh_freq=0.001):
    """
    Build a directed graph for the CN combination map.

    Parameters
    ----------
    simp_freqs : pandas.Series
        A series containing the frequencies of simplices.
    thresh_freq : float, optional
        The threshold frequency to filter simplices, by default 0.001.

    Returns
    -------
    tuple
        A tuple containing:
        - g : networkx.DiGraph
            The directed graph with edges representing the CN combination map.
        - tops : list
            A list of the top 20 simplices sorted by frequency.
        - e0 : str
            The last simplex in the outer loop.
        - e1 : str
            The last simplex in the inner loop.
    """
    g = nx.DiGraph()

    # selected_simps = simp_sums[simp_sums<=thresh_cumulative].index.values
    selected_simps = simp_freqs[simp_freqs >= thresh_freq].index.values
    selected_simps

    """
    this builds the graph for the CN combination map
    """
    for e0 in selected_simps:
        for e1 in selected_simps:
            if (set(list(e0)) < set(list(e1))) and (len(e1) == len(e0) + 1):
                g.add_edge(e0, e1)

    tops = (
        simp_freqs[simp_freqs >= thresh_freq]
        .sort_values(ascending=False)
        .index.values.tolist()[:20]
    )

    return (g, tops, e0, e1)


[docs] def clustering( adata, clustering="leiden", marker_list=None, resolution=1, n_neighbors=10, reclustering=False, key_added=None, key_filter=None, subset_cluster=None, seed=42, fs_xdim=10, fs_ydim=10, fs_rlen=10, # FlowSOM parameters **cluster_kwargs, ): """ Perform clustering on the given annotated data matrix. Parameters ---------- adata : AnnData The annotated data matrix of shape n_obs x n_vars. Rows correspond to cells and columns to stained markers. clustering : str, optional The clustering algorithm to use. Options are "leiden" or "louvain". Defaults to "leiden". marker_list : list, optional A list of markers for clustering. Defaults to None. resolution : int, optional The resolution for the clustering algorithm. Defaults to 1. n_neighbors : int, optional The number of neighbors to use for the neighbors graph. Defaults to 10. reclustering : bool, optional If set to True, the function will skip the calculation of neighbors and UMAP. This can be used to speed up the process when just reclustering or running flowSOM. key_added : str, optional The key name to add to the adata object. Defaults to None. key_filter : str, optional The key name to filter the adata object. Defaults to None. subset_cluster : list, optional The list of clusters to subset. Defaults to None. seed : int, optional Seed for random state. Default is 42. fs_xdim : int, optional X dimension for FlowSOM. Default is 10. fs_ydim : int, optional Y dimension for FlowSOM. Default is 10. fs_rlen : int, optional Rlen for FlowSOM. Default is 10. **cluster_kwargs : dict Additional keyword arguments for the clustering function. Returns ------- AnnData The annotated data matrix with the clustering results added. """ if clustering not in ["leiden", "louvain", "leiden_gpu", "flowSOM"]: print( "Invalid clustering options. Please select from leiden, louvain, leiden_gpu or flowSOM!" ) print("For GPU accelerated leiden clustering, please use leiden_gpu") sys.exit() # test if rapids_singlecell is available if clustering == "leiden_gpu": try: import cudf import cuml import cupy import rapids_singlecell as rsc except ImportError: print("Please install rapids_singlecell to use leiden_gpu!") print("install_gpu_leiden(CUDA = your cuda version as string)") print("For example: sp.tl.install_gpu_leiden(CUDA = '12')") print("THIS FUNCTION DOES NOT WORK ON MacOS OR WINDOWS") print("using leiden instead of leiden_gpu") clustering = "leiden" if key_added is None: key_added = clustering + "_" + str(resolution) if key_filter is not None: if subset_cluster is None: print("Please provide subset_cluster!") sys.exit() else: adata_tmp = adata adata = adata[adata.obs[key_filter].isin(subset_cluster)] # input a list of markers for clustering # reconstruct the anndata if marker_list is not None: if len(list(set(marker_list) - set(adata.var_names))) > 0: print("Marker list not all in adata var_names! Using intersection instead!") marker_list = list(set(marker_list) & set(adata.var_names)) print("New marker_list: " + " ".join(marker_list)) if key_filter is None: adata_tmp = adata adata = adata[:, marker_list] # Compute the neighborhood relations of single cells the range 2 to 100 and usually 10 if reclustering: if clustering == "leiden_gpu": print("Clustering on GPU") anndata_to_GPU(adata) # moves `.X` to the GPU rsc.tl.leiden( adata, resolution=resolution, key_added=key_added, random_state=seed, **cluster_kwargs, ) anndata_to_CPU(adata) # moves `.X` to the CPU else: print("Clustering") if clustering == "leiden": sc.tl.leiden( adata, resolution=resolution, key_added=key_added, random_state=seed, **cluster_kwargs, ) else: if clustering == "louvain": print("Louvain clustering") sc.tl.louvain( adata, resolution=resolution, key_added=key_added, random_state=seed, **cluster_kwargs, ) else: print("FlowSOM clustering") adata_df = pd.DataFrame( adata.X, index=adata.obs.index, columns=adata.var.index ) # df to numpy array som_input_arr = adata_df.to_numpy() # train the SOM node_output = som( som_input_arr, xdim=fs_xdim, ydim=fs_ydim, rlen=fs_rlen, seed=seed, ) # use trained SOM to assign clusters to each observation in your data clusters, dists = map_data_to_nodes(node_output, som_input_arr) clusters = pd.Categorical(clusters) # add cluster to adata adata.obs[key_added] = clusters else: if clustering == "leiden_gpu": anndata_to_GPU(adata) # moves `.X` to the GPU print("Computing neighbors and UMAP on GPU") rsc.pp.neighbors(adata, n_neighbors=n_neighbors) # UMAP computation rsc.tl.umap(adata) print("Clustering on GPU") # Perform leiden clustering - improved version of louvain clustering rsc.tl.leiden( adata, resolution=resolution, key_added=key_added, random_state=seed ) anndata_to_CPU(adata) # moves `.X` to the CPU else: print("Computing neighbors and UMAP") print("- neighbors") sc.pp.neighbors(adata, n_neighbors=n_neighbors) # UMAP computation print("- UMAP") sc.tl.umap(adata) print("Clustering") # Perform leiden clustering - improved version of louvain clustering if clustering == "leiden": print("Leiden clustering") sc.tl.leiden( adata, resolution=resolution, key_added=key_added, random_state=seed, **cluster_kwargs, ) else: if clustering == "louvain": print("Louvain clustering") sc.tl.louvain( adata, resolution=resolution, key_added=key_added, random_state=seed, **cluster_kwargs, ) else: print("FlowSOM clustering") adata_df = pd.DataFrame( adata.X, index=adata.obs.index, columns=adata.var.index ) # df to numpy array som_input_arr = adata_df.to_numpy() # train the SOM node_output = som( som_input_arr, xdim=fs_xdim, ydim=fs_ydim, rlen=fs_rlen, seed=seed, ) # use trained SOM to assign clusters to each observation in your data clusters, dists = map_data_to_nodes(node_output, som_input_arr) # make clusters a string clusters = clusters.astype(str) clusters = pd.Categorical(clusters) # add cluster to adata adata.obs[key_added] = clusters if key_filter is None: if marker_list is None: return adata else: adata_tmp.obs[key_added] = adata.obs[key_added].values # append other data adata_tmp.obsm = adata.obsm adata_tmp.obsp = adata.obsp adata_tmp.uns = adata.uns if key_filter is not None: original_df = adata_tmp.obs donor_df = adata.obs donor_df_cols = donor_df.loc[:, donor_df.columns != key_added].columns.tolist() # Perform the merge operation merged_df = pd.merge( original_df, donor_df, left_on=donor_df_cols, right_on=donor_df_cols, how="left", ) # Fill NA/NaN values in 'key_added' using the values from 'key_filter' merged_df[key_filter] = merged_df[key_filter].astype(str) merged_df[key_added] = merged_df[key_added].astype(str) merged_df.replace("nan", np.nan, inplace=True) merged_df[key_added].fillna(merged_df[key_filter], inplace=True) merged_df[key_filter] = merged_df[key_filter].astype("category") merged_df[key_added] = merged_df[key_added].astype("category") merged_df.index = merged_df.index.astype(str) # assign df as obs for adata_tmp adata_tmp.obs = merged_df return adata_tmp
[docs] def neighborhood_analysis( adata, unique_region, cluster_col, X="x", Y="y", k=35, n_neighborhoods=30, elbow=False, metric="distortion", ): """ Compute for Cellular neighborhoods (CNs). Parameters ---------- adata : AnnData Annotated data matrix. unique_region : str Each region is one independent CODEX image. cluster_col : str Columns to compute CNs on, typically 'celltype'. X : str, optional X coordinate column name, by default "x". Y : str, optional Y coordinate column name, by default "y". k : int, optional Number of neighbors to compute, by default 35. n_neighborhoods : int, optional Number of neighborhoods one ends up with, by default 30. elbow : bool, optional Whether to test for optimal number of clusters and visualize as elbow plot or not, by default False. If set to True, the function will test 1 to n_neighborhoods and plot the distortion score in an elbow plot to assist the user in finding the optimal number of clusters. metric : str, optional The metric to use when calculating distance between instances in a feature array, by default "distortion". Other options include "silhouette" and "calinski_harabasz". Returns ------- AnnData Annotated data matrix with updated neighborhood information. Notes ----- The function performs the following steps: 1. Extracts relevant columns from the input AnnData object. 2. Computes dummy variables for the cluster column. 3. Groups data by the unique region and computes neighborhoods. 4. Optionally performs k-means clustering and visualizes the elbow plot if `elbow` is set to True. 5. Updates the input AnnData object with neighborhood labels and centroids. """ df = pd.DataFrame(adata.obs[[X, Y, cluster_col, unique_region]]) cells = pd.concat([df, pd.get_dummies(df[cluster_col])], axis=1) sum_cols = cells[cluster_col].unique() values = cells[sum_cols].values neighborhood_name = "CN" + "_k" + str(k) + "_n" + str(n_neighborhoods) centroids_name = "Centroid" + "_k" + str(k) + "_n" + str(n_neighborhoods) n_neighbors = k cells[unique_region] = cells[unique_region].astype("str") cells["cellid"] = cells.index.values cells.reset_index(inplace=True) keep_cols = [X, Y, unique_region, cluster_col] # Get each region tissue_group = cells[[X, Y, unique_region]].groupby(unique_region) exps = list(cells[unique_region].unique()) tissue_chunks = [ (time.time(), exps.index(t), t, a) for t, indices in tissue_group.groups.items() for a in np.array_split(indices, 1) ] tissues = [ hf_get_windows(job, n_neighbors, exps=exps, tissue_group=tissue_group, X=X, Y=Y) for job in tissue_chunks ] # Loop over k to compute neighborhoods out_dict = {} for neighbors, job in zip(tissues, tissue_chunks): chunk = np.arange(len(neighbors)) # indices tissue_name = job[2] indices = job[3] window = ( values[neighbors[chunk, :k].flatten()] .reshape(len(chunk), k, len(sum_cols)) .sum(axis=1) ) out_dict[(tissue_name, k)] = (window.astype(np.float16), indices) windows = {} window = pd.concat( [ pd.DataFrame( out_dict[(exp, k)][0], index=out_dict[(exp, k)][1].astype(int), columns=sum_cols, ) for exp in exps ], axis=0, ) window = window.loc[cells.index.values] window = pd.concat([cells[keep_cols], window], axis=1) windows[k] = window # Fill in based on above k_centroids = {} # producing what to plot windows2 = windows[k] windows2[cluster_col] = cells[cluster_col] if elbow != True: km = MiniBatchKMeans(n_clusters=n_neighborhoods, random_state=0) labels = km.fit_predict(windows2[sum_cols].values) k_centroids[str(k)] = km.cluster_centers_ adata.obs[neighborhood_name] = labels adata.uns[centroids_name] = k_centroids else: km = MiniBatchKMeans(random_state=0) X = windows2[sum_cols].values labels = km.fit_predict(X) k_centroids[str(k)] = km.cluster_centers_ adata.obs[neighborhood_name] = labels adata.uns[centroids_name] = k_centroids visualizer = KElbowVisualizer( km, k=(n_neighborhoods), timings=False, metric=metric ) visualizer.fit(X) # Fit the data to the visualizer visualizer.show() # Finalize and render the figure return adata
def build_cn_map( adata, cn_col, unique_region, palette=None, k=75, X="x", Y="y", threshold=0.85, per_keep_thres=0.85, sub_list=None, sub_col=None, rand_seed=1, ): """ Generate a cellular neighborhood (CN) map. Parameters ---------- adata : AnnData Annotated data matrix. cn_col : str Column name for cellular neighborhood. unique_region : str Unique region identifier. palette : dict, optional Color palette for the CN map, by default None. k : int, optional Number of neighbors to compute, by default 75. X : str, optional X coordinate column name, by default "x". Y : str, optional Y coordinate column name, by default "y". threshold : float, optional Threshold for neighborhood computation, by default 0.85. per_keep_thres : float, optional Threshold for keeping percentage, by default 0.85. sub_list : list, optional List of sub regions, by default None. sub_col : str, optional Column name for sub regions, by default None. rand_seed : int, optional Random seed for color generation, by default 1. Returns ------- dict Dictionary containing the graph, top nodes, edges and simplicial frequencies. """ ks = [k] cells_df = pd.DataFrame(adata.obs) cells_df = cells_df[[X, Y, unique_region, cn_col]] cells_df.reset_index(inplace=True) sum_cols = cells_df[cn_col].unique() keep_cols = cells_df.columns cn_colors = hf_generate_random_colors( len(adata.obs[cn_col].unique()), rand_seed=rand_seed ) if palette is None: if cn_col + "_colors" not in adata.uns.keys(): palette = dict(zip(np.sort(adata.obs[cn_col].unique()), cn_colors)) adata.uns[cn_col + "_colors"] = cn_colors else: palette = dict( zip(np.sort(adata.obs[cn_col].unique()), adata.uns[cn_col + "_colors"]) ) Neigh = Neighborhoods( cells_df, ks, cn_col, sum_cols, keep_cols, X, Y, reg=unique_region, add_dummies=True, ) windows = Neigh.k_windows() w = windows[k] if sub_list: # convert sub_list to list if only str is provided if isinstance(sub_list, str): sub_list = [sub_list] w = w[w[sub_col].isin(sub_list)] l = list(palette.keys()) simps, simp_freqs, simp_sums = tl_calculate_neigh_combs( w, l, k, threshold=threshold, per_keep_thres=per_keep_thres # color palette ) g, tops, e0, e1 = tl_build_graph_CN_comb_map(simp_freqs) return { "g": g, "tops": tops, "e0": e0, "e1": e1, "simp_freqs": simp_freqs, "w": w, "l": l, "k": k, "threshold": threshold, } def tl_format_for_squidpy(adata, x_col, y_col): """ Format an AnnData object for use with Squidpy. Parameters ---------- adata : AnnData Annotated data matrix. x_col : str Column name for x spatial coordinates. y_col : str Column name for y spatial coordinates. Returns ------- AnnData Annotated data matrix formatted for Squidpy, with spatial data in the 'obsm' attribute. """ # Validate input types if not isinstance(adata, ad.AnnData): raise TypeError("adata must be an AnnData object") if not isinstance(x_col, str) or not isinstance(y_col, str): raise TypeError("x_col and y_col must be strings") # Check if the columns exist in the 'obs' metadata if x_col not in adata.obs.columns or y_col not in adata.obs.columns: raise ValueError(f"Columns {x_col} and/or {y_col} not found in adata.obs") # Extract the count data from your original AnnData object counts = adata.X # Extract the spatial coordinates from the 'obs' metadata spatial_coordinates = adata.obs[[x_col, y_col]].values # Ensure spatial coordinates are numeric if not np.issubdtype(spatial_coordinates.dtype, np.number): raise ValueError("Spatial coordinates must be numeric") # Create a new AnnData object with the expected format new_adata = ad.AnnData(counts, obsm={"spatial": spatial_coordinates}) return new_adata def compute_triangulation_edges(df_input, x_pos, y_pos): """ Compute unique Delaunay triangulation edges from input coordinates. This function computes the Delaunay triangulation for the set of points defined by the x and y positions contained in a DataFrame. It then extracts all unique edges from the triangulation, calculates their Euclidean distances, and returns these as a new DataFrame. Parameters ---------- df_input : pandas.DataFrame DataFrame containing the coordinate data. x_pos : str The column name in df_input for the x-coordinate. y_pos : str The column name in df_input for the y-coordinate. Returns ------- pandas.DataFrame A DataFrame with columns: - ind1: Index of the first point in each edge. - ind2: Index of the second point in each edge. - x1: x-coordinate of the first point. - y1: y-coordinate of the first point. - x2: x-coordinate of the second point. - y2: y-coordinate of the second point. - distance: Euclidean distance between the two points. """ points = df_input[[x_pos, y_pos]].values tri = Delaunay(points) # Generate edges from triangles and remove duplicates edges = np.vstack( [tri.simplices[:, [0, 1]], tri.simplices[:, [1, 2]], tri.simplices[:, [2, 0]]] ) # Sort each edge so that [i, j] and [j, i] are considered the same edges = np.sort(edges, axis=1) # Remove duplicate edges edges = np.unique(edges, axis=0) # Vectorized distance computation x_coords = points[:, 0] y_coords = points[:, 1] ind1, ind2 = edges[:, 0], edges[:, 1] x1_arr, y1_arr = x_coords[ind1], y_coords[ind1] x2_arr, y2_arr = x_coords[ind2], y_coords[ind2] dist_arr = np.sqrt((x2_arr - x1_arr) ** 2 + (y2_arr - y1_arr) ** 2) edges_df = pd.DataFrame( { "ind1": ind1, "ind2": ind2, "x1": x1_arr, "y1": y1_arr, "x2": x2_arr, "y2": y2_arr, "distance": dist_arr, } ) return edges_df def annotate_triangulation_vectorized( edges_df, df_input, id_col, x_pos, y_pos, cell_type_col, region ): """ Annotate edges with cell metadata in a vectorized manner. This function takes the computed edges from the triangulation and annotates them with additional information retrieved from the input DataFrame. It creates both the forward and reverse (symmetrical) edges with cell identifiers, cell types, positions, and region info. Parameters ---------- edges_df : pandas.DataFrame DataFrame containing the triangulation edges and their distances. df_input : pandas.DataFrame DataFrame containing cell metadata. id_col : str The column name in df_input that serves as the cell identifier. x_pos : str The column name in df_input for the x-coordinate. y_pos : str The column name in df_input for the y-coordinate. cell_type_col : str The column name in df_input for cell type annotation. region : str The column name in df_input for region information. Returns ------- pandas.DataFrame A DataFrame containing annotated edges with the following columns: - region: The region identifier. - celltype1_index, celltype1, celltype1_X, celltype1_Y: Information for the first cell. - celltype2_index, celltype2, celltype2_X, celltype2_Y: Information for the second cell. - distance: The Euclidean distance between the two cells. """ if len(df_input[region].unique()) == 1: region_val = df_input[region].iloc[0] else: # In case of multiple regions, use the first region as annotation. region_val = df_input[region].iloc[0] # Convert needed columns to arrays for fast indexing id_array = df_input[id_col].values ct_array = df_input[cell_type_col].values x_array = df_input[x_pos].values y_array = df_input[y_pos].values # Build references from edges DataFrame ind1 = edges_df["ind1"].values ind2 = edges_df["ind2"].values x1_arr = edges_df["x1"].values y1_arr = edges_df["y1"].values x2_arr = edges_df["x2"].values y2_arr = edges_df["y2"].values dist_arr = edges_df["distance"].values # Create direct "forward" annotated DataFrame data_forward = pd.DataFrame( { region: [region_val] * len(ind1), "celltype1_index": id_array[ind1], "celltype1": ct_array[ind1], "celltype1_X": x1_arr, "celltype1_Y": y1_arr, "celltype2_index": id_array[ind2], "celltype2": ct_array[ind2], "celltype2_X": x2_arr, "celltype2_Y": y2_arr, "distance": dist_arr, } ) # Create symmetrical (reverse) annotated DataFrame data_reverse = pd.DataFrame( { region: [region_val] * len(ind1), "celltype1_index": id_array[ind2], "celltype1": ct_array[ind2], "celltype1_X": x2_arr, "celltype1_Y": y2_arr, "celltype2_index": id_array[ind1], "celltype2": ct_array[ind1], "celltype2_X": x1_arr, "celltype2_Y": y1_arr, "distance": dist_arr, } ) # Concatenate forward and reverse dataframes annotated_result = pd.concat([data_forward, data_reverse], ignore_index=True) annotated_result = annotated_result[ [ region, "celltype1_index", "celltype1", "celltype1_X", "celltype1_Y", "celltype2_index", "celltype2", "celltype2_X", "celltype2_Y", "distance", ] ] return annotated_result def calculate_triangulation_distances(df_input, id, x_pos, y_pos, cell_type, region): """ Calculate and annotate triangulation distances for cells. This function computes the triangulation edges for input cell data and then annotates them with cell metadata. It serves as a wrapper combining both steps into one process. Parameters ---------- df_input : pandas.DataFrame DataFrame containing the cell data. id : str Column name for cell identifiers. x_pos : str Column name for the x-coordinate. y_pos : str Column name for the y-coordinate. cell_type : str Column name for cell type information. region : str Column name for region information. Returns ------- pandas.DataFrame Annotated DataFrame with triangulation edges and metadata. """ edges_df = compute_triangulation_edges(df_input, x_pos, y_pos) annotated_result = annotate_triangulation_vectorized( edges_df, df_input, id, x_pos, y_pos, cell_type, region ) return annotated_result def process_region(df, unique_region, id, x_pos, y_pos, cell_type, region): """ Process triangulation distances for a specific region. This function subsets the dataframe to one specific region, adds unique identifier columns, and calculates the triangulation distances for that region. Parameters ---------- df : pandas.DataFrame The full dataset containing cell information. unique_region : str The specific region to process. id : str Column name for cell identifiers. x_pos : str Column name for x-coordinate. y_pos : str Column name for y-coordinate. cell_type : str Column name for cell type information. region : str Column name for region information. Returns ------- pandas.DataFrame Annotated DataFrame with triangulation distances for the specified region. """ subset = df[df[region] == unique_region].copy() subset["uniqueID"] = ( subset[id].astype(str) + "-" + subset[x_pos].astype(str) + "-" + subset[y_pos].astype(str) ) subset["XYcellID"] = subset[x_pos].astype(str) + "_" + subset[y_pos].astype(str) result = calculate_triangulation_distances( df_input=subset, id=id, x_pos=x_pos, y_pos=y_pos, cell_type=cell_type, region=region, ) return result def get_triangulation_distances( df_input, id, x_pos, y_pos, cell_type, region, num_cores=None, correct_dtype=True ): """ Compute triangulation distances for each unique region with parallel processing. This function processes the input DataFrame by first ensuring datatype consistency (optionally converting coordinate values to integers), and then computes triangulation distances per region in parallel using half of the available CPU cores (by default). Parameters ---------- df_input : pandas.DataFrame DataFrame containing cell data including coordinates, cell types, and region info. id : str Column name for cell identifiers. x_pos : str Column name for the x-coordinate. y_pos : str Column name for the y-coordinate. cell_type : str Column name for cell type information. region : str Column name for region information. num_cores : int, optional Number of CPU cores to use for parallel processing. If None, defaults to half of os.cpu_count(). correct_dtype : bool, optional Flag to convert columns to proper data types. Defaults to True. Returns ------- pandas.DataFrame A concatenated DataFrame with triangulation distances computed for all regions. """ if correct_dtype: df_input[cell_type] = df_input[cell_type].astype(str) df_input[region] = df_input[region].astype(str) if not issubclass(df_input[x_pos].dtype.type, np.integer): print("This function expects integer values for xy coordinates.") print( x_pos + " and " + y_pos + " will be changed to integer. Please check the generated output!" ) df_input[x_pos] = df_input[x_pos].astype(int).values df_input[y_pos] = df_input[y_pos].astype(int).values unique_regions = df_input[region].unique() df_input = df_input.loc[:, [id, x_pos, y_pos, cell_type, region]] if num_cores is None: num_cores = os.cpu_count() // 2 # Parallelize region processing results = Parallel(n_jobs=num_cores)( delayed(process_region)(df_input, reg, id, x_pos, y_pos, cell_type, region) for reg in unique_regions ) triangulation_distances = pd.concat(results) return triangulation_distances def shuffle_annotations(df_input, cell_type, region, permutation): """ Shuffle cell type annotations within each region. This function randomizes the cell type annotations of the input DataFrame on a per-region basis using a pseudo-random permutation seed. Parameters ---------- df_input : pandas.DataFrame DataFrame containing cell data. cell_type : str Column name for cell type information. region : str Column name for region information. permutation : int An integer used to seed the random number generator for reproducible shuffling. Returns ------- pandas.DataFrame A copy of df_input with an added column "random_annotations" representing the shuffled cell types. """ np.random.seed(permutation + 1234) df_shuffled = df_input.copy() for region_name in df_shuffled[region].unique(): region_mask = df_shuffled[region] == region_name shuffled_values = df_shuffled.loc[region_mask, cell_type].sample(frac=1).values df_shuffled.loc[region_mask, "random_annotations"] = shuffled_values return df_shuffled def _process_region_iterations( subset, edges_df, id_col, x_col, y_col, cell_type_col, region_col, region_val, num_iterations, ): """ Process multiple iterations of permutation for a given region. This helper function takes a subset of the data and precomputed triangulation edges and performs a series of iterations where cell type annotations are shuffled and the mean distances are computed for each permutation. Parameters ---------- subset : pandas.DataFrame DataFrame containing a subset of cell data for a single region. edges_df : pandas.DataFrame Precomputed triangulation edges for the subset. id_col : str Column name for cell identifiers. x_col : str Column name for the x-coordinate. y_col : str Column name for the y-coordinate. cell_type_col : str Column name for cell type or annotation to be shuffled. region_col : str Column name for region information. region_val : str The specific region value being processed. num_iterations : int Number of permutation iterations to perform. Returns ------- pandas.DataFrame A DataFrame concatenating the mean distance summaries for each iteration. """ results_list = [] for iteration in range(1, num_iterations + 1): shuffled = shuffle_annotations(subset, cell_type_col, region_col, iteration) annotated_df = annotate_triangulation_vectorized( edges_df, shuffled, id_col, x_col, y_col, "random_annotations", region_col ) per_cell_summary = ( annotated_df.groupby(["celltype1_index", "celltype1", "celltype2"]) .distance.mean() .reset_index(name="per_cell_mean_dist") ) per_celltype_summary = ( per_cell_summary.groupby(["celltype1", "celltype2"]) .per_cell_mean_dist.mean() .reset_index(name="mean_dist") ) per_celltype_summary[region_col] = region_val per_celltype_summary["iteration"] = iteration results_list.append(per_celltype_summary) return pd.concat(results_list, ignore_index=True) def tl_iterate_tri_distances( df_input, id, x_pos, y_pos, cell_type, region, num_cores=None, num_iterations=1000 ): """ Perform iterative permutation analysis for triangulation distances. This function iterates over each unique region to calculate permutation-based triangulation distance summaries using precomputed edges. It applies parallel processing to perform multiple iterations efficiently. Parameters ---------- df_input : pandas.DataFrame DataFrame containing the cell information. id : str Column name for cell identifiers. x_pos : str Column name for the x-coordinate. y_pos : str Column name for the y-coordinate. cell_type : str Column name for cell type information. region : str Column name for region information. num_cores : int, optional Number of CPU cores to use for parallelization. Defaults to half of os.cpu_count() if None. num_iterations : int, optional Number of permutation iterations to perform. Defaults to 1000. Returns ------- pandas.DataFrame A concatenated DataFrame with permutation-based mean distances for each region. """ unique_regions = df_input[region].unique() df_input = df_input[[id, x_pos, y_pos, cell_type, region]] # Precompute triangulation edges for each region region2df = {} region2edges_df = {} for reg_name in unique_regions: subset = df_input[df_input[region] == reg_name].copy() subset["uniqueID"] = ( subset[id].astype(str) + "-" + subset[x_pos].astype(str) + "-" + subset[y_pos].astype(str) ) subset["XYcellID"] = subset[x_pos].astype(str) + "_" + subset[y_pos].astype(str) edges_df = compute_triangulation_edges(subset, x_pos, y_pos) region2df[reg_name] = subset region2edges_df[reg_name] = edges_df def process_one_region(r): subset = region2df[r] edges_df = region2edges_df[r] return _process_region_iterations( subset, edges_df, id, x_pos, y_pos, cell_type, region, r, num_iterations ) results_per_region = Parallel( n_jobs=num_cores if num_cores is not None else os.cpu_count() // 2 )(delayed(process_one_region)(r) for r in unique_regions) iterative_triangulation_distances = pd.concat(results_per_region, ignore_index=True) return iterative_triangulation_distances def add_missing_columns( triangulation_distances, metadata, shared_column="unique_region" ): """ Add missing metadata columns to the triangulation distances DataFrame. This function compares the metadata DataFrame with the triangulation distances DataFrame and adds any columns from the metadata that are not present. It uses the shared_column to map values and fills any missing values with "Unknown". Parameters ---------- triangulation_distances : pandas.DataFrame DataFrame containing triangulation distances and possibly missing metadata columns. metadata : pandas.DataFrame DataFrame containing additional metadata including the shared column. shared_column : str, optional Column name that is common to both DataFrames, by default "unique_region". Returns ------- pandas.DataFrame The updated triangulation distances DataFrame with added metadata columns. """ missing_columns = set(metadata.columns) - set(triangulation_distances.columns) for column in missing_columns: triangulation_distances[column] = pd.NA region_to_tissue = pd.Series( metadata[column].values, index=metadata["unique_region"] ).to_dict() triangulation_distances[column] = triangulation_distances["unique_region"].map( region_to_tissue ) triangulation_distances[column].fillna("Unknown", inplace=True) return triangulation_distances def calculate_pvalue(row): """ Calculate the p-value using the Mann-Whitney U test. For a given row containing expected and observed lists of distances, this function computes the p-value from the Mann-Whitney U test comparing the two distributions. If the test fails, a NaN is returned. Parameters ---------- row : pandas.Series A row containing "expected" and "observed" distance lists. Returns ------- float The p-value computed from the Mann-Whitney U test, or NaN if computation fails. """ try: return st.mannwhitneyu( row["expected"], row["observed"], alternative="two-sided" ).pvalue except ValueError: return np.nan
[docs] def identify_interactions( adata, cellid, x_pos, y_pos, cell_type, region, comparison, min_observed=10, distance_threshold=128, num_cores=None, num_iterations=1000, key_name=None, correct_dtype=False, aggregate_per_cell=True, ): """ Identify significant cell-cell interactions based on spatial distances. This function processes the input annotated data (adata) to compute observed triangulation distances and perform permutation testing to generate expected distances. It then compares the observed with expected mean distances using the Mann-Whitney U test to compute a p-value and a log-fold change for each pair of cell types. The results are stored back in the adata object and returned. Parameters ---------- adata : AnnData Annotated data object that holds cell observation data (adata.obs). cellid : str Column name to be used as the unique cell identifier. x_pos : str Column name for the x-coordinate. y_pos : str Column name for the y-coordinate. cell_type : str Column name for cell type information. region : str Column name for region information. comparison : str Column name used to compare different conditions. min_observed : int, optional Minimum number of observed distance measurements required to consider a significant interaction (default: 10). distance_threshold : int, optional Maximum distance to consider when grouping cell interactions (default: 128). num_cores : int, optional Number of CPU cores to use for parallel processing. Defaults to half of available cores if None. num_iterations : int, optional The number of permutation iterations for generating expected distances (default: 1000). key_name : str, optional Key under which the triangulation distances will be stored in adata.uns. If None, defaults to "triDist". correct_dtype : bool, optional Flag to convert coordinate and region columns to string types (default: False). aggregate_per_cell : bool, optional Whether to aggregate distances initially at a per-cell basis (default: True). Returns ------- tuple A tuple containing: - distance_pvals (pandas.DataFrame): DataFrame with p-values and log-fold changes for each pair of cell types. - triangulation_distances_dict (dict): Dictionary containing observed and iterated triangulation distance DataFrames. """ df_input = pd.DataFrame(adata.obs) if cellid in df_input.columns: df_input.index = df_input[cellid] else: print(cellid + " is not in the adata.obs, use index as cellid instead!") df_input[cellid] = df_input.index df_input[cell_type] = df_input[cell_type].astype(str) df_input[region] = df_input[region].astype(str) print("Computing for observed distances between cell types!") triangulation_distances = get_triangulation_distances( df_input=df_input, id=cellid, x_pos=x_pos, y_pos=y_pos, cell_type=cell_type, region=region, num_cores=num_cores, correct_dtype=correct_dtype, ) if key_name is None: triDist_keyname = "triDist" else: triDist_keyname = key_name adata.uns[triDist_keyname] = triangulation_distances print("Save triangulation distances output to anndata.uns " + triDist_keyname) print("Permuting data labels to obtain the randomly distributed distances!") print("this step can take awhile") iterative_triangulation_distances = tl_iterate_tri_distances( df_input=df_input, id=cellid, x_pos=x_pos, y_pos=y_pos, cell_type=cell_type, region=region, num_cores=num_cores, num_iterations=num_iterations, ) metadata = df_input.loc[:, ["unique_region", comparison]].copy() # Reformat observed dataset triangulation_distances_long = add_missing_columns( triangulation_distances, metadata, shared_column=region ) if aggregate_per_cell: observed_distances = ( triangulation_distances_long.query("distance <= @distance_threshold") .groupby(["celltype1_index", "celltype1", "celltype2", comparison, region]) .agg(mean_per_cell=("distance", "mean")) .reset_index() .groupby(["celltype1", "celltype2", comparison]) .agg( observed=("mean_per_cell", list), observed_mean=("mean_per_cell", "mean"), ) .reset_index() ) else: observed_distances = ( triangulation_distances_long.query("distance <= @distance_threshold") .groupby( [ "celltype1_index", "celltype2_index", "celltype1", "celltype2", comparison, region, ] ) .agg(mean_per_cell=("distance", "mean")) .reset_index() .groupby(["celltype1", "celltype2", comparison]) .agg( observed=("mean_per_cell", list), observed_mean=("mean_per_cell", "mean"), ) .reset_index() ) # Reformat expected dataset iterated_triangulation_distances_long = add_missing_columns( iterative_triangulation_distances, metadata, shared_column=region ) expected_distances = ( iterated_triangulation_distances_long.query("mean_dist <= @distance_threshold") .groupby(["celltype1", "celltype2", comparison]) .agg(expected=("mean_dist", list), expected_mean=("mean_dist", "mean")) .reset_index() ) observed_distances["keep"] = observed_distances["observed"].apply( lambda x: len(x) > min_observed ) observed_distances = observed_distances[observed_distances["keep"]] expected_distances["keep"] = expected_distances["expected"].apply( lambda x: len(x) > min_observed ) expected_distances = expected_distances[expected_distances["keep"]] distance_pvals = expected_distances.merge( observed_distances, on=["celltype1", "celltype2", comparison], how="left" ) distance_pvals["pvalue"] = distance_pvals.apply(calculate_pvalue, axis=1) distance_pvals["logfold_group"] = np.log2( distance_pvals["observed_mean"] / distance_pvals["expected_mean"] ) distance_pvals["interaction"] = ( distance_pvals["celltype1"] + " --> " + distance_pvals["celltype2"] ) # Collect final results triangulation_distances_dict = { "distance_pvals": distance_pvals, "triangulation_distances_observed": iterated_triangulation_distances_long, "triangulation_distances_iterated": triangulation_distances_long, } return distance_pvals, triangulation_distances_dict
def adata_cell_percentages(adata, column_percentage="cell_type"): """ Calculate the percentage of each cell type in an AnnData object. Parameters: adata (AnnData): An AnnData object containing single-cell data. column_percentage (str): The column name in adata.obs that contains cell type information. Default is 'cell_type'. Returns: DataFrame: A pandas DataFrame with two columns: the specified column name and 'percentage', representing the percentage of each cell type. """ # Assuming 'adata' is an AnnData object and 'cell_type' is the column with cell type information cell_type_counts = adata.obs[column_percentage].value_counts() total_cells = len(adata) cell_type_percentages = (cell_type_counts / total_cells) * 100 # Convert to DataFrame for better readability cell_type_percentages_df = pd.DataFrame( { column_percentage: cell_type_counts.index, "percentage": cell_type_percentages.values, } ) return cell_type_percentages_df
[docs] def filter_interactions( distance_pvals, pvalue=0.05, logfold_group_abs=0.1, comparison="condition" ): """ Filters interactions based on p-value, logfold change, and other conditions. Parameters ---------- distance_pvals : pandas.DataFrame DataFrame containing p-values, logfold changes, and interactions for each comparison. pvalue : float, optional The maximum p-value to consider for significance. Defaults to 0.05. logfold_group_abs : float, optional The minimum absolute logfold change to consider for significance. Defaults to 0.1. comparison : str, optional The comparison condition to filter by. Defaults to "condition". Returns ------- dist_table : pandas.DataFrame DataFrame containing logfold changes sorted into two columns by the comparison condition. distance_pvals_sig_sub : pandas.DataFrame Subset of the original DataFrame containing only significant interactions based on the specified conditions. """ # calculate absolute logfold difference distance_pvals["logfold_group_abs"] = distance_pvals["logfold_group"].abs() # Creating pairs distance_pvals["pairs"] = ( distance_pvals["celltype1"] + "_" + distance_pvals["celltype2"] ) # Filter significant p-values and other specified conditions distance_pvals_sig = distance_pvals[ (distance_pvals["pvalue"] < pvalue) & (distance_pvals["celltype1"] != distance_pvals["celltype2"]) & (~distance_pvals["observed_mean"].isna()) & (distance_pvals["logfold_group_abs"] > logfold_group_abs) ] # Assuming distance_pvals_interesting2 is a pandas DataFrame with the same structure as the R dataframe. # pair_to = distance_pvals_sig["interaction"].unique() pairs = distance_pvals_sig["pairs"].unique() # Filtering data data = distance_pvals[~distance_pvals["interaction"].isna()] # Subsetting data distance_pvals_sig_sub = data[data["pairs"].isin(pairs)] distance_pvals_sig_sub_reduced = distance_pvals_sig_sub.loc[ :, [comparison, "logfold_group", "pairs"] ].copy() # set pairs as index distance_pvals_sig_sub_reduced = distance_pvals_sig_sub_reduced.set_index("pairs") # sort logfold_group into two columns by tissue dist_table = distance_pvals_sig_sub_reduced.pivot( columns=comparison, values="logfold_group" ) dist_table.dropna(inplace=True) return dist_table, distance_pvals_sig_sub
[docs] def remove_rare_cell_types( adata, distance_pvals, cell_type_column="cell_type", min_cell_type_percentage=1 ): """ Remove cell types with a percentage lower than the specified threshold from the distance_pvals DataFrame. Parameters ---------- adata : AnnData Annotated data matrix. distance_pvals : DataFrame DataFrame containing distance p-values with columns 'celltype1' and 'celltype2'. cell_type_column : str, optional Column name in adata containing cell type information, by default "cell_type". min_cell_type_percentage : float, optional Minimum percentage threshold for cell types to be retained, by default 1. Returns ------- DataFrame Filtered distance_pvals DataFrame with rare cell types removed. """ cell_type_percentages_df = adata_cell_percentages( adata, column_percentage=cell_type_column ) # Identify cell types with less than the specified percentage of the total cells rare_cell_types = cell_type_percentages_df[ cell_type_percentages_df["percentage"] < min_cell_type_percentage ][cell_type_column].values # Print the names of the cell types with less than the specified percentage of the total cells print( "Cell types that belong to less than " + str(min_cell_type_percentage) + "% of total cells:" ) print(rare_cell_types) # Remove rows from distance_pvals that contain rare cell types in column celltype1 or celltype2 distance_pvals = distance_pvals[ ~distance_pvals["celltype1"].isin(rare_cell_types) & ~distance_pvals["celltype2"].isin(rare_cell_types) ] return distance_pvals
def stellar_get_edge_index( pos, distance_thres, max_memory_usage=1.6e10, chunk_size=5000, ): """ Constructs edge indexes in one region based on pairwise distances and a distance threshold. Parameters: pos (array-like): An array-like object of shape (n_samples, n_features) representing the positions. distance_thres (float): The distance threshold. Pairs of positions with distances less than this threshold will be considered as edges. max_memory_usage (float): The maximum memory usage in bytes before switching to chunk processing. chunk_size (int): The size of the chunks to process at a time. Returns: edge_list (list): A list of lists where each inner list contains two indices representing an edge. """ n_samples = pos.shape[0] estimated_memory_usage = n_samples * n_samples * 8 # ~float64 for distance matrix if estimated_memory_usage > max_memory_usage: print("Processing will be done in chunks to save memory.") edge_list = [] for i in tqdm(range(0, n_samples, chunk_size), desc="Processing chunks"): pos_chunk = pos[i : i + chunk_size] dists_chunk = pairwise_distances(pos_chunk, pos) dists_mask_chunk = dists_chunk < distance_thres np.fill_diagonal(dists_mask_chunk[:, i : i + chunk_size], 0) chunk_edge_list = np.transpose(np.nonzero(dists_mask_chunk)).tolist() chunk_edge_list = [[i + edge[0], edge[1]] for edge in chunk_edge_list] edge_list.extend(chunk_edge_list) else: dists = pairwise_distances(pos) dists_mask = dists < distance_thres np.fill_diagonal(dists_mask, 0) edge_list = np.transpose(np.nonzero(dists_mask)).tolist() return edge_list
[docs] def adata_stellar( adata_train, adata_unannotated, celltype_col="cell_type", region_column=None, x_col="x", y_col="y", sample_rate=0.5, distance_thres=50, epochs=50, num_seed_class=3, key_added="stellar_pred", STELLAR_path="", max_memory_usage=1.6e10, chunk_size=5000, wd=5e-2, lr=1e-3, seed=1, batch_size=1, ): """ Apply the STELLAR algorithm to annotated and unannotated spatial single-cell data. This function processes the input AnnData objects by preparing the training data, constructing graph edges based on spatial coordinates, and then running the STELLAR algorithm for label prediction. When a region column is provided, the edge computations are performed for each region separately and the resulting edges are concatenated. Parameters ---------- adata_train : AnnData The annotated single-cell data used for training. adata_unannotated : AnnData The unannotated single-cell data for which predictions are desired. celltype_col : str, optional Column name in adata_train.obs that contains the cell type labels, by default "cell_type". region_column : str or None, optional Column name to partition data into regions. If not None, edges are computed independently per region, by default None. x_col : str, optional Column name in the AnnData objects denoting the x-coordinate, by default "x". y_col : str, optional Column name in the AnnData objects denoting the y-coordinate, by default "y". sample_rate : float, optional The rate at which to sample the training data (between 0 and 1), by default 0.5. distance_thres : int, optional Distance threshold (in the same unit as the spatial coordinates) used to determine whether a pair of cells is connected, by default 50. epochs : int, optional Number of training epochs for the STELLAR model, by default 50. num_seed_class : int, optional Number of seed classes, which are appended to the number of unique cell types, by default 3. key_added : str, optional Key under which the predicted labels will be stored in adata_unannotated.obs, by default "stellar_pred". STELLAR_path : str, optional Filesystem path to the STELLAR repository. This path is added to sys.path, by default "". max_memory_usage : float, optional Maximum allowable memory usage in bytes when computing pairwise distances; if exceeded, the computation will be done in chunks, by default 1.6e10. chunk_size : int, optional The size of chunks to use for edge computation when memory usage is high, by default 5000. wd : float, optional Weight decay parameter for model optimization, by default 5e-2. lr : float, optional Learning rate for model training, by default 1e-3. seed : int, optional Seed used for reproducibility, by default 1. batch_size : int, optional Batch size for model training, by default 1. Returns ------- AnnData The unannotated AnnData object with an additional observation column (key_added) containing the predicted cell type labels. Notes ----- The function performs the following steps: 1. Prints a citation reminder for the STELLAR algorithm. 2. Sets up the model arguments by parsing command-line-like arguments. 3. Prepares the training data by concatenating coordinate information and cell types, and builds a mapping between original and sampled indices. 4. Computes graph edges either globally or per region (if region_column is provided) using the provided spatial coordinates and distance threshold. 5. Constructs a GraphDataset and runs the STELLAR algorithm on it. 6. Returns the modified adata_unannotated with predictions stored in obs[key_added]. The function assumes that helper functions (e.g., stellar_get_edge_index) and necessary modules, including torch, argparse, and dataset utility modules, are available in the environment. """ print( "Please consider to cite the following paper when using STELLAR: " "Brbić, M., Cao, K., Hickey, J.W. et al. Annotation of spatially resolved single-cell data with STELLAR. " "Nat Methods 19, 1411–1418 (2022). https://doi.org/10.1038/s41592-022-01651-8" ) sys.path.append(str(STELLAR_path)) from datasets import GraphDataset from STELLAR import STELLAR from utils import prepare_save_dir parser = argparse.ArgumentParser(description="STELLAR") parser.add_argument("--seed", type=int, default=seed, metavar="S") parser.add_argument("--name", type=str, default="STELLAR") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--lr", type=float, default=lr) parser.add_argument("--wd", type=float, default=wd) # Don't set input_dim here - let STELLAR set it from the dataset parser.add_argument("--num-heads", type=int, default=13) parser.add_argument("--num-seed-class", type=int, default=num_seed_class) parser.add_argument("--sample-rate", type=float, default=sample_rate) parser.add_argument("-b", "--batch-size", default=batch_size, type=int, metavar="N") parser.add_argument("--distance_thres", default=50, type=int) parser.add_argument("--savedir", type=str, default="./") args = parser.parse_args(args=[]) args.cuda = torch.cuda.is_available() args.device = torch.device("cuda" if args.cuda else "cpu") args.epochs = epochs args.distance_thres = distance_thres # Set the number of heads based on the number of cell types (IMPORTANT) args.num_heads = len(adata_train.obs[celltype_col].unique()) + num_seed_class print("Preparing input data") # Create a mapping between original cell indices and sampled indices train_df = adata_train.to_df() positions_celltype = adata_train.obs[[x_col, y_col, celltype_col]] train_df = pd.concat([train_df, positions_celltype], axis=1) # Create a key to track the original indices before sampling train_df["original_idx"] = np.arange(len(train_df)) # train_df = train_df.sample(n=round(sample_rate * len(train_df)), random_state=1) # Get the mapping from sampled indices to original indices sampled_to_original = dict(enumerate(train_df["original_idx"].values)) original_to_sampled = {v: k for k, v in sampled_to_original.items()} # Remove the helper column train_df = train_df.drop(columns=["original_idx"]) train_X = train_df.iloc[:, :-3].values test_X = adata_unannotated.to_df().values # DO NOT convert to lowercase - keep original cell type labels train_y = train_df[celltype_col].values labeled_pos = train_df.iloc[:, -3:-1].values unlabeled_pos = adata_unannotated.obs[[x_col, y_col]].values # Print information about the data to help with debugging print(f"Number of unique cell types: {len(np.unique(train_y))}") print(f"Cell types: {np.unique(train_y)[:10]}") cell_types = np.sort(list(set(train_y))).tolist() cell_type_dict = {} inverse_dict = {} for i, t in enumerate(cell_types): cell_type_dict[t] = i inverse_dict[i] = t # Print dictionaries to verify proper mapping print(f"Cell type mapping size: {len(cell_type_dict)}") print(f"First 3 mappings: {list(cell_type_dict.items())[:3]}") train_y = np.array([cell_type_dict[x] for x in train_y]) # If region_column is provided, compute edges per region and concatenate if region_column is not None: labeled_edges_list = [] unlabeled_edges_list = [] train_regions_map = {} # To map region cells to train_df indices # Create a mapping between original cells and their position in the sampled data sampled_cells = set(train_df.index) # Get unique regions unique_regions_train = adata_train.obs[region_column].unique() unique_regions_unannot = adata_unannotated.obs[region_column].unique() # For each region in annotated data print("Processing labeled data") for region in unique_regions_train: print(f"Processing region: {region}") # Get cells in this region that are also in the sampled training data region_mask_full = adata_train.obs[region_column] == region # Extract positions for just those cells that are in the training data region_cells = adata_train[region_mask_full].obs_names region_sampled_cells = [ cell for cell in region_cells if cell in sampled_cells ] if len(region_sampled_cells) < 2: print( f" Skipping region {region}: too few sampled cells ({len(region_sampled_cells)})" ) continue # Get positions of sampled cells in this region region_pos = train_df.loc[region_sampled_cells, [x_col, y_col]].values # Create mapping from local indices to global training indices local_to_global = { i: original_to_sampled.get(adata_train.obs_names.get_loc(cell), None) for i, cell in enumerate(region_sampled_cells) } local_to_global = { k: v for k, v in local_to_global.items() if v is not None } if len(local_to_global) < 2: print( f" Skipping region {region}: too few cells with valid mapping ({len(local_to_global)})" ) continue # Calculate edges using local indices region_edges = stellar_get_edge_index( region_pos, distance_thres=distance_thres, max_memory_usage=max_memory_usage, chunk_size=chunk_size, ) # Convert to global indices for the training set valid_edges = [] for edge in region_edges: src, dst = edge if src in local_to_global and dst in local_to_global: valid_edges.append([local_to_global[src], local_to_global[dst]]) if valid_edges: labeled_edges_list.extend(valid_edges) # For each region in unannotated data print("Processing unlabeled data") for region in unique_regions_unannot: print(f"Processing region: {region}") region_mask = adata_unannotated.obs[region_column] == region region_indices = np.where(region_mask)[0] if len(region_indices) < 2: print( f" Skipping region {region}: too few cells ({len(region_indices)})" ) continue # Get positions region_pos = adata_unannotated.obs.loc[region_mask, [x_col, y_col]].values # Calculate edges region_edges = stellar_get_edge_index( region_pos, distance_thres=distance_thres, max_memory_usage=max_memory_usage, chunk_size=chunk_size, ) # Map local indices to global indices valid_edges = [] for edge in region_edges: src_local, dst_local = edge if src_local < len(region_indices) and dst_local < len(region_indices): src_global = region_indices[src_local] dst_global = region_indices[dst_local] if src_global < len(test_X) and dst_global < len(test_X): valid_edges.append([src_global, dst_global]) if valid_edges: unlabeled_edges_list.extend(valid_edges) # Convert edge lists to arrays if labeled_edges_list: labeled_edges = np.array(labeled_edges_list) # Final sanity check for labeled edges max_allowed_idx = len(train_X) - 1 valid_mask = (labeled_edges[:, 0] <= max_allowed_idx) & ( labeled_edges[:, 1] <= max_allowed_idx ) labeled_edges = labeled_edges[valid_mask] else: labeled_edges = np.array([]).reshape(0, 2) if unlabeled_edges_list: unlabeled_edges = np.array(unlabeled_edges_list) # Final sanity check for unlabeled edges max_allowed_idx = len(test_X) - 1 valid_mask = (unlabeled_edges[:, 0] <= max_allowed_idx) & ( unlabeled_edges[:, 1] <= max_allowed_idx ) unlabeled_edges = unlabeled_edges[valid_mask] else: unlabeled_edges = np.array([]).reshape(0, 2) print(f"Final labeled edges: {labeled_edges.shape[0]}") print(f"Final unlabeled edges: {unlabeled_edges.shape[0]}") else: # Standard approach with global edge indices labeled_edges = stellar_get_edge_index( labeled_pos, distance_thres=distance_thres, max_memory_usage=max_memory_usage, chunk_size=chunk_size, ) unlabeled_edges = stellar_get_edge_index( unlabeled_pos, distance_thres=distance_thres, max_memory_usage=max_memory_usage, chunk_size=chunk_size, ) # Convert to numpy arrays labeled_edges = np.array(labeled_edges) unlabeled_edges = np.array(unlabeled_edges) # Final sanity checks max_train_idx = len(train_X) - 1 max_test_idx = len(test_X) - 1 valid_mask = (labeled_edges[:, 0] <= max_train_idx) & ( labeled_edges[:, 1] <= max_train_idx ) labeled_edges = labeled_edges[valid_mask] valid_mask = (unlabeled_edges[:, 0] <= max_test_idx) & ( unlabeled_edges[:, 1] <= max_test_idx ) unlabeled_edges = unlabeled_edges[valid_mask] print("Building dataset") # Debug info to verify edge indices are within bounds print(f"Training data shape: {train_X.shape}, max label index: {len(train_X)-1}") print(f"Testing data shape: {test_X.shape}, max label index: {len(test_X)-1}") if len(labeled_edges) > 0: print(f"Max labeled edge index: {np.max(labeled_edges)}") if len(unlabeled_edges) > 0: print(f"Max unlabeled edge index: {np.max(unlabeled_edges)}") # Make sure all edge indices are within bounds if len(labeled_edges) > 0: max_idx = len(train_X) - 1 valid_mask = np.all(labeled_edges <= max_idx, axis=1) if not np.all(valid_mask): print(f"Removing {(~valid_mask).sum()} out-of-bounds labeled edges") labeled_edges = labeled_edges[valid_mask] if len(unlabeled_edges) > 0: max_idx = len(test_X) - 1 valid_mask = np.all(unlabeled_edges <= max_idx, axis=1) if not np.all(valid_mask): print(f"Removing {(~valid_mask).sum()} out-of-bounds unlabeled edges") unlabeled_edges = unlabeled_edges[valid_mask] # Ensure we have at least some edges - STELLAR will fail without edges if len(labeled_edges) == 0 or len(unlabeled_edges) == 0: print( "WARNING: No edges found for labeled or unlabeled data - STELLAR requires edges for both datasets" ) if len(labeled_edges) == 0: print("Adding token edge to labeled data") labeled_edges = np.array([[0, min(1, len(train_X) - 1)]]) if len(unlabeled_edges) == 0: print("Adding token edge to unlabeled data") unlabeled_edges = np.array([[0, min(1, len(test_X) - 1)]]) # Build dataset dataset = GraphDataset(train_X, train_y, test_X, labeled_edges, unlabeled_edges) # IMPORTANT: Set input dimension correctly from dataset args.input_dim = train_X.shape[1] print(f"Setting input dimension to: {args.input_dim}") print("Running STELLAR") stellar = STELLAR(args, dataset) # Additional diagnostic info model_params = sum(p.numel() for p in stellar.model.parameters()) print(f"Model has {model_params} parameters") # Execute training stellar.train() # Get predictions uncertainty, results = stellar.pred() print(f"Mean prediction uncertainty: {uncertainty:.4f}") # Count predictions per class to check for uniformity unique_preds, counts = np.unique(results, return_counts=True) print("Prediction distribution:") for pred, count in zip(unique_preds, counts): print(f"Class {pred}: {count} cells ({count/len(results)*100:.2f}%)") # Convert numerical predictions back to string labels results = results.astype("object") for i in range(len(results)): if results[i] in inverse_dict.keys(): results[i] = inverse_dict[results[i]] else: # Handle new/unseen classes results[i] = f"New_Class_{results[i]}" adata_unannotated.obs[key_added] = pd.Categorical(results) adata_unannotated.obs[key_added] = adata_unannotated.obs[key_added].astype(str) return adata_unannotated
[docs] def ml_train( adata_train, label, test_size=0.2, random_state=0, nan_policy_y="omit", mode="accurate_SVC", # 'accurate_SVC' or 'fast_SVC' or 'knn' showfig=True, figsize=(8, 6), n_neighbors=5, # Number of neighbors for KNN ): """ Train a classifier (SVC, LinearSVC, or KNN) on the data. Parameters ---------- adata_train : anndata.AnnData The AnnData object containing the training data. The input features are expected in adata_train.X, and the target labels in adata_train.obs[label]. label : str The column name in adata_train.obs to use as the target variable. test_size : float, optional The proportion of the dataset to include in the test split. Default is 0.2. random_state : int, optional Seed for the random number generator. Default is 0. nan_policy_y : {'omit', 'raise'}, optional Policy for handling NaNs in the target variable. 'omit' removes NaNs, 'raise' raises an error. Default is 'omit'. mode : {'accurate_SVC', 'fast_SVC', 'knn'}, optional The type of classifier to use. - 'accurate_SVC': Uses SVC with probability=True (slower, provides predict_proba). - 'fast_SVC': Uses LinearSVC with CalibratedClassifierCV (faster, optional predict_proba). - 'knn': Uses KNeighborsClassifier with specified n_neighbors. Default is 'accurate_SVC'. showfig : bool, optional Whether to display a heatmap of the classification report. Default is True. figsize : tuple, optional Size of the figure if showfig is True. Default is (8, 6). n_neighbors : int, optional Number of neighbors for KNN classifier. Default is 5. Returns ------- svc : sklearn.base.BaseEstimator The trained classifier model. For 'accurate_SVC' and 'fast_SVC', the model supports predict_proba. For 'knn', only predict is available. Raises ------ ValueError If `mode` is not one of the allowed options, or if `nan_policy_y` is not 'omit' or 'raise'. Notes ----- - The function handles NaNs in the target variable based on `nan_policy_y`. - For 'accurate_SVC', the classifier provides `predict_proba` for probability estimates. - For 'fast_SVC', the classifier uses LinearSVC with calibration for `predict_proba`. - For 'knn', the classifier uses KNeighborsClassifier with the specified `n_neighbors`. - The classification report is displayed as a heatmap if `showfig` is True. - The function prints progress messages (e.g., "Preparing training data!", "Training now!", etc.). """ print("Preparing training data!") X = adata_train.X if not isinstance(X, np.ndarray): X = X.toarray() y = adata_train.obs[label].values if nan_policy_y == "omit": y_msk = ~pd.isna(y) # <-- pd.isna works fine with mixed types X = X[y_msk] y = y[y_msk] elif nan_policy_y != "raise": raise ValueError("nan_policy_y must be either 'omit' or 'raise'") X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state ) print("Unique labels:", np.unique(y)) # Adjusted for clarity print("Training now!") if mode == "accurate_SVC": svc = SVC(kernel="linear", probability=True) # SVC with probability elif mode == "knn": print("Using KNN classifier with n_neighbors = ", n_neighbors) svc = KNeighborsClassifier(n_neighbors=n_neighbors) elif mode == "fast_SVC": base_svc = LinearSVC() svc = CalibratedClassifierCV( base_svc ) # CalibratedClassifierCV to add predict_proba else: raise ValueError("mode must be 'accurate_SVC', 'fast_SVC', or 'knn'") svc.fit(X_train, y_train) print("Evaluating now!") # Predict probability only for 'accurate' mode if mode == "accurate_SVC": y_prob = svc.predict_proba(X_test) y_prob = pd.DataFrame(y_prob, columns=svc.classes_) svm_label = y_prob.idxmax(axis=1) elif mode == "knn": svm_label = svc.predict(X_test) # No predict_proba for LinearSVC elif mode == "fast_SVC": svm_label = svc.predict(X_test) # No predict_proba for LinearSVC # Generate classification report target_names = svc.classes_ svm_eval = classification_report( y_true=y_test, y_pred=svm_label, target_names=target_names, output_dict=True ) if showfig: plt.figure(figsize=figsize) sns.heatmap(pd.DataFrame(svm_eval).iloc[:-1, :].T, annot=True) plt.title(f"Classification Report ({mode})") plt.show() return svc
[docs] def ml_predict(adata_val, svc, save_name="svm_pred", return_prob_mat=False): """ Predict labels for a given dataset using a trained Support Vector Classifier (SVC) model. Parameters ---------- adata_val : AnnData The validation data as an AnnData object. svc : SVC The trained Support Vector Classifier model. save_name : str, optional The name under which the predictions will be saved in the AnnData object, by default "svm_pred". return_prob_mat : bool, optional Whether to return the probability matrix, by default False. Returns ------- DataFrame or None If `return_prob_mat` is True, returns a DataFrame with the probability matrix. Otherwise, returns None. """ print("Classifying!") X_val = pd.DataFrame(adata_val.X) y_prob_val = svc.predict_proba(X_val) y_prob_val = pd.DataFrame(y_prob_val) y_prob_val.columns = svc.classes_ svm_label_val = y_prob_val.idxmax(axis=1, skipna=True) svm_label_val.index = X_val.index print("Saving cell type labels to adata!") adata_val.obs[save_name] = svm_label_val.values if return_prob_mat: print("Returning probability matrix!") y_prob_val.columns = svc.classes_ svm_label_val = y_prob_val.idxmax(axis=1, skipna=True) return svm_label_val
def masks_to_outlines_scikit_image(masks): """get outlines of masks as a 0-1 array Parameters ---------------- masks: int, 2D or 3D array size [Ly x Lx] or [Lz x Ly x Lx], 0=NO masks; 1,2,...=mask labels Returns ---------------- outlines: 2D or 3D array size [Ly x Lx] or [Lz x Ly x Lx], True pixels are outlines """ if masks.ndim > 3 or masks.ndim < 2: raise ValueError( "masks_to_outlines takes 2D or 3D array, not %dD array" % masks.ndim ) if masks.ndim == 3: outlines = np.zeros(masks.shape, bool) for i in range(masks.shape[0]): outlines[i] = find_boundaries(masks[i], mode="inner") return outlines else: return find_boundaries(masks, mode="inner") def download_file_tm(url, save_path): """ Download a file from a given URL and save it to a specified path. Parameters ---------- url : str The URL of the file to download. save_path : str The local path where the downloaded file will be saved. Raises ------ requests.exceptions.HTTPError If the HTTP request returned an unsuccessful status code. """ response = requests.get(url) response.raise_for_status() # Check if the request was successful with open(save_path, "wb") as file: file.write(response.content) def check_download_tm_plugins(): """ Check and download the TissUUmaps plugins if they are not already present. This function checks if the required TissUUmaps plugins are present in the appropriate directory within the active Conda environment. If any plugins are missing, they are downloaded from the specified URLs. Raises ------ EnvironmentError If the Conda environment is not activated. """ urls = [ "https://tissuumaps.github.io/TissUUmaps/plugins/latest/ClassQC.js", "https://tissuumaps.github.io/TissUUmaps/plugins/latest/Plot_Histogram.js", "https://tissuumaps.github.io/TissUUmaps/plugins/latest/Points2Regions.js", "https://tissuumaps.github.io/TissUUmaps/plugins/latest/Spot_Inspector.js", "https://tissuumaps.github.io/TissUUmaps/plugins/latest/Feature_Space.js", ] conda_env_path = os.getenv("CONDA_PREFIX") if not conda_env_path: raise EnvironmentError("Conda environment is not activated.") python_version = f"python{sys.version_info.major}.{sys.version_info.minor}" save_directory = os.path.join( conda_env_path, "lib", python_version, "site-packages", "tissuumaps", "plugins" ) if not os.path.exists(save_directory): save_directory_option = os.path.join( conda_env_path, "lib", "site-packages", "tissuumaps", "plugins" ) for url in urls: file_name = os.path.basename(url) save_path = os.path.join(save_directory_option, file_name) if not os.path.exists(save_path): download_file_tm(url, save_path) print(f"Plug-in downloaded and saved to {save_path}") else: for url in urls: file_name = os.path.basename(url) save_path = os.path.join(save_directory, file_name) if not os.path.exists(save_path): download_file_tm(url, save_path) print(f"Plug-in downloaded and saved to {save_path}")
[docs] def tm_viewer( adata, images_pickle_path, directory=None, region_column="unique_region", region="", xSelector="x", ySelector="y", color_by="cell_type", keep_list=None, include_masks=True, open_viewer=True, add_UMAP=True, use_jpg_compression=False, ): """ Prepare and visualize spatial transcriptomics data using TissUUmaps. Parameters ---------- adata : AnnData Annotated data matrix. images_pickle_path : str Path to the pickle file containing images and masks. directory : str, optional Directory to save the output files. If None, a temporary directory will be created. region_column : str, optional Column name in `adata.obs` that specifies the region, by default "unique_region". region : str, optional Specific region to process, by default "". xSelector : str, optional Column name for x coordinates, by default "x". ySelector : str, optional Column name for y coordinates, by default "y". color_by : str, optional Column name for coloring the points, by default "celltype_fine". keep_list : list, optional List of columns to keep from `adata.obs`, by default None. include_masks : bool, optional Whether to include masks in the output, by default True. open_viewer : bool, optional Whether to open the TissUUmaps viewer, by default True. add_UMAP : bool, optional Whether to add UMAP coordinates to the output, by default True. use_jpg_compression : bool, optional Whether to use JPEG compression for saving images, by default False. Returns ------- list List of paths to the saved image files. list List of paths to the saved CSV files. """ print( "Please consider to cite the following paper when using TissUUmaps: TissUUmaps 3: Improvements in interactive visualization, exploration, and quality assessment of large-scale spatial omics data - Pielawski, Nicolas et al. 2023 - Heliyon, Volume 9, Issue 5, e15306" ) check_download_tm_plugins() segmented_matrix = adata.obs with open(images_pickle_path, "rb") as f: seg_output = pickle.load(f) image_dict = seg_output["image_dict"] masks = seg_output["masks"] if keep_list is None: keep_list = [region_column, xSelector, ySelector, color_by] print("Preparing TissUUmaps input...") if directory is None: directory = tempfile.mkdtemp() cache_dir = pathlib.Path(directory) / region cache_dir.mkdir(parents=True, exist_ok=True) # only keep columns in keep_list segmented_matrix = segmented_matrix[keep_list] if add_UMAP: # add UMAP coordinates to segmented_matrix segmented_matrix["UMAP_1"] = adata.obsm["X_umap"][:, 0] segmented_matrix["UMAP_2"] = adata.obsm["X_umap"][:, 1] csv_paths = [] # separate matrix by region and save every region as single csv file region_matrix = segmented_matrix.loc[segmented_matrix[region_column] == region] region_matrix.to_csv(cache_dir / (region + ".csv")) csv_paths.append(cache_dir / (region + ".csv")) # generate subdirectory for images image_dir = cache_dir / "images" image_dir.mkdir(parents=True, exist_ok=True) image_list = [] # save every image as tif file in image directory from image_dict. name by key in image_dict if use_jpg_compression == True: print("Using jpg compression") for key, image in image_dict.items(): if use_jpg_compression == True: file_path = os.path.join(image_dir, f"{key}.jpg") imsave(file_path, image, quality=100) else: file_path = os.path.join(image_dir, f"{key}.tif") imsave(file_path, image, check_contrast=False) image_list.append(file_path) if include_masks: # select first item from image_dict as reference image reference_image = list(image_dict.values())[0] # make reference image black by setting all values to 0 reference_image = np.zeros_like(reference_image) # make the reference image rgb. Add empty channels if len(reference_image.shape) == 2: reference_image = np.expand_dims(reference_image, axis=-1) reference_image = np.repeat(reference_image, 3, axis=-1) # remove last dimension from masks masks_3d = np.squeeze(masks) outlines = masks_to_outlines_scikit_image(masks_3d) reference_image[outlines] = [255, 0, 0] file_path = os.path.join(image_dir, "masks.jpg") # save black pixel as transparent reference_image = reference_image.astype(np.uint8) imsave(file_path, reference_image) image_list.append(file_path) if open_viewer: print("Opening TissUUmaps viewer...") tj.loaddata( images=image_list, csvFiles=[str(p) for p in csv_paths], xSelector=xSelector, ySelector=ySelector, keySelector=color_by, nameSelector=color_by, colorSelector=color_by, piechartSelector=None, shapeSelector=None, scaleSelector=None, fixedShape=None, scaleFactor=1, colormap=None, compositeMode="source-over", boundingBox=None, port=5100, host="localhost", height=900, tmapFilename=region + "_project", plugins=[ "Plot_Histogram", "Points2Regions", "Spot_Inspector", "Feature_Space", "ClassQC", ], ) return image_list, csv_paths
[docs] def tm_viewer_catplot( adata, directory=None, region_column="unique_region", x="x", y="y", color_by="cell_type", open_viewer=True, add_UMAP=False, keep_list=None, ): """ Generate and visualize categorical plots using TissUUmaps. Parameters ---------- adata : AnnData Annotated data matrix. directory : str, optional Directory to save the output CSV files. If None, a temporary directory is created. region_column : str, optional Column name in `adata.obs` that contains region information. Default is "unique_region". x : str, optional Column name in `adata.obs` to be used for x-axis. Default is "x". y : str, optional Column name in `adata.obs` to be used for y-axis. Default is "y". color_by : str, optional Column name in `adata.obs` to be used for coloring the points. Default is "cell_type". open_viewer : bool, optional Whether to open the TissUUmaps viewer after generating the CSV files. Default is True. add_UMAP : bool, optional Whether to add UMAP coordinates to the output data. Default is False. keep_list : list of str, optional List of columns to keep from `adata.obs`. If None, defaults to [region_column, x, y, color_by]. Returns ------- list of str List of paths to the generated CSV files. """ check_download_tm_plugins() segmented_matrix = adata.obs if keep_list is None: keep_list = [region_column, x, y, color_by] print("Preparing TissUUmaps input...") if directory is None: print( "Creating temporary directory... If you want to save the files, please specify a directory." ) directory = tempfile.mkdtemp() if not os.path.exists(directory): os.makedirs(directory) # only keep columns in keep_list segmented_matrix = segmented_matrix[keep_list] if add_UMAP: # add UMAP coordinates to segmented_matrix segmented_matrix["UMAP_1"] = adata.obsm["X_umap"][:, 0] segmented_matrix["UMAP_2"] = adata.obsm["X_umap"][:, 1] csv_paths = [] # separate matrix by region and save every region as single csv file unique_regions = segmented_matrix[region_column].unique() for region in unique_regions: region_matrix = segmented_matrix.loc[segmented_matrix[region_column] == region] region_csv_path = os.path.join(directory, region + ".csv") region_matrix.to_csv(region_csv_path) csv_paths.append(region_csv_path) if open_viewer: print("Opening TissUUmaps viewer...") tj.loaddata( images=[], csvFiles=[str(p) for p in csv_paths], xSelector=x, ySelector=y, keySelector=color_by, nameSelector=color_by, colorSelector=color_by, piechartSelector=None, shapeSelector=None, scaleSelector=None, fixedShape=None, scaleFactor=1, colormap=None, compositeMode="source-over", boundingBox=None, port=5100, host="localhost", height=900, tmapFilename="project", plugins=[ "Plot_Histogram", "Points2Regions", "Spot_Inspector", "Feature_Space", "ClassQC", ], ) return csv_paths
def install_gpu_leiden(CUDA="12"): """ Install the necessary packages for GPU-accelerated Leiden clustering. Parameters ---------- CUDA : str, optional The version of CUDA to use for the installation. Options are '11' and '12'. Default is '12'. Returns ------- None Notes ----- This function runs a series of pip install commands to install the necessary packages. The specific packages and versions installed depend on the CUDA version. The function prints the output and any errors from each command. """ if platform.system() != "Linux": print("This feature is currently only supported on Linux.") else: print("installing rapids_singlecell") # Define the commands to run if CUDA == "11": commands = [ "pip install rapids-singlecell==0.9.5", "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu11==24.2.* dask-cudf-cu11==24.2.* cuml-cu11==24.2.* cugraph-cu11==24.2.* cuspatial-cu11==24.2.* cuproj-cu11==24.2.* cuxfilter-cu11==24.2.* cucim-cu11==24.2.* pylibraft-cu11==24.2.* raft-dask-cu11==24.2.*", "pip install protobuf==3.20", ] else: commands = [ "pip install rapids-singlecell==0.9.5", "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12==24.2.* dask-cudf-cu12==24.2.* cuml-cu12==24.2.* cugraph-cu12==24.2.* cuspatial-cu12==24.2.* cuproj-cu12==24.2.* cuxfilter-cu12==24.2.* cucim-cu12==24.2.* pylibraft-cu12==24.2.* raft-dask-cu12==24.2.*", "pip install protobuf==3.20", ] # Run each command for command in commands: process = subprocess.Popen( command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = process.communicate() # Print the output and error, if any if stdout: print(f"Output:\n{stdout.decode()}") if stderr: print(f"Error:\n{stderr.decode()}") def anndata_to_GPU( adata: AnnData, layer: str | None = None, convert_all: bool = False, ) -> AnnData: """ Transfers matrices and arrays to the GPU Parameters ---------- adata AnnData object layer Layer to use as input instead of `X`. If `None`, `X` is used. convert_all If True, move all supported arrays and matrices on the GPU Returns ------- Returns an updated copy with data on GPU """ adata_gpu = adata.copy() if convert_all: anndata_to_GPU(adata_gpu) if adata_gpu.layers: for key in adata_gpu.layers.keys(): anndata_to_GPU(adata_gpu, layer=key) else: X = _get_obs_rep(adata_gpu, layer=layer) if isspmatrix_csr_cpu(X): X = csr_matrix_gpu(X) elif isspmatrix_csc_cpu(X): X = csc_matrix_gpu(X) elif isinstance(X, np.ndarray): # Convert to CuPy array only when necessary for GPU computations X_gpu = cp.asarray(X) X = X_gpu else: error = layer if layer else "X" warnings.warn(f"{error} not supported for GPU conversion", Warning) _set_obs_rep(adata_gpu, X, layer=layer) return adata_gpu def anndata_to_CPU( adata: AnnData, layer: str | None = None, convert_all: bool = False, copy: bool = False, ) -> AnnData | None: """ Transfers matrices and arrays from the GPU Parameters ---------- adata AnnData object layer Layer to use as input instead of `X`. If `None`, `X` is used. convert_all If True, move all GPU based arrays and matrices to the host memory copy Whether to return a copy or update `adata`. Returns ------- Updates `adata` inplace or returns an updated copy """ if copy: adata = adata.copy() if convert_all: anndata_to_CPU(adata) if adata.layers: for key in adata.layers.keys(): anndata_to_CPU(adata, layer=key) else: X = _get_obs_rep(adata, layer=layer) if isspmatrix_csr_gpu(X): X = X.get() elif isspmatrix_csc_gpu(X): X = X.get() elif isinstance(X, cp.ndarray): X = X.get() else: pass _set_obs_rep(adata, X, layer=layer) if copy: return adata def install_stellar(CUDA=12): if CUDA == 12: subprocess.run(["pip", "install", "torch"], check=True) subprocess.run(["pip", "install", "torch_geometric"], check=True) subprocess.run( [ "pip", "install", "pyg_lib", "torch_scatter", "torch_sparse", "torch_cluster", "torch_spline_conv", "-f", "https://data.pyg.org/whl/torch-2.3.0+cu121.html", ], check=True, ) elif CUDA == 11.8: subprocess.run( [ "pip3", "install", "torch", "--index-url", "https://download.pytorch.org/whl/cu118", ], check=True, ) subprocess.run(["pip", "install", "torch_geometric"], check=True) subprocess.run( [ "pip", "install", "pyg_lib", "torch_scatter", "torch_sparse", "torch_cluster", "torch_spline_conv", "-f", "https://data.pyg.org/whl/torch-2.3.0+cu118.html", ], check=True, ) else: print("Please choose between CUDA 12 or 11.8") print( "If neither is working for you check the installation guide at: https://pytorch.org/get-started/locally/ and https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html" )
[docs] def launch_interactive_clustering(adata=None, output_dir=None): """ Launch an interactive clustering application for single-cell data analysis. Parameters ---------- adata : AnnData, optional An AnnData object containing single-cell data. If provided, the data will be loaded automatically. output_dir : str, optional The directory where the annotated AnnData object will be saved. Required if `adata` is provided. Returns ------- main_layout : panel.layout.Row The main layout of the interactive clustering application. Raises ------ ValueError If `adata` is provided but `output_dir` is not specified, or if `output_dir` is not a string. """ warnings.filterwarnings("ignore") pn.extension("deckgl", design="bootstrap", theme="default", template="bootstrap") pn.state.template.config.raw_css.append( """ #main { padding: 0; }""" ) # check if output_dir is provided if adata is provided if adata is not None and not output_dir: raise ValueError( "Please provide an output directory to save the annotated AnnData object." ) # exit the function if output_dir is not provided return else: # check if output_dir is a string if output_dir and not isinstance(output_dir, str): raise ValueError("output_dir must be a string.") # check if output directory exists and create if not: if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) # Define the app def create_clustering_app(): # Callback to load data def load_data(event=None): if adata is not None: adata_container["adata"] = adata marker_list_input.options = list(adata.var_names) output_area.object = "**AnnData object loaded successfully.**" return if not input_path.value or not os.path.isfile(input_path.value): output_area.object = "**Please enter a valid AnnData file path.**" return loaded_adata = sc.read_h5ad(input_path.value) adata_container["adata"] = loaded_adata marker_list_input.options = list(loaded_adata.var_names) output_area.object = "**AnnData file loaded successfully.**" # Callback to run clustering def run_clustering(event): adata = adata_container.get("adata", None) if adata is None: output_area.object = "**Please load an AnnData file first.**" return marker_list = ( list(marker_list_input.value) if marker_list_input.value else None ) key_added = ( key_added_input.value if key_added_input.value else clustering_method.value + "_" + str(resolution.value) ) # Start loading indicator loading_indicator.active = True output_area.object = "**Clustering in progress...**" # Run clustering try: adata = clustering( adata, clustering=clustering_method.value, marker_list=marker_list, resolution=resolution.value, n_neighbors=n_neighbors.value, reclustering=reclustering.value, key_added=key_added, key_filter=None, subset_cluster=None, seed=42, fs_xdim=fs_xdim.value, fs_ydim=fs_ydim.value, fs_rlen=fs_rlen.value, ) adata_container["adata"] = adata output_area.object = "**Clustering completed.**" # Automatically generate visualization key_to_visualize = key_added tabs = [] sc.pl.umap(adata, color=[key_to_visualize], show=False) umap_fig = plt.gcf() plt.close() tabs.append(("UMAP", pn.pane.Matplotlib(umap_fig, dpi=100))) if marker_list: sc.pl.dotplot( adata, marker_list, groupby=key_to_visualize, dendrogram=True, show=False, ) dotplot_fig = plt.gcf() plt.close() tabs.append(("Dotplot", pn.pane.Matplotlib(dotplot_fig, dpi=100))) # Generate histogram plot cluster_counts = adata.obs[key_to_visualize].value_counts() cluster_counts.sort_index(inplace=True) cluster_counts.plot(kind="bar") plt.xlabel("Cluster") plt.ylabel("Number of Cells") plt.title(f"Cluster Counts for {key_to_visualize}") hist_fig = plt.gcf() plt.close() tabs.append(("Histogram", pn.pane.Matplotlib(hist_fig, dpi=100))) # Add new tabs to visualization area for name, pane in tabs: visualization_area.append((name, pane)) # Update cluster annotations clusters = adata.obs[key_to_visualize].unique().astype(str) annotations_df = pd.DataFrame( {"Cluster": clusters, "Annotation": [""] * len(clusters)} ) cluster_annotation.value = annotations_df except Exception as e: output_area.object = f"**Error during clustering: {e}**" finally: # Stop loading indicator loading_indicator.active = False # Callback to run subclustering def run_subclustering(event): adata = adata_container.get("adata", None) if adata is None: output_area.object = "**Please run clustering first.**" return if not subcluster_key.value or not subcluster_values.value: output_area.object = "**Please provide subcluster key and values.**" return clusters = [c.strip() for c in subcluster_values.value.split(",")] key_added = subcluster_key.value + "_subcluster" # Start loading indicator for subclustering loading_indicator_subcluster.active = True output_area.object = "**Subclustering in progress...**" try: sc.tl.leiden( adata, seed=seed.value, restrict_to=(subcluster_key.value, clusters), resolution=subcluster_resolution.value, key_added=key_added, ) adata_container["adata"] = adata output_area.object = "**Subclustering completed.**" # Update visualization tabs = [] sc.pl.umap(adata, color=[key_added], show=False) umap_fig = plt.gcf() plt.close() tabs.append(("UMAP_Sub", pn.pane.Matplotlib(umap_fig, dpi=100))) marker_list = ( list(marker_list_input.value) if marker_list_input.value else None ) if marker_list: sc.pl.dotplot( adata, marker_list, groupby=key_added, dendrogram=True, show=False, ) dotplot_fig = plt.gcf() plt.close() tabs.append( ("Dotplot_Sub", pn.pane.Matplotlib(dotplot_fig, dpi=100)) ) # Generate histogram plot cluster_counts = adata.obs[key_added].value_counts() cluster_counts.sort_index(inplace=True) cluster_counts.plot(kind="bar") plt.xlabel("Subcluster") plt.ylabel("Number of Cells") plt.title(f"Subcluster Counts for {key_added}") hist_fig = plt.gcf() plt.close() tabs.append(("Histogram_Sub", pn.pane.Matplotlib(hist_fig, dpi=100))) # Add new tabs to visualization area for name, pane in tabs: visualization_area.append((name, pane)) # Update cluster annotations clusters = adata.obs[key_added].unique().astype(str) annotations_df = pd.DataFrame( {"Cluster": clusters, "Annotation": [""] * len(clusters)} ) cluster_annotation.value = annotations_df except Exception as e: output_area.object = f"**Error during subclustering: {e}**" finally: # Stop loading indicator for subclustering loading_indicator_subcluster.active = False # Callback to save annotations def save_annotations(event): adata = adata_container.get("adata", None) if adata is None: output_area.object = "**No AnnData object to annotate.**" return annotation_dict = dict( zip( cluster_annotation.value["Cluster"], cluster_annotation.value["Annotation"], ) ) key_to_annotate = ( key_added_input.value if key_added_input.value else clustering_method.value + "_" + str(resolution.value) ) adata.obs["cell_type"] = ( adata.obs[key_to_annotate] .astype(str) .map(annotation_dict) .astype("category") ) output_area.object = "**Annotations saved to AnnData object.**" def save_adata(event): adata = adata_container.get("adata", None) if adata is None: output_area.object = "**No AnnData object to save.**" return if not output_dir_widget.value: output_area.object = "**Please specify an output directory.**" return os.makedirs(output_dir_widget.value, exist_ok=True) output_filepath = os.path.join( output_dir_widget.value, "adata_annotated.h5ad" ) adata.write(output_filepath) output_area.object = f"**AnnData saved to {output_filepath}.**" # Callback to run spatial visualization def run_spatial_visualization(event): adata = adata_container.get("adata", None) if adata is None: output_area.object = "**Please load an AnnData file first.**" return try: catplot( adata, color=spatial_color.value, unique_region=spatial_unique_region.value, X=spatial_x.value, Y=spatial_y.value, n_columns=spatial_n_columns.value, palette=spatial_palette.value, savefig=spatial_savefig.value, output_fname=spatial_output_fname.value, output_dir=output_dir_widget.value, figsize=spatial_figsize.value, size=spatial_size.value, ) spatial_fig = plt.gcf() plt.close() # Add new tab to visualization area visualization_area.append( ("Spatial Visualization", pn.pane.Matplotlib(spatial_fig, dpi=100)) ) output_area.object = "**Spatial visualization completed.**" except Exception as e: output_area.object = f"**Error during spatial visualization: {e}**" # File paths input_path = pn.widgets.TextInput( name="AnnData File Path", placeholder="Enter path to .h5ad file" ) output_dir_widget = pn.widgets.TextInput( name="Output Directory", placeholder="Enter output directory path", value=output_dir if output_dir else "", ) load_data_button = pn.widgets.Button(name="Load Data", button_type="primary") # Clustering parameters clustering_method = pn.widgets.Select( name="Clustering Method", options=["leiden", "louvain", "flowSOM", "leiden_gpu"], ) resolution = pn.widgets.FloatInput(name="Resolution", value=1.0) n_neighbors = pn.widgets.IntInput(name="Number of Neighbors", value=10) reclustering = pn.widgets.Checkbox(name="Reclustering", value=False) seed = pn.widgets.IntInput(name="Random Seed", value=42) key_added_input = pn.widgets.TextInput( name="Key Added", placeholder="Enter key to add to AnnData.obs", value="" ) marker_list_input = pn.widgets.MultiChoice( name="Marker List", options=[], width=950 ) # Subclustering parameters subcluster_key = pn.widgets.TextInput( name="Subcluster Key", placeholder='Enter key to filter on (e.g., "leiden_1")', ) subcluster_values = pn.widgets.TextInput( name="Subcluster Values", placeholder="Enter clusters to subset (comma-separated)", ) subcluster_resolution = pn.widgets.FloatInput( name="Subcluster Resolution", value=0.3 ) subcluster_button = pn.widgets.Button( name="Run Subclustering", button_type="primary" ) # Cluster annotation cluster_annotation = pn.widgets.DataFrame( pd.DataFrame(columns=["Cluster", "Annotation"]), name="Cluster Annotations", autosize_mode="fit_columns", ) save_annotations_button = pn.widgets.Button( name="Save Annotations", button_type="success" ) fs_xdim = pn.widgets.IntInput(name="FlowSOM xdim", value=10) fs_ydim = pn.widgets.IntInput(name="FlowSOM ydim", value=10) fs_rlen = pn.widgets.IntInput(name="FlowSOM rlen", value=10) # Buttons run_clustering_button = pn.widgets.Button( name="Run Clustering", button_type="primary" ) save_adata_button = pn.widgets.Button( name="Save AnnData", button_type="success" ) # Loading indicators loading_indicator = pn.widgets.Progress( name="Clustering Progress", active=False, bar_color="primary" ) loading_indicator_subcluster = pn.widgets.Progress( name="Subclustering Progress", active=False, bar_color="primary" ) # Output areas output_area = pn.pane.Markdown() visualization_area = pn.Tabs() # Changed to pn.Tabs to hold multiple plots # Global variable to hold the AnnData object adata_container = {} # Spatial visualization parameters spatial_color = pn.widgets.TextInput( name="Color By Column", placeholder="Enter group column name (e.g., cell_type_coarse)", ) spatial_unique_region = pn.widgets.TextInput( name="Unique Region Column", value="unique_region" ) spatial_x = pn.widgets.TextInput(name="X Coordinate Column", value="x") spatial_y = pn.widgets.TextInput(name="Y Coordinate Column", value="y") spatial_n_columns = pn.widgets.IntInput(name="Number of Columns", value=2) spatial_palette = pn.widgets.TextInput(name="Color Palette", value="tab20") spatial_figsize = pn.widgets.FloatInput(name="Figure Size", value=17) spatial_size = pn.widgets.FloatInput(name="Point Size", value=20) spatial_savefig = pn.widgets.Checkbox(name="Save Figure", value=False) spatial_output_fname = pn.widgets.TextInput( name="Output Filename", placeholder="Enter output filename" ) run_spatial_visualization_button = pn.widgets.Button( name="Run Spatial Visualization", button_type="primary" ) # Link callbacks load_data_button.on_click(load_data) run_clustering_button.on_click(run_clustering) subcluster_button.on_click(run_subclustering) save_annotations_button.on_click(save_annotations) save_adata_button.on_click(save_adata) run_spatial_visualization_button.on_click(run_spatial_visualization) # Clustering Tab Layout clustering_tab = pn.Column( pn.pane.Markdown("### Load Data"), ( pn.Row(input_path, output_dir_widget, load_data_button) if adata is None else pn.pane.Markdown("AnnData object loaded.") ), pn.layout.Divider(), pn.pane.Markdown("### Clustering Parameters"), pn.Row(clustering_method, resolution, n_neighbors), pn.Row(seed, reclustering), pn.Row(fs_xdim, fs_ydim, fs_rlen), key_added_input, marker_list_input, pn.layout.Divider(), pn.Row(run_clustering_button, loading_indicator), output_area, ) # Subclustering Tab Layout subclustering_tab = pn.Column( pn.pane.Markdown("### Subclustering Parameters"), pn.Row(subcluster_key, subcluster_values, subcluster_resolution), pn.layout.Divider(), pn.Row(subcluster_button, loading_indicator_subcluster), output_area, ) # Annotation Tab Layout annotation_tab = pn.Column( pn.pane.Markdown("### Cluster Annotation"), cluster_annotation, pn.layout.Divider(), save_annotations_button, output_area, ) # Save Tab Layout save_tab = pn.Column( pn.pane.Markdown("### Save Data"), save_adata_button, output_area ) # Spatial Visualization Tab Layout spatial_visualization_tab = pn.Column( pn.pane.Markdown("### Spatial Visualization Parameters"), pn.Row(spatial_color, spatial_palette), pn.Row(spatial_unique_region, spatial_n_columns), pn.Row(spatial_x, spatial_y), pn.Row(spatial_figsize, spatial_size), pn.layout.Divider(), pn.Row(spatial_savefig, spatial_output_fname), pn.layout.Divider(), pn.Row(run_spatial_visualization_button), output_area, ) # Assemble Tabs tabs = pn.Tabs( ("Clustering", clustering_tab), ("Subclustering", subclustering_tab), ("Annotation", annotation_tab), ("Spatial Visualization", spatial_visualization_tab), ("Save", save_tab), ) # Main Layout with Visualization Area main_layout = pn.Row(tabs, visualization_area, sizing_mode="stretch_both") # Automatically load data if adata is provided if adata is not None: load_data() return main_layout # Run the app main_layout = create_clustering_app() main_layout.servable(title="SPACEc Clustering App") return main_layout
# Functions for PPA ## Adjust clustering parameter to get the desired number of clusters def apply_dbscan_clustering( df, min_samples=10, x_col="x", y_col="y", allow_single_cluster=True ): """ Apply DBSCAN clustering to a dataframe and update the cluster labels in the original dataframe. Parameters ---------- df : pandas.DataFrame The dataframe to be clustered. min_cluster_size : int, optional The number of samples in a neighborhood for a point to be considered as a core point, by default 10 Returns ------- None """ # Initialize a new column for cluster labels df["cluster"] = -1 # Apply DBSCAN clustering hdbscan = HDBSCAN( min_samples=min_samples, min_cluster_size=5, cluster_selection_epsilon=0.0, metric="euclidean", cluster_selection_method="eom", allow_single_cluster=allow_single_cluster, ) coords = df[[x_col, y_col]].values labels = hdbscan.fit_predict(coords) # Number of clusters in labels, ignoring noise if present. n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) n_noise_ = list(labels).count(-1) print("Estimated number of clusters: %d" % n_clusters_) print("Estimated number of noise points: %d" % n_noise_) # Update the cluster labels in the original dataframe df.loc[df.index, "cluster"] = labels def identify_points_in_proximity( df, full_df, identification_column, cluster_column="cluster", x_column="x", y_column="y", radius=200, edge_neighbours=3, plot=True, concave_hull_length_threshold=50, concavity=2, ): """ Identify points in proximity within clusters and generate result and outline DataFrames. Parameters ---------- df : pandas.DataFrame DataFrame containing the points to be processed. full_df : pandas.DataFrame Full DataFrame containing all points. identification_column : str Column name used for identification. cluster_column : str, optional Column name for cluster labels, by default "cluster". x_column : str, optional Column name for x-coordinates, by default "x". y_column : str, optional Column name for y-coordinates, by default "y". radius : int, optional Radius for proximity search, by default 200. edge_neighbours : int, optional Number of edge neighbours, by default 3. plot : bool, optional Whether to plot the results, by default True. concave_hull_length_threshold : int, optional Threshold for concave hull length, by default 50. Returns ------- result : pandas.DataFrame DataFrame containing the result points. outlines : pandas.DataFrame DataFrame containing the outline points. """ nbrs, unique_clusters = precompute( df, x_column, y_column, full_df, identification_column, edge_neighbours ) num_processes = max( 1, os.cpu_count() - 1 ) # Use all available CPUs minus 2, but at least 1 with Pool(processes=num_processes) as pool: results = pool.starmap( process_cluster, [ ( ( df, cluster, cluster_column, x_column, y_column, concave_hull_length_threshold, edge_neighbours, full_df, radius, plot, identification_column, concavity, ), nbrs, unique_clusters, ) for cluster in set(df[cluster_column]) - {-1} ], ) # Unpack the results result_list, outline_list = zip(*results) # Concatenate the list of DataFrames into a single result DataFrame if len(result_list) > 0: result = pd.concat(result_list) else: result = pd.DataFrame( columns=[x_column, y_column, "patch_id", identification_column] ) if len(outline_list) > 0: outlines = pd.concat(outline_list) else: outlines = pd.DataFrame( columns=[x_column, y_column, "patch_id", identification_column] ) return result, outlines # Precompute nearest neighbors model and unique clusters def precompute(df, x_column, y_column, full_df, identification_column, edge_neighbours): """ Precompute nearest neighbors and unique clusters. Parameters ---------- df : pandas.DataFrame DataFrame containing the points to be processed. x_column : str Column name for x-coordinates. y_column : str Column name for y-coordinates. full_df : pandas.DataFrame Full DataFrame containing all points. identification_column : str Column name used for identification. edge_neighbours : int Number of edge neighbours. Returns ------- nbrs : sklearn.neighbors.NearestNeighbors Fitted NearestNeighbors model. unique_clusters : numpy.ndarray Array of unique cluster identifiers. """ nbrs = NearestNeighbors(n_neighbors=edge_neighbours).fit(df[[x_column, y_column]]) unique_clusters = full_df[identification_column].unique() return nbrs, unique_clusters def process_cluster(args, nbrs, unique_clusters): ( df, cluster, cluster_column, x_column, y_column, concave_hull_length_threshold, edge_neighbours, full_df, radius, plot, identification_column, concavity, ) = args """ Process a single cluster to identify points in proximity and generate hull points. Parameters ---------- args : tuple Tuple containing the following elements: - df : pandas.DataFrame DataFrame containing the points to be processed. - cluster : int Cluster identifier. - cluster_column : str Column name for cluster labels. - x_column : str Column name for x-coordinates. - y_column : str Column name for y-coordinates. - concave_hull_length_threshold : int Threshold for concave hull length. - edge_neighbours : int Number of edge neighbours. - full_df : pandas.DataFrame Full DataFrame containing all points. - radius : int Radius for proximity search. - plot : bool Whether to plot the results. - identification_column : str Column name used for identification. - concavity : int Concavity parameter for hull generation. nbrs : sklearn.neighbors.NearestNeighbors Fitted NearestNeighbors model. unique_clusters : numpy.ndarray Array of unique cluster identifiers. Returns ------- prox_points : pandas.DataFrame DataFrame containing points within the proximity of the cluster. hull_nearest_neighbors : pandas.DataFrame DataFrame containing the nearest neighbors of the hull points. """ # Filter DataFrame for the current cluster subset = df.loc[df[cluster_column] == cluster] points = subset[[x_column, y_column]].values # Compute concave hull indexes idxes = concave_hull_indexes( points[:, :2], length_threshold=concave_hull_length_threshold, concavity=concavity, ) # Get hull points from the DataFrame hull_points = pd.DataFrame(points[idxes], columns=[x_column, y_column]) # Find nearest neighbors of hull points in the original DataFrame distances, indices = nbrs.kneighbors(hull_points[[x_column, y_column]]) hull_nearest_neighbors = df.iloc[indices.flatten()] # Convert radius to a list if it's a single value if not isinstance(radius, (list, tuple, np.ndarray)): radius_list = [radius] else: radius_list = radius # Extract hull points coordinates hull_coords = hull_nearest_neighbors[[x_column, y_column]].values # Calculate distances from all points in full_df to all hull points distances = cdist(full_df[[x_column, y_column]].values, hull_coords) # Process each radius all_prox_points = [] for r in radius_list: # Identify points within the circle for each hull point in_circle = distances <= r # Identify points from a different cluster for each hull point diff_cluster = ( full_df[identification_column].values[:, np.newaxis] != hull_nearest_neighbors[identification_column].values ) # Combine the conditions in_circle_diff_cluster = in_circle & diff_cluster # Collect all points within the circle but from a different cluster r_in_circle_diff_cluster = full_df[np.any(in_circle_diff_cluster, axis=1)] # Remove duplicates r_prox_points = r_in_circle_diff_cluster.drop_duplicates() # Add patch_id and distance_from_patch columns r_prox_points["patch_id"] = cluster r_prox_points["distance_from_patch"] = r all_prox_points.append(r_prox_points) # Combine results from all radii if all_prox_points: prox_points = pd.concat(all_prox_points, ignore_index=True) # If multiple radii were used, keep only the smallest distance for each point if len(radius_list) > 1: prox_points = prox_points.sort_values( "distance_from_patch" ).drop_duplicates( subset=[ col for col in prox_points.columns if col != "distance_from_patch" ] ) else: # Create empty DataFrame with appropriate columns prox_points = pd.DataFrame( columns=full_df.columns.tolist() + ["patch_id", "distance_from_patch"] ) return prox_points, hull_nearest_neighbors def identify_hull_points( df, cluster_column="cluster", x_col="x", y_col="y", concave_hull_length_threshold=50, concavity=2, ): """ Identify hull points with improved performance. Parameters ---------- df : pandas.DataFrame DataFrame containing spatial points and cluster labels. cluster_column : str, optional Column name for clusters, by default "cluster". x_col : str, optional Column name for the x-coordinate, by default "x". y_col : str, optional Column name for the y-coordinate, by default "y". concave_hull_length_threshold : int, optional Threshold for concave hull length, by default 50. concavity : int, optional Concavity parameter, by default 2. Returns ------- pandas.DataFrame DataFrame of hull points sorted by patch_id and order. """ clusters = sorted(set(df[cluster_column].unique()) - {-1}) if not clusters: return pd.DataFrame(columns=[x_col, y_col, "patch_id"]) hullpoints_list = [] for cluster in clusters: mask = df[cluster_column] == cluster subset = df[mask] points = subset[[x_col, y_col]].values if len(points) < 3: continue idxes = concave_hull_indexes( points, concavity=concavity, length_threshold=concave_hull_length_threshold ) hull_points = subset.iloc[idxes].reset_index(drop=True) hull_points["order"] = range(len(hull_points)) hull_points["patch_id"] = cluster hullpoints_list.append(hull_points) if not hullpoints_list: return pd.DataFrame(columns=[x_col, y_col, "patch_id"]) hull = pd.concat(hullpoints_list, ignore_index=True, sort=False) return hull.sort_values(by=["patch_id", "order"]).drop(columns="order") def convert_dataframe_to_geojson( df, output_dir, region_name=None, x="x", y="y", sample_col=None, region_col="unique_region", patch_col="patch_id", geojson_prefix="hull_coordinates", save_geojson=True, ): """ Convert a DataFrame into GeoJSON format with optional saving to file. Parameters ---------- df : pandas.DataFrame Input DataFrame with spatial coordinates. output_dir : str Directory in which GeoJSON files will be saved. region_name : str, optional Optional region name to create a subfolder, by default None. x : str, optional Column name for the x-coordinate, by default "x". y : str, optional Column name for the y-coordinate, by default "y". sample_col : str, optional Column name to separate by samples, by default None. region_col : str, optional Column name for the region, by default "unique_region". patch_col : str, optional Column name for the patch, by default "patch_id". geojson_prefix : str, optional Prefix for the GeoJSON filename, by default "hull_coordinates". save_geojson : bool, optional Whether to save the GeoJSON to disk, by default True. Returns ------- list of dict Each dictionary contains a filename and the GeoJSON feature collection. """ required_columns = [region_col, patch_col, x, y] if sample_col is not None: required_columns.append(sample_col) missing_cols = [col for col in required_columns if col not in df.columns] if missing_cols: raise KeyError(f"Missing required columns in dataframe: {missing_cols}") if save_geojson: if region_name is not None: region_dir = os.path.join(output_dir, f"region_{region_name}") os.makedirs(region_dir, exist_ok=True) else: region_dir = output_dir os.makedirs(region_dir, exist_ok=True) geojson_results = [] sample_values = df[sample_col].unique() if sample_col is not None else [None] for sample in sample_values: if sample is None: sample_df = df sample_label = "all" else: sample_df = df[df[sample_col] == sample] sample_search = re.search(r"\d+", str(sample)) sample_label = sample_search.group() if sample_search else str(sample) for region in sample_df[region_col].unique(): region_df = sample_df[sample_df[region_col] == region] features, skipped, region_label = process_geojson_region( region_df, region, region_col, patch_col, x, y, sample_label ) all_features = features if sample is not None: filename = f"{geojson_prefix}_sample-{sample_label}_region-{region_label}_separate_coordinates.geojson" else: filename = f"{geojson_prefix}_region-{region_label}_separate_coordinates.geojson" geojson_dict = { "type": "FeatureCollection", "features": all_features, "name": filename, } if save_geojson: geojson_path = os.path.join(region_dir, filename) with open(geojson_path, "w") as f: json.dump(geojson_dict, f) geojson_results.append({"filename": filename, "geojson": geojson_dict}) if skipped: print(f"Skipped {len(skipped)} clusters with insufficient points") return geojson_results def process_geojson_region( region_df, region, region_col, patch_col, x, y, sample_label="all" ): """ Process a single region to generate GeoJSON features. Parameters ---------- region_df : pandas.DataFrame Subset DataFrame for the region. region : any Region identifier used to extract a label. region_col : str Column name indicating region information. patch_col : str Column name for patch identifiers. x : str Column name for x-coordinate. y : str Column name for y-coordinate. sample_label : str, optional Label for sample grouping, by default "all". Returns ------- tuple A tuple containing: - features (list): List of GeoJSON feature dictionaries. - skipped_clusters (list): List of clusters skipped due to insufficient points. - region_label (str): Extracted region label. """ region_search = re.search(r"\d+", str(region)) region_label = region_search.group() if region_search else str(region) features = [] skipped_clusters = [] for cluster, cluster_df in region_df.groupby(patch_col): coordinates = cluster_df[[x, y]].values.tolist() if len(coordinates) >= 3: geom_type = "Polygon" coords_format = [coordinates] elif len(coordinates) == 2: geom_type = "LineString" coords_format = coordinates else: skipped_clusters.append((sample_label, region_label, cluster)) continue feature = { "type": "Feature", "properties": { "sample": ( int(sample_label) if str(sample_label).isdigit() else sample_label ), "region": ( int(region_label) if str(region_label).isdigit() else region_label ), "cluster": int(cluster) if str(cluster).isdigit() else cluster, }, "geometry": { "type": geom_type, "coordinates": coords_format, }, } features.append(feature) return features, skipped_clusters, region_label def extract_region_number(unique_region_value): """ Extract the numeric part of a region identifier. Parameters ---------- unique_region_value : int, float, or str The region value from which to extract digits. Returns ------- str The numeric region value as a string. """ try: if isinstance(unique_region_value, (int, float)): return str(int(unique_region_value)) digits = "".join(filter(str.isdigit, str(unique_region_value))) return str(int(digits)) if digits else str(unique_region_value) except ValueError: return str(unique_region_value) def analyze_peripheral_cells( patches_gdf, codex_gdf, buffer_distances, original_unit_scale, tolerance_distance ): """ Analyze peripheral cells with parallel processing. Parameters ---------- patches_gdf : geopandas.GeoDataFrame GeoDataFrame with patch geometries. codex_gdf : geopandas.GeoDataFrame GeoDataFrame with codex point geometries. buffer_distances : list of int List of distances to buffer. original_unit_scale : float Scale factor for the units. tolerance_distance : float Tolerance for determining peripheral regions. Returns ------- tuple A tuple containing: - results (dict): Dictionary with keys as distances and values as DataFrames of peripheral cells. - buffer_geometries (dict): Dictionary with buffer geometries for visualization. """ region_tasks = [] for region_name, region_patches in patches_gdf.groupby("region_numeric"): region_codex_cells = codex_gdf[ codex_gdf["unique_region_numeric"] == region_name ] if len(region_codex_cells) > 0: region_tasks.append( ( region_name, region_patches, region_codex_cells, buffer_distances, original_unit_scale, tolerance_distance, ) ) results = {dist: [] for dist in buffer_distances} buffer_geometries = {dist: [] for dist in buffer_distances} max_workers = min(os.cpu_count(), len(region_tasks)) if max_workers > 1: with concurrent.futures.ProcessPoolExecutor( max_workers=max_workers ) as executor: for region_name, region_results, region_buffers in executor.map( process_region_peripheral_cells, region_tasks ): for dist, df in region_results.items(): if not df.empty: results[dist].append(df) for dist, buffers in region_buffers.items(): buffer_geometries[dist].extend(buffers) else: for task in region_tasks: region_name, region_results, region_buffers = ( process_region_peripheral_cells(task) ) for dist, df in region_results.items(): if not df.empty: results[dist].append(df) for dist, buffers in region_buffers.items(): buffer_geometries[dist].extend(buffers) for dist in buffer_distances: if results[dist]: results[dist] = pd.concat(results[dist], ignore_index=True) else: results[dist] = pd.DataFrame() return results, buffer_geometries def save_peripheral_cells(results, unit_name, region_name, output_dir, save_csv=True): """ Save peripheral cells for each buffer distance to CSV files. Parameters ---------- results : dict Dictionary with keys as distances and values as DataFrames of peripheral cells. unit_name : str Name of the unit to include in filenames. region_name : str Region identifier used in filenames. output_dir : str Directory to save CSV files. save_csv : bool, optional Whether to save CSV files, by default True. Returns ------- pandas.DataFrame Combined DataFrame of peripheral cells from all distances. """ all_frames = [] if save_csv: region_dir = os.path.join(output_dir, f"region_{region_name}") os.makedirs(region_dir, exist_ok=True) for dist, data in results.items(): if data.empty: continue data["dist"] = dist all_frames.append(data) if save_csv: peripheral_path = os.path.join( region_dir, f"{unit_name}_region_{region_name}_peripheral_cells_{dist}um.csv", ) data.to_csv(peripheral_path, index=False) combined_df = ( pd.concat(all_frames, ignore_index=True) if all_frames else pd.DataFrame() ) if save_csv and not combined_df.empty: combined_path = os.path.join( region_dir, f"{unit_name}_region_{region_name}_peripheral_cells_combined.csv", ) combined_df.to_csv(combined_path, index=False) return combined_df def process_region_peripheral_cells(args): """ Process peripheral cells for a given region (for parallel processing). Parameters ---------- args : tuple A tuple containing: - region_name : any Region identifier. - region_patches : geopandas.GeoDataFrame GeoDataFrame of patches in the region. - region_codex_cells : geopandas.GeoDataFrame GeoDataFrame of codex cells in the region. - buffer_distances : list of int List of distances to buffer. - original_unit_scale : float Original unit scale for distance conversion. - tolerance_distance : float Tolerance for buffering. Returns ------- tuple A tuple containing: - region_name : any - results : dict Dictionary with peripheral cell DataFrames for each distance. - buffer_geometries : dict Dictionary with buffer geometry information. """ ( region_name, region_patches, region_codex_cells, buffer_distances, original_unit_scale, tolerance_distance, ) = args region_codex_cells_sindex = region_codex_cells.sindex results = {dist: [] for dist in buffer_distances} buffer_geometries = {dist: [] for dist in buffer_distances} for _, patch in region_patches.iterrows(): patch_polygon = patch["geometry"] cluster_label = patch["cluster"] patch_id = patch.name for dist in buffer_distances: scaled_dist = dist / original_unit_scale expanded_patch = patch_polygon.buffer(scaled_dist) peripheral_region = expanded_patch.difference(patch_polygon).buffer( tolerance_distance ) buffer_geometries[dist].append( { "patch_id": patch_id, "cluster": cluster_label, "original": patch_polygon, "expanded": expanded_patch, "peripheral": peripheral_region, } ) possible_matches_idx = list( region_codex_cells_sindex.intersection(peripheral_region.bounds) ) possible_matches = region_codex_cells.iloc[possible_matches_idx] mask = possible_matches.geometry.within(peripheral_region) peripheral_cells = possible_matches[mask].copy() peripheral_cells["cluster"] = cluster_label peripheral_cells["buffer_distance"] = dist peripheral_cells["patch_id"] = patch_id results[dist].append(peripheral_cells) for dist in buffer_distances: if results[dist]: results[dist] = pd.concat(results[dist], ignore_index=True) else: results[dist] = pd.DataFrame() return region_name, results, buffer_geometries def extract_unit_name(geojson): """ Extract a unit name from a GeoJSON object. Parameters ---------- geojson : dict or object with attribute 'name' The GeoJSON object or an object having a 'name' property. Returns ------- str The extracted unit name. Raises ------ ValueError If the GeoJSON does not have a 'name' property. """ if hasattr(geojson, "name"): file_name = geojson.name elif isinstance(geojson, dict) and "name" in geojson: file_name = geojson["name"] else: raise ValueError("GeoJSON object does not have a 'name' property") file_name_no_ext = os.path.splitext(file_name)[0] parts = file_name_no_ext.split("_") if "ppa" in parts: return "_".join(parts[: parts.index("ppa")]) return file_name_no_ext
[docs] def patch_proximity_analysis( adata, region_column, patch_column, group, min_cluster_size=80, x_column="x", y_column="y", radius=128, edge_neighbours=1, plot=True, savefig=False, output_dir="./", output_fname="", save_geojson=True, allow_single_cluster=True, method="border_cell_radius", concave_hull_length_threshold=50, concavity=2, original_unit_scale=1, tolerance_distance=0.001, key_name=None, ): """ Performs a proximity analysis on patches of a given group within each region of a dataset. This function processes an AnnData object by extracting its cell observations and performing proximity analysis on a specified cell group (e.g. a cell type or neighborhood) within each region. Depending on the chosen method ("border_cell_radius" or "hull_expansion"), the analysis applies DBSCAN clustering, identifies concave hull boundaries, and then either determines nearby cells based on a fixed search radius or uses a peripheral buffering approach. Optionally, the function can plot visualization of the analysis and save outputs (figures, CSV files, and GeoJSON). Parameters ---------- adata : AnnData The annotated data matrix of shape (n_obs x n_vars). Rows correspond to individual cells and columns to gene expression or other features. region_column : str The name of the column in adata.obs that contains region information. patch_column : str The name of the column in adata.obs that contains patch (or group) information. group : str The specific group (e.g. cell type or patch identifier) on which the proximity analysis is to be performed. min_cluster_size : int, optional The minimum number of cells required in a region to perform the analysis. Regions with fewer cells than this value will be skipped. Default is 80. x_column : str, optional The column name in adata.obs corresponding to the x-coordinate of each cell. Default is "x". y_column : str, optional The column name in adata.obs corresponding to the y-coordinate of each cell. Default is "y". radius : int, optional The distance (in spatial units) within which points are considered to be in proximity. This value is multiplied by original_unit_scale. Default is 128. edge_neighbours : int, optional The number of neighbouring edge points to consider when identifying proximity relationships. Default is 1. plot : bool, optional Whether to generate and display visualizations of the proximity analysis. Default is True. savefig : bool, optional Whether to save the generated figure to disk. Default is False. output_dir : str, optional The directory in which to save output files (figures, CSVs, or GeoJSON files). Default is "./". output_fname : str, optional The filename prefix to use when saving figures. Default is an empty string. save_geojson : bool, optional Whether to convert certain results to GeoJSON format and save them. Default is True. allow_single_cluster : bool, optional If True, allows DBSCAN to assign all cells to a single cluster even if no separate clusters exist. Default is True. method : str, optional The analysis method to use. Options are "border_cell_radius" (default) or "hull_expansion". Each method applies a different strategy for proximity detection. concave_hull_length_threshold : int, optional Threshold value used for generating the concave hull boundary. Default is 50. concavity : int, optional Parameter specifying the degree of concavity when calculating the hull boundary. Default is 2. original_unit_scale : int or float, optional A scaling factor to convert the radius from its given unit to the coordinate system unit. Default is 1. tolerance_distance : float, optional Tolerance value for buffering in the peripheral analysis (used when method is "hull_expansion"). Default is 0.001. key_name : str, optional The key under which the final proximity analysis results are stored in adata.uns. If not provided, defaults to "ppa_result". Returns ------- final_results : pandas.DataFrame A DataFrame containing the combined proximity analysis results from all processed regions. It includes, among other information, a newly generated "unique_patch_ID" column that concatenates the region, group, and patch identifier. outlines_results : pandas.DataFrame A DataFrame containing the outline (or hull) points corresponding to the patches; useful for visualization or further spatial analysis. """ # multiply radius by original_unit_scale if isinstance(radius, (list, tuple)): radius = [r * original_unit_scale for r in radius] else: radius = radius * original_unit_scale # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) distance_from_patch = radius # make list if isinstance(distance_from_patch, int): distance_from_patch = [distance_from_patch] # Get data from adata df = adata.obs # Check if the required columns are present in the DataFrame if region_column not in df.columns: raise ValueError(f"Column '{region_column}' not found in adata.obs") if patch_column not in df.columns: raise ValueError(f"Column '{patch_column}' not found in adata.obs") if group not in df[patch_column].unique(): raise ValueError(f"Group '{group}' not found in column '{patch_column}'") # Convert categorical columns to string once for col in df.select_dtypes(["category"]).columns: df[col] = df[col].astype(str) # list to store results for each region region_results = [] outlines = [] for region in df[region_column].unique(): df_region = df[df[region_column] == region].copy() df_community = df_region[df_region[patch_column] == group].copy() # Check if region is large enough if df_community.shape[0] < min_cluster_size: print(f"No {group} in {region}") continue else: print(f"Processing {region}_{group}") # Create region directory if save_geojson or (plot and savefig): region_dir = os.path.join(output_dir, f"region_{region}") os.makedirs(region_dir, exist_ok=True) apply_dbscan_clustering( df_community, min_samples=min_cluster_size, x_col=x_column, y_col=y_column, allow_single_cluster=allow_single_cluster, ) # Identify hull points hull = identify_hull_points( df_community, cluster_column="cluster", x_col=x_column, y_col=y_column, concave_hull_length_threshold=concave_hull_length_threshold, concavity=concavity, ) # Skip if no clusters were found if hull.empty: print(f"No clusters found for region {region}.") continue if method == "border_cell_radius": results, hull_nearest_neighbors = identify_points_in_proximity( df=df_community, full_df=df_region, cluster_column="cluster", identification_column=patch_column, x_column=x_column, y_column=y_column, radius=radius, edge_neighbours=edge_neighbours, plot=plot, concave_hull_length_threshold=concave_hull_length_threshold, concavity=concavity, ) # add hull_nearest_neighbors to list outlines.append(hull_nearest_neighbors) # Convert to GeoJSON if save_geojson: geojson_results = convert_dataframe_to_geojson( df=hull, output_dir=output_dir, region_name=region, x=x_column, y=y_column, region_col="unique_region", patch_col="patch_id", save_geojson=save_geojson, ) # Create visualizations for border_cell_radius method if plot: try: fig = create_visualization_border_cell_radius( region_name=region, group_name=group, df_community=df_community, df_full=df_region, cluster_column="cluster", identification_column=patch_column, x_column=x_column, y_column=y_column, radius=radius, hull_points=hull, proximity_results=results, hull_neighbors=hull_nearest_neighbors, ) if savefig: plot_path = os.path.join( region_dir if region_dir else output_dir, f"{output_fname}_patch_proximity_analysis_region_{region}.pdf", ) # Save with higher quality and proper bounds fig.savefig( plot_path, bbox_inches="tight", dpi=300, format="pdf" ) print(f"Saved visualization to: {plot_path}") else: plt.show() plt.close(fig) # Ensure figure is closed to free memory except Exception as e: print( f"Warning: Failed to create visualization for region {region}: {str(e)}" ) print(f"Finished {region}_{group}") # append to region_results region_results.append(results) elif method == "hull_expansion": geojson_results = convert_dataframe_to_geojson( df=hull, output_dir=output_dir, region_name=region, x=x_column, y=y_column, region_col="unique_region", patch_col="patch_id", save_geojson=save_geojson, ) # Create GeoDataFrames once per region patches_gdf = gpd.GeoDataFrame.from_features( geojson_results[0]["geojson"]["features"] ) codex_points = gpd.points_from_xy( df_region[x_column], df_region[y_column] ) codex_gdf = gpd.GeoDataFrame(df_region, geometry=codex_points) codex_gdf.set_crs(patches_gdf.crs, inplace=True) # Extract region numbers once patches_gdf["region_numeric"] = patches_gdf["region"].apply( extract_region_number ) codex_gdf["unique_region_numeric"] = df_region[region_column].apply( extract_region_number ) # Run peripheral analysis with buffer geometries buffer_results, buffer_geometries = analyze_peripheral_cells( patches_gdf=patches_gdf, codex_gdf=codex_gdf, buffer_distances=distance_from_patch, original_unit_scale=original_unit_scale, tolerance_distance=tolerance_distance, ) # Save results unit_name = extract_unit_name(geojson_results[0]["geojson"]) combined_df = save_peripheral_cells( results=buffer_results, unit_name=unit_name, region_name=region, output_dir=output_dir, save_csv=False, ) # rename dist to distance_from_patch combined_df.rename( columns={"dist": "distance_from_patch"}, inplace=True ) # remove column unique_region_numeric and buffer_distance combined_df.drop( columns=["unique_region_numeric", "buffer_distance"], inplace=True, errors="ignore", ) # remove duplicates combined_df = combined_df.drop_duplicates( subset=[x_column, y_column, "patch_id", "distance_from_patch"] ) # append to region_results region_results.append(combined_df) outlines.append(hull) if plot: try: fig = create_visualization_hull_expansion( region=region, group=group, df_community=df_community, hull=hull, patches_gdf=patches_gdf, df_region=df_region, buffer_geometries=buffer_geometries, peripheral_results=buffer_results, x_column=x_column, y_column=y_column, buffer_distances=distance_from_patch, original_unit_scale=original_unit_scale, ) if savefig: plot_path = os.path.join( region_dir if region_dir else output_dir, f"{output_fname}_patch_proximity_analysis_region_{region}.pdf", ) # Save with higher quality and proper bounds fig.savefig( plot_path, bbox_inches="tight", dpi=300, format="pdf" ) print(f"Saved visualization to: {plot_path}") else: plt.show() plt.close(fig) # Ensure figure is closed to free memory except Exception as e: print( f"Warning: Failed to create visualization for region {region}: {str(e)}" ) else: raise ValueError( f"Unknown method: {method}. Please choose either 'border_cell_radius' or 'hull_expansion'." ) # Concatenate all results into a single DataFrame final_results = pd.concat(region_results) outlines_results = pd.concat(outlines) # add as key to adata.uns if key_name is None: key_name = "ppa_result" if key_name in adata.uns: adata.uns[key_name] = pd.concat([adata.uns[key_name], final_results]) else: adata.uns[key_name] = final_results # generate new column named unique_patch_ID that combines the region, group and patch ID final_results["unique_patch_ID"] = ( final_results[region_column] + "_" + final_results[patch_column] + "_" + "patch_no_" + final_results["patch_id"].astype(str) ) return final_results, outlines_results
def create_visualization_hull_expansion( region, group, df_community, hull, patches_gdf, df_region, buffer_geometries, peripheral_results, x_column, y_column, buffer_distances, original_unit_scale, figsize=(20, 16), ): """ Create comprehensive visualization of the patch proximity analysis. This function creates a multi-panel figure that visualizes the complete workflow of the analysis, including original clustering of cells, identification of concave hull boundaries, visualization of expanded buffer zones, and detection of peripheral cells near the patch boundaries. Parameters ---------- region : str Name of the region to be analyzed. group : str Name of the cell group (or category) under investigation. df_community : pandas.DataFrame DataFrame containing the subset of cells (community) used for clustering and further analysis. hull : pandas.DataFrame DataFrame with hull points that form the concave boundaries of clusters. patches_gdf : geopandas.GeoDataFrame GeoDataFrame containing the patch geometries for visualization. df_region : pandas.DataFrame DataFrame containing all cells in the region for contextual plotting. buffer_geometries : dict Dictionary mapping each buffer distance to a list of buffer geometry objects (expanded polygons). peripheral_results : dict Dictionary mapping each buffer distance to a DataFrame of peripheral cells detected within the buffer zones. x_column : str Column name for the x-coordinate in the DataFrames. y_column : str Column name for the y-coordinate in the DataFrames. buffer_distances : list of int List of radii (in spatial units) used for creating buffer zones. original_unit_scale : float Scale factor representing the unit conversion (e.g., 1 unit = N µm). figsize : tuple, optional Size of the generated figure (width, height in inches). Default is (20, 16). Returns ------- matplotlib.figure.Figure A Matplotlib Figure object containing the multi-panel visualization. Notes ----- - The function uses a colorblind-friendly palette (Matplotlib's tab20) for the clustering. - Search radii are visualized using dashed circles drawn around hull neighbor points. - Legends, annotations, and titles are added to convey clustering metrics and analysis steps. """ # Filter to show only clustered points df_filtered = df_community[df_community["cluster"] != -1] # Create a colormap for clusters unique_clusters = sorted(df_filtered["cluster"].unique()) n_clusters = len(unique_clusters) # Create a colorblind-friendly color palette for clusters cluster_colors = plt.cm.tab20(np.linspace(0, 1, max(20, n_clusters))) cluster_cmap = mcolors.ListedColormap(cluster_colors[:n_clusters]) # Create buffer distance colors (using a different color scheme) buffer_colors = { dist: plt.cm.plasma(i / len(buffer_distances)) for i, dist in enumerate(buffer_distances) } # Calculate data bounds to maintain consistent view across all plots x_min = df_region[x_column].min() x_max = df_region[x_column].max() y_min = df_region[y_column].min() y_max = df_region[y_column].max() # Add some padding to the bounds (5% on each side) x_padding = 0.05 * (x_max - x_min) y_padding = 0.05 * (y_max - y_min) x_bounds = [x_min - x_padding, x_max + x_padding] y_bounds = [y_min - y_padding, y_max + y_padding] # Create figure with adequate spacing between subplots fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True) axes = axes.flatten() # PANEL 1: Original clustering ax1 = axes[0] scatter1 = ax1.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered["cluster"], cmap=cluster_cmap, alpha=0.7, s=30, edgecolor="none", ) # Add background points (all cells in region) ax1.scatter( df_region[x_column], df_region[y_column], color="lightgray", alpha=0.3, s=10, label="All cells", ) ax1.set_title(f"Clustering of {group} cells in Region {region}", fontsize=14) ax1.set_xlabel(x_column, fontsize=12) ax1.set_ylabel(y_column, fontsize=12) ax1.grid(alpha=0.3) ax1.set_aspect("equal") # Maintain aspect ratio ax1.set_xlim(x_bounds) ax1.set_ylim(y_bounds) # Add legend for clusters in a good position if n_clusters <= 10: # Only show legend for reasonable number of clusters legend1 = ax1.legend( handles=[ Patch(color=cluster_cmap(i), label=f"Cluster {cluster}") for i, cluster in enumerate(unique_clusters) ], title="Clusters", loc="best", frameon=True, bbox_to_anchor=(1.02, 1), fontsize=10, ) ax1.add_artist(legend1) else: # Just add a colorbar cbar = fig.colorbar(scatter1, ax=ax1, pad=0.01, shrink=0.8) cbar.set_label("Cluster ID") # Annotate with number of clusters using relative positioning # Position in the top-left with padding from the axes ax1.annotate( f"Number of clusters: {n_clusters}", xy=(0.02, 0.98), xycoords="axes fraction", fontsize=11, ha="left", va="top", bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8), ) # PANEL 2: Hull points and polygons ax2 = axes[1] # Plot original points with lower alpha ax2.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered["cluster"], cmap=cluster_cmap, alpha=0.3, s=20, ) # Plot hull points if not hull.empty: ax2.scatter( hull[x_column], hull[y_column], color="red", s=50, label="Hull Points", edgecolor="black", linewidth=1, alpha=0.8, ) # Plot the polygons from patches_gdf for idx, patch in patches_gdf.iterrows(): cluster_idx = ( unique_clusters.index(patch["cluster"]) if patch["cluster"] in unique_clusters else 0 ) color = cluster_cmap(cluster_idx) # Plot the polygon boundary x, y = patch.geometry.exterior.xy ax2.plot(x, y, color=color, linewidth=2, alpha=0.9) # Add label in center of polygon, with smart positioning # Use path effects to ensure visibility against any background centroid = patch.geometry.centroid txt = ax2.text( centroid.x, centroid.y, f"C{patch['cluster']}", fontsize=10, ha="center", va="center", fontweight="bold", color="black", ) txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground="white")]) ax2.set_title("Hull Points and Resulting Polygons", fontsize=14) ax2.set_xlabel(x_column, fontsize=12) ax2.set_ylabel(y_column, fontsize=12) ax2.grid(alpha=0.3) ax2.set_aspect("equal") # Maintain aspect ratio ax2.set_xlim(x_bounds) ax2.set_ylim(y_bounds) # Position the legend in a less crowded area ax2.legend(loc="best", bbox_to_anchor=(1.02, 1)) # PANEL 3: Buffer regions ax3 = axes[2] # First plot all original polygons with light colors for idx, patch in patches_gdf.iterrows(): cluster_idx = ( unique_clusters.index(patch["cluster"]) if patch["cluster"] in unique_clusters else 0 ) color = cluster_cmap(cluster_idx) # Plot the original polygon with a solid line x, y = patch.geometry.exterior.xy ax3.plot(x, y, color=color, linewidth=2, alpha=0.7) # For each buffer distance, draw expanded polygons for dist_idx, dist in enumerate(buffer_distances): buffer_color = buffer_colors[dist] # Find the corresponding buffer geometry for buffer_geom in buffer_geometries[dist]: if buffer_geom["patch_id"] == idx: # Draw the expanded polygon with a dashed line try: x, y = buffer_geom["expanded"].exterior.xy ax3.plot( x, y, color=buffer_color, linewidth=1.5, linestyle="--", alpha=0.7, label=( f"{dist} unit buffer" if idx == list(patches_gdf.index)[0] and dist_idx == 0 else "" ), ) except: # Handle MultiPolygons or other complex geometries if isinstance(buffer_geom["expanded"], MultiPolygon): for geom in buffer_geom["expanded"].geoms: x, y = geom.exterior.xy ax3.plot( x, y, color=buffer_color, linewidth=1.5, linestyle="--", alpha=0.7, ) ax3.set_title("Buffer Zones Around Polygons", fontsize=14) ax3.set_xlabel(x_column, fontsize=12) ax3.set_ylabel(y_column, fontsize=12) ax3.grid(alpha=0.3) ax3.set_aspect("equal") # Maintain aspect ratio ax3.set_xlim(x_bounds) ax3.set_ylim(y_bounds) # Create legend for buffer distances with custom positioning buffer_legend_handles = [ Patch(color=buffer_colors[dist], alpha=0.7, label=f"{dist} unit buffer") for dist in buffer_distances ] ax3.legend(handles=buffer_legend_handles, loc="best", bbox_to_anchor=(1.02, 1)) # PANEL 4: Peripheral cells ax4 = axes[3] # Plot original polygons for idx, patch in patches_gdf.iterrows(): cluster_idx = ( unique_clusters.index(patch["cluster"]) if patch["cluster"] in unique_clusters else 0 ) color = cluster_cmap(cluster_idx) # Plot the polygon outline x, y = patch.geometry.exterior.xy ax4.plot(x, y, color=color, linewidth=2, alpha=0.7) # Plot peripheral cells for each buffer distance for dist in buffer_distances: peripheral_cells = peripheral_results[dist] if peripheral_cells.empty: continue # Plot cells with distinct markers for each buffer distance markers = ["o", "s", "^", "d", "*"] # circle, square, triangle, diamond, star marker = markers[buffer_distances.index(dist) % len(markers)] ax4.scatter( peripheral_cells[x_column], peripheral_cells[y_column], color=buffer_colors[dist], marker=marker, s=40, alpha=0.7, edgecolor="black", linewidth=0.5, label=f"Peripheral cells ({dist} unit buffer)", ) ax4.set_title("Detected Peripheral Cells by Buffer Distance", fontsize=14) ax4.set_xlabel(x_column, fontsize=12) ax4.set_ylabel(y_column, fontsize=12) ax4.grid(alpha=0.3) ax4.set_aspect("equal") # Maintain aspect ratio ax4.set_xlim(x_bounds) ax4.set_ylim(y_bounds) # Position the legend outside the plot area if it might overlap with data ax4.legend(loc="best", bbox_to_anchor=(1.02, 1)) # Add overall title with enough space fig.suptitle( f"Patch Proximity Analysis for {group} in Region {region}", fontsize=18, y=0.98 ) # Add explanatory text at the bottom with enough padding # Position it well below the plots to avoid overlap explanation_text = ( f"This analysis identifies clusters of {group} cells, creates boundary polygons, and detects " f"nearby cells within {', '.join(map(str, buffer_distances))} unit buffer zones. " ) fig.text( 0.5, 0.01, explanation_text, ha="center", va="bottom", fontsize=12, bbox=dict(boxstyle="round,pad=0.5", fc="lightyellow", ec="orange", alpha=0.8), ) # Make sure layout adapts to the content plt.tight_layout(rect=[0, 0.03, 1, 0.95]) return fig def create_visualization_border_cell_radius( region_name, group_name, df_community, df_full, cluster_column="cluster", identification_column=None, x_column="x", y_column="y", radius=200, hull_points=None, proximity_results=None, hull_neighbors=None, figsize=(20, 16), ): """ Create a multi-panel visualization for the border cell radius proximity analysis method. This function generates a figure that illustrates the workflow of the analysis by plotting four panels: (1) the original clustering with noise indicated, (2) the concave hull boundary detection, (3) the radius search visualization from all hull points, and (4) the proximity results showing points from different categories near the cluster boundaries. Each panel is carefully formatted to maintain a consistent view across the plots. Parameters ---------- region_name : str Name of the region being analyzed. group_name : str Name of the cell group (or category) under investigation. df_community : pandas.DataFrame DataFrame containing the subset of cells (community) used for clustering and further analysis. df_full : pandas.DataFrame DataFrame containing all the cells in the region for context in plots. cluster_column : str, optional Column name for cluster labels in df_community. Default is "cluster". identification_column : str, optional Column name used to identify cell categories when plotting proximity results. Default is None. x_column : str, optional Column name for the x-coordinate in the DataFrames. Default is "x". y_column : str, optional Column name for the y-coordinate in the DataFrames. Default is "y". radius : int or list of int, optional The radius or list of radii (in the same units as the coordinates) used for the proximity search. Default is 200. hull_points : pandas.DataFrame, optional DataFrame containing the points that form the concave hull boundaries of clusters. If provided, these points are used for visualizing the hull boundaries. Default is None. proximity_results : pandas.DataFrame, optional DataFrame containing the results from the proximity search (cells near the hull points) with additional information (e.g. distance from patch). Default is None. hull_neighbors : pandas.DataFrame, optional DataFrame containing the hull neighbor points where the search circles (radii) are drawn. Default is None. figsize : tuple, optional Size of the generated figure in inches (width, height). Default is (20, 16). Returns ------- matplotlib.figure.Figure A Matplotlib Figure object that contains the generated multi-panel visualization. Notes ----- - The function uses the matplotlib patches (Circle) to draw search radii around each hull point. - A colorblind-friendly color palette (tab20 from Matplotlib) is used for representing cluster identities. - If a single radius is provided, it is internally converted to a list to allow uniform processing. - Legends and annotations are added to provide additional context on the clustering and proximity metrics. """ from matplotlib.patches import Circle # Filter to show only clustered points df_filtered = df_community[df_community[cluster_column] != -1] # Create a colormap for clusters unique_clusters = sorted(df_filtered[cluster_column].unique()) n_clusters = len(unique_clusters) # Create a colorblind-friendly color palette for clusters cluster_colors = plt.cm.tab20(np.linspace(0, 1, max(20, n_clusters))) cluster_cmap = mcolors.ListedColormap(cluster_colors[:n_clusters]) # Convert radius to list if it's a single value if not isinstance(radius, (list, tuple, np.ndarray)): radius_list = [radius] else: radius_list = radius # Create a color mapping for different radii radius_colors = plt.cm.plasma(np.linspace(0, 0.8, len(radius_list))) radius_color_map = {r: radius_colors[i] for i, r in enumerate(radius_list)} # Calculate data bounds to maintain consistent view across all plots x_min = df_full[x_column].min() x_max = df_full[x_column].max() y_min = df_full[y_column].min() y_max = df_full[y_column].max() # Add some padding to the bounds (5% on each side) x_padding = 0.05 * (x_max - x_min) y_padding = 0.05 * (y_max - y_min) x_bounds = [x_min - x_padding, x_max + x_padding] y_bounds = [y_min - y_padding, y_max + y_padding] # Create figure with adequate spacing between subplots fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True) axes = axes.flatten() # PANEL 1: Original clustering ax1 = axes[0] # Plot all points in the region with low opacity ax1.scatter( df_full[x_column], df_full[y_column], color="lightgray", alpha=0.3, s=10, label="All cells", ) # Plot clustered points with colors by cluster scatter1 = ax1.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered[cluster_column], cmap=cluster_cmap, alpha=0.7, s=30, edgecolor="none", ) # Mark noise points with 'x' noise_points = df_community[df_community[cluster_column] == -1] if len(noise_points) > 0: ax1.scatter( noise_points[x_column], noise_points[y_column], color="gray", marker="x", s=20, alpha=0.5, label="Noise points", ) ax1.set_title( f"HDBSCAN Clustering of {group_name} in Region {region_name}", fontsize=14 ) ax1.set_xlabel(x_column, fontsize=12) ax1.set_ylabel(y_column, fontsize=12) ax1.grid(alpha=0.3) ax1.set_aspect("equal") ax1.set_xlim(x_bounds) ax1.set_ylim(y_bounds) # Add legend for clusters if n_clusters <= 10: # Only show legend for reasonable number of clusters legend1 = ax1.legend( handles=[ Patch(color=cluster_cmap(i), label=f"Cluster {cluster}") for i, cluster in enumerate(unique_clusters) ], title="Clusters", loc="best", frameon=True, bbox_to_anchor=(1.02, 1), fontsize=10, ) ax1.add_artist(legend1) else: # Add colorbar instead cbar = fig.colorbar(scatter1, ax=ax1, pad=0.01, shrink=0.8) cbar.set_label("Cluster ID") # Annotate with number of clusters and noise points ax1.annotate( f"Number of clusters: {n_clusters}\n" f"Noise points: {len(noise_points)}", xy=(0.02, 0.98), xycoords="axes fraction", fontsize=11, ha="left", va="top", bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8), ) # PANEL 2: Concave Hull Identification ax2 = axes[1] # Plot clustered points with reduced opacity ax2.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered[cluster_column], cmap=cluster_cmap, alpha=0.3, s=20, ) # Plot hull points if provided if hull_points is not None and len(hull_points) > 0: # Group by cluster if multiple clusters for cluster in unique_clusters: cluster_hull = ( hull_points[hull_points["patch_id"] == cluster] if "patch_id" in hull_points.columns else hull_points ) if len(cluster_hull) > 0: # Plot hull points ax2.scatter( cluster_hull[x_column], cluster_hull[y_column], color="red", s=50, edgecolor="black", linewidth=1, alpha=0.8, label="Hull Points" if cluster == unique_clusters[0] else "", ) # Connect hull points to show the boundary if len(cluster_hull) > 2: hull_x = cluster_hull[x_column].values hull_y = cluster_hull[y_column].values # If ordered by 'order' column, use that ordering if "order" in cluster_hull.columns: ordered_hull = cluster_hull.sort_values("order") hull_x = ordered_hull[x_column].values hull_y = ordered_hull[y_column].values # Close the loop by adding the first point again hull_x = np.append(hull_x, hull_x[0]) hull_y = np.append(hull_y, hull_y[0]) cluster_idx = unique_clusters.index(cluster) color = cluster_cmap(cluster_idx) ax2.plot( hull_x, hull_y, color=color, linestyle="-", linewidth=2, alpha=0.7, label=( f"Hull Boundary (C{cluster})" if cluster == unique_clusters[0] else "" ), ) ax2.set_title("Concave Hull Boundary Detection", fontsize=14) ax2.set_xlabel(x_column, fontsize=12) ax2.set_ylabel(y_column, fontsize=12) ax2.grid(alpha=0.3) ax2.set_aspect("equal") ax2.set_xlim(x_bounds) ax2.set_ylim(y_bounds) # Add explanation of concave hull ax2.annotate( "Concave hull forms the\nouter boundary of each cluster", xy=(0.02, 0.02), xycoords="axes fraction", fontsize=11, ha="left", va="bottom", bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8), ) ax2.legend(loc="best", bbox_to_anchor=(1.02, 1)) # PANEL 3: Radius Search from Hull Points ax3 = axes[2] # Plot background with even lower alpha to make circles more visible ax3.scatter( df_full[x_column], df_full[y_column], color="lightgray", alpha=0.1, s=10 ) # Plot clustered points with lower alpha to improve circle visibility ax3.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered[cluster_column], cmap=cluster_cmap, alpha=0.3, s=20, ) # Show search radius from ALL hull points with improved visibility if hull_neighbors is not None and len(hull_neighbors) > 0: # Better alpha calculation - minimum 0.2 alpha to ensure visibility num_circles = len(hull_neighbors) * len(radius_list) min_alpha = 0.2 # Minimum alpha value for visibility # If many circles, use a more aggressive scale-down but never below min_alpha if num_circles > 30: circle_alpha = max(min_alpha, 0.6 - (num_circles - 30) * 0.005) else: circle_alpha = max(min_alpha, 0.6 - num_circles * 0.005) # Create a list to hold circle objects for the legend legend_circles = [] # Draw circles for all hull points with different colors for each radius for idx, hull_point in hull_neighbors.iterrows(): for r_idx, r in enumerate(radius_list): # Get color for this radius circle_color = radius_color_map[r] circle_width = 1.0 # Slightly thicker lines # Draw search circle with improved visibility circle = Circle( (hull_point[x_column], hull_point[y_column]), r, color=circle_color, fill=False, linestyle="--", linewidth=circle_width, alpha=circle_alpha, ) ax3.add_patch(circle) # Create a single circle for legend (only once per radius) if idx == 0: legend_circle = Circle( (0, 0), 1, color=circle_color, fill=False, linestyle="--", linewidth=circle_width, alpha=0.8, ) ax3.add_patch(legend_circle) legend_circle.set_visible(False) # Hide it, just for legend legend_circles.append((legend_circle, f"Search radius ({r} units)")) # Hull points are drawn on top of circles with high visibility ax3.scatter( hull_neighbors[x_column], hull_neighbors[y_column], color="red", s=40, edgecolor="black", linewidth=0.8, alpha=0.8, label="Hull Points", ) # Format radii for title radii_str = ( ", ".join(map(str, radius_list)) if len(radius_list) > 1 else str(radius_list[0]) ) ax3.set_title(f"Search Radii ({radii_str} units) from Hull Points", fontsize=14) ax3.set_xlabel(x_column, fontsize=12) ax3.set_ylabel(y_column, fontsize=12) ax3.grid(alpha=0.3) ax3.set_aspect("equal") ax3.set_xlim(x_bounds) ax3.set_ylim(y_bounds) # Add explanation of search radius if hull_neighbors is not None: radius_text = ( ", ".join(map(str, radius_list)) if len(radius_list) > 1 else str(radius_list[0]) ) ax3.annotate( f"All {len(hull_neighbors)} hull points search for\n" f"neighboring cells within {radius_text} units", xy=(0.02, 0.02), xycoords="axes fraction", fontsize=11, ha="left", va="bottom", bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8), ) # Legend for radius circles if hull_neighbors is not None and len(hull_neighbors) > 0 and legend_circles: circles, labels = zip(*legend_circles) ax3.legend(circles, labels, loc="upper right") # PANEL 4: Proximity Results - Show all points ax4 = axes[3] # Start with background - all points in the dataset with low opacity ax4.scatter( df_full[x_column], df_full[y_column], color="lightgray", alpha=0.15, s=10, label="All cells", ) # Plot clustered points with colors by cluster to highlight the community ax4.scatter( df_filtered[x_column], df_filtered[y_column], c=df_filtered[cluster_column], cmap=cluster_cmap, alpha=0.4, s=20, ) # Plot hull points with increased visibility if hull_points is not None and len(hull_points) > 0: ax4.scatter( hull_points[x_column], hull_points[y_column], color="red", s=30, alpha=0.7, edgecolor="black", linewidth=0.5, label="Hull Points", ) # Plot proximity results (points from other categories near hull) if proximity_results is not None and len(proximity_results) > 0: # Check if we need to visualize different radii if "distance_from_patch" in proximity_results.columns and len(radius_list) > 1: # For each radius, plot points with a different marker or color for r_idx, r in enumerate(radius_list): r_points = proximity_results[ proximity_results["distance_from_patch"] == r ] if len(r_points) > 0: # If identification column is provided, use different colors for categories if ( identification_column is not None and identification_column in r_points.columns ): # Get unique categories in proximity results prox_categories = r_points[identification_column].unique() # Create color mapping for categories category_colors = plt.cm.Set2( np.linspace(0, 1, len(prox_categories)) ) # Markers for different radii (cycle through a few options) markers = [ "o", "s", "^", "d", "*", ] # circle, square, triangle, diamond, star marker = markers[r_idx % len(markers)] # Plot each category with different color but same marker for this radius for i, category in enumerate(prox_categories): category_points = r_points[ r_points[identification_column] == category ] ax4.scatter( category_points[x_column], category_points[y_column], color=category_colors[i], marker=marker, s=80, alpha=0.8, edgecolor="black", linewidth=0.5, label=f"{category} ({r} units)", ) else: # Use same color scheme as for radius circles with different markers color = radius_color_map[r] markers = ["o", "s", "^", "d", "*"] marker = markers[r_idx % len(markers)] ax4.scatter( r_points[x_column], r_points[y_column], color=color, marker=marker, s=80, alpha=0.8, edgecolor="black", linewidth=0.5, label=f"Proximity Points ({r} units)", ) else: # Original behavior for single radius if ( identification_column is not None and identification_column in proximity_results.columns ): # Get unique categories in proximity results prox_categories = proximity_results[identification_column].unique() # Create color mapping for categories category_colors = plt.cm.Set2(np.linspace(0, 1, len(prox_categories))) # Plot each category with different color for i, category in enumerate(prox_categories): category_points = proximity_results[ proximity_results[identification_column] == category ] ax4.scatter( category_points[x_column], category_points[y_column], color=category_colors[i], marker="o", s=80, alpha=0.8, edgecolor="black", linewidth=0.5, label=f"{category}", ) else: # Plot all proximity points with same color ax4.scatter( proximity_results[x_column], proximity_results[y_column], color="gold", marker="o", s=80, alpha=0.8, edgecolor="black", linewidth=0.5, label="Proximity Points", ) ax4.set_title("Detected Points in Proximity to Clusters", fontsize=14) ax4.set_xlabel(x_column, fontsize=12) ax4.set_ylabel(y_column, fontsize=12) ax4.grid(alpha=0.3) ax4.set_aspect("equal") ax4.set_xlim(x_bounds) ax4.set_ylim(y_bounds) # Add explanation text prox_count = len(proximity_results) if proximity_results is not None else 0 ax4.annotate( f"Found {prox_count} cells from different\ncategories near cluster boundaries", xy=(0.02, 0.02), xycoords="axes fraction", fontsize=11, ha="left", va="bottom", bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8), ) # Position the legend ax4.legend(loc="best", bbox_to_anchor=(1.02, 1)) # Add overall title and explanation fig.suptitle( f"Border Cell Radius Proximity Analysis for {group_name} in Region {region_name}", fontsize=18, y=0.98, ) # Add explanatory text at the bottom with radius list radii_text = ( ", ".join(map(str, radius_list)) if len(radius_list) > 1 else str(radius_list[0]) ) explanation_text = ( "This analysis: (1) Clusters cells using HDBSCAN algorithm, (2) Identifies the concave hull boundary of each cluster, " f"(3) For each hull point, searches for cells within {radii_text} units, and " "(4) Identifies cells from different categories that are in proximity to cluster boundaries." ) fig.text( 0.5, 0.01, explanation_text, ha="center", va="bottom", fontsize=12, bbox=dict(boxstyle="round,pad=0.5", fc="lightyellow", ec="orange", alpha=0.8), ) # Make sure layout adapts to the content plt.tight_layout(rect=[0, 0.03, 1, 0.95]) return fig