from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
from . import plot_helper, an
from .CCIData_class import CCIData
[docs]
def network_plot(
network,
p_vals=None,
diff_plot=False,
normalise=True,
remove_unconnected=True,
show_labels=False,
p_val_cutoff=0.05,
edge_weight=50,
text_size=15,
node_size=2500,
figsize=None,
arrowsize=20,
node_label_dist=1,
p_val_text_size=10,
node_colors=None,
node_palette="tab20",
outer_node_palette=None,
show=True,
show_legend=True,
legend_size=12,
title=None,
title_size=14
):
"""Plots a network with optional edge significance highlighting and node coloring based on in-degree and out-degree difference.
Args:
network (pandas.DataFrame or numpy.ndarray): The adjacency matrix representing the network.
p_vals (pandas.DataFrame or numpy.ndarray, optional): A matrix of p-values corresponding to the edges in `network`. If not provided, significance values will not be plotted. Defaults to None.
diff_plot (bool, optional): Whether you are plotting the network difference, to show up and down-regulated edges. Defaults to False.
normalise (bool, optional): Whether to normalize the network matrix before plotting. Defaults to True.
remove_unconnected (bool, optional): Whether to remove cell types that do not interact with any cell types. Defaults to True.
show_labels (bool, optional): Whether to show node labels. Defaults to True.
p_val_cutoff (float, optional): The p-value cutoff for determining significant edges. Defaults to 0.05.
edge_weight (float, optional): The base weight for edges. Defaults to 20.
text_size (int, optional): The font size for node labels. Defaults to 15.
node_size (int, optional): The size of the nodes. Defaults to 2500.
figsize (tuple, optional): The size of the figure. Defaults to None.
arrowsize (int, optional): The size of the arrow heads for edges. Defaults to 50.
node_label_dist (float, optional): A factor for adjusting the distance between nodes and labels. Defaults to 1.
p_val_text_size (int, optional): The font size for p-value labels. Defaults to 10.
node_colors (dict, optional): A dictionary of colors for each node. Overwrites node_palette. Defaults to None.
node_palette (str, optional): The name of the color palette to use for nodes. Defaults to "tab20".
outer_node_palette (str, optional): The name of the color palette to use for outer nodes to show sender/reciever nodes. Defaults to None.
show (bool, optional): Whether to show the plot or not. Defaults to True.
show_legend (bool, optional): Whether to show legend. Defaults to False.
legend_size (int, optional): Font size for legend. Defaults to 12.
title (str, optional): Title of the plot. Defaults to None.
title_size (int, optional): Font size for title. Defaults to 14.
Returns:
tuple: A tuple containing the figure and axis objects.
"""
if not isinstance(network, pd.DataFrame):
raise ValueError("Input should be a dataframe.")
if figsize is None:
if show_legend:
figsize = (10, 8)
else:
figsize = (8, 8)
# Adjust the figure layout to accommodate the legend on the right
fig = plt.figure(figsize=figsize)
gs = gridspec.GridSpec(1, 2, width_ratios=[5, 1], wspace=0.2)
# Main plot area
ax = fig.add_subplot(gs[0])
plt.sca(ax) # Set the current axis to the main plot area
if remove_unconnected:
cell_types = (network != 0).any(axis=0) + (network != 0).any(axis=1)
network = network.loc[cell_types, cell_types]
if normalise:
if network.min().min() >= 0:
network = network / network.sum().sum()
else:
network = network / network.abs().sum().sum()
network_abs = abs(network)
if normalise:
network_abs = network_abs / network_abs.sum().sum()
network_abs = network.astype(float)
G_network = nx.from_pandas_adjacency(network_abs, create_using=nx.DiGraph)
pos = nx.circular_layout(G_network)
weights = nx.get_edge_attributes(G_network, "weight")
# Calculate the in-degree and out-degree for each node
in_degree = dict(G_network.in_degree(weight="weight"))
out_degree = dict(G_network.out_degree(weight="weight"))
in_out_diff = {node: in_degree[node] - out_degree[node] for node in G_network.nodes}
max_diff = max(abs(value) for value in in_out_diff.values())
color_scale = np.linspace(-max_diff, max_diff, 256)
if outer_node_palette is None:
# Create a color scale based on the in-degree and out-degree difference
cmap_colors = [(1, 0, 0), (0.7, 0.7, 0.7), (0, 0, 1)] # Blue, Grey, Red
outer_node_cmap = LinearSegmentedColormap.from_list("custom_cmap", cmap_colors)
else:
outer_node_cmap = plt.get_cmap(outer_node_palette)
edge_colors = []
# Map node colors to the in-degree and out-degree difference
if sum(abs(value) for value in in_out_diff.values()) == 0:
edge_colors = ['grey' for node in G_network.nodes]
else:
edge_colors = [
outer_node_cmap(int(np.interp(in_out_diff[node], color_scale, range(256))))
for node in G_network.nodes
]
if node_colors is not None:
node_colors_list = [node_colors[node] for node in G_network.nodes]
else:
if node_palette is not None:
node_colors_list = \
list(plt.get_cmap(node_palette).colors)[:len(G_network.nodes)]
else:
node_colors_list = ["grey" for node in G_network.nodes]
if p_vals is None or diff_plot == False:
# Create a non-significant matrix
p_vals = network_abs.replace(network_abs.values, 1, inplace=False)
else:
# Prevent removal of pvals of 0
p_vals[p_vals == 0] = 1e-300
# Get edges that are significant
G_p_vals = nx.from_pandas_adjacency(p_vals, create_using=nx.DiGraph)
G_network_updown = nx.from_pandas_adjacency(network, create_using=nx.DiGraph)
non_sig_up = [
(u, v) for (u, v, d) in G_p_vals.edges(data=True)
if d["weight"] > p_val_cutoff
and u in G_network_updown
and v in G_network_updown[u]
and G_network_updown[u][v]["weight"] > 0]
non_sig_up = [edge for edge in non_sig_up if edge in weights.keys()]
non_sig_down = [
(u, v) for (u, v, d) in G_p_vals.edges(data=True)
if d["weight"] > p_val_cutoff
and u in G_network_updown
and v in G_network_updown[u]
and G_network_updown[u][v]["weight"] < 0]
non_sig_down = [edge for edge in non_sig_down if edge in weights.keys()]
sig_up = [
(u, v) for (u, v, d) in G_p_vals.edges(data=True)
if d["weight"] <= p_val_cutoff
and u in G_network_updown
and v in G_network_updown[u]
and G_network_updown[u][v]["weight"] > 0]
sig_up = [edge for edge in sig_up if edge in weights.keys()]
sig_down = [
(u, v) for (u, v, d) in G_p_vals.edges(data=True) if d["weight"] <= p_val_cutoff
and u in G_network_updown
and v in G_network_updown[u]
and G_network_updown[u][v]["weight"] < 0]
sig_down = [edge for edge in sig_down if edge in weights.keys()]
edge_thickness_non_sig_up = []
edge_thickness_non_sig_down = []
edge_thickness_sig_up = []
edge_thickness_sig_down = []
for edge in weights.keys():
if edge in non_sig_up:
edge_thickness_non_sig_up.append(weights[edge] * edge_weight)
else:
edge_thickness_non_sig_up.append(0)
if edge in non_sig_down:
edge_thickness_non_sig_down.append(weights[edge] * edge_weight)
else:
edge_thickness_non_sig_down.append(0)
if edge in sig_up:
edge_thickness_sig_up.append(weights[edge] * edge_weight)
else:
edge_thickness_sig_up.append(0)
if edge in sig_down:
edge_thickness_sig_down.append(weights[edge] * edge_weight)
else:
edge_thickness_sig_down.append(0)
# if node_colors is None:
nx.draw_networkx_nodes(
G_network,
pos,
node_size=node_size,
node_color=node_colors_list,
edgecolors=edge_colors,
linewidths=8.0,
)
if diff_plot:
# Draw non-self edges first
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_non_sig_up,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
edge_color="pink"
)
# Same pattern for non-sig down edges
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_non_sig_down,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
edge_color="lightgreen"
)
else:
# Non-self edges
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_non_sig_up,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
)
# Same for non-sig down edges
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_non_sig_down,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
)
# Significant up edges
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_sig_up,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
edge_color="purple",
)
# Significant down edges
nx.draw_networkx_edges(
G_network,
pos,
node_size=node_size * 2,
connectionstyle="arc3,rad=0.08",
width=edge_thickness_sig_down,
arrows=True,
arrowstyle="->",
arrowsize=arrowsize,
edge_color="green",
)
edge_labels = nx.get_edge_attributes(G_p_vals, "weight")
edge_labels = {
key: edge_labels[key] for key in G_network.edges().keys() if key in edge_labels
}
# Add edge labels for significant edges
for key, value in edge_labels.items():
if value > p_val_cutoff:
edge_labels[key] = ""
else:
edge_labels[key] = round(value, 3)
def offset(d, pos, dist=0.05, loop_shift=0.1):
for (u, v), obj in d.items():
if u != v:
par = dist * (pos[v] - pos[u])
dx, dy = par[1], -par[0]
x, y = obj.get_position()
obj.set_position((x + dx, y + dy))
else:
x, y = obj.get_position()
obj.set_position((x, y + loop_shift))
d = nx.draw_networkx_edge_labels(
G_network,
pos,
edge_labels,
font_size=p_val_text_size,
connectionstyle="arc3,rad=0.08"
)
offset(d, pos)
pos.update(
(x, [y[0] * 1.4 * node_label_dist, y[1] * (1.25 + 0.05) * node_label_dist])
for x, y in pos.items()
)
if show_labels:
nx.draw_networkx_labels(
G_network,
pos,
font_weight="bold",
font_color="black",
font_size=text_size,
clip_on=False,
horizontalalignment="center",
)
ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
if show_legend:
# Create a color bar in the bottom-right corner
color_bar_ax = fig.add_axes([0.77, 0.2, 0.05, 0.2])
if sum(abs(value) for value in in_out_diff.values()) != 0:
sm = plt.cm.ScalarMappable(cmap=outer_node_cmap,
norm=plt.Normalize(vmin=-max_diff, vmax=max_diff))
sm.set_array([])
cbar = plt.colorbar(sm, cax=color_bar_ax)
cbar.set_ticks([]) # Remove the ticks
cbar.set_label('Net sender ← → Net receiver', fontsize=legend_size)
# Add the legend in the top-right corner
legend_elements = []
if diff_plot:
legend_elements.extend([
plt.Line2D([0], [0], color='pink', lw=2,
label='Non-significant positive'),
plt.Line2D([0], [0], color='lightgreen', lw=2,
label='Non-significant negative'),
plt.Line2D([0], [0], color='purple', lw=2,
label='Significant positive'),
plt.Line2D([0], [0], color='green', lw=2,
label='Significant negative')
])
if node_colors is not None:
for node, color in node_colors.items():
if node in network.index:
legend_elements.append(
plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=color, markersize=10, label=node)
)
else:
for i in range(len(G_network.nodes)):
legend_elements.append(
plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=node_colors_list[i], markersize=10,
label=list(G_network.nodes.keys())[i])
)
if legend_elements:
legend_ax = fig.add_axes([0.7, 0.55, 0.2, 0.3])
legend_ax.axis('off')
legend_ax.legend(handles=legend_elements, loc='center',
fontsize=legend_size)
if title is not None:
plt.suptitle(title, fontsize=title_size, y = 0.875, fontweight="bold")
plt.tight_layout()
if show:
plt.show()
else:
return fig, ax
[docs]
def chord_plot(
network,
min_int=0.001,
n_top_ccis=10,
colors=None,
show=True,
title=None,
title_size=14,
label_size=10,
figsize=None,
show_legend=False,
legend_size=12
):
"""Plots a chord plot of a network
Args:
network (pandas.DataFrame or numpy.ndarray): The adjacency matrix representing the network.
min_int (float): Minimum interactions to display cell type. Defaults to 0.01.
n_top_ccis (int): Number of top cell types to display. Defaults to 10.
colors (dict): Dict of colors for each cell type to use for the plot. Defaults to None.
show (bool): Whether to show plot or not. Defaults to True.
title (str): Title of the plot. Defaults to None.
title_size (int): Font size of the title. Defaults to 14.
label_size (int): Font size of the labels. Defaults to None.
figsize (tuple): Size of the figure. Defaults to None.
show_legend (bool): Whether to show legend. Defaults to False.
legend_size (int): Font size for legend. Defaults to 12.
Returns:
tuple: A tuple containing the figure and axis objects.
"""
if not isinstance(network, pd.DataFrame):
raise ValueError("Input should be a dataframe.")
network = network.transpose()
if figsize is None:
if show_legend:
figsize = (10, 8)
else:
figsize = (8, 8)
# Create figure with gridspec to accommodate legend
fig = plt.figure(figsize=figsize)
if show_legend:
gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1], wspace=0)
ax = fig.add_subplot(gs[0])
else:
ax = plt.axes([0, 0, 1, 1])
flux = network.values
total_ints = flux.sum(axis=1) + flux.sum(axis=0) - flux.diagonal()
keep = total_ints > min_int
# Limit of 10 for good display #
if sum(keep) > n_top_ccis:
keep = np.argsort(-total_ints)[0:n_top_ccis]
flux = flux[:, keep]
flux = flux[keep, :].astype(float)
cell_names = network.index.values.astype(str)[keep]
nodes = cell_names
color_list = []
if colors is not None:
for cell in cell_names:
color_list.append(colors[cell])
else:
color_list = None
nodePos = plot_helper.chordDiagram(flux, ax, lim=1.25, colors=color_list)
ax.axis("off")
prop = dict(fontsize=label_size, ha="center", va="center")
for i in range(len(cell_names)):
x, y = nodePos[i][0:2]
if label_size != 0:
ax.text(x, y, nodes[i], rotation=nodePos[i][2], **prop)
if show_legend and colors is not None:
# Add the legend in the right subplot
legend_ax = fig.add_subplot(gs[1])
legend_ax.axis('off')
legend_elements = []
for node, color in colors.items():
if node in cell_names:
legend_elements.append(
plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=color, markersize=10, label=node)
)
if legend_elements:
legend_ax.legend(handles=legend_elements, loc='center',
fontsize=legend_size)
if title is not None:
fig.suptitle(title, fontsize=title_size, y=0.95, fontweight="bold")
plt.tight_layout()
if show:
plt.show()
else:
return fig, ax
[docs]
def dissim_hist(
dissimilarity_scores,
x_label_size=18,
y_label_size=24,
x_tick_size=14,
y_tick_size=12,
figsize=(6, 5),
show=True,
title=None,
title_size=14
):
"""Plots a histogram of dissimilarity scores.
Args:
dissimilarity_scores (dict): A dictionary of dissimilarity scores.
x_label_size (int): Font size for x-axis label. Defaults to 18.
y_label_size (int): Font size for y-axis label. Defaults to 24.
x_tick_size (int): Font size for ticks. Defaults to 14.
y_tick_size (int): Font size for ticks. Defaults to 12.
figsize (tuple): Size of the figure. Defaults to (10, 8).
show (bool): Whether to show the plot or not. Defaults to True.
title (str): Title of the plot. Defaults to None.
title_size (int): Font size of the title. Defaults to 14.
Returns:
matplotlib.figure.Figure: The figure
"""
fig = plt.figure(figsize=figsize)
plt.style.use('default')
plt.hist(list(dissimilarity_scores.values()))
plt.xlim(0, 1)
plt.xlabel("Dissimilarity Score", fontsize=x_label_size)
plt.ylabel("Count", fontsize=y_label_size)
plt.tick_params(axis='x', which='major', labelsize=x_tick_size)
plt.tick_params(axis='y', which='major', labelsize=y_tick_size)
if title is not None:
plt.title(title, fontsize=title_size, pad=20)
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def lr_top_dissimilarity(
dissimilarity_scores,
n=10,
top=True,
x_label_size=18,
y_label_size=24,
x_tick_size=14,
y_tick_size=12,
figsize=(6, 5),
show=True,
title=None,
title_size=14
):
"""Plots a bar plot of LR pairs with highest/lowest dissimilarity scores.
Args:
dissimilarity_scores (dict): A dictionary of dissimilarity scores.
n (int): Number of LR pairs to plot.
top (bool): If True, plot LR pairs with highest dissimilarity scores. If False, plot LR pairs with lowest dissimilarity scores.
x_label_size (int): Font size for x-axis label. Defaults to 18.
y_label_size (int): Font size for y-axis label. Defaults to 24.
x_tick_size (int): Font size for ticks. Defaults to 14.
y_tick_size (int): Font size for ticks. Defaults to 12.
figsize (tuple): Size of the figure. Defaults to (10, 8).
show (bool): Whether to show the plot or not. Defaults to True.
title (str): Title of the plot. Defaults to None.
title_size (int): Font size of the title. Defaults to 14.
Returns:
matplotlib.figure.Figure: The figure
"""
reverse = not top
sorted_items = sorted(
dissimilarity_scores.items(), key=lambda x: x[1], reverse=reverse
)
top_n_items = sorted_items[-n:]
keys, values = zip(*top_n_items)
fig = plt.figure(figsize=figsize)
plt.style.use('default')
plt.barh(keys, values)
plt.xlabel("Dissimilarity Score", fontsize=x_label_size)
plt.ylabel("LR Pair", fontsize=y_label_size)
plt.tick_params(axis='x', which='major', labelsize=x_tick_size)
plt.tick_params(axis='y', which='major', labelsize=y_tick_size)
if title is not None:
plt.title(title, fontsize=title_size, pad=20)
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def lr_barplot(
sample,
assay="raw",
n=15,
x_label_size=18,
y_label_size=24,
x_tick_size=14,
y_tick_size=12,
figsize=(6, 5),
show=True,
title=None,
title_size=14
):
"""Plots a bar plot of LR pairs and their proportions for a sample.
Args:
sample (CCIData): The CCIData object.
assay (str): The assay to use. Defaults to "raw".
n (int): Number of LR pairs to plot. If None, plot all LR pairs. Defaults to 15.
x_label_size (int): Font size for x-axis label. Defaults to 18.
y_label_size (int): Font size for y-axis label. Defaults to 24.
x_tick_size (int): Font size for tick labels. Defaults to 14.
y_tick_size (int): Font size for tick labels. Defaults to 12.
figsize (tuple): Size of the figure. Defaults to (10, 8).
title (str) (optional): Title for the plot. Defaults to None.
title_size (int): Font size of the title. Defaults to 14.
Returns:
matplotlib.figure.Figure: The figure
"""
if assay not in sample.assays:
raise ValueError("Assay not found in sample.")
interactions = [(lr, df.sum().sum()) for (lr, df) \
in sample.assays[assay]['cci_scores'].items()]
interactions.sort(key=lambda x: x[1])
interactions = interactions[-n:]
keys, values = zip(*interactions)
values = [value / sum(values) for value in values]
fig = plt.figure(figsize=figsize)
plt.style.use('default')
plt.barh(keys, values)
plt.xlabel("Relative Interaction Strength", fontsize=x_label_size)
plt.ylabel("LR Pair", fontsize=y_label_size)
plt.tick_params(axis='x', which='major', labelsize=x_tick_size)
plt.tick_params(axis='y', which='major', labelsize=y_tick_size)
if title:
plt.title(title, pad=20, fontsize=title_size)
plt.tight_layout()
if show:
plt.show()
else:
plt.close(fig)
return fig
[docs]
def lrs_per_celltype(
sample,
sender = None,
receiver = None,
assay="raw",
key="cci_scores",
p_vals=None,
n=15,
x_label_size=18,
y_label_size=24,
x_tick_size=14,
y_tick_size=12,
figsize=(6, 5),
show=True,
title=None,
title_size=14
):
"""Plots a bar plot of LR pairs and their proportions for a sender and receiver cell type pair along with p_values (optional).
Args:
sample (CCIData): The CCIData object.
sender (str): The sender cell type. Defaults to None.
receiver (str): The receiver cell type. Defaults to None.
assay (str): The assay to use. Defaults to "raw".
key (str): The key to use. Defaults to "cci_scores".
p_vals (dict): A dictionary of p-values. Defaults to None.
n (int): Number of LR pairs to plot. Defaults to 15.
x_label_size (int): Font size for x-axis label. Defaults to 18.
y_label_size (int): Font size for y-axis label. Defaults to 24.
x_tick_size (int): Font size for tick labels. Defaults to 14.
y_tick_size (int): Font size for tick labels. Defaults to 12.
figsize (tuple): Size of the figure. Defaults to (10, 8).
title (str) (optional): Title for the plot. Defaults to None.
title_size (int): Font size of the title. Defaults to 14.
Returns:
matplotlib.figure.Figure: The figure
"""
pairs = sample.get_lr_proportions(sender, receiver, assay, key)
keys = list(pairs.keys())[:n]
values = list(pairs.values())[:n]
keys.reverse()
values.reverse()
if p_vals is not None:
p_val_pairs = an.get_p_vals_per_celltype(p_vals, sender, receiver)
labels = [p_val_pairs[key] for key in keys]
# make labels readable (if less than 0.00001, show as <0.00001)
for i in range(len(labels)):
if labels[i] < 0.001:
labels[i] = "<0.001"
else:
labels[i] = f"{labels[i]:.3f}"
# Define colors based on p-values
colors = [
'#1f77b4' if val < 0.05 else 'grey' for val in [
p_val_pairs[key] for key in keys]]
# Create the figure and axis
fig, ax = plt.subplots(figsize=figsize)
plt.style.use('default')
if p_vals is None:
ax.barh(keys, values)
else:
bars = ax.barh(keys, values, color=colors)
ax.bar_label(bars, labels)
ax.set_xlabel("Proportion", fontsize=x_label_size)
ax.set_ylabel("LR Pair", fontsize=y_label_size)
ax.tick_params(axis='x', which='major', labelsize=x_tick_size)
ax.tick_params(axis='y', which='major', labelsize=y_tick_size)
if title:
plt.title(title, pad=20, fontsize=title_size)
plt.tight_layout()
if show:
plt.show()
else:
plt.close(fig)
return fig