from typing import Optional, Union, Tuple
from anndata import AnnData
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib import colormaps
from matplotlib.collections import LineCollection, PolyCollection
from matplotlib.colors import Normalize
from matplotlib.patches import Patch
from matplotlib.transforms import offset_copy
from scipy.sparse import csr_matrix, isspmatrix
from ..tools._cell_shape_modeling import SEM
from .._utils import get_axes, get_cid_list, add_colorbar,set_axes, get_arr, get_cat_arr_color
[docs]
def plot_cell_shape(sem: SEM,
vis_key: Optional[str] = None,
arr: Optional[Union[np.ndarray, pd.Series]] = None,
summary: str = 'sender',
compute_alphashape: bool = False,
cid_list: Optional[np.ndarray] = None,
cmap_name: str = 'Reds',
vmax: Optional[float] = None,
vmin: Optional[float] = None,
boundary_width: float = 1,
boundary_color: Optional[Union[str, Tuple]] = None,
boundary_alpha: float = 1,
face_alpha: float = 1,
show_axis: bool = False,
enable_annotation: bool = False,
enable_legend: bool = True,
enable_colorbar: bool = True,
ax: Optional[Axes] = None,
save_name: Optional[str] = None,
**kwargs) -> Axes:
"""
Plot cell shape using alpha shape, visualize cell data by colors.
Parameters
----------
sem : SEM
Subcellular element method object
vis_key : str, optional
Key to retrieve visualization data from `sem.adata`
arr : np.ndarray or pd.Series, optional
Data for visualization. shape = (nc,) or (len(cid_list),)
Ignored if `vis_key` is provided .
`sem.ctype` will be visualized if `arr` and `vis_key` both are not provided.
compute_alphashape : bool, default=False
Compute alphashape if True
cid_list : ndarray, optional
Array of index for cells to be visualized. Default: all cells
cmap_name : str, default='Reds'
Valid matplotlib colormap name to visualize data
vmax : float, optional
Colormap upper bound. Default: 95th percentile for positive data
vmin : float, optional
Colormap lower bound. Default: data min
boundary_width : float, default=1
Cell boundary line width
boundary_color : str or tuple, optional
Cell boundary line color, Default: matches face color
boundary_alpha : float, default=1
Cell boundary line opacity, 0 (fully transparent), 1 (fully opaque)
face_alpha : float, default=1
Cell shape face opacity, 0 (fully transparent), 1 (fully opaque)
show_axis : bool, default=False
Show axis
enable_annotation : bool, default=False
Annotate cells with index at centroids
enable_legend : bool, default=True
Show categorical legend (only for category data)
enable_colorbar : bool, default=True
Show colorbar (only for continuous data)
ax : Axes, optional
Target matplotlib axes object. Creates new figure if None.
save_name : str, optional
Output path for figure saving (e.g., 'figure.pdf')
**kwargs
keyword arguments passed to `sem.compute_alphashape()`
Returns
--------
ax : Axes
Examples
--------
>>> cr.pl.plot_cell_shape(sem)
"""
if compute_alphashape or not sem.alphashape_info['computed'] or kwargs:
sem.compute_alphashape(**kwargs)
fig, ax = get_axes(ax)
cid_list, _ = get_cid_list(sem, cid_list)
arr = get_arr(sem, vis_key, arr, summary)
if arr is None:
# vis sem.ctype
if vis_key is None:
# use cell type color in sem
cat_code = sem.ctype[cid_list]
cat_list = sem.ctype_list
color_list = sem.color_list
colors = color_list[cat_code]
facecolors = np.insert(colors, 3, face_alpha, axis=1)
edgecolors = np.insert(colors, 3, boundary_alpha, axis=1)
enable_colorbar = False
else:
raise KeyError(
f"vis_key '{vis_key}' not found in genes or adata.obs")
else:
# vis arr
if arr.dtype.name == 'category':
# obtain category and color from arr
cat_code, cat_list, color_list = get_cat_arr_color(
sem, arr, cid_list, vis_key, cmap_name)
colors = color_list[cat_code]
facecolors = np.insert(colors, 3, face_alpha, axis=1)
edgecolors = np.insert(colors, 3, boundary_alpha, axis=1)
enable_colorbar = False
else:
# set color based on arr
if len(arr) == sem.nc:
arr = arr[cid_list]
elif len(arr) != len(cid_list):
raise ValueError('len(arr)!=len(cid_list)')
cmap = colormaps[cmap_name]
# vmax = np.percentile(arr,95) if vmax is None else vmax
vmax = arr.max() if vmax is None else vmax
vmin = arr.min() if vmin is None else vmin
norm = Normalize(vmin=vmin, vmax=vmax, clip=False)
facecolors = cmap(norm(arr))
edgecolors = cmap(norm(arr))
facecolors[:, 3] = face_alpha
edgecolors[:, 3] = boundary_alpha
# enable_colorbar = True
enable_legend = False
# draw cell shape
all_boundaries = []
fc = []
bc = []
for i, cid in enumerate(cid_list):
all_boundaries.append(sem.alphashape[cid].get_boundary())
fc.append(facecolors[i])
bc.append(facecolors[i])
if boundary_color is not None:
bc = boundary_color
polyc = PolyCollection(all_boundaries,
facecolors=fc,
edgecolors=bc,
linewidths=boundary_width)
ax.add_collection(polyc)
set_axes(ax, show_axis)
if enable_colorbar:
# draw colorbar
add_colorbar(fig, ax, cmap, norm)
elif enable_legend:
# draw legend
legend_patches = []
for i in np.unique(cat_code):
legend_patches.append(
Patch(color=color_list[i], label=cat_list[i]))
transform = offset_copy(ax.transAxes, x=5, y=0,
units='points', fig=fig)
ax.legend(handles=legend_patches,
loc='center left',
bbox_to_anchor=(1, 0.5),
bbox_transform=transform,
frameon=False)
if enable_annotation:
spatial_coor = sem.xc*sem.scale+sem.deltax
for i in cid_list:
ax.annotate(f'{i}', spatial_coor[i], ha='center', va='center', fontweight='bold')
if save_name is not None:
fig.savefig(save_name, dpi=500, bbox_inches='tight', transparent=True)
return ax
def element_plot(sem: SEM,
vis_key: Optional[str] = None,
arr: Optional[Union[np.ndarray, pd.Series]] = None,
summary: str = 'sender',
cid_list: Optional[np.ndarray] = None,
cmap_name: str = 'Reds',
spot_size: float = 1,
scaling: bool = True,
show_axis: bool = True,
enable_colorbar: bool = True,
enable_legend: bool = True,
ax: Optional[Axes] = None,
save_name: Optional[str] = None,) -> Axes:
"""
Plotting cell elements
Parameters
----------
sem : SEM
Subcellular element method object
vis_key : str, optional
Key to retrieve visualization data from `sem.adata`.
arr : np.ndarray or pd.Series, optional
Data for visualization. Accepts both cell-level (nc,) and element-level (ne,)
summary : str, default='sender'
'sender' represents sender signal, retrieves data from adata.obsm['sender_signal'][vis_key]
'receiver' retrieves receiver signal data from adata.obsm['receiver_signal'][vis_key]
'gene' retrieves gene expression data from adata
cid_list : ndarray, optional
Array of index for cells to be visualized. Default: all cells
cmap_name : str, default='Reds'
Valid matplotlib colormap name to visualize data
spot_size : float, default=1
Markersize for `matplotlib.pyplot.scatter`
scaling : bool, default=True
Scale coordinates back to original data(`xc`) if True, otherwise visualize directly.
show_axis : bool, default=True
Show axis.
enable_legend : bool, default=False
Show categorical legend (only for category data).
enable_colorbar : bool, default=False
Show colorbar (only for continuous data).
ax : Axes, optional
Target matplotlib axes object. Creates new figure if None
save_name : str, optional
Output path for figure saving (e.g., 'figure.pdf')
Returns
----------
ax : Axes
"""
fig, ax = get_axes(ax)
cid_list, xe = get_cid_list(sem, cid_list, scaling)
arr = get_arr(sem, vis_key, arr, summary)
ec = None
if arr is None:
# vis sem.ctype
if vis_key is None:
# use cell type color in sem
cat_code = sem.ctype[cid_list]
cat_list = sem.ctype_list
color_list = sem.color_list
else:
raise KeyError(
f"vis_key '{vis_key}' not found in genes or adata.obs")
else:
# vis arr
if arr.dtype.name == 'category':
# obtain category and color from arr
cat_code, cat_list, color_list = get_cat_arr_color(
sem, arr, cid_list, vis_key, cmap_name)
else:
cmap = colormaps[cmap_name]
# color norm
if arr.min() >= 0:
norm = Normalize(
vmin=arr.min(), vmax=np.percentile(arr, 95), clip=False)
else:
a = np.percentile(np.abs(arr), 95)
norm = Normalize(vmin=-a, vmax=a, clip=False)
# set color
if arr.shape[0] == sem.nc:
# cell color
cc = cmap(norm(arr))
# cell color -> element color
ec = np.zeros((sem.ne, cc.shape[1]))
for cid in range(sem.nc):
ne_i = sem.ceidn[cid+1]-sem.ceidn[cid]
ec[sem.ceidn[cid]:sem.ceidn[cid+1],
:] = np.tile(cc[cid], (ne_i, 1))
else:
ec = cmap(norm(arr)) # element color
# plot
if ec is None:
# cell color
ecid = []
for n, cid in enumerate(cid_list):
ecid.append(n*np.ones(sem.ceidn[cid+1]-sem.ceidn[cid]))
ecid = np.concatenate(ecid).astype(int)
element_cat = cat_code[ecid]
for i in np.unique(cat_code):
vis = element_cat == i
ax.scatter(xe[vis, 0], xe[vis, 1],
c=color_list[i][np.newaxis],
label=cat_list[i],
s=spot_size)
if enable_legend:
# draw legend
transform = offset_copy(
ax.transAxes, x=5, y=0, units='points', fig=fig)
ax.legend(loc='center left',
bbox_to_anchor=(1, 0.5),
bbox_transform=transform,
frameon=False,
markerscale=5/spot_size)
else:
# element color
for cid in cid_list:
ax.scatter(xe[sem.ceidn[cid]:sem.ceidn[cid+1], 0], xe[sem.ceidn[cid]:sem.ceidn[cid+1], 1],
c=ec[sem.ceidn[cid]:sem.ceidn[cid+1]],
s=spot_size)
if enable_colorbar:
# draw colorbar
add_colorbar(fig, ax, cmap, norm)
set_axes(ax, show_axis)
if save_name is not None:
fig.savefig(save_name, dpi=500, bbox_inches='tight', transparent=True)
return ax
def plot_cell_element(
sem: SEM,
vis_key: Optional[str] = None,
arr: Optional[Union[np.ndarray, pd.Series]] = None,
summary: str = 'sender',
cid_list: Optional[np.ndarray] = None,
cmap_name: str ='Reds',
spot_size: float = 1,
scaling: bool = True,
show_axis: bool = True,
enable_colorbar: bool = True,
enable_legend: bool = True,
ax: Optional[Axes] = None,
save_name: Optional[str] = None
) -> Axes:
"""
Plotting cell elements
Parameters
----------
sem : SEM
Subcellular element method object
vis_key : str, optional
Key to retrieve visualization data from `sem.adata`.
arr : np.ndarray or pd.Series, optional
Data for visualization. Accepts both cell-level (nc,) and element-level (ne,)
summary : str, default='sender'
'sender' represents sender signal, retrieves data from adata.obsm['sender_signal'][vis_key]
'receiver' retrieves receiver signal data from adata.obsm['receiver_signal'][vis_key]
'gene' retrieves gene expression data from adata
cid_list : ndarray, optional
Array of index for cells to be visualized. Default: all cells
cmap_name : str, default='Reds'
Valid matplotlib colormap name to visualize data
spot_size : float, default=1
Markersize for `matplotlib.pyplot.scatter`
scaling : bool, default=True
Scale coordinates back to original data(`xc`) if True, otherwise visualize directly.
show_axis : bool, default=True
Show axis.
enable_legend : bool, default=False
Show categorical legend (only for category data).
enable_colorbar : bool, default=False
Show colorbar (only for continuous data).
ax : Axes, optional
Target matplotlib axes object. Creates new figure if None
save_name : str, optional
Output path for figure saving (e.g., 'figure.pdf')
Returns
----------
ax : Axes
"""
fig, ax = get_axes(ax)
cid_list, xe = get_cid_list(sem, cid_list, scaling)
arr = get_arr(sem, vis_key, arr, summary)
ec = None
if arr is None:
# vis sem.ctype
if vis_key is None:
# use cell type color in sem
cat_code = sem.ctype[cid_list]
cat_list = sem.ctype_list
color_list = sem.color_list
else:
raise KeyError(f"vis_key '{vis_key}' not found in genes or adata.obs")
else:
# vis arr
if arr.dtype.name == 'category':
# obtain category and color from arr
cat_code, cat_list, color_list = get_cat_arr_color(sem,arr,cid_list,vis_key,cmap_name)
else:
cmap = colormaps[cmap_name]
# color norm
if arr.min()>=0:
norm = Normalize(vmin=arr.min(), vmax=np.percentile(arr,95), clip=False)
else:
a = np.percentile(np.abs(arr),95)
norm = Normalize(vmin=-a, vmax=a, clip=False)
# set color
if arr.shape[0] == sem.nc:
# cell color
cc = cmap(norm(arr))
# cell color -> element color
ec = np.zeros((sem.ne,cc.shape[1]))
for cid in range(sem.nc):
ne_i = sem.ceidn[cid+1]-sem.ceidn[cid]
ec[sem.ceidn[cid]:sem.ceidn[cid+1],:] = np.tile(cc[cid],(ne_i,1))
else:
ec = cmap(norm(arr)) # element color
# plot
if ec is None:
# cell color
ecid = []
for n,cid in enumerate(cid_list):
ecid.append(n*np.ones(sem.ceidn[cid+1]-sem.ceidn[cid]))
ecid = np.concatenate(ecid).astype(int)
element_cat = cat_code[ecid]
for i in np.unique(cat_code):
vis = element_cat == i
ax.scatter(
xe[vis, 0], xe[vis, 1],
c = color_list[i][np.newaxis],
label=cat_list[i],
s=spot_size
)
if enable_legend:
# draw legend
transform = offset_copy(ax.transAxes, x=5, y=0, units='points',fig=fig)
ax.legend(
loc='center left',
bbox_to_anchor=(1, 0.5),
bbox_transform=transform,
frameon=False,
markerscale=5/spot_size
)
else:
# element color
for cid in cid_list:
ax.scatter(
xe[sem.ceidn[cid]:sem.ceidn[cid+1], 0],
xe[sem.ceidn[cid]:sem.ceidn[cid+1], 1],
c = ec[sem.ceidn[cid]:sem.ceidn[cid+1]],
s=spot_size
)
if enable_colorbar:
# draw colorbar
add_colorbar(fig, ax, cmap, norm)
set_axes(ax, show_axis)
if save_name is not None:
fig.savefig(save_name, dpi=500, bbox_inches='tight', transparent=True)
return ax