# load required packages
import os as os
import pathlib
import textwrap
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import skimage
import tensorly as tl
from scipy import stats
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from tensorly.decomposition import non_negative_tucker
from ..helperfunctions._general import *
# Setup
sns.set_style("ticks")
# plotting functions
############################################################
"""
This is a function that creates a stacked bar plot of the percentage of observations in each category in a dataset. The input data can be normalized by a grouping variable or not, and the output can be saved as a file.
data: a Pandas DataFrame containing the data to be plotted.
per_cat: a string representing the column name containing the categories to be plotted.
grouping: a string representing the column name used to group the data.
cell_list: a list of strings representing the categories to be plotted.
output_dir: a string representing the output directory to save the plot.
norm: a boolean value indicating whether to normalize the data or not (default: True).
save_name: a string representing the filename to save the plot (default: None).
col_order: a list of strings representing the order of the columns in the plot (default: None).
sub_col: a string representing the column name used to subset the data (default: None).
name_cat: a string representing the name of the category column in the plot (default: 'Cell Type').
fig_sizing: a tuple representing the size of the plot (default: (8,4)).
plot_order: a list of strings representing the order of the categories in the plot (default: None).
color_dic: a dictionary containing color codes for the categories in the plot (default: None).
remove_leg: a boolean value indicating whether to remove the legend or not (default: False).
The function returns a Pandas DataFrame and a list of strings. The DataFrame contains the data used to create the plot, and the list of strings represents the order of the categories in the plot.
"""
def pl_stacked_bar_plot(
data,
per_cat,
grouping,
cell_list,
output_dir,
norm=True,
save_name=None,
col_order=None,
sub_col=None,
name_cat="Cell Type",
fig_sizing=(8, 4),
plot_order=None,
color_dic=None,
remove_leg=False,
):
"""
Plot a stacked bar plot based on the given data.
Parameters
----------
data : pandas.DataFrame
The input data containing the necessary information for plotting.
per_cat : str
The column name representing the categories.
grouping : str
The column name representing the grouping.
cell_list : list
The list of cell types to include in the plot.
output_dir : str
The output directory for saving the plot.
norm : bool, optional
Flag indicating whether to normalize the values. Defaults to True.
save_name : str, optional
The name to use when saving the plot. Defaults to None.
col_order : list, optional
The order of columns/categories for plotting. Defaults to None.
sub_col : str, optional
The column name representing sub-categories. Defaults to None.
name_cat : str, optional
The name for the category column in the plot. Defaults to 'Cell Type'.
fig_sizing : tuple, optional
The size of the figure (width, height) in inches. Defaults to (8, 4).
plot_order : list, optional
The order of categories for plotting. Defaults to None.
color_dic : dict, optional
A dictionary mapping categories to colors for custom colorization. Defaults to None.
remove_leg : bool, optional
Flag indicating whether to remove the legend. Defaults to False.
Returns
-------
pandas.DataFrame
The pivoted data used for plotting.
list
The order of categories used for plotting.
"""
# Find Percentage of cell type
if norm == True:
if sub_col is None:
test1 = data.loc[data[per_cat].isin(cell_list)]
sub_cell_list = list(test1[per_cat].unique())
else:
test1 = data.loc[data[sub_col].isin(cell_list)]
sub_cell_list = list(test1[per_cat].unique())
else:
if sub_col is None:
test1 = data.copy()
sub_cell_list = list(
data.loc[data[per_cat].isin(cell_list)][per_cat].unique()
)
else:
test1 = data.copy()
sub_cell_list = list(
data.loc[data[sub_col].isin(cell_list)][per_cat].unique()
)
test1[per_cat] = test1[per_cat].astype("category")
test_freq = test1.groupby(grouping).apply(
lambda x: x[per_cat].value_counts(normalize=True, sort=False) * 100
)
test_freq.columns = test_freq.columns.astype(str)
##### Can subset it here if I do not want normalized per the group
test_freq.reset_index(inplace=True)
sub_cell_list.append(grouping)
test_freq = test_freq[sub_cell_list]
melt_test = pd.melt(
test_freq, id_vars=[grouping]
) # , value_vars=test_freq.columns)
melt_test.rename(columns={per_cat: name_cat, "value": "percent"}, inplace=True)
if norm == True:
if col_order is None:
bb = melt_test.groupby([grouping, per_cat]).sum().reset_index()
col_order = (
bb.loc[bb[per_cat] == bb[per_cat][0]]
.sort_values(by="percent")[grouping]
.to_list()
)
else:
if col_order is None:
col_order = (
melt_test.groupby(grouping)
.sum()
.reset_index()
.sort_values(by="percent")[grouping]
.to_list()
)
if plot_order is None:
plot_order = list(melt_test[per_cat].unique())
# Set up for plotting
melt_test_piv = pd.pivot_table(
melt_test, columns=[name_cat], index=[grouping], values=["percent"]
)
melt_test_piv.columns = melt_test_piv.columns.droplevel(0)
melt_test_piv.reset_index(inplace=True)
melt_test_piv.set_index(grouping, inplace=True)
melt_test_piv = melt_test_piv.reindex(col_order)
melt_test_piv = melt_test_piv[plot_order]
# Get color dictionary
if color_dic is None:
# first subplot
ax1 = melt_test_piv.plot.bar(
alpha=0.8,
linewidth=1,
figsize=fig_sizing,
rot=90,
stacked=True,
edgecolor="black",
)
else:
# first subplot
ax1 = melt_test_piv.plot.bar(
alpha=0.8,
linewidth=1,
color=[color_dic.get(x) for x in melt_test_piv.columns],
figsize=fig_sizing,
rot=90,
stacked=True,
edgecolor="black",
)
for line in ax1.lines:
line.set_color("black")
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
if remove_leg == True:
ax1.set_ylabel("")
ax1.set_xlabel("")
else:
ax1.set_ylabel("percent")
# ax1.spines['left'].set_position(('data', 1.0))
# ax1.set_xticks(np.arange(1,melt_test.day.max()+1,1))
# ax1.set_ylim([0, int(ceil(max(max(melt_test_piv.sum(axis=1)), max(tm_piv.sum(axis=1)))))])
plt.xticks(
list(range(len(list(melt_test_piv.index)))),
list(melt_test_piv.index),
rotation=90,
)
lgd2 = ax1.legend(
loc="center left", bbox_to_anchor=(1.0, 0.5), ncol=1, frameon=False
)
if save_name:
plt.savefig(
output_dir + save_name + "_stacked_barplot.pdf",
dpi=300,
transparent=True,
bbox_inches="tight",
)
return melt_test_piv, plot_order
#############
"""
data: pandas DataFrame containing the data to be plotted
grouping: name of the column containing the grouping variable for the swarm boxplot
replicate: name of the column containing the replicate variable for the swarm boxplot
sub_col: name of the column containing the subsetting variable for the swarm boxplot
sub_list: list of values to subset the data by
per_cat: name of the column containing the categorical variable for the swarm boxplot
output_dir: directory where the output plot will be saved
norm: boolean (default True) to normalize data by subsetting variable before plotting
figure_sizing: tuple (default (10,5)) containing the size of the output plot
save_name: name of the file to save the output plot (if output_dir is provided)
plot_order: list of values to specify the order of the horizontal axis
col_in: list of values to subset the data by the per_cat column
color_dic: seaborn color palette for the boxplot and swarmplot
flip: boolean (default False) to flip the orientation of the plot
"""
# This function creates a box plot and swarm plot from the given data
# and returns a plot object.
def pl_swarm_box(
data,
grouping,
per_cat,
replicate,
sub_col,
sub_list,
output_dir,
norm=True,
figure_sizing=(10, 5),
save_name=None,
plot_order=None,
col_in=None,
color_dic=None,
flip=False,
):
# Find Percentage of cell type
test = data.copy()
sub_list1 = sub_list.copy()
if norm == True:
test1 = test.loc[test[sub_col].isin(sub_list1)]
immune_list = list(test1[per_cat].unique())
else:
test1 = test.copy()
immune_list = list(test.loc[test[sub_col].isin(sub_list1)][per_cat].unique())
test1[per_cat] = test1[per_cat].astype("category")
test_freq = test1.groupby([grouping, replicate]).apply(
lambda x: x[per_cat].value_counts(normalize=True, sort=False) * 100
)
test_freq.columns = test_freq.columns.astype(str)
test_freq.reset_index(inplace=True)
immune_list.extend([grouping, replicate])
test_freq1 = test_freq[immune_list]
melt_per_plot = pd.melt(
test_freq1,
id_vars=[
grouping,
replicate,
],
) # ,value_vars=immune_list)
melt_per_plot.rename(columns={"value": "percentage"}, inplace=True)
if col_in:
melt_per_plot = melt_per_plot.loc[melt_per_plot[per_cat].isin(col_in)]
else:
melt_per_plot = melt_per_plot
if plot_order is None:
plot_order = list(melt_per_plot[grouping].unique())
else:
# Order by average
plot_order = (
melt_per_plot.groupby(per_cat)
.mean()
.reset_index()
.sort_values(by="percentage")[per_cat]
.to_list()
)
# swarmplot to compare clustering
plt.figure(figsize=figure_sizing)
if flip == True:
plt.figure(figsize=figure_sizing)
if color_dic is None:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
color="white",
palette=color_dic,
)
else:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
palette=color_dic,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
palette=color_dic,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
plt.xticks(rotation=90)
plt.xlabel("")
plt.ylabel("")
plt.title(sub_list[0])
sns.despine()
else:
if color_dic is None:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
color="white",
)
else:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
palette=color_dic,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
palette=color_dic,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
# ax.set_yscale(\log\)
plt.xlabel("")
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[: len(melt_per_plot[grouping].unique())],
labels[: len(melt_per_plot[grouping].unique())],
bbox_to_anchor=(1.05, 1),
loc=2,
borderaxespad=0.0,
frameon=False,
)
plt.xticks(rotation=90)
ax.set(ylim=(0, melt_per_plot["percentage"].max() + 1))
sns.despine()
if output_dir:
if save_name:
plt.savefig(
output_dir + save_name + "_swarm_boxplot.pdf",
dpi=300,
transparent=True,
bbox_inches="tight",
)
else:
print("define save_name")
else:
print("plot was not saved - to save the plot specify an output directory")
return melt_per_plot
#############
def pl_Shan_div(
tt,
test_results,
res,
grouping,
color_dic,
sub_list,
output_dir,
save=False,
plot_order=None,
fig_size=1.5,
):
"""
Plot Shannon Diversity using boxplot and swarmplot.
Parameters
----------
tt : unused
Not used in the function.
test_results : float
The p-value from the statistical test.
res : pandas.DataFrame
The input data containing the results and grouping information.
grouping : str
The column name representing the grouping.
color_dic : dict
A dictionary mapping groups to colors for custom colorization.
sub_list : list
The list of sub-groups.
output_dir : str
The output directory for saving the plots.
save : bool, optional
Flag indicating whether to save the plots. Defaults to False.
plot_order : list, optional
The order of groups for plotting. Defaults to None.
fig_size : float, optional
The size of the figure. Defaults to 1.5.
Returns
-------
pandas.DataFrame or bool
The Tukey's test results if the p-value is less than 0.05, otherwise False.
"""
# Order by average
if color_dic is None:
if plot_order is None:
plot_order = res[grouping].unique()
else:
plot_order = plot_order
# Plot the swarmplot of results
plt.figure(figsize=(fig_size, 3))
ax = sns.boxplot(
data=res, x=grouping, y="Shannon Diversity", dodge=True, order=plot_order
)
ax = sns.swarmplot(
data=res,
x=grouping,
y="Shannon Diversity",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
color="white",
)
else:
if plot_order is None:
plot_order = res[grouping].unique()
else:
plot_order = plot_order
# Plot the swarmplot of results
plt.figure(figsize=(fig_size, 3))
ax = sns.boxplot(
data=res,
x=grouping,
y="Shannon Diversity",
dodge=True,
order=plot_order,
palette=color_dic,
)
ax = sns.swarmplot(
data=res,
x=grouping,
y="Shannon Diversity",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
palette=color_dic,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
plt.xticks(rotation=90)
plt.xlabel("")
plt.ylabel("Shannon Diversity")
plt.title("")
sns.despine()
if save == True:
plt.savefig(
output_dir + sub_list[0] + "_Shannon.pdf",
dpi=300,
transparent=True,
bbox_inches="tight",
)
plt.show()
if test_results < 0.05:
plt.figure(figsize=(fig_size, fig_size))
tukey = pairwise_tukeyhsd(
endog=res["Shannon Diversity"], groups=res[grouping], alpha=0.05
)
tukeydf = pd.DataFrame(
data=tukey._results_table.data[1:], columns=tukey._results_table.data[0]
)
tukedf_rev = tukeydf.copy()
tukedf_rev.rename(
columns={"group1": "groupa", "group2": "groupb"}, inplace=True
)
tukedf_rev.rename(
columns={"groupa": "group2", "groupb": "group1"}, inplace=True
)
tukedf_rev = tukedf_rev[tukeydf.columns]
tukey_all = pd.concat([tukedf_rev, tukeydf])
# Plot with tissue order preserved
table1 = pd.pivot_table(
tukey_all, values="p-adj", index=["group1"], columns=["group2"]
)
table1 = table1[plot_order]
table1 = table1.reindex(plot_order)
plt.figure(figsize=(5, 5))
ax = sns.heatmap(table1, cmap="coolwarm", center=0.05, vmax=0.05)
ax.set_title("Shannon Diversity")
ax.set_ylabel("")
ax.set_xlabel("")
if save == True:
plt.savefig(
output_dir + sub_list[0] + "_tukey.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
plt.show()
else:
table1 = False
#############
def pl_cell_type_composition_vis(
data,
sample_column="sample",
cell_type_column="Cell Type",
figsize=(10, 10),
output_dir=None,
):
"""
Visualize cell type composition using stacked and unstacked bar plots.
Parameters
----------
data : pandas.DataFrame
The input data containing the sample and cell type information.
sample_column : str, optional
The column name representing the sample. Defaults to "sample".
cell_type_column : str, optional
The column name representing the cell type. Defaults to "Cell Type".
figsize : tuple, optional
The size of the figure (width, height) in inches. Defaults to (10, 10).
output_dir : str, optional
The output directory for saving the plots. Defaults to None.
Returns
-------
None
"""
if output_dir == None:
print("You have defined no output directory!")
# plotting option1
# pd.crosstab(df['sample'], df['final_cell_types']).plot(kind='barh', stacked=True,figsize = (10,12))
# plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
# plt.show()
# plotting option2
ax = pd.crosstab(data[sample_column], data[cell_type_column]).plot(
kind="barh", stacked=True, figsize=figsize
)
ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
fig = ax.get_figure()
ax.set(xlabel="count")
plt.savefig(output_dir + "/cell_types_composition_hstack.png", bbox_inches="tight")
# plotting option1
# pd.crosstab(df['sample'], df['final_cell_types']).plot(kind='barh', figsize = (10,10))
# plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
# plt.show()
# plotting option2
ax = pd.crosstab(data[sample_column], data[cell_type_column]).plot(
kind="barh", stacked=False, figsize=figsize
)
ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
fig = ax.get_figure()
ax.set(xlabel="count")
plt.savefig(
output_dir + "/cell_types_composition_hUNstack.png", bbox_inches="tight"
)
# Cell type percentage
st = pd.crosstab(data[sample_column], data[cell_type_column])
df_perc = (st / np.sum(st, axis=1)[:, None]) * 100
df_perc
# df_perc['sample'] = df_perc.index
# df_perc
tmp = st.T.apply(lambda x: 100 * x / x.sum())
ax = tmp.T.plot(kind="barh", stacked=True, figsize=figsize)
ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
fig = ax.get_figure()
ax.set(xlabel="percentage")
plt.savefig(
output_dir + "/cell_types_composition_perc_hstack.png", bbox_inches="tight"
)
##############
def pl_regions_per_sample(data, sample_col, region_col, bar_color="grey"):
# Group the dataframe by the specified sample column and count the unique regions
region_counts = data.groupby(sample_col)[region_col].nunique()
# Create a bar chart with the specified color
plt.bar(region_counts.index, region_counts.values, color=bar_color)
# Set chart title and axis labels
plt.title("Count of Unique Regions per Sample")
plt.xlabel("Samples")
plt.ylabel("Unique Regions")
# Rotate x-axis labels if needed
plt.xticks(rotation=45)
# Display the chart
plt.show()
##############
def pl_neighborhood_analysis_2(
data,
k_centroids,
values,
sum_cols,
X="x",
Y="y",
reg="unique_region",
output_dir=None,
k=35,
plot_specific_neighborhoods=None,
size=3,
axis="on",
ticks_fontsize=15,
show_spatial_plots=True,
palette="tab20",
):
"""
Perform neighborhood analysis and visualize results.
Parameters
----------
data : pandas.DataFrame
The input data containing the neighborhood information.
k_centroids : numpy.ndarray
The centroids of the neighborhoods.
values : numpy.ndarray
The values associated with each cell.
sum_cols : list
The column names to sum.
X : str, optional
The column name representing the x-coordinate. Defaults to 'x'.
Y : str, optional
The column name representing the y-coordinate. Defaults to 'y'.
reg : str, optional
The column name representing the region. Defaults to 'unique_region'.
output_dir : str, optional
The output directory for saving the plots. Defaults to None.
k : int, optional
The number of neighborhoods. Defaults to 35.
plot_specific_neighborhoods : bool or int, optional
Flag indicating whether to plot specific neighborhoods or not. If True, all neighborhoods will be plotted.
If an integer, only the specified neighborhood will be plotted. Defaults to None.
Returns
-------
None
"""
if show_spatial_plots == True:
# modify figure size aesthetics for each neighborhood
figs = pl_catplot(
data,
X=X,
Y=Y,
exp=reg,
hue="neighborhood" + str(k),
invert_y=True,
size=size,
axis=axis,
ticks_fontsize=ticks_fontsize,
palette=palette,
)
# Save Plots for Publication
for n, f in enumerate(figs):
f.savefig(output_dir + "neighborhood_" + str(k) + "_id{}.png".format(n))
# this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
k_to_plot = k
niche_clusters = k_centroids[k_to_plot]
tissue_avgs = values.mean(axis=0)
fc = np.log2(
(
(niche_clusters + tissue_avgs)
/ (niche_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc = pd.DataFrame(fc, columns=sum_cols)
s = sns.clustermap(
fc, vmin=-3, vmax=3, cmap="bwr", row_colors=sns.color_palette(palette, len(fc))
)
s.savefig(output_dir + "celltypes_perniche_" + "_" + str(k) + ".png", dpi=600)
if plot_specific_neighborhoods is True:
# this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
k_to_plot = k
niche_clusters = k_centroids[k_to_plot]
tissue_avgs = values.mean(axis=0)
fc = np.log2(
(
(niche_clusters + tissue_avgs)
/ (niche_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc = pd.DataFrame(fc, columns=sum_cols)
s = sns.clustermap(
fc.iloc[plot_specific_neighborhoods, :], vmin=-3, vmax=3, cmap="bwr"
)
s.savefig(output_dir + "celltypes_perniche_" + "_" + str(k) + ".png", dpi=600)
##############
def pl_highlighted_dot(
df,
x_col,
y_col,
group_col,
highlight_group,
highlight_color="red",
region_col="unique_region",
subset_col=None,
subset_list=None,
):
"""
Plots an XY dot plot colored by a grouping column for each unique region.
Parameters
----------
df : pandas.DataFrame
Input DataFrame.
x_col : str
Name of the column to be plotted on the x-axis.
y_col : str
Name of the column to be plotted on the y-axis.
group_col : str
Name of the column used for grouping and coloring the dots.
highlight_group : object
Value of the group to be highlighted.
highlight_color : str, optional
Color of the dots for the highlighted group (default: "red").
region_col : str, optional
Name of the column with information about the unique regions (default: "unique_region").
subset_col : str, optional
Name of the column to subset the data (default: None).
subset_list : list, optional
List of values to subset the data from the subset column (default: None).
Returns
-------
None
"""
# Create a colormap dictionary for coloring dots
colormap = {highlight_group: highlight_color, "default": "grey"}
# Subset the data based on the subset column and list
if subset_col and subset_list:
df = df[df[subset_col].isin(subset_list)]
unique_regions = df[region_col].unique()
# Determine the number of plots and the grid layout
num_plots = len(unique_regions)
if len(unique_regions) > 3:
num_cols = len(unique_regions) // 2 # Number of columns in the grid
else:
num_cols = 2
num_rows = (num_plots - 1) // num_cols + 1 # Number of rows in the grid
# Create the figure and subplots
fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
# Iterate over unique regions and corresponding subplots
for region, ax in zip(unique_regions, axs.flatten()):
# Filter the dataframe for the current region
region_df = df[df[region_col] == region]
# Iterate over unique groups in the current region
for group in region_df[group_col].unique():
# Get x and y values for the group
x = region_df.loc[region_df[group_col] == group, x_col]
y = region_df.loc[region_df[group_col] == group, y_col]
# Get the color for the group
color = colormap.get(group, colormap["default"])
# Set the alpha value for red dots
alpha = 0.7 if group == highlight_group else 1.0
# Plot the dots
ax.scatter(x, y, color=color, label=group, alpha=alpha, zorder=2, s=1)
# Remove legend
ax.legend().remove()
# Set axis labels
ax.set_xlabel(x_col)
ax.set_ylabel(y_col)
# Set the title to the current region
ax.set_title(f"Region: {region}")
# Show the plot
plt.tight_layout()
plt.show()
##############
def pl_create_pie_charts(
data,
group_column,
count_column,
plot_order=None,
show_percentages=True,
color_dict=None,
):
"""
Create pie charts for each group based on a grouping column, showing the percentage of total rows based on a
count column.
Parameters:
data (pd.DataFrame): The input DataFrame.
group_column (str): The column name for grouping the data.
count_column (str): The column name used for counting occurrences.
plot_order (list, optional): The order of groups for plotting. Defaults to None.
show_percentages (bool, optional): Whether to show the percentage numbers on the pie charts. Defaults to True.
color_dict (dict, optional): A dictionary to manually set colors for neighborhoods. Defaults to None.
Returns:
None
"""
# Group the data by the grouping column
grouped_data = data.groupby(group_column)
# Sort the groups based on the plot_order if provided
if plot_order:
grouped_data = sorted(grouped_data, key=lambda x: plot_order.index(x[0]))
# Calculate the number of rows and columns for subplots
num_groups = len(grouped_data)
num_cols = 3 # Number of columns for subplots
num_rows = (num_groups - 1) // num_cols + 1
# Create subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
axes = axes.flatten() # Flatten the subplots array
# Create a color dictionary if not provided
if color_dict is None:
color_dict = {}
# Generate a color dictionary for consistent colors if not provided
if not color_dict:
neighborhoods = data[count_column].unique()
color_cycle = plt.cm.tab20.colors
color_dict = {
neighborhood: color_cycle[i % 20]
for i, neighborhood in enumerate(neighborhoods)
}
# Iterate over each group and create a pie chart
for i, (group, group_df) in enumerate(grouped_data):
# Count the occurrences of each neighborhood within the group
neighborhood_counts = group_df[count_column].value_counts()
# Calculate the percentage of total rows for each neighborhood
percentages = neighborhood_counts / group_df.shape[0] * 100
# Create a color list for neighborhoods in the count column
colors = [
color_dict.get(neighborhood, "gray") for neighborhood in percentages.index
]
if show_percentages:
wedges, texts, autotexts = axes[i].pie(
percentages, labels=percentages.index, autopct="%1.1f%%", colors=colors
)
axes[i].set_title(f"Group: {group}")
else:
wedges, texts = axes[i].pie(
percentages, labels=percentages.index, colors=colors
)
axes[i].set_title(f"Group: {group}")
# Remove unused subplots
for j in range(num_groups, num_rows * num_cols):
fig.delaxes(axes[j])
# Adjust spacing between subplots
fig.tight_layout()
# Show the plot
plt.show()
##############
def pl_cell_types_de(data, pvals, neigh_num, output_dir, figsize=(20, 10)):
"""
Plot cell types differential expression as a heatmap.
Parameters
----------
data : pandas.DataFrame
The input data containing the differential expression values.
pvals : numpy.ndarray
The p-values associated with the differential expression values.
neigh_num : dict
A dictionary mapping neighborhood numbers to labels.
output_dir : str
The output directory for saving the plot.
figsize : tuple, optional
The size of the figure (width, height) in inches. Defaults to (20, 10).
Returns
-------
None
"""
# plot as heatmap
f, ax = plt.subplots(figsize=figsize)
g = sns.heatmap(data, cmap="bwr", vmin=-1, vmax=1, cbar=False, ax=ax)
for a, b in zip(*np.where(pvals < 0.05)):
plt.text(b + 0.5, a + 0.55, "*", fontsize=20, ha="center", va="center")
plt.tight_layout()
inv_map = {v: k for k, v in neigh_num.items()}
inv_map
# plot as heatmap
plt.style.use(["default"])
# GENERAL GRAPH SETTINGs
# font size of graph
SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18
# Settings for graph
plt.rc("font", size=SMALL_SIZE) # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE) # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
data_2 = data.rename(index=inv_map)
# Sort both axes
sort_sum = data_2.abs().sum(axis=1).to_frame()
sort_sum.columns = ["sum_col"]
xx = sort_sum.sort_values(by="sum_col")
sort_x = xx.index.values.tolist()
sort_sum_y = data_2.abs().sum(axis=0).to_frame()
sort_sum_y.columns = ["sum_col"]
yy = sort_sum_y.sort_values(by="sum_col")
sort_y = yy.index.values.tolist()
df_sort = data_2.reindex(index=sort_x, columns=sort_y)
f, ax = plt.subplots(figsize=figsize)
g = sns.heatmap(df_sort, cmap="bwr", vmin=-1, vmax=1, cbar=True, ax=ax)
for a, b in zip(*np.where(pvals < 0.05)):
plt.text(b + 0.5, a + 0.55, "*", fontsize=20, ha="center", va="center")
plt.tight_layout()
f.savefig(
output_dir + "tissue_neighborhood_coeff_pvalue_bar.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
df_sort.abs().sum()
##############
def pl_community_analysis_2(
data,
values,
sum_cols,
output_dir,
# neighborhood_name,
k_centroids,
X="x",
Y="y",
reg="unique_region",
save_path=None,
k=100,
size=3,
axis="on",
ticks_fontsize=15,
plot_specific_community=None,
show_spatial_plots=True,
palette="tab20",
):
"""
Plot community analysis.
Parameters
----------
data : pandas.DataFrame
The input data containing the community information.
output_dir : str
The output directory for saving the plots.
neighborhood_name : str
The name of the neighborhood.
figsize : tuple, optional
The size of the figure. Defaults to (10, 10).
plot_specific_community : bool, optional
Whether to plot a specific community. Defaults to None.
Returns
-------
None
"""
output_dir2 = output_dir + "community_analysis/"
if not os.path.exists(output_dir2):
os.makedirs(output_dir2)
# cells = data.copy()
# #modify figure size aesthetics for each neighborhood
# plt.rcParams["legend.markerscale"] = 10
# figs = pl_catplot(cells,X = X,Y=Y,exp = reg,
# hue = neighborhood_name,invert_y=True,size = size,figsize=8, axis=axis, ticks_fontsize=ticks_fontsize)
# #Save Plots for Publication
# for n,f in enumerate(figs):
# f.savefig(output_dir2+neighborhood_name+'_id{}.png'.format(n))
# if plot_specific_community is True:
# #this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
# k_to_plot = k
# niche_clusters = (k_centroids[k_to_plot])
# tissue_avgs = values.mean(axis = 0)
# fc = np.log2(((niche_clusters+tissue_avgs)/(niche_clusters+tissue_avgs).sum(axis = 1, keepdims = True))/tissue_avgs)
# fc = pd.DataFrame(fc,columns = sum_cols)
# s=sns.clustermap(fc.iloc[plot_specific_community,:], vmin =-3,vmax = 3,cmap = 'bwr',figsize=(10,5))
# s.savefig(output_dir2+"celltypes_perniche_"+"_"+str(k)+".png", dpi=600)
# #this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
# k_to_plot = k
# niche_clusters = (k_centroids[k_to_plot])
# tissue_avgs = values.mean(axis = 0)
# fc = np.log2(((niche_clusters+tissue_avgs)/(niche_clusters+tissue_avgs).sum(axis = 1, keepdims = True))/tissue_avgs)
# fc = pd.DataFrame(fc,columns = sum_cols)
# s=sns.clustermap(fc, vmin =-3,vmax = 3,cmap = 'bwr', figsize=(10,10))
# s.savefig(output_dir2+"celltypes_perniche_"+"_"+str(k)+".png", dpi=600)
if show_spatial_plots == True:
# modify figure size aesthetics for each neighborhood
figs = pl_catplot(
data,
X=X,
Y=Y,
exp=reg,
hue="community" + str(k),
invert_y=True,
size=size,
axis=axis,
ticks_fontsize=ticks_fontsize,
palette=palette,
)
# Save Plots for Publication
for n, f in enumerate(figs):
f.savefig(output_dir + "community_" + str(k) + "_id{}.png".format(n))
# this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
k_to_plot = k
niche_clusters = k_centroids[k_to_plot]
tissue_avgs = values.mean(axis=0)
fc = np.log2(
(
(niche_clusters + tissue_avgs)
/ (niche_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc = pd.DataFrame(fc, columns=sum_cols)
s = sns.clustermap(
fc, vmin=-3, vmax=3, cmap="bwr", row_colors=sns.color_palette(palette, len(fc))
)
s.savefig(output_dir + "celltypes_perniche_" + "_" + str(k) + ".png", dpi=600)
if plot_specific_community is True:
# this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
k_to_plot = k
niche_clusters = k_centroids[k_to_plot]
tissue_avgs = values.mean(axis=0)
fc = np.log2(
(
(niche_clusters + tissue_avgs)
/ (niche_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc = pd.DataFrame(fc, columns=sum_cols)
s = sns.clustermap(
fc.iloc[plot_specific_neighborhoods, :], vmin=-3, vmax=3, cmap="bwr"
)
s.savefig(output_dir + "celltypes_perniche_" + "_" + str(k) + ".png", dpi=600)
###############
"""
This function visualizes the results of Canonical Correlation Analysis (CCA) using a graph.
The function takes in several parameters including the CCA results, the save path for the resulting plot, whether or not to save the plot, a p-value threshold, a name for the plot file, and a color palette to use for the nodes.
The function first creates an empty Petersen graph and then iterates over each pair of cell types in the CCA results.
For each pair, it calculates the observed correlation and the correlation for a set of permuted samples.
If the p-value for the observed correlation is less than the specified threshold, it adds an edge to the graph between the two cell types, weighted by the p-value.
The function then uses the graphviz_layout function to position the nodes in the graph and assigns a color to each node based on the specified color palette.
It then iterates over each edge in the graph and sets its alpha and linewidth based on the weight of the edge. Finally, it saves the resulting plot to the specified save path if save_fig is True.
Overall, this function provides a way to visually represent the relationships between cell types in the CCA results, allowing for a better understanding of the underlying patterns and correlations in the data.
"""
def pl_Visulize_CCA_results(
CCA_results,
output_dir,
save_fig=False,
p_thresh=0.1,
save_name="CCA_vis.png",
colors=None,
):
"""
Visualize the results of Canonical Correlation Analysis (CCA) using a graph.
Parameters
----------
CCA_results : dict
Dictionary containing the CCA results, where the keys are cell type pairs and
the values are tuples of (observed correlation, permuted correlations).
output_dir : str
The output directory for saving the plot.
save_fig : bool, optional
Whether to save the plot. Defaults to False.
p_thresh : float, optional
The p-value threshold for adding edges to the graph. Defaults to 0.1.
save_name : str, optional
The name of the plot file. Defaults to "CCA_vis.png".
colors : list or None, optional
The color palette for the nodes. If None, a default palette will be used.
Returns
-------
None
"""
# Visualization of CCA
g1 = nx.petersen_graph()
for cn_pair, cc in CCA_results.items():
s, t = cn_pair
obs, perms = cc
p = np.mean(obs > perms)
if p > p_thresh:
g1.add_edge(s, t, weight=p)
if colors != None:
pal = colors
else:
pal = sns.color_palette("bright", 50)
pos = nx.nx_agraph.graphviz_layout(g1, prog="neato")
for k, v in pos.items():
x, y = v
plt.scatter([x], [y], c=[pal[k]], s=300, zorder=3)
# plt.text(x,y, k, fontsize = 10, zorder = 10,ha = 'center', va = 'center')
plt.axis("off")
for e0, e1 in g1.edges():
p = g1.get_edge_data(e0, e1, default=0)
if len(p) == 0:
p = 0
else:
p = p["weight"]
print(p)
alpha = 3 * p**1
if alpha > 1:
alpha = 1
plt.plot(
[pos[e0][0], pos[e1][0]],
[pos[e0][1], pos[e1][1]],
c="black",
alpha=alpha,
linewidth=3 * p**3,
)
if save_fig == True:
plt.savefig(
output_dir + "/" + save_name,
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
#######
def pl_plot_modules_heatmap(
data, cns, cts, figsize=(20, 5), num_tissue_modules=2, num_cn_modules=5
):
"""
Plot the modules and their loadings using heatmaps.
Parameters
----------
data : array-like
The input data.
cns : list
The names of the copy number alterations (CNs).
cts : list
The names of the cell types.
figsize : tuple, optional
The figure size. Defaults to (20, 5).
num_tissue_modules : int, optional
The number of tissue modules. Defaults to 2.
num_cn_modules : int, optional
The number of CN modules. Defaults to 5.
Returns
-------
None
"""
figsize = figsize
core, factors = non_negative_tucker(
data, rank=[num_tissue_modules, num_cn_modules, num_cn_modules], random_state=32
)
plt.subplot(1, 2, 1)
sns.heatmap(pd.DataFrame(factors[1], index=cns))
plt.ylabel("CN")
plt.xlabel("CN module")
plt.title("Loadings onto CN modules")
plt.subplot(1, 2, 2)
sns.heatmap(pd.DataFrame(factors[2], index=cts))
plt.ylabel("CT")
plt.xlabel("CT module")
plt.title("Loadings onto CT modules")
plt.show()
figsize = (num_tissue_modules * 3, 3)
for p in range(num_tissue_modules):
plt.subplot(1, num_tissue_modules, p + 1)
sns.heatmap(pd.DataFrame(core[p]))
plt.title("tissue module {}, couplings".format(p))
plt.ylabel("CN module")
plt.ylabel("CT module")
plt.show()
#######
"""
This is a Python function that generates a graphical representation of modules discovered in a dataset using non-negative matrix factorization (NMF).
The function takes as input a dataset ('dat'), lists of tissue types ('cts') and copy number segments ('cns'), and parameters specifying the number of tissue and copy number modules to identify ('num_tissue_modules' and 'num_cn_modules', respectively).
The function then performs NMF on the input dataset and plots a separate graph for each tissue module.
In each graph, the function displays the copy number segments and tissue types as scatter points, with the color of each point representing the degree to which that segment or type belongs to the corresponding module.
The function also draws rectangles and lines to visually separate the different modules and indicate the strength of the relationships between copy number segments and tissue types within each module.
The resulting plots can be saved to a specified file path and name using the 'save_path' and 'save_name' arguments.
"""
def pl_plot_modules_graphical(
data,
cts,
cns,
num_tissue_modules=2,
num_cn_modules=4,
scale=0.4,
color_dic=None,
save_name=None,
save_path=None,
):
"""
Generate a graphical representation of modules discovered in a dataset using non-negative matrix factorization (NMF).
Parameters
----------
data : array-like
The input dataset.
cts : list
The list of tissue types.
cns : list
The list of copy number segments.
num_tissue_modules : int, optional
The number of tissue modules to identify. Defaults to 2.
num_cn_modules : int, optional
The number of copy number modules to identify. Defaults to 4.
scale : float, optional
The scaling factor for the plot. Defaults to 0.4.
color_dic : dict, optional
The color palette dictionary. Defaults to None.
save_name : str, optional
The name for saving the plots. Defaults to None.
save_path : str, optional
The file path for saving the plots. Defaults to None.
Returns
-------
None
"""
core, factors = non_negative_tucker(
data, rank=[num_tissue_modules, num_cn_modules, num_cn_modules], random_state=32
)
if color_dic is None:
color_dic = sns.color_palette("bright", 10)
palg = sns.color_palette("Greys", 10)
figsize = (3.67 * scale, 2.00 * scale)
cn_scatter_size = scale * scale * 45
cel_scatter_size = scale * scale * 15
for p in range(num_tissue_modules):
for idx in range(num_cn_modules):
an = (
float(np.max(core[p][idx, :]) > 0.1)
+ (np.max(core[p][idx, :]) <= 0.1) * 0.05
)
ac = (
float(np.max(core[p][:, idx]) > 0.1)
+ (np.max(core[p][:, idx]) <= 0.1) * 0.05
)
cn_fac = factors[1][:, idx]
cel_fac = factors[2][:, idx]
cols_alpha = [
(*color_dic[cn], an * np.minimum(cn_fac, 1.0)[i])
for i, cn in enumerate(cns)
]
cols = [
(*color_dic[cn], np.minimum(cn_fac, 1.0)[i]) for i, cn in enumerate(cns)
]
cell_cols_alpha = [
(0, 0, 0, an * np.minimum(cel_fac, 1.0)[i])
for i, _ in enumerate(cel_fac)
]
cell_cols = [
(0, 0, 0, np.minimum(cel_fac, 1.0)[i]) for i, _ in enumerate(cel_fac)
]
plt.scatter(
0.5 * np.arange(len(cn_fac)),
5 * idx + np.zeros(len(cn_fac)),
c=cols_alpha,
s=cn_scatter_size,
)
offset = 9
for i, k in enumerate(cns):
plt.text(
0.5 * i,
5 * idx,
k,
fontsize=scale * 2,
ha="center",
va="center",
alpha=an,
)
plt.scatter(
-4.2 + 0.25 * np.arange(len(cel_fac)) + offset,
5 * idx + np.zeros(len(cel_fac)),
c=cell_cols_alpha,
s=0.5 * cel_scatter_size,
) # ,vmax = 0.5,edgecolors=len(cell_cols_alpha)*[(0,0,0,min(1.0,max(0.1,2*an)))], linewidths= 0.05)
rect = plt.Rectangle(
(-0.5, 5 * idx - 2),
4.5,
4,
linewidth=scale * scale * 1,
edgecolor="black",
facecolor="none",
zorder=0,
alpha=an,
linestyle="--",
)
ax = plt.gca()
ax.add_artist(rect)
plt.scatter(
[offset - 5],
[5 * idx],
c="black",
marker="D",
s=scale * scale * 5,
zorder=5,
alpha=an,
)
plt.text(
offset - 5,
5 * idx,
idx,
color="white",
alpha=an,
ha="center",
va="center",
zorder=6,
fontsize=4.5,
)
plt.scatter(
[offset - 4.5],
[5 * idx],
c="black",
marker="D",
s=scale * scale * 5,
zorder=5,
alpha=ac,
)
plt.text(
offset - 4.5,
5 * idx,
idx,
color="white",
alpha=ac,
ha="center",
va="center",
zorder=6,
fontsize=4.5,
)
rect = plt.Rectangle(
(offset - 4.5, 5 * idx - 2),
4.5,
4,
linewidth=scale * 1,
edgecolor="black",
facecolor="none",
zorder=0,
alpha=ac,
linestyle="-.",
)
ax.add_artist(rect)
for i, ct in enumerate(cts):
plt.text(
-4.2 + offset + 0.25 * i,
27.5,
ct,
rotation=45,
color="black",
ha="left",
va="bottom",
fontsize=scale * 2,
alpha=1,
)
for cn_i in range(num_cn_modules):
for cel_i in range(num_cn_modules):
plt.plot(
[-3 + offset - 2, -4 + offset - 0.5],
[5 * cn_i, 5 * cel_i],
color="black",
linewidth=2
* scale
* scale
* 1
* min(1.0, max(0, -0.00 + core[p][cn_i, cel_i])),
alpha=min(1.0, max(0.000, -0.00 + 10 * core[p][cn_i, cel_i])),
) # max(an,ac))
plt.ylim(-5, 30)
plt.axis("off")
if save_name:
plt.savefig(
save_path + save_name + "_" + str(p) + "_tensor.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
plt.show()
#########
def pl_evaluate_ranks(data, num_tissue_modules=2):
"""
Evaluate the reconstruction error of different ranks in non-negative matrix factorization (NMF).
Parameters
----------
data : array-like
The input dataset.
num_tissue_modules : int, optional
The number of tissue modules to evaluate. Defaults to 2.
Returns
-------
None
"""
num_tissue_modules = num_tissue_modules + 1
pal = sns.color_palette("bright", 10)
palg = sns.color_palette("Greys", 10)
mat1 = np.zeros((num_tissue_modules, 15))
for i in range(2, 15):
for j in range(1, num_tissue_modules):
# we use NNTD as described in the paper
facs_overall = non_negative_tucker(data, rank=[j, i, i], random_state=2336)
mat1[j, i] = np.mean(
(data - tl.tucker_to_tensor((facs_overall[0], facs_overall[1]))) ** 2
)
for j in range(1, num_tissue_modules):
plt.plot(2 + np.arange(13), mat1[j][2:], label="rank = ({},x,x)".format(j))
plt.xlabel("x")
plt.ylabel("reconstruction error")
plt.legend()
plt.show()
#########
"""
data: the input pandas data frame.
sub_list2: a list of subcategories to be considered.
per_categ: the categorical column in the data frame to be used.
group2: the grouping column in the data frame.
repl: the replicate column in the data frame.
sub_collumn: the subcategory column in the data frame.
cell: the cell type column in the data frame.
thres (optional): the threshold for the correlation, default is 0.9.
normed (optional): if the percentage should be normalized, default is True.
cell2 (optional): the second cell type column in the data frame.
"""
def pl_corr_cell(
data,
per_categ,
group2,
rep,
sub_column,
cell,
output_dir,
save_name,
thres=0.9,
normed=True,
cell2=None,
sub_list2=None,
):
"""
Perform correlation analysis on a pandas DataFrame.
Parameters
----------
data : pandas DataFrame
The input DataFrame.
per_categ : str
The categorical column in the DataFrame to be used.
group2 : str
The grouping column in the DataFrame.
rep : str
The replicate column in the DataFrame.
sub_column : str
The subcategory column in the DataFrame.
cell : str
The cell type column in the DataFrame.
output_dir : str
The directory to save the correlation plot.
save_name : str
The name of the saved correlation plot.
thres : float, optional
The threshold for correlation. Default is 0.9.
normed : bool, optional
If the percentage should be normalized. Default is True.
cell2 : str, optional
The second cell type column in the DataFrame. Default is None.
sub_list2 : list, optional
A list of subcategories to be considered. Default is None.
Returns
-------
all_pairs : list
List of all correlated pairs.
pair2 : list
List of correlated pairs above the threshold.
"""
if sub_list2 != None:
result = hf_per_only(
data=data,
per_cat=per_categ,
grouping=group2,
sub_list=sub_list2,
replicate=rep,
sub_col=sub_column,
norm=normed,
)
else:
sub_list2 = data[per_categ].unique()
result = hf_per_only(
data=data,
per_cat=per_categ,
grouping=group2,
sub_list=sub_list2,
replicate=rep,
sub_col=sub_column,
norm=normed,
)
# Format for correlation function
mp = pd.pivot_table(
result, columns=[per_categ], index=[group2, rep], values=["percentage"]
)
mp.columns = mp.columns.droplevel(0)
cc = mp.reset_index()
cmat = cc.corr()
# Plot
sl2, pair2, all_pairs = hf_cor_subset(cor_mat=cmat, threshold=thres, cell_type=cell)
if cell2:
sl3 = [cell2, cell]
pl_cor_subplot(
mp=cc, sub_list=sl3, output_dir=output_dir, save_name=cell + "_" + cell2
)
else:
pl_cor_subplot(mp=cc, sub_list=sl2, output_dir=output_dir, save_name=cell)
if save_name:
plt.savefig(
output_dir + save_name + ".png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
return all_pairs, pair2, cc
###########
"""
data: Pandas data frame which is used as input for plotting.
group1: Categorical column in data that will be used as the x-axis in the pairplot.
per_cat: Categorical column in data that will be used to calculate the correlation between categories in group1.
sub_col (optional): Categorical column in data that is used to subset the data.
sub_list (optional): List of values that is used to select a subset of data based on the sub_col.
norm (optional): Boolean that determines if the data should be normalized or not.
group2 (optional): Categorical column in data that is used to group the data.
count (optional): Boolean that determines if the count of each category in per_cat should be used instead of the percentage.
plot_scatter (optional): Boolean that determines if the scatterplot should be plotted or not.
cor_mat: Output data frame containing the correlation matrix.
mp: Output data frame containing the pivot table of the count or percentage of each category in per_cat based on group1.
Returns:
cor_mat (pandas dataframe): Correlation matrix.
mp (pandas dataframe): Data after pivoting and grouping.
"""
def pl_cor_plot(
data,
group1,
per_cat,
sub_col=None,
sub_list=None,
norm=False,
group2=None,
count=False,
plot_scatter=True,
):
"""
Create a correlation plot using a pandas DataFrame.
Parameters
----------
data : pandas DataFrame
The input DataFrame.
group1 : str
Categorical column in data that will be used as the x-axis in the pairplot.
per_cat : str
Categorical column in data that will be used to calculate the correlation between categories in group1.
sub_col : str, optional
Categorical column in data that is used to subset the data. Default is None.
sub_list : list, optional
List of values that is used to select a subset of data based on the sub_col. Default is None.
norm : bool, optional
Boolean that determines if the data should be normalized or not. Default is False.
group2 : str, optional
Categorical column in data that is used to group the data. Default is None.
count : bool, optional
Boolean that determines if the count of each category in per_cat should be used instead of the percentage. Default is False.
plot_scatter : bool, optional
Boolean that determines if the scatterplot should be plotted or not. Default is True.
Returns
-------
cor_mat : pandas DataFrame
Correlation matrix.
mp : pandas DataFrame
Data after pivoting and grouping.
"""
if group2:
plt.rcParams["legend.markerscale"] = 1
tf = (
data.groupby([group1, group2])
.apply(lambda x: x[per_cat].value_counts(normalize=True, sort=False) * 100)
.to_frame()
)
tf.columns = tf.columns.astype(str)
tf.reset_index(inplace=True)
mp = pd.pivot_table(
tf, columns=["level_2"], index=[group1, group2], values=[per_cat]
)
mp.columns = mp.columns.droplevel(0)
mp.reset_index(inplace=True)
mp2 = mp.fillna(0)
cor_mat = mp2.corr()
mask = np.triu(np.ones_like(cor_mat, dtype=bool))
plt.figure(figsize=(len(cor_mat.index), len(cor_mat.columns) * 0.8))
sns.heatmap(cor_mat, cmap="coolwarm", center=0, vmin=-1, vmax=1, mask=mask)
if plot_scatter:
sns.pairplot(
mp,
diag_kind="kde",
plot_kws={"alpha": 0.6, "s": 80, "edgecolor": "k"},
size=4,
hue=group2,
)
else:
if count:
tf = data.groupby([group1, per_cat]).count()["region"].to_frame()
tf.reset_index(inplace=True)
mp = pd.pivot_table(
tf, columns=[per_cat], index=[group1], values=["region"]
)
mp.columns = mp.columns.droplevel(0)
mp.reset_index(inplace=True)
mp2 = mp.fillna(0)
cor_mat = mp2.corr()
mask = np.triu(np.ones_like(cor_mat, dtype=bool))
plt.figure(figsize=(len(cor_mat.index), len(cor_mat.columns) * 0.8))
sns.heatmap(cor_mat, cmap="coolwarm", center=0, vmin=-1, vmax=1, mask=mask)
if plot_scatter:
sns.pairplot(
mp,
diag_kind="kde",
plot_kws={"scatter_kws": {"alpha": 0.6, "s": 80, "edgecolor": "k"}},
size=4,
kind="reg",
)
else:
# Find Percentage of cell type
test = data.copy()
if sub_list == None:
sub_list = data[per_cat].unique()
sub_list1 = sub_list.copy()
if norm == True:
test1 = test.loc[test[sub_col].isin(sub_list1)]
immune_list = list(test1[per_cat].unique())
else:
test1 = test.copy()
immune_list = list(
test.loc[test[sub_col].isin(sub_list1)][per_cat].unique()
)
test1[per_cat] = test1[per_cat].astype("category")
tf = test1.groupby([group1]).apply(
lambda x: x[per_cat].value_counts(normalize=True, sort=False) * 100
)
tf.columns = tf.columns.astype(str)
mp = tf[immune_list]
mp.reset_index(inplace=True)
cor_mat = mp.corr()
mask = np.triu(np.ones_like(cor_mat, dtype=bool))
plt.figure(figsize=(len(cor_mat.index), len(cor_mat.columns) * 0.8))
sns.heatmap(cor_mat, cmap="coolwarm", center=0, vmin=-1, vmax=1, mask=mask)
if plot_scatter:
sns.pairplot(
mp,
diag_kind="kde",
plot_kws={"scatter_kws": {"alpha": 0.6, "s": 80, "edgecolor": "k"}},
size=4,
kind="reg",
)
return cor_mat, mp
########
"""
mp: A pandas dataframe from which a subset of columns will be selected and plotted.
sub_list: A list of column names from the dataframe mp that will be selected and plotted.
save_name (optional): A string that specifies the file name for saving the plot.
If save_name is not provided, the plot will not be saved.
"""
def pl_cor_subplot(mp, sub_list, output_dir, save_name=None):
"""
Create a subplot of pairwise correlation plots using a subset of columns from a pandas DataFrame.
Parameters
----------
mp : pandas DataFrame
The input DataFrame from which a subset of columns will be selected and plotted.
sub_list : list
A list of column names from the dataframe `mp` that will be selected and plotted.
output_dir : str
The output directory where the plot will be saved.
save_name : str, optional
A string that specifies the file name for saving the plot. If `save_name` is not provided, the plot will not be saved.
"""
sub_cor = mp[sub_list]
sns.pairplot(
sub_cor,
diag_kind="kde",
plot_kws={"scatter_kws": {"alpha": 0.6, "s": 80, "edgecolor": "k"}},
size=4,
kind="reg",
corner=True,
)
if save_name:
plt.savefig(
output_dir + save_name + "_corrplot.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
def annotate(data, names, **kws):
r, p = sp.stats.pearsonr(data[names[0]], data[names[1]])
ax = plt.gca()
ax.text(
0.5, 0.8, "r={:.2f}, p={:.2g}".format(r, p), transform=ax.transAxes, fontsize=14
)
def pl_cor_subplot_new(mp, sub_list, output_dir, save_name=None):
"""
Create a subplot of pairwise correlation plots using a subset of columns from a pandas DataFrame.
Parameters
----------
mp : pandas DataFrame
The input DataFrame from which a subset of columns will be selected and plotted.
sub_list : list
A list of column names from the dataframe `mp` that will be selected and plotted.
output_dir : str
The output directory where the plot will be saved.
save_name : str, optional
A string that specifies the file name for saving the plot. If `save_name` is not provided, the plot will not be saved.
"""
sub_cor = mp[sub_list]
names = sub_cor.columns.tolist()
g = sns.lmplot(x=names[0], y=names[1], data=df2, height=5, aspect=1)
g.map_dataframe(annotate(data=mp, names=names))
plt.xlabel(names[0], fontsize=14)
plt.ylabel(names[1], fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.show()
if save_name:
plt.savefig(
output_dir + save_name + "_corrplot.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
########
def pl_Niche_heatmap(k_centroids, w, n_num, sum_cols):
"""
Create a heatmap to show the types of cells (ClusterIDs) in different niches.
Parameters
----------
k_centroids : numpy array
The centroid values for each niche.
w : pandas DataFrame
The input data containing cluster information.
n_num : int
The niche number to plot.
sum_cols : list
The list of columns to sum for the heatmap.
"""
# this plot shows the types of cells (ClusterIDs) in the different niches (0-9)
k_to_plot = n_num
niche_clusters = k_centroids[k_to_plot]
values = w[sum_cols].values
tissue_avgs = values.mean(axis=0)
fc = np.log2(
(
(niche_clusters + tissue_avgs)
/ (niche_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc = pd.DataFrame(fc, columns=sum_cols)
s = sns.clustermap(fc, cmap="bwr", vmax=-5)
def pl_Barycentric_coordinate_projection(
w,
plot_list,
threshold,
output_dir,
save_name,
col_dic,
l,
n_num,
cluster_col,
SMALL_SIZE=14,
MEDIUM_SIZE=16,
BIGGER_SIZE=18,
figsize=(14, 14),
):
"""
Create a barycentric coordinate projection plot.
Parameters
----------
w : pandas DataFrame
The input data containing coordinate information.
plot_list : list
The list of columns to plot.
threshold : int
The threshold value for data selection.
output_dir : str
The output directory where the plot will be saved.
save_name : str
The file name for saving the plot.
col_dic : dict
A dictionary mapping cluster IDs to colors.
l : list
A list of cluster IDs.
n_num : int
The niche number.
cluster_col : str
The column containing cluster information.
SMALL_SIZE : int, optional
The font size for small text. Default is 14.
MEDIUM_SIZE : int, optional
The font size for medium text. Default is 16.
BIGGER_SIZE : int, optional
The font size for large text. Default is 18.
figsize : tuple, optional
The figure size. Default is (14, 14).
"""
# Settings for graph
plt.rc("font", size=SMALL_SIZE) # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE) # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
lmap = {j: i for i, j in enumerate(l)}
palt = col_dic
wgc = w.loc[w.loc[:, plot_list].sum(axis=1) > threshold, :]
idx = wgc.index.values
xl = wgc.loc[:, plot_list]
proj = np.array([[0, 0], [np.cos(np.pi / 3), np.sin(np.pi / 3)], [1, 0]])
coords = np.dot(xl / n_num, proj) #####window size fraction
plt.figure(figsize=figsize)
jit = 0.002
cols = [palt[a] for a in wgc[cluster_col]]
plt.scatter(
coords[:, 0] + jit * np.random.randn(len(coords)),
coords[:, 1] + jit * np.random.randn(len(coords)),
s=15,
alpha=0.5,
c=cols,
)
plt.axis("off")
plt.show()
if save_name:
plt.savefig(
output_dir + save_name + ".png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
########
def pl_get_network(
ttl_per_thres,
comb_per_thres,
color_dic,
windows,
n_num,
l,
tissue_col=None,
tissue_subset_list=None,
sub_col="Tissue Unit",
neigh_sub=None,
save_name=None,
save_path=None,
figsize=(20, 10),
):
"""
Generate a network plot based on combination frequencies.
Parameters
----------
ttl_per_thres : float
The threshold for the total percentage of combinations.
comb_per_thres : float
The threshold for the combination frequency.
color_dic : dict
A dictionary mapping cluster IDs to colors.
windows : dict
A dictionary containing window data.
n_num : int
The window size.
l : list
A list of cluster IDs.
tissue_col : bool or None, optional
Whether to filter data based on tissue columns. Default is None.
tissue_subset_list : list or None, optional
A list of tissue subsets to consider. Default is None.
sub_col : str, optional
The name of the column for subsetting. Default is 'Tissue Unit'.
neigh_sub : None, optional
Subset neighborhoods based on specified values. Default is None.
save_name : str or None, optional
The name for saving the plot. Default is None.
save_path : str or None, optional
The path for saving the plot. Default is None.
figsize : tuple, optional
The figure size. Default is (20, 10).
"""
# Choose the windows size to continue with
w = windows[n_num]
if tissue_col == True:
w = w[w.tissue_col.isin(tissue_subset_list)]
if neigh_sub:
w = w[w[sub_col].isin(neigh_sub)]
xm = w.loc[:, l].values / n_num
# Get the neighborhood combinations based on the threshold
simps = hf_get_thresh_simps(xm, ttl_per_thres)
simp_freqs = simps.value_counts(normalize=True)
simp_sums = np.cumsum(simp_freqs)
g = nx.DiGraph()
thresh_cumulative = 0.95
thresh_freq = comb_per_thres
# selected_simps = simp_sums[simp_sums<=thresh_cumulative].index.values
selected_simps = simp_freqs[simp_freqs >= thresh_freq].index.values
# this builds the graph for the CN combination map
selected_simps
for e0 in selected_simps:
for e1 in selected_simps:
if (set(list(e0)) < set(list(e1))) and (len(e1) == len(e0) + 1):
g.add_edge(e0, e1)
# this plots the CN combination map
draw = g
pos = nx.drawing.nx_pydot.graphviz_layout(draw, prog="dot")
height = 8
plt.figure(figsize=figsize)
for n in draw.nodes():
col = "black"
if len(draw.in_edges(n)) < len(n):
col = "black"
plt.scatter(
pos[n][0],
pos[n][1] - 5,
s=simp_freqs[list(simp_freqs.index).index(n)] * 10000,
c=col,
zorder=-1,
)
# if n in tops:
# plt.text(pos[n][0],pos[n][1]-7, '*', fontsize = 25, color = 'white', ha = 'center', va = 'center',zorder = 20)
delta = 8
# plot_sim((pos[n][0]+delta, pos[n][1]+delta),n, scale = 20,s = 200,text = True,fontsize = 15)
plt.scatter(
[pos[n][0]] * len(n),
[pos[n][1] + delta * (i + 1) for i in range(len(n))],
c=[color_dic[l[i]] for i in n],
marker="^",
zorder=5,
s=400,
)
j = 0
for e0, e1 in draw.edges():
weight = 0.2
alpha = 0.3
if len(draw.in_edges(e1)) < len(e1):
color = "black"
lw = 1
weight = 0.4
color = "black"
plt.plot(
[pos[e0][0], pos[e1][0]],
[pos[e0][1], pos[e1][1]],
color=color,
linewidth=weight,
alpha=alpha,
zorder=-10,
)
plt.axis("off")
if save_name is not None:
plt.savefig(save_path + save_name + "_spatial_contexts.pdf") #'.png', dpi=300)
plt.show()
#########
def pl_spatial_context_stats_vis(
neigh_comb,
simp_df_tissue1,
simp_df_tissue2,
pal_tis={"Resection": "blue", "Biopsy": "orange"},
plot_order=["Resection", "Biopsy"],
figsize=(5, 5),
):
# Set Neigh and make comparison
neigh_comb = (9,)
df1 = simp_df_tissue1.loc[[neigh_comb]].T
df2 = simp_df_tissue2.loc[[neigh_comb]].T
print(stats.mannwhitneyu(df1[df1.columns[0]], df2[df2.columns[0]]))
df1.reset_index(inplace=True)
df1[["donor", "tissue"]] = df1["index"].str.split("_", expand=True)
df2.reset_index(inplace=True)
df2[["donor", "tissue"]] = df2["index"].str.split("_", expand=True)
df_m = pd.concat([df1, df2])
df_m["combo"] = str(neigh_comb)
# swarmplot to compare
plt.figure(figsize=figsize)
ax = sns.boxplot(
data=df_m,
x="combo",
y=neigh_comb,
hue="tissue",
dodge=True,
hue_order=plot_order,
palette=pal_tis,
)
ax = sns.swarmplot(
data=df_m,
x="combo",
y=neigh_comb,
hue="tissue",
dodge=True,
hue_order=plot_order,
edgecolor="black",
linewidth=1,
palette=pal_tis,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
# ax.set_yscale(\log\)
plt.xlabel("")
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[: len(df_m["tissue"].unique())],
labels[: len(df_m["tissue"].unique())],
bbox_to_anchor=(1.05, 1),
loc=2,
borderaxespad=0.0,
frameon=False,
)
plt.xticks(rotation=90)
sns.despine(trim=True)
# pt.savefig(save_path+save_name+'_swarm_boxplot.png', format='png', dpi=300, transparent=True, bbox_inches='tight')
##########
def pl_conplot(
df,
feature,
exp="Exp",
X="X",
Y="Y",
invert_y=False,
cmap="RdBu",
size=5,
alpha=1,
figsize=10,
exps=None,
fig=None,
**kwargs,
):
"""
Plot continuous variable with a colormap:
df: dataframe of cells with spatial location and feature to color. Must have columns ['X','Y','Exp',feature]
feature: feature in df to color points by
cmap: matplotlib colormap
size: point size
thresh_val: only include points below this value
"""
if invert_y:
y_orig = df[Y].values.copy()
df[Y] *= -1
if exps is None:
exps = list(df[exp].unique()) # display all experiments
elif type(exps) != list:
exps = [exps]
if fig is None:
f, ax = plt.subplots(len(exps), 1, figsize=(figsize, len(exps) * figsize))
if len(exps) == 1:
ax = [ax]
else:
f, ax = fig
for i, name in enumerate(exps):
data = df[df[exp] == name]
ax[i].scatter(
data[X], -data[Y], c=data[feature], cmap=cmap, s=size, alpha=alpha, **kwargs
)
ax[i].set_title(name + "_" + str(feature) + "_" + str(len(data)))
ax[i].axis("off")
if invert_y:
df[Y] = y_orig
return f, ax
##############
def pl_catplot(
df,
hue,
exp="Exp",
X="X",
Y="Y",
invert_y=False,
size=3,
legend=True,
palette="bright",
figsize=5,
style="white",
exps=None,
axis="on",
ticks_fontsize=15,
scatter_kws={},
**kwargs,
):
"""
Plots cells in tissue section color coded by either cell type or node allocation.
df: dataframe with cell information
size: size of point to plot for each cell.
hue: color by "Clusterid" or "Node" respectively.
legend: to include legend in plot.
"""
scatter_kws_ = {"s": size, "alpha": 1}
scatter_kws_.update(scatter_kws)
figures = []
df = df.rename(columns=lambda x: str(x))
df[hue] = df[hue].astype("category")
if invert_y:
y_orig = df[Y].values.copy()
df[Y] *= -1
style = {"axes.facecolor": style}
sns.set_style(style)
if exps == None:
exps = list(df[exp].unique()) # display all experiments
elif type(exps) != list:
exps = [exps]
for name in exps:
data = df[df[exp] == name]
data[X] = data[X] - data[X].min()
data[Y] = data[Y] - data[Y].min()
print(name)
xrange = data[X].max() - data[X].min()
yrange = data[Y].max() - data[Y].min()
# if 'aspect' not in kwargs:
# kwargs['aspect'] = xrange/yrange
f = sns.lmplot(
x=X,
y=Y,
data=data,
hue=hue,
legend=legend,
fit_reg=False,
markers=".",
height=yrange / 400,
palette=palette,
scatter=True,
scatter_kws=scatter_kws_,
aspect=xrange / yrange,
**kwargs,
)
if axis == "off":
sns.despine(top=True, right=True, left=True, bottom=True)
f = f.set(xticks=[], yticks=[]).set_xlabels("").set_ylabels("")
# plt.legend(frameon=True)
plt.title(name)
plt.xticks(fontsize=ticks_fontsize) # Increase x-axis label size
plt.yticks(fontsize=ticks_fontsize) # Increase y-axis label size
plt.show()
figures += [f]
if invert_y:
df[Y] = y_orig
return figures
##########
def pl_comb_num_freq(data_list, plot_order=None, pal_tis=None, figsize=(5, 5)):
df_new = []
for df in data_list:
df.reset_index(inplace=True)
df.rename(columns={"merge": "combination"}, inplace=True)
df["count"] = df["combination"].apply(len)
sum_df = df.groupby("count").sum()
tbt = sum_df.reset_index()
ttt = tbt.melt(id_vars=["count"])
ttt.rename(
columns={"variable": "unique_cond", "value": "fraction"}, inplace=True
)
df_new.append(ttt)
df_exp = pd.concat(df_new)
df_exp[["donor", "tissue"]] = df_exp["unique_cond"].str.split("_", expand=True)
# swarmplot to compare
plt.figure(figsize=figsize)
ax = sns.boxplot(
data=df_exp,
x="count",
y="fraction",
hue="tissue",
dodge=True,
hue_order=plot_order,
palette=pal_tis,
)
ax = sns.swarmplot(
data=df_exp,
x="count",
y="fraction",
hue="tissue",
dodge=True,
hue_order=plot_order,
edgecolor="black",
linewidth=1,
palette=pal_tis,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
# ax.set_yscale(\log\)
plt.xlabel("")
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[: len(df_exp["tissue"].unique())],
labels[: len(df_exp["tissue"].unique())],
bbox_to_anchor=(1.05, 1),
loc=2,
borderaxespad=0.0,
frameon=False,
)
plt.xticks(rotation=90)
sns.despine(trim=True)
return df_exp
##########
# this function helps to determine what threshold to use for remove noises
# default cut off is top 1%
[docs]
def zcount_thres(
dfz, col_num, cut_off=0.01, count_bin=50, zsum_bin=50, figsize=(10, 5)
):
"""
Determines the threshold to use for removing noises. The default cut off is the top 1%.
Parameters
----------
dfz : DataFrame
The input data from which the threshold is to be determined.
col_num : int
The column number up to which the operation is performed.
cut_off : float, optional
The cut off percentage for the threshold. By default, it is 0.01 (1%).
count_bin : int, optional
The number of bins for the count histogram. By default, it is 50.
zsum_bin : int, optional
The number of bins for the z-score sum histogram. By default, it is 50.
figsize : tuple, optional
The size of the figure to be plotted. By default, it is (10, 5).
Returns
-------
None
This function doesn't return anything. It plots two histograms for 'Count' and 'Zscore sum' with the cut off line.
"""
dfz_copy = dfz
dfz_copy["Count"] = dfz.iloc[:, : col_num + 1].ge(0).sum(axis=1)
dfz_copy["z_sum"] = dfz.iloc[:, : col_num + 1].sum(axis=1)
fig, axes = plt.subplots(1, 2, constrained_layout=True, figsize=figsize)
axes[0].hist(dfz_copy["Count"], bins=count_bin)
axes[0].set_title("Count")
axes[0].axvline(
dfz_copy["Count"].quantile(1 - cut_off),
color="k",
linestyle="dashed",
linewidth=1,
)
axes[0].text(
0.75,
0.75,
"Cut off: {:.2f}".format(dfz_copy["Count"].quantile(1 - cut_off)),
ha="right",
va="bottom",
transform=axes[0].transAxes,
)
axes[1].hist(dfz_copy["z_sum"], bins=zsum_bin)
axes[1].title.set_text("Zscore sum")
axes[1].axvline(
dfz_copy["z_sum"].quantile(1 - cut_off),
color="k",
linestyle="dashed",
linewidth=1,
)
axes[1].text(
0.75,
0.75,
"Cut off: {:.2f}".format(dfz_copy["z_sum"].quantile(1 - cut_off)),
ha="right",
va="bottom",
transform=axes[1].transAxes,
)
##########
def pl_mono_cluster_spatial(
df,
sample_col="Sample",
cluster_col="Cell Type",
x="x",
y="y",
color_dict=None,
s=3,
alpha=0.5,
figsize=(15, 12),
):
for i in df[sample_col].unique():
df_sub = df[df[sample_col] == i]
print(i)
celltype = list(df_sub[cluster_col].unique())
ncols = 4
nrows = len(celltype) // ncols + (len(celltype) % ncols > 0)
fig, axs = plt.subplots(nrows, ncols, figsize=figsize)
plt.subplots_adjust(hspace=0.5)
for ct, ax in zip(celltype, axs.ravel()):
df_tmp = df_sub[df_sub[cluster_col] == ct]
if color_dict == None:
sns.scatterplot(
x=x, y=y, data=df_tmp, hue=cluster_col, s=s, alpha=alpha, ax=ax
)
else:
sns.scatterplot(
x=x,
y=y,
data=df_tmp,
hue=cluster_col,
s=s,
alpha=alpha,
ax=ax,
palette=color_dict,
)
ax.set_title(ct.upper())
ax.invert_yaxis()
ax.get_legend().remove()
ax.set_xlabel("")
plt.show()
#########
def pl_visualize_2D_density_plot(
df,
region_column,
selected_region,
subsetting_column,
values_list,
x_column,
y_column,
):
# Subset the DataFrame based on region_column and selected_region
subset_df1 = df[df[region_column] == selected_region]
# Subset the DataFrame based on subsetting_column and values_list
subset_df2 = subset_df1[subset_df[subsetting_column].isin(values_list)]
# Create a 2D density plot
sns.kdeplot(data=subset_df2, x=x_column, y=y_column, fill=True)
# Overlay the individual data points as a scatter plot
sns.scatterplot(
data=subset_df1, x=x_column, y=y_column, color="lightgrey", alpha=0.5
)
# Add labels and title to the plot
plt.xlabel(x_column)
plt.ylabel(y_column)
plt.title("2D Density Plot with Overlay")
# Display the plot
plt.show()
#######
def pl_create_cluster_celltype_heatmap(dataframe, cluster_column, celltype_column):
# Create a frequency table using pandas crosstab
frequency_table = pd.crosstab(dataframe[cluster_column], dataframe[celltype_column])
# Create the heatmap using seaborn
plt.figure(figsize=(20, 6)) # Set the size of the heatmap (adjust as needed)
sns.heatmap(
frequency_table, cmap="YlGnBu", annot=True, fmt="d"
) # cmap sets the color palette
plt.title("Cluster-Cell Type Heatmap")
plt.xlabel("Cell Types")
plt.ylabel("Cluster IDs")
plt.show()
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
[docs]
def catplot(
adata,
color,
unique_region,
subset=None,
X="x",
Y="y",
invert_y=False,
size=6,
alpha=1,
palette=None, # default is None which means the color comes from the anndata object
savefig=False,
output_dir="./",
output_fname="",
figsize=5,
style="white",
axis="on",
scatter_kws={},
n_columns=4,
legend_padding=0.2,
rand_seed=1,
):
"""
Plots cells in tissue section color coded by either cell type or node allocation.
Parameters
----------
adata : AnnData
Annotated data matrix.
color : str
Color by "Clusterid" or "Node" respectively.
unique_region : str
Each region is one independent CODEX image.
subset : str, optional
Subset of data to plot. If None, all data is plotted.
X : str, optional
Column name for x-axis in the DataFrame.
Y : str, optional
Column name for y-axis in the DataFrame.
invert_y : bool, optional
If True, invert y-axis.
size : int, optional
Size of point to plot for each cell.
alpha : float, optional
Transparency of points.
palette : dict, optional
Colors to use for different levels of the `hue` variable. Should be something that can be interpreted by `color_palette()`, or a dictionary mapping hue levels to matplotlib colors.
savefig : bool, optional
If True, save figure.
output_dir : str, optional
Directory to save figure.
output_fname : str, optional
Filename to save figure.
figsize : int, optional
Size of the figure.
style : str, optional
Style of the plot.
axis : str, optional
If "off", axis is not displayed.
scatter_kws : dict, optional
Additional keyword arguments to pass to `scatterplot()`.
n_columns : int, optional
Number of columns in the figure.
legend_padding : float, optional
Padding around the legend.
rand_seed : int, optional
Seed for random number generator.
Returns
-------
None
This function doesn't return anything. It plots a scatterplot with the specified parameters.
"""
scatter_kws_ = {"s": size, "alpha": alpha}
scatter_kws_.update(scatter_kws)
df = pd.DataFrame(adata.obs[[X, Y, color, unique_region]])
df[color] = df[color].astype("category")
if invert_y:
y_orig = df[Y].values.copy()
df[Y] *= -1
if palette is None:
if color + "_colors" not in adata.uns.keys():
ct_colors = hf_generate_random_colors(
len(adata.obs[color].unique()), rand_seed=rand_seed
)
palette = dict(zip(np.sort(adata.obs[color].unique()), ct_colors))
adata.uns[color + "_colors"] = ct_colors
else:
palette = dict(
zip(np.sort(adata.obs[color].unique()), adata.uns[color + "_colors"])
)
style = {"axes.facecolor": style}
sns.set_style(style)
if subset is None:
region_list = list(
df[unique_region].unique().sort_values()
) # display all experiments
else:
if subset not in list(df[unique_region].unique().sort_values()):
print(subset + " is not in unique_region!")
return
else:
region_list = [subset]
n_rows = int(np.ceil(len(region_list) / n_columns))
fig, axes = plt.subplots(
n_rows,
n_columns,
figsize=(figsize * n_columns, figsize * n_rows),
squeeze=False,
gridspec_kw={"wspace": 0.5, "hspace": 0.4},
)
for i_ax, (name, ax) in enumerate(zip(region_list, axes.flatten())):
data = df[df[unique_region] == name]
# print(name)
sns.scatterplot(
x=X, y=Y, data=data, hue=color, palette=palette, ax=ax, s=size, alpha=alpha
)
ax.grid(False)
if axis == "off":
ax.axis("off")
ax.set_title(name)
ax.set_aspect("equal")
# Add padding to the legend
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
# frame = legend.get_frame()
# frame.set_facecolor('white') # Adjust the legend background color
for i in range(i_ax + 1, n_rows * n_columns):
axes.flatten()[i].axis("off")
# fig.tight_layout(pad = 0.5)
if savefig:
fig.savefig(
output_dir + output_fname + "_spatial_plot.pdf", bbox_inches="tight"
)
def pl_generate_CN_comb_map(
graph,
tops,
e0,
e1,
simp_freqs,
palette,
figsize=(40, 20),
savefig=False,
output_dir="./",
):
draw = graph
pos = nx.drawing.nx_pydot.graphviz_layout(draw, prog="dot")
height = 8
plt.figure(figsize=figsize)
for n in draw.nodes():
col = "black"
if len(draw.in_edges(n)) < len(n):
col = "black"
plt.scatter(
pos[n][0],
pos[n][1] - 5,
s=simp_freqs[list(simp_freqs.index).index(n)] * 10000,
c=col,
zorder=-1,
)
if n in tops:
plt.text(
pos[n][0],
pos[n][1] - 7,
"*",
fontsize=25,
color="white",
ha="center",
va="center",
zorder=20,
)
delta = 8
# l is just the color keys
l = list(palette.keys())
plt.scatter(
[pos[n][0]] * len(n),
[pos[n][1] + delta * (i + 1) for i in range(len(n))],
c=[palette[l[i]] for i in n],
marker="s",
zorder=5,
s=400,
)
j = 0
for e0, e1 in draw.edges():
weight = 0.2
alpha = 0.3
color = "black"
if len(draw.in_edges(e1)) < len(e1):
color = "black"
lw = 1
weight = 0.4
plt.plot(
[pos[e0][0], pos[e1][0]],
[pos[e0][1], pos[e1][1]],
color=color,
linewidth=weight,
alpha=alpha,
zorder=-10,
)
plt.axis("off")
if savefig:
plt.savefig(output_dir + "_CNMap.pdf", bbox_inches="tight")
else:
plt.show()
[docs]
def stacked_bar_plot(
adata,
color,
grouping,
cell_list,
output_dir,
norm=True,
savefig=False, # new
output_fname="", # new
col_order=None,
sub_col=None,
name_cat="celltype",
fig_sizing=(8, 4),
plot_order=None,
palette=None,
remove_leg=False,
rand_seed=1,
):
"""
Plot a stacked bar plot based on the given data.
Parameters
----------
data : pandas.DataFrame
The input data containing the necessary information for plotting.
color : str
The column name representing the categories.
grouping : str
The column name representing the grouping.
cell_list : list
The list of cell types to include in the plot.
output_dir : str
The output directory for saving the plot.
norm : bool, optional
Flag indicating whether to normalize the values. Defaults to True.
save_name : str, optional
The name to use when saving the plot. Defaults to None.
col_order : list, optional
The order of columns/categories for plotting. Defaults to None.
sub_col : str, optional
The column name representing sub-categories. Defaults to None.
name_cat : str, optional
The name for the category column in the plot. Defaults to 'celltype'.
fig_sizing : tuple, optional
The size of the figure (width, height) in inches. Defaults to (8, 4).
plot_order : list, optional
The order of categories for plotting. Defaults to None.
palette : dict, optional
A dictionary mapping categories to colors for custom colorization. Defaults to None.
remove_leg : bool, optional
Flag indicating whether to remove the legend. Defaults to False.
Returns
-------
pandas.DataFrame
The pivoted data used for plotting.
list
The order of categories used for plotting.
"""
data = adata.obs
# Find Percentage of cell type
if norm == True:
if sub_col is None:
test1 = data.loc[data[color].isin(cell_list)]
sub_cell_list = list(test1[color].unique())
else:
test1 = data.loc[data[sub_col].isin(cell_list)]
sub_cell_list = list(test1[color].unique())
else:
if sub_col is None:
test1 = data.copy()
sub_cell_list = list(data.loc[data[color].isin(cell_list)][color].unique())
else:
test1 = data.copy()
sub_cell_list = list(
data.loc[data[sub_col].isin(cell_list)][color].unique()
)
if palette is None:
if color + "_colors" not in adata.uns.keys():
ct_colors = hf_generate_random_colors(
len(adata.obs[color].unique()), rand_seed=rand_seed
)
palette = dict(zip(np.sort(adata.obs[color].unique()), ct_colors))
adata.uns[color + "_colors"] = ct_colors
else:
palette = dict(
zip(np.sort(adata.obs[color].unique()), adata.uns[color + "_colors"])
)
test1[color] = test1[color].astype("category")
test_freq = test1.groupby(grouping).apply(
lambda x: x[color].value_counts(normalize=True, sort=False) * 100
)
test_freq.columns = test_freq.columns.astype(str)
##### Can subset it here if I do not want normalized per the group
test_freq.reset_index(inplace=True)
sub_cell_list.append(grouping)
test_freq = test_freq[sub_cell_list]
melt_test = pd.melt(
test_freq, id_vars=[grouping]
) # , value_vars=test_freq.columns)
# melt_test.rename(columns={per_cat: name_cat, "value": "percent"}, inplace=True)
melt_test.rename(columns={"value": "percent"}, inplace=True)
if norm == True:
if col_order is None:
bb = melt_test.groupby([grouping, color]).sum().reset_index()
col_order = (
bb.loc[bb[color] == bb[color][0]]
.sort_values(by="percent")[grouping]
.to_list()
)
else:
if col_order is None:
col_order = (
melt_test.groupby(grouping)
.sum()
.reset_index()
.sort_values(by="percent")[grouping]
.to_list()
)
if plot_order is None:
plot_order = list(melt_test[color].unique())
# Set up for plotting
melt_test_piv = pd.pivot_table(
melt_test, columns=[color], index=[grouping], values=["percent"]
)
melt_test_piv.columns = melt_test_piv.columns.droplevel(0)
melt_test_piv.reset_index(inplace=True)
melt_test_piv.set_index(grouping, inplace=True)
melt_test_piv = melt_test_piv.reindex(col_order)
melt_test_piv = melt_test_piv[plot_order]
# Get color dictionary
ax1 = melt_test_piv.plot.bar(
alpha=0.8,
linewidth=1,
color=[palette.get(x) for x in melt_test_piv.columns],
figsize=fig_sizing,
rot=90,
stacked=True,
edgecolor="black",
)
for line in ax1.lines:
line.set_color("black")
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
if remove_leg == True:
ax1.set_ylabel("")
ax1.set_xlabel("")
else:
ax1.set_ylabel("percent")
# ax1.spines['left'].set_position(('data', 1.0))
# ax1.set_xticks(np.arange(1,melt_test.day.max()+1,1))
# ax1.set_ylim([0, int(ceil(max(max(melt_test_piv.sum(axis=1)), max(tm_piv.sum(axis=1)))))])
plt.xticks(
list(range(len(list(melt_test_piv.index)))),
list(melt_test_piv.index),
rotation=90,
)
lgd2 = ax1.legend(
loc="center left", bbox_to_anchor=(1.0, 0.5), ncol=1, frameon=False
)
if savefig:
plt.savefig(
output_dir + output_fname + ".pdf",
format="pdf",
dpi=300,
transparent=True,
bbox_inches="tight",
)
else:
plt.show()
return melt_test_piv, plot_order
def pl_swarm_box_ad(
adata,
grouping,
per_cat,
replicate,
sub_col,
sub_list,
output_dir,
norm=True,
figure_sizing=(10, 5),
save_name=None,
plot_order=None,
col_in=None,
color_dic=None,
flip=False,
):
# extract information form adata
data = adata.obs
# Find Percentage of cell type
test = data.copy()
sub_list1 = sub_list.copy()
if norm == True:
test1 = test.loc[test[sub_col].isin(sub_list1)]
immune_list = list(test1[per_cat].unique())
else:
test1 = test.copy()
immune_list = list(test.loc[test[sub_col].isin(sub_list1)][per_cat].unique())
test1[per_cat] = test1[per_cat].astype("category")
test_freq = test1.groupby([grouping, replicate]).apply(
lambda x: x[per_cat].value_counts(normalize=True, sort=False) * 100
)
test_freq.columns = test_freq.columns.astype(str)
test_freq.reset_index(inplace=True)
immune_list.extend([grouping, replicate])
test_freq1 = test_freq[immune_list]
melt_per_plot = pd.melt(
test_freq1,
id_vars=[
grouping,
replicate,
],
) # ,value_vars=immune_list)
melt_per_plot.rename(columns={"value": "percentage"}, inplace=True)
if col_in:
melt_per_plot = melt_per_plot.loc[melt_per_plot[per_cat].isin(col_in)]
else:
melt_per_plot = melt_per_plot
if plot_order is None:
plot_order = list(melt_per_plot[grouping].unique())
else:
# Order by average
plot_order = (
melt_per_plot.groupby(per_cat)
.mean()
.reset_index()
.sort_values(by="percentage")[per_cat]
.to_list()
)
# swarmplot to compare clustering
plt.figure(figsize=figure_sizing)
if flip == True:
plt.figure(figsize=figure_sizing)
if color_dic is None:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
color="white",
palette=color_dic,
)
else:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
palette=color_dic,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
palette=color_dic,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
plt.xticks(rotation=90)
plt.xlabel("")
plt.ylabel("")
plt.title(sub_list[0])
sns.despine()
else:
if color_dic is None:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
color="white",
)
else:
ax = sns.boxplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
palette=color_dic,
)
ax = sns.swarmplot(
data=melt_per_plot,
x=grouping,
y="percentage",
dodge=True,
order=plot_order,
edgecolor="black",
linewidth=1,
palette=color_dic,
)
for patch in ax.artists:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, 0.3))
# ax.set_yscale(\log\)
plt.xlabel("")
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[: len(melt_per_plot[grouping].unique())],
labels[: len(melt_per_plot[grouping].unique())],
bbox_to_anchor=(1.05, 1),
loc=2,
borderaxespad=0.0,
frameon=False,
)
plt.xticks(rotation=90)
ax.set(ylim=(0, melt_per_plot["percentage"].max() + 1))
sns.despine()
if output_dir:
if save_name:
plt.savefig(
output_dir + save_name + "_swarm_boxplot.png",
format="png",
dpi=300,
transparent=True,
bbox_inches="tight",
)
else:
print("define save_name")
else:
print("plot was not saved - to save the plot specify an output directory")
return melt_per_plot
[docs]
def create_pie_charts(
adata,
color,
grouping,
plot_order=None,
show_percentages=True,
palette=None,
savefig=False,
output_fname="",
output_dir="./",
rand_seed=1,
):
"""
Create pie charts for each group based on a grouping column, showing the percentage of total rows based on a
count column.
Parameters
----------
adata : pd.DataFrame
The input DataFrame.
color : str
The column name used for counting occurrences.
grouping : str
The column name for grouping the data.
plot_order : list, optional
The order of groups for plotting. Defaults to None.
show_percentages : bool, optional
Whether to show the percentage numbers on the pie charts. Defaults to True.
palette : dict, optional
A dictionary to manually set colors for neighborhoods. Defaults to None.
savefig : bool, optional
Whether to save the figure or not. Defaults to False.
output_fname : str, optional
The output file name. Defaults to "".
output_dir : str, optional
The output directory. Defaults to "./".
rand_seed : int, optional
The random seed for color generation. Defaults to 1.
Returns
-------
None
"""
data = adata.obs
# Group the data by the grouping column
grouped_data = data.groupby(grouping)
# Sort the groups based on the plot_order if provided
if plot_order:
grouped_data = sorted(grouped_data, key=lambda x: plot_order.index(x[0]))
# Calculate the number of rows and columns for subplots
num_groups = len(grouped_data)
num_cols = 3 # Number of columns for subplots
num_rows = (num_groups - 1) // num_cols + 1
# Create subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
axes = axes.flatten() # Flatten the subplots array
# Create a color dictionary if not provided
if palette is None:
if color + "_colors" not in adata.uns.keys():
ct_colors = hf_generate_random_colors(
len(adata.obs[color].unique()), rand_seed=rand_seed
)
palette = dict(zip(np.sort(adata.obs[color].unique()), ct_colors))
adata.uns[color + "_colors"] = ct_colors
else:
palette = dict(
zip(np.sort(adata.obs[color].unique()), adata.uns[color + "_colors"])
)
# Iterate over each group and create a pie chart
for i, (group, group_df) in enumerate(grouped_data):
# Count the occurrences of each neighborhood within the group
neighborhood_counts = group_df[color].value_counts()
# Calculate the percentage of total rows for each neighborhood
percentages = neighborhood_counts / group_df.shape[0] * 100
# Create a color list for neighborhoods in the count column
colors = [
palette.get(neighborhood, "gray") for neighborhood in percentages.index
]
if show_percentages:
wedges, texts, autotexts = axes[i].pie(
percentages, labels=percentages.index, autopct="%1.1f%%", colors=colors
)
axes[i].set_title(f"Group: {group}")
else:
wedges, texts = axes[i].pie(
percentages, labels=percentages.index, colors=colors
)
axes[i].set_title(f"Group: {group}")
# Remove unused subplots
for j in range(num_groups, num_rows * num_cols):
fig.delaxes(axes[j])
# Adjust spacing between subplots
fig.tight_layout()
if savefig:
plt.savefig(
output_dir + output_fname + "_piechart.pdf",
format="pdf",
dpi=300,
transparent=True,
bbox_inches="tight",
)
else:
# Show the plot
plt.show()
[docs]
def cn_exp_heatmap(
adata,
cluster_col,
cn_col,
palette=None,
savefig=False,
output_fname="",
output_dir="./",
row_clus=True,
col_clus=True,
rand_seed=1,
):
"""
Create a heatmap of expression data, clustered by rows and columns.
Parameters
----------
adata : AnnData
Annotated data matrix.
cluster_col : str
The column name for clustering the data.
cn_col : str
The column name for the color selection.
palette : dict, optional
A dictionary to manually set colors for neighborhoods. Defaults to None.
figsize : tuple, optional
The size of the figure. Defaults to (18, 12).
savefig : bool, optional
Whether to save the figure or not. Defaults to False.
output_fname : str, optional
The output file name. Defaults to "".
output_dir : str, optional
The output directory. Defaults to "./".
row_clus : bool, optional
Whether to cluster the rows or not. Defaults to True.
col_clus : bool, optional
Whether to cluster the columns or not. Defaults to True.
rand_seed : int, optional
The random seed for color generation. Defaults to 1.
Returns
-------
None
"""
data = adata.obs
output_dir = pathlib.Path(output_dir)
if palette is None:
if cn_col + "_colors" not in adata.uns.keys():
# Create a color dictionary if not provided
cn_colors = hf_generate_random_colors(
len(adata.obs[cn_col].unique()), rand_seed=rand_seed
)
palette = dict(zip(np.sort(adata.obs[cn_col].unique()), cn_colors))
adata.uns[cn_col + "_colors"] = cn_colors
else:
palette = dict(
zip(np.sort(adata.obs[cn_col].unique()), adata.uns[cn_col + "_colors"])
)
neigh_data = pd.DataFrame(
{cn_col: list(palette.keys()), "color": list(palette.values())}
)
neigh_data.set_index(keys=cn_col, inplace=True)
df3 = pd.concat([data, pd.get_dummies(data[cluster_col])], axis=1)
sum_cols2 = df3[cluster_col].unique()
values2 = df3[sum_cols2].values
cell_list = sum_cols2.copy()
cell_list = cell_list.tolist()
cell_list.append(cn_col)
subset = df3[cell_list]
niche_sub = subset.groupby(cn_col).sum()
niche_df = niche_sub.apply(lambda x: x / x.sum() * 10, axis=1)
neigh_clusters = niche_df.to_numpy()
tissue_avgs = values2.mean(axis=0)
fc_2 = np.log2(
(
(neigh_clusters + tissue_avgs)
/ (neigh_clusters + tissue_avgs).sum(axis=1, keepdims=True)
)
/ tissue_avgs
)
fc_2 = pd.DataFrame(fc_2, columns=sum_cols2)
fc_2.set_index(niche_df.index, inplace=True)
s = sns.clustermap(
fc_2,
vmin=-3,
vmax=3,
cmap="bwr",
figsize=(10, 5),
row_colors=[neigh_data.reindex(fc_2.index)["color"]],
cbar_pos=(0.03, 0.06, 0.03, 0.1),
)
s.ax_row_dendrogram.set_visible(row_clus)
s.ax_col_dendrogram.set_visible(col_clus)
s.ax_heatmap.set_ylabel("", labelpad=25)
s.ax_heatmap.tick_params(axis="y", pad=42)
s.ax_heatmap.yaxis.set_ticks_position("right")
# add text to bottom left corner of the plot
s.ax_heatmap.text(
-0.25,
-0.85,
"log2 fold change \nover tissue average",
transform=s.ax_heatmap.transAxes,
fontsize=12,
verticalalignment="bottom",
horizontalalignment="center",
)
if savefig:
s.figure.savefig(output_dir / (output_fname + ".pdf"), bbox_inches="tight")
def pl_area_nuc_cutoff(
df,
cutoff_area,
cutoff_nuc,
cellsize_column="area",
nuc_marker_column="Hoechst1",
color_by="unique_label",
palette="Paired",
alpha=0.8,
size=0.4,
log_scale=True,
):
# Custom the inside plot: options are: “scatter” | “reg” | “resid” | “kde” | “hex”
g = sns.jointplot(
x=df[nuc_marker_column],
y=df[cellsize_column],
hue=df[color_by],
palette=palette,
alpha=alpha,
)
if log_scale == True:
# log scale joint plot
g.ax_joint.set_xscale("log")
g.ax_joint.set_yscale("log")
# add horizontal and vertical lines
g.ax_joint.axhline(cutoff_area, color="k", linestyle="dashed", linewidth=1)
g.ax_joint.axvline(cutoff_nuc, color="k", linestyle="dashed", linewidth=1)
# place legend outside
g.ax_joint.legend(bbox_to_anchor=(1.2, 1), loc="upper left", borderaxespad=0)
# show plot
plt.show()
def pl_plot_scatter_correlation(data, x, y, xlabel=None, ylabel=None, save_path=None):
g = sns.lmplot(x=x, y=y, data=data, height=5, aspect=1)
g.map_dataframe(hf_annotate_cor_plot, x=x, y=y, data=data)
if xlabel:
plt.xlabel(xlabel, fontsize=14)
if ylabel:
plt.ylabel(ylabel, fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
if save_path:
plt.savefig(
save_path + "_corrplot.pdf", transparent=True, dpi=600, bbox_inches="tight"
)
plt.show()
def pl_plot_scatter_correlation_ad(
adata, x, y, xlabel=None, ylabel=None, save_path=None
):
data = adata.obs
g = sns.lmplot(x=x, y=y, data=data, height=5, aspect=1)
g.map_dataframe(hf_annotate_cor_plot, x=x, y=y, data=data)
if xlabel:
plt.xlabel(xlabel, fontsize=14)
if ylabel:
plt.ylabel(ylabel, fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
if save_path:
plt.savefig(
save_path + "_corrplot.pdf", transparent=True, dpi=600, bbox_inches="tight"
)
plt.show()
########
def pl_plot_correlation_matrix(cmat):
# plot correlation matrix as heatmap
# Create a mask to hide the upper triangle
mask = np.triu(np.ones_like(cmat, dtype=bool))
fig, ax = plt.subplots(figsize=(20, 20))
sns.heatmap(
cmat, annot=True, fmt=".2f", cmap="coolwarm", square=True, mask=mask, ax=ax
)
plt.show()
[docs]
def dumbbell(
data,
figsize=(10, 10),
colors=["#DB444B", "#006BA2"],
savefig=False,
output_fname="",
output_dir="./",
):
"""
Create a dumbbell plot.
Parameters
----------
data : pd.DataFrame
The input DataFrame. The DataFrame should have two columns representing two different conditions.
figsize : tuple, optional
The size of the figure. Defaults to (10, 10).
colors : list, optional
The colors to use for the two conditions. Defaults to ["#DB444B", "#006BA2"].
Returns
-------
None
"""
fig, ax = plt.subplots(figsize=figsize, facecolor="white")
# plot each country one at a time
# Create grid
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
ax.grid(which="major", axis="both", color="#758D99", alpha=0.6, zorder=1)
# Remove splines. Can be done one at a time or can slice with a list.
ax.spines[["top", "right", "bottom"]].set_visible(False)
# Plot data
comp_cat = data.columns
# Plot horizontal lines first
ax.hlines(
y=data.index,
xmin=data[comp_cat[0]],
xmax=data[comp_cat[1]],
color="#758D99",
zorder=2,
linewidth=2,
label="_nolegend_",
alpha=0.8,
)
# Plot bubbles next
ax.scatter(
data[comp_cat[0]], data.index, label="2014", s=60, color=colors[0], zorder=3
)
ax.scatter(
data[comp_cat[1]], data.index, label="2018", s=60, color=colors[1], zorder=3
)
# Set xlim
# ax.set_xlim(-3, 3)
# Reformat x-axis tick labels
ax.xaxis.set_tick_params(
labeltop=True, # Put x-axis labels on top
labelbottom=False, # Set no x-axis labels on bottom
bottom=False, # Set no ticks on bottom
labelsize=11, # Set tick label size
pad=-1,
) # Lower tick labels a bit
ax.axvline(x=0, color="k", linestyle="--")
# Set Legend
ax.legend(
data.columns,
loc=(0, 1.076),
ncol=2,
frameon=False,
handletextpad=0.4,
handleheight=1,
)
if savefig == True:
plt.savefig(output_dir + output_fname + ".pdf", bbox_inches="tight")
[docs]
def plot_top_n_distances(
dist_table_filt,
dist_data_filt,
n=5,
colors=None,
dodge=False,
savefig=False,
output_fname="",
output_dir="./",
figsize=(5, 5),
unit="px",
errorbars=True,
):
# Calculate the aspect ratio based on the desired figure size
aspect = figsize[0] / figsize[1]
# calculate abs distance
dist_table_filt["abs_dist"] = abs(
dist_table_filt.iloc[:, 0] - dist_table_filt.iloc[:, 1]
)
# replace with the number of combinations you want
top_n_combinations = dist_table_filt.nlargest(n, "abs_dist")
print(top_n_combinations)
# Convert index of top_n_combinations to a set
top_pairs = set(top_n_combinations.index)
# Filter dist_data_filt
top_dist_data_filt = dist_data_filt[dist_data_filt["pairs"].isin(top_pairs)]
exploded_df = top_dist_data_filt.explode("observed")
# Draw a pointplot to show pulse as a function of three categorical factors
if errorbars == True:
g = sns.catplot(
data=exploded_df,
x="condition",
y="observed",
hue="interaction",
capsize=0.2,
palette=colors,
errorbar="se",
kind="point",
height=figsize[1],
aspect=aspect,
dodge=dodge,
)
else:
g = sns.catplot(
data=exploded_df,
x="condition",
y="observed",
hue="interaction",
capsize=0.2,
palette=colors,
kind="point",
height=figsize[1],
aspect=aspect,
dodge=dodge,
ci=None,
)
# change y axis name to "observed distance"
g.set_axis_labels("Condition", "Observed distance in " + unit)
g.despine(left=True)
if savefig == True:
plt.savefig(output_dir + output_fname + ".pdf", bbox_inches="tight")
[docs]
def cn_map(
adata,
cnmap_dict,
cn_col,
palette=None,
figsize=(40, 20),
savefig=False,
output_fname="",
output_dir="./",
rand_seed=1,
):
"""
Generates a CNMap plot using the provided data and parameters.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix.
cnmap_dict : dict
Dictionary containing graph, tops, e0, e1, and simp_freqs.
cn_col : str
Column name in adata to be used for color coding.
palette : dict, optional
Color palette to use for the plot. If None, a random color palette is generated.
figsize : tuple, optional
Size of the figure. Defaults to (40, 20).
savefig : bool, optional
Whether to save the figure or not. Defaults to False.
output_fname : str, optional
The filename for the saved figure. Required if savefig is True. Defaults to "".
output_dir : str, optional
The directory where the figure will be saved. Defaults to "./".
rand_seed : int, optional
Seed for random number generator. Defaults to 1.
Returns
-------
None
"""
graph = cnmap_dict["g"]
tops = cnmap_dict["tops"]
e0 = cnmap_dict["e0"]
e1 = cnmap_dict["e1"]
simp_freqs = cnmap_dict["simp_freqs"]
draw = graph
pos = nx.drawing.nx_pydot.graphviz_layout(draw, prog="dot")
height = 8
# generate color
cn_colors = hf_generate_random_colors(
len(adata.obs[cn_col].unique()), rand_seed=rand_seed
)
if palette is None:
if cn_col + "_colors" not in adata.uns.keys():
palette = dict(zip(np.sort(adata.obs[cn_col].unique()), cn_colors))
adata.uns[cn_col + "_colors"] = cn_colors
else:
palette = dict(
zip(np.sort(adata.obs[cn_col].unique()), adata.uns[cn_col + "_colors"])
)
plt.figure(figsize=figsize)
# Sort the nodes
sorted_nodes = sorted(draw.nodes())
# Iterate over the sorted nodes
for n in sorted_nodes:
col = "black"
if len(draw.in_edges(n)) < len(n):
col = "black"
plt.scatter(
pos[n][0],
pos[n][1] - 5,
s=simp_freqs[list(simp_freqs.index).index(n)] * 10000,
c=col,
zorder=-1,
)
# Dummy scatter plots for legend
freqs = simp_freqs * 10000
max_size = max(freqs)
sizes = [
round(max_size) / 4,
round(max_size) / 2,
round(max_size),
] # Replace with the sizes you want in the legend
labels = [
str(round(max_size / 100) / 4) + "%",
str(round(max_size / 100) / 2) + "%",
str(round(max_size / 100)) + "%",
] # Replace with the labels you want in the legend
# Add legend
legend_elements = [
plt.Line2D(
[0],
[0],
marker="o",
color="w",
label=label,
markerfacecolor="black",
markersize=size**0.5,
)
for size, label in zip(sizes, labels)
]
# Add first legend
legend1 = plt.legend(
handles=legend_elements,
loc="lower right",
title="Total frequency",
title_fontsize=30,
fontsize=30,
handlelength=6,
handletextpad=1,
bbox_to_anchor=(0.0, -0.15, 1.0, 0.102),
)
if n in tops:
plt.text(
pos[n][0],
pos[n][1] - 7,
"*",
fontsize=25,
color="white",
ha="center",
va="center",
zorder=20,
)
delta = 8
# l is just the color keys
l = list(palette.keys())
plt.scatter(
[pos[n][0]] * len(n),
[pos[n][1] + delta * (i + 1) for i in range(len(n))],
c=[palette[l[i]] for i in n],
marker="s",
zorder=5,
s=1000,
)
j = 0
for e0, e1 in draw.edges():
weight = 2
alpha = 0.3
color = "black"
if len(draw.in_edges(e1)) < len(e1):
color = "black"
lw = 1
weight = 3
plt.plot(
[pos[e0][0], pos[e1][0]],
[pos[e0][1], pos[e1][1]],
color=color,
linewidth=weight,
alpha=alpha,
zorder=-10,
)
# Create a list of Patch objects, one for each unique color in your palette
legend_patches = [
mpatches.Patch(color=color, label=label) for label, color in palette.items()
]
# Add legend to bottom of plot
plt.gca().add_artist(legend1)
plt.legend(
handles=legend_patches,
bbox_to_anchor=(0.0, -0.15, 1.0, 0.102),
loc="lower left",
ncol=3,
borderaxespad=0.0,
fontsize=35,
)
plt.axis("off")
if savefig:
plt.savefig(output_dir + output_fname + "_CNMap.pdf", bbox_inches="tight")
else:
plt.show()
[docs]
def coordinates_on_image(
df,
overlay_data,
color=None,
x="x",
y="y",
fig_width=20,
fig_height=20,
dot_size=10,
convert_to_grey=True,
scale=False,
cmap="inferno",
savefig=False,
output_dir="./",
output_fname="",
):
"""
Plot coordinates on an image.
Parameters
----------
df : pd.DataFrame
The input DataFrame. The DataFrame should have columns 'x' and 'y' representing the coordinates.
overlay_data : ndarray
The image data to overlay the coordinates on.
color : str, optional
The column name in df for the color variable. Defaults to None.
x : str, optional
The column name in df for the x-coordinate. Defaults to "x".
y : str, optional
The column name in df for the y-coordinate. Defaults to "y".
fig_width : int, optional
The width of the figure. Defaults to 20.
fig_height : int, optional
The height of the figure. Defaults to 20.
dot_size : int, optional
The size of the dots representing the coordinates. Defaults to 10.
convert_to_grey : bool, optional
Whether to convert the image to grayscale. Defaults to True.
scale : bool, optional
Whether to scale the color variable. Defaults to False.
cmap : str, optional
The colormap to use. Defaults to "inferno".
savefig : bool, optional
Whether to save the figure or not. Defaults to False.
output_dir : str, optional
The output directory. Defaults to "./".
output_fname : str, optional
The output file name. Defaults to "".
Returns
-------
None
"""
# Create a new figure with increased size
plt.figure(figsize=(fig_width, fig_height))
if convert_to_grey:
# Convert the image to grayscale
overlay_data = skimage.color.rgb2gray(overlay_data)
plt.imshow(overlay_data, cmap="gray")
else:
plt.imshow(overlay_data)
image_width, image_height = overlay_data.shape[1], overlay_data.shape[0]
# Plot the coordinates on top of the image
# colorscale by area
if color != None:
if scale:
# minmax scale the variable
df[color] = (df[color] - df[color].min()) / (
df[color].max() - df[color].min()
)
# change dot size based on variable
plt.scatter(df[x], df[y], s=df[color] * 30, c=df[color], cmap=cmap)
else:
plt.scatter(df[x], df[y], c=df[color], s=dot_size, cmap=cmap)
else:
plt.scatter(df[x], df["y"], s=dot_size)
# add colorbar
plt.colorbar()
# set axis limits
plt.xlim(0, image_width)
plt.ylim(image_height, 0)
# Show the plot
if savefig:
plt.savefig(
output_dir + output_fname + "_seg_masks_overlay.pdf", bbox_inches="tight"
)
else:
plt.show()
[docs]
def count_patch_proximity_res(
adata,
x,
hue,
palette="Set3",
order=True,
key_name="ppa_result",
savefig=False,
output_dir="./",
output_fname="",
):
"""
Create a count plot for patch proximity results.
Parameters
----------
adata : AnnData
Annotated data matrix.
x : str
The column name in the DataFrame for the x-axis variable.
hue : str
The column name in the DataFrame for the hue variable.
palette : str, optional
The palette to use for the plot. Defaults to "Set3".
order : bool, optional
Whether to order the count plot. Defaults to True.
key_name : str, optional
The key name for the patch proximity results in adata.uns. Defaults to "ppa_result".
savefig : bool, optional
Whether to save the figure or not. Defaults to False.
output_dir : str, optional
The output directory. Defaults to "./".
output_fname : str, optional
The output file name. Defaults to "".
Returns
-------
None
"""
region_results = adata.uns[key_name]
ax = sns.countplot(
x=x,
hue=hue,
data=region_results,
palette=palette,
order=region_results[x].value_counts().index,
)
tick_positions = range(len(region_results[x].value_counts()))
tick_labels = region_results[x].value_counts().index
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels, rotation=90)
if savefig:
plt.savefig(
output_dir + output_fname + "_count_ppa_result.pdf", bbox_inches="tight"
)
else:
plt.show()
[docs]
def BC_projection(
adata,
cnmap_dict,
cn_col,
plot_list, # list of 3 elements from cn_col
cn_col_annt=None,
palette=None,
figsize=(7, 7),
rand_seed=1,
SMALL_SIZE=14,
MEDIUM_SIZE=16,
BIGGER_SIZE=18,
n_num=None,
threshold=None,
savefig=False,
output_fname="",
output_dir="", # output directory
dpi=300,
):
"""
Plot barycentric projection.
Parameters
----------
adata : AnnData
Annotated data object.
cnmap_dict : dict
Dictionary containing keys 'w', 'l', 'k', and 'threshold'.
cn_col : str
Column name in adata.obs containing copy number information.
plot_list : list
List of 3 elements from cn_col.
cn_col_annt : str, optional
Annotated column name, by default None.
palette : dict, optional
Color palette, by default None.
figsize : tuple, optional
Figure size, by default (7, 7).
rand_seed : int, optional
Random seed, by default 1.
SMALL_SIZE : int, optional
Font size for small text, by default 14.
MEDIUM_SIZE : int, optional
Font size for medium text, by default 16.
BIGGER_SIZE : int, optional
Font size for large text, by default 18.
n_num : int, optional
Number, by default None.
threshold : float, optional
Threshold, by default None.
savefig : bool, optional
Whether to save the figure, by default False.
output_fname : str, optional
Output file name, by default "".
output_dir : str, optional
Output directory, by default "".
dpi : int, optional
Dots per inch, by default 300.
"""
if cn_col_annt is None:
plot_list_annot = plot_list
all_list = list(adata.obs[cn_col].unique())
if all(x in all_list for x in plot_list):
pass
else:
print(
"Please provide a valid list of plot_list from cn_col or specify cn_col_annt!"
)
return
else:
if cn_col == cn_col_annt:
plot_list_annot = plot_list
else:
annt_map = dict(
zip(adata.obs[cn_col].unique(), adata.obs[cn_col_annt].unique())
)
plot_list_annot = [annt_map[i] for i in plot_list]
# Settings for graph
plt.rc("font", size=SMALL_SIZE) # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE) # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
# generate color
if palette is None:
cn_colors = hf_generate_random_colors(
len(adata.obs[cn_col].unique()), rand_seed=rand_seed
)
if cn_col_annt + "_colors" not in adata.uns.keys():
palette = dict(zip(np.sort(adata.obs[cn_col].unique()), cn_colors))
adata.uns[cn_col + "_colors"] = cn_colors
else:
palette = dict(
zip(np.sort(adata.obs[cn_col].unique()), adata.uns[cn_col + "_colors"])
)
w = cnmap_dict["w"]
l = cnmap_dict["l"]
if n_num is None:
n_num = cnmap_dict["k"]
else:
n_num = n_num
if threshold is None:
threshold = cnmap_dict["threshold"]
threshold = threshold * 100
else:
threshold = threshold * 100
# lmap = {j: i for i, j in enumerate(l)}
wgc = w.loc[w.loc[:, plot_list].sum(axis=1) > threshold, :]
# idx = wgc.index.values
xl = wgc.loc[:, plot_list]
proj = np.array([[0, 0], [np.cos(np.pi / 3), np.sin(np.pi / 3)], [1, 0]])
coords = np.dot(xl / n_num, proj) #####window size fraction
plt.figure(figsize=figsize)
jit = 0.002
cols = [palette[a] for a in wgc[cn_col]]
plt.scatter(
coords[:, 0] + jit * np.random.randn(len(coords)),
coords[:, 1] + jit * np.random.randn(len(coords)),
s=15,
alpha=0.5,
c=cols,
)
# add label to corners of the triangle
plt.text(-0.15, -0.05, plot_list_annot[0], fontsize=10)
plt.text(0.65, -0.05, plot_list_annot[2], fontsize=10)
plt.text(
np.cos(np.pi / 3),
np.sin(np.pi / 3) + 0.035,
plot_list_annot[1],
fontsize=10,
horizontalalignment="center",
verticalalignment="center",
)
# Create a list of Patch objects, one for each unique color in your palette
legend_patches = [
mpatches.Patch(color=color, label=label) for label, color in palette.items()
]
# Add legend to bottom of plot
plt.legend(
handles=legend_patches,
bbox_to_anchor=(0.0, -0.15, 1.0, 0.102),
loc="lower center",
ncol=3,
borderaxespad=0.0,
fontsize=10,
)
plt.axis("off")
if savefig:
plt.savefig(
output_dir + output_fname + "_bc_proj.pdf",
dpi=dpi,
transparent=True,
bbox_inches="tight",
)
else:
plt.show()
[docs]
def ppa_res_donut(
adata,
cat_col,
palette=None,
key_names="ppa_result",
radii=[1, 2, 3, 4, 5],
unit="µm",
figsize=(10, 10),
add_guides=True,
text="example CN",
label_color="black",
rand_seed=1,
subset_column=None,
subset_condition=None,
title="Title",
savefig=False,
output_fname="",
output_dir="./",
):
# plotting
key_names = key_names[::-1]
plt.figure(figsize=figsize)
if add_guides == True:
# add grid lines
plt.plot([0, 0], [1.05, -1.05], color="black", alpha=0.3, zorder=-1)
plt.plot([1.05, -1.05], [0, 0], color="black", alpha=0.3, zorder=-1)
# add diagonal lines
for angle in [45, -45]:
x_new = 1.05 * np.cos(np.radians(angle))
y_new = 1.05 * np.sin(np.radians(angle))
plt.plot(
[-x_new, x_new], [-y_new, y_new], color="black", alpha=0.3, zorder=-1
)
# generate reproducable colors if no palette is provided
if palette is None:
if cat_col + "_colors" not in adata.uns.keys():
ct_colors = hf_generate_random_colors(
len(adata.obs[cat_col].unique()), rand_seed=rand_seed
)
palette = dict(zip(np.sort(adata.obs[cat_col].unique()), ct_colors))
adata.uns[cat_col + "_colors"] = ct_colors
else:
palette = dict(
zip(
np.sort(adata.obs[cat_col].unique()), adata.uns[cat_col + "_colors"]
)
)
for i, key_name in enumerate(key_names):
print(f"Key {i}: {key_name}")
# extract key from adata
region_results = adata.uns[key_name]
# test if region_results is empty
if region_results.shape[0] == 0:
print(f"Key {i} is empty.")
continue
else:
print(f"Key {i} has {region_results.shape[0]} rows.")
# subset by condition
if subset_column != None:
if subset_column not in region_results.columns:
raise ValueError(
f"Column '{subset_column}' does not exist in the DataFrame."
)
elif subset_condition not in region_results[subset_column].unique():
raise ValueError(
f"Value '{subset_condition}' does not exist in the column '{subset_column}'."
)
else:
region_results = region_results[
region_results[subset_column] == subset_condition
]
# calculate percentages of categories
percentage_list = region_results[cat_col].value_counts(normalize=True) * 100
# Check if all categories have a color in the palette
for category in percentage_list.index:
if category not in palette:
raise ValueError(
f"No color provided for category {category} in the palette"
)
# Get the colors for the current categories from the palette
category_colors = [palette[category] for category in percentage_list.index]
rsed_index = len(radii) - 1 - i
plt.pie(
percentage_list, radius=(0.6 + 0.1 * rsed_index), colors=category_colors
)
# add labels for each distance
for j, percentage_list in enumerate(radii):
rsed_index = len(radii) - 1 - j
plt.text(
0,
(0.53 + 0.1 * rsed_index),
str(radii[rsed_index]) + unit,
horizontalalignment="center",
fontweight="bold",
color=label_color,
)
# add a circle at the center to transform it in a donut chart
my_circle = plt.Circle((0, 0), 0.5, color="white")
p = plt.gcf()
p.gca().add_artist(my_circle)
# add a legend based on the colors and keys in palette
handles = [plt.Rectangle((0, 0), 1, 1, color=palette[key]) for key in palette]
plt.legend(
handles,
palette.keys(),
bbox_to_anchor=(0.94, 0.925),
loc="upper left",
prop={"size": 15},
)
# Define the maximum length of a line
max_line_length = 20
# Split the text into multiple lines if it's too long
wrapped_text = textwrap.fill(text, max_line_length)
# Add a title in the middle of the white circle
plt.text(
0,
0,
wrapped_text,
horizontalalignment="center",
verticalalignment="center",
fontsize=18,
)
plt.title(title, size=24, y=0.96)
if savefig:
plt.savefig(output_dir + output_fname + ".pdf", bbox_inches="tight")
else:
plt.show()
[docs]
def distance_graph(
dist_table,
distance_pvals,
palette=None,
condition_pair=None,
interaction_col="interaction",
condition_col="condition",
logfold_group_col="logfold_group",
celltype1_col="celltype1",
celltype2_col="celltype2",
pair_col="pairs",
with_labels=True,
node_size=910,
font_size=7,
multiplication_factor=10,
savefig=False,
output_fname="",
output_dir="", # output directory
dpi=300,
color_seed=0,
):
"""
Generates a distance graph from a dataframe.
Parameters:
df (DataFrame): The input dataframe.
palette (dict, optional): A dictionary mapping nodes to colors. If None, nodes are colored 'lightgrey'.
condition_pair (list, optional): A list of two conditions to compare. If None, uses unique conditions from the dataframe.
interaction_col (str, optional): The name of the interaction column in the dataframe. Defaults to 'interaction'.
condition_col (str, optional): The name of the condition column in the dataframe. Defaults to 'condition'.
logfold_group_col (str, optional): The name of the logfold group column in the dataframe. Defaults to 'logfold_group'.
celltype1_col (str, optional): The name of the first cell type column in the dataframe. Defaults to 'celltype1'.
celltype2_col (str, optional): The name of the second cell type column in the dataframe. Defaults to 'celltype2'.
with_labels (bool, optional): Whether to draw labels on the nodes. Defaults to True.
node_size (int, optional): The size of the nodes. Defaults to 910.
font_size (int, optional): The font size for the labels. Defaults to 7.
Returns:
None: The function generates a plot and does not return anything.
"""
# Generate df for graph
# this is to restrict to the pairs that were filtered
pairs = dist_table[pair_col].unique()
distance_pvals[pair_col] = (
distance_pvals[celltype1_col] + "_" + distance_pvals[celltype2_col]
)
distance_pvals = distance_pvals[distance_pvals[pair_col].isin(pairs)]
print(distance_pvals.shape)
pairs = distance_pvals[interaction_col].unique()
distance_pvals_grouped = distance_pvals.groupby([celltype1_col, celltype2_col])
result_list = {}
if condition_pair == None:
conditions = distance_pvals[condition_col].unique()
else:
conditions = condition_pair
for p in pairs:
distance_pvals_c1 = distance_pvals[
(distance_pvals[interaction_col] == p)
& (distance_pvals[condition_col] == conditions[0])
]
distance_pvals_c2 = distance_pvals[
(distance_pvals[interaction_col] == p)
& (distance_pvals[condition_col] == conditions[1])
]
difference = abs(
distance_pvals_c2[logfold_group_col].values[0]
- distance_pvals_c1[logfold_group_col].values[0]
)
if (
0
> (
distance_pvals_c2[logfold_group_col].values[0]
- distance_pvals_c1[logfold_group_col]
).values[0]
):
direction = "#3976AC"
else:
direction = "#C63D30"
df_res = pd.DataFrame(
{
celltype1_col: distance_pvals_c2[celltype1_col],
celltype2_col: distance_pvals_c2[celltype2_col],
"difference": difference,
"direction": direction,
}
)
result_list[p] = df_res
graph_df = pd.concat(result_list.values())
graph_df = graph_df.drop_duplicates(
subset=[celltype1_col, celltype2_col], keep="first"
)
## remove nan values
graph_df = graph_df[graph_df["difference"].notna()]
# Generate graph
G = nx.from_pandas_edgelist(
graph_df,
source=celltype1_col,
target=celltype2_col,
edge_attr=["difference", "direction"],
)
for u, v, d in G.edges(data=True):
d["weight"] = (
1e-100
if d["difference"] == 0
else abs(d["difference"] * multiplication_factor)
)
# Generate plot
weights = [G[u][v]["weight"] for u, v in G.edges()]
pos = nx.circular_layout(G)
edge_colors = [d["direction"] for u, v, d in G.edges(data=True)]
if palette is None:
combined_list = (
distance_pvals[celltype1_col].tolist()
+ distance_pvals[celltype2_col].tolist()
)
unique_values = set(combined_list)
ct_colors = hf_generate_random_colors(len(unique_values), rand_seed=color_seed)
# map colors to cell types
palette = {value: ct_colors[i] for i, value in enumerate(unique_values)}
node_colors = [palette[node] for node in G.nodes()]
else:
node_colors = [palette[node] for node in G.nodes()]
nx.draw(
G,
pos,
node_color=node_colors,
with_labels=False, # Don't draw labels here
edge_color=edge_colors,
node_size=node_size,
width=weights,
)
for node, (x, y) in pos.items():
wrapped_node = textwrap.fill(node, 10) # Split long names into multiple lines
plt.text(
x,
y,
wrapped_node,
fontsize=font_size,
color="white" if is_dark(palette.get(node, "lightgrey")) else "black",
ha="center",
va="center",
)
# add legend
plt.plot(
[],
[],
color="#3976AC",
label="More distant in " + conditions[0] + " than " + conditions[1],
)
plt.plot(
[],
[],
color="#C63D30",
label="Closer in " + conditions[0] + " than " + conditions[1],
)
# add width legend
plt.plot(
[],
[],
color="k",
label="edge width = absolute difference * " + str(multiplication_factor),
)
# position legend at bottom middle
plt.legend(bbox_to_anchor=(0.5, -0.2), loc="lower center", ncol=2)
if savefig:
plt.savefig(
output_dir + output_fname + "_dist_graph.pdf",
dpi=dpi,
bbox_inches="tight",
)
else:
plt.show()