from typing import Optional, Union, List, Dict
from math import pi, sqrt
from scipy.spatial import Delaunay, distance
from scipy.sparse import lil_matrix, coo_matrix, csr_matrix
from matplotlib import colormaps
from matplotlib.colors import to_rgb
import numpy as np
from numba import cuda
import pickle
import os
from sklearn.neighbors import NearestNeighbors
from anndata import AnnData
import warnings
from tqdm import tqdm
from .._utils import AlphaShape
class CellBase:
"""Base class for cell morphology representations"""
adata: AnnData
"""Linked annotation data"""
nc: int
"""cell number"""
dim: int
"""spatial dimension"""
xc: np.ndarray
"""cell coordinates. shape=(nc,dim)"""
ctype: np.ndarray
"""cell type code. shape=(nc)"""
ctype_list: np.ndarray
"""cell type category. shape=(nct)"""
color_list: np.ndarray
"""colors for cell type category. shape=(nct,3)"""
ne: int
"""element number"""
xe: np.ndarray
"""element coordinates. shape=(ne,dim)"""
ecid: np.ndarray
"""id of cell to which each element belongs. shape: (ne)"""
ceidn: np.ndarray
"""first elements id of each cell. shape=(nc+1)
Elements of cell i can be retrieve by ceidn[i]:ceidn[i+1].
Number of elements of cell i is ceidn[i+1]-ceidn[i]"""
scale: float
"""normalization scaling factor"""
deltax: np.ndarray
"""normalization scaling offset. shape=(1,dim)"""
contact_matrix: csr_matrix
"""cell-cell contact matrix. shape=(nc,nc)"""
spatial_distances: Dict[str, csr_matrix]
"""dictionary to store cell-cell distance matrix"""
alphashape: List[AlphaShape]
"""list of alpha-shape object for each cell. len=nc"""
alphashape_info: Dict[str, Union[bool, float, int, None]]
"""information of alpha-shape"""
alpha_radius : np.ndarray
"""alpha radius of each cell"""
ns_default: int
"""number of augment points for alpha-shape"""
def __init__(self,
adata: Optional[AnnData] = None,
xc: Optional[np.ndarray] = None,
ctype: Optional[np.ndarray] = None,
cluster_key: str = 'leiden',
spatial_key: str = 'spatial',
color_list: Optional[np.ndarray] = None):
""" Initialize common attributes"""
self.adata = adata
self.contact_matrix = None
self.spatial_distances = dict()
self.alphashape_info = {'computed': False, 'alpha': None, 'ns': None, 'r': None}
self.ns_default = 0
self.scale = 1.
self.deltax = np.zeros(2)
# Initialize common attributes
if self.adata is not None:
self._init_from_adata(cluster_key,spatial_key)
else:
self._init_from_direct_inputs(xc,ctype,color_list)
# Set default colors if not provided
if self.color_list is None:
self.set_color()
# check 3d
if self.dim > 2:
warnings.warn('xc is 3d, use first two dim')
self.dim = 2
self.xc = self.xc[:,[0,1]]
self.alpha_radius = np.zeros(self.nc)
def _init_from_adata(self,cluster_key,spatial_key):
"""Initialize properties from AnnData object"""
if spatial_key in self.adata.obsm:
self.xc = self.adata.obsm[spatial_key].astype(np.float32)
self.nc, self.dim = self.xc.shape
print(f"gathering cell positions from adata.obsm['{spatial_key}']")
else:
raise KeyError(f"Spatial key '{spatial_key}' not found in adata.obsm")
if cluster_key in self.adata.obs:
self.ctype = self.adata.obs[cluster_key].cat.codes.to_numpy()
self.ctype_list = self.adata.obs[cluster_key].cat.categories.to_numpy()
print(f"gathering cell types from adata.obs['{cluster_key}']")
else:
raise KeyError(f"'{cluster_key}' not found in adata.obs")
if f'{cluster_key}_colors' in self.adata.uns:
color_list = self.adata.uns[f'{cluster_key}_colors']
if isinstance(color_list[0], str):
color_list = np.array([to_rgb(x) for x in color_list])
self.color_list = color_list
print(f"gathering cell type colors from adata.uns['{cluster_key}_colors']")
else:
raise KeyError(f"'{cluster_key}_colors' not found in adata.uns")
def _init_from_direct_inputs(self,xc,ctype,color_list):
"""Initialize properties from direct inputs"""
if xc is None:
raise ValueError("Either adata or xc must be provided")
else:
self.xc = xc.astype(np.float32)
self.nc, self.dim = xc.shape
self.ctype = ctype
if self.ctype is None:
warnings.warn('ctype are not provided, set cell type to 0')
self.ctype = np.zeros(self.nc, dtype=int)
self.ctype_list = np.unique(self.ctype)
self.color_list = color_list
if self.color_list is not None:
if isinstance(self.color_list[0], str):
self.color_list = np.array([to_rgb(x) for x in self.color_list])
def set_color(self, cmap_name: str = 'Set1') -> None:
"""
Set a color map for `sem.ctype` (cell types)
Parameters
-------
cmap_name: str
Name of a matplotlib colormap (`matplotlib.colormaps`)
"""
if cmap_name not in colormaps:
cmap_name = 'Set1'
cmap = colormaps[cmap_name]
self.color_list = cmap(np.linspace(0, 1, len(self.ctype_list)))[:, :3]
def compute_contact(self,
k: int = 8,
d_th: Optional[float] = None,
add_key: str = 'contacts') -> None:
"""
Compute cell-cell contacts
add .contact_matrix (csr_matrix)
stored in adata.obsp[add_key]
Parameters
------
k : int, default=8
Number of neighbors
d_th : Optional[float], default=None
Distance threshold
add_key : str, default='contacts'
Key for storing cell-cell contacts matrix to .obsp
"""
if d_th is None:
d_th = 2*self._get_e_radius()
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(self.xe)
distances, indices = nbrs.kneighbors(self.xe)
row_indices = []
col_indices = []
data = []
for ci in range(self.nc):
neighbor = indices[self.ceidn[ci]:self.ceidn[ci+1], 1:].flatten()
d_contact = distances[self.ceidn[ci]:self.ceidn[ci+1], 1:].flatten()
contact_eid = np.unique(neighbor[d_contact<=d_th])
contact_cid = self.ecid[contact_eid]
for cj in contact_cid[contact_cid!=ci]:
row_indices.append(ci)
col_indices.append(cj)
data.append(1)
# make symmetrical
contact_matrix = coo_matrix((data, (row_indices, col_indices)), shape=(self.nc, self.nc))
self.contact_matrix = (contact_matrix+contact_matrix.T)/2 # contact_matrix is csr_matrix. tocoo() do not have .coords
# if linked with adata, add contact_matrix to .obsp[add_key]
if self.adata is not None:
print(f"add .obsp['{add_key}'], .uns['{add_key}']")
self.adata.obsp[add_key] = self.contact_matrix
self.adata.uns[add_key] = {'k':k, 'd_th':d_th}
def _get_e_radius(self) -> float: # replaced in sub-class
"""get default d_th for computing contact and alphashape"""
return 1.0
def compute_distance(self,
method: str,
k: int = 3,
return_distances: bool = False) -> Union[csr_matrix, None]:
"""
Compute cell-cell distances in contact matrix
Add .spatial_distances[method] (csr_matrix)
Parameters
------
method : str
valid methods: 'knn', 'delaunay', 'contact'
k : int, default: 3
k for knn
return_distances : bool, default: False
Whether to return the distances matrix
Return
------
distance_matrix : csr_matrix
Cell-cell distances if return_distances is True, None otherwise
"""
xc = self.xc
distance_matrix = lil_matrix((self.nc, self.nc))
assert method in ('knn', 'delaunay', 'contact'), f"method must be 'knn' or 'delaunay', got {method}"
if method == 'knn':
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(xc)
distances,indices = nbrs.kneighbors(xc)
for i in range(self.nc):
for nj,j in enumerate(indices[i,1:]):
distance_matrix[i, j] = distances[i,nj+1]
distance_matrix[j, i] = distances[i,nj+1]
if method == 'delaunay':
tri = Delaunay(xc)
for simplex in tri.simplices:
for i in range(3):
for j in range(i + 1, 3):
d = distance.euclidean(xc[simplex[i]],xc[simplex[j]])
distance_matrix[simplex[i], simplex[j]] = d
distance_matrix[simplex[j], simplex[i]] = d
if method == 'contact':
if self.contact_matrix is None:
raise ValueError('contact is not computed')
else:
indices = self.contact_matrix.indices
indptr = self.contact_matrix.indptr
for i in range(self.nc):
for j in indices[indptr[i]:indptr[i+1]]:
d = distance.euclidean(xc[i],xc[j])
distance_matrix[i, j] = d
distance_matrix[j, i] = d
if return_distances:
return distance_matrix.tocsr()
else:
self.spatial_distances[method] = distance_matrix.tocsr()
def compute_alphashape(self,
alpha: Optional[Union[np.ndarray,float]] = None,
ns: Optional[int] = None,
r: Optional[float] = None) -> None:
"""
Compute alpha-shape for each cell
Add .alphashape: List[AlphaShape]
"""
if r is None:
r = self._get_e_radius()/2
if ns is None:
ns = self.ns_default
alpha_array = np.full(self.nc, alpha) if isinstance(alpha, float) else alpha
# Check if recomputing alphashape are needed
if not self.alphashape_info['computed'] or self.alphashape_info['ns']!=ns or self.alphashape_info['r']!=r:
print(f"Computing alpha-shape with parameters: alpha={alpha}, ns={ns}, r={r}")
xe = self.xe*self.scale+self.deltax
self.alphashape = []
for cid in tqdm(range(self.nc),'Processing Cell Shapes'):
alpha_i = alpha_array[cid] if alpha_array is not None else None
shp = AlphaShape(xe[self.ceidn[cid]:self.ceidn[cid+1]],alpha=alpha_i,ns=ns,r=r*self.scale)
if alpha_i is None:
shp.update(2*shp.alpha_best)
self.alpha_radius[cid] = shp.alpha
self.alphashape.append(shp)
elif self.alphashape_info['alpha']!=alpha:
print(f"Updating alpha to {alpha}")
for cid,shp in enumerate(self.alphashape):
alpha_i = alpha_array[cid] if alpha_array is not None else None
if alpha_i is None:
if shp.alpha_best is None:
shp.optimize_alpha()
alpha_i = 2*shp.alpha_best
shp.update(alpha_i)
self.alpha_radius[cid] = shp.alpha
# else do nothing
# update alphashape_info
self.alphashape_info['computed'] = True
self.alphashape_info.update({'alpha':alpha, 'ns':ns, 'r':r})
def get_alpha(self) -> Union[np.ndarray, None]:
"""
Get alpha radius of each cell
Return
------
alpha : Union[np.ndarray, None]
Alpha radius of each cell if alphashapes have been computed, otherwise None
"""
if not self.alphashape_info['computed']:
warnings.warn('alphashape is not computed')
return self.alpha_radius
def get_area(self) -> Union[np.ndarray, None]:
"""
Get cell areas
Return
------
area : Union[np.ndarray, None]
Cell areas if alphashapes have been computed, otherwise None
"""
if self.alphashape_info['computed']:
area = np.zeros(self.nc)
for i in range(self.nc):
area[i] = self.alphashape[i].get_area()
return area
else:
warnings.warn('alphashape is not computed')
return None
def get_elements(self, i: int) -> np.ndarray:
'''
Get elements of cell i
Parameters
------
i : int
Cell index
Return
------
xe : np.ndarray
Element coordinates of cell i
'''
return self.xe[self.ceidn[i]:self.ceidn[i+1]]*self.scale+self.deltax
def update_xc(self) -> None:
"""
Update .xc (cell coordinates)
"""
self.xc = np.array([np.mean(self.xe[self.ceidn[i]:self.ceidn[i+1]], axis=0) for i in range(self.nc)])
class SEM(CellBase):
"""Subcellular Element Method"""
sim_name: str
"""simulation name"""
t: int
"""time step"""
param: dict
"""simulation parameters"""
def __init__(self,
ne_per_cell: int,
re: float,
rd_ratio: float = 2.5,
adata: Optional[AnnData] = None,
cluster_key: Optional[str] = None,
spatial_key: str = 'spatial',
embedding_key: str = 'X_pca',
xc: Optional[np.ndarray] = None,
ctype: Optional[np.ndarray] = None,
sim_name: str = 'untitled',# param: dict = {}
seed: int = 1):
"""
Create a SEM object
Parameters
-------
ne_per_cell : int, default: 20
Number of elements per cell
re : float
Element radius
rd_ratio : float
Cell radius-distance ratio
rd_ratio>2: cell radius < cell distance/2, tissue with gaps
rd_ratio=2: cell radius = cell distance/2, no gaps (confluent tissue)
rd_ratio<2: cell radius > cell distance/2, overcrowded
adata : Anndata
Anndata with .obsm[spatial_key] for cell coordinates, .obs[cluster_key] for cell types, .obsm[embedding_key] for low-dim embedding
If not provided, xc and ctype are required
xc : Optional[np.ndarray]
cell coordinates. Ignored, if adata.obsm[spatial_key] is provided
ctype : Optional[np.ndarray]
cell types. Ignored, if adata.obs[cluster_key] is provided
cluster_key : str, default: 'leiden'
Key for cell type in .obs
spatial_key : str, default: 'spatial'
Key for spatial coordinates in .obsm
embedding_key : str, default: 'X_pca'
Key for low-dim embedding in .obsm, used for computing gene similarity
sim_name : str, default: 'untitled'
Simulation name
"""
super().__init__(adata, xc, ctype, cluster_key, spatial_key, None)
# Initialize simulation-specific properties
# element info
self.ne = ne_per_cell * self.nc # total number of elements
self.ecid = np.repeat(np.arange(self.nc, dtype=np.int32), ne_per_cell) # id of cell to which each element belongs
self.ceidn = np.insert(np.cumsum([ne_per_cell]*self.nc), 0, 0)# first elements id of each cell. Elements of cell i can be retrieve by ceidn[i]:ceidn[i+1]. Number of elements of cell i is ceidn[i+1]-ceidn[i]
self.xe = np.zeros((self.ne, self.dim), dtype=np.float32) # elements coordinates, n_element*dim
# cell radius
if self.nc > 2:
# estimate cell radius by Delaunay
distance_matrix = self.compute_distance('delaunay', return_distances=True)
dc = np.zeros(self.nc)
for cid in range(self.nc):
_,j=distance_matrix[cid].nonzero()
dc[cid] = np.mean(distance_matrix[cid,j]) if len(j)>0 else np.nan # some points might overlap with others
dc = dc[~np.isnan(dc)]
rc = np.median(dc)/rd_ratio
# rd_ratio>2: cell radius < cell distance/2, tissue with gaps
# rd_ratio=2: cell radius = cell distance/2, no gaps (confluent tissue)
# rd_ratio<2: cell radius > cell distance/2, overcrowded
elif self.nc == 2:
# only two cells
rc = distance.euclidean(self.xc[0],self.xc[1])/2
else:
# only one cell
rc = 1
self.rc = rc
rc_n = np.sqrt(ne_per_cell*(re/2)**2)
self.scale = rc/rc_n
self.deltax = np.mean(self.xc, axis=0)
self.xc = (self.xc-self.deltax)/self.scale ## scaling xc to xc_n/rc_n = xc/rc
## random number generator
self.rng = np.random.default_rng(seed)
self.rng_seed = seed
## deploy elements to the spherical region around each cell coordinates
for cid in range(self.nc):
# generate element in a spherical region following uniform distribution
ne_i = self.ceidn[cid+1] - self.ceidn[cid]
r = rc_n*np.sqrt(self.rng.uniform(0, 1, size=(ne_i, 1))) #np.sqrt() # cell_r = rc_n*self.ne_per_cell[i]/ne_per_cell
phi = self.rng.uniform(-pi, pi, size=(ne_i, 1))
xe = np.concatenate((r*np.cos(phi), r*np.sin(phi)), axis=1)
xe = xe-np.mean(xe, axis=0)+self.xc[cid] # move initial element to cell center
self.xe[self.ceidn[cid]:self.ceidn[cid+1]] = xe.astype(np.float32)
# simulation info
self.sim_name = sim_name
self.t = 0
self.param = dict()
# adhesion based on gene simarity
if self.adata is None:
self.corr_matrix = np.ones((self.nc,self.nc), dtype=np.float32)
else:
X_em = self.adata.obsm[embedding_key]# default PCA matrix
corr_matrix = np.corrcoef(X_em)
c_min = 0.05
corr_matrix[corr_matrix<c_min] = c_min
self.corr_matrix = corr_matrix.astype(np.float32) # gene simarity matrix, (n_c*n_c)
self.ns_default = 10 # alphashape default param
def __repr__(self):
return f'Simulation Name: {self.sim_name}\nt: {self.t}\nCell Number: {self.nc}\nElement Number: {self.xe.shape[0]}\nDim: {self.dim}\nParameters: {self.param}\nContact Matrix: {self.contact_matrix.__repr__()}'
def _get_e_radius(self) -> float:
"""get SEM default d_th for computing contact and alphashape"""
if len(self.param) > 0:
d_th = self.param['rm_inter']
else:
warnings.warn('rm_inter is not provided, using d_th = 1')
d_th = 1.0
return d_th
def sim_gpu(self, param: dict, T: int) -> None:
"""
Implement SEM simulation
Parameters
------
param : dict
Parameters
T : int
Time steps
"""
self.param = param
# get parameters
rm_intra = param["rm_intra"]
rm_inter = param["rm_inter"]
dt = param["dt"]
sigma = param["sigma"]
gamma = param["gamma"]
alpha_max,alpha_min = param["alpha"]
cmax = self.corr_matrix.max()
cmin = self.corr_matrix.min()
if cmax==cmin:
# corr_matrix is constant, set alpha to ones
alpha = alpha_max*np.ones_like(self.corr_matrix)
else:
# scale corr_matrix to [alpha_min,alpha_max]
alpha = (alpha_max-alpha_min)/(cmax-cmin)*self.corr_matrix+(alpha_min*cmax-alpha_max*cmin)/(cmax-cmin)
sigmadt = sqrt(dt) * sigma
# transfer array to gpu
d_xe = cuda.to_device(self.xe)
d_xe_F = cuda.to_device(self.xe)
d_ecid = cuda.to_device(self.ecid)
d_alpha = cuda.to_device(alpha)
# gpu thread number
tpb = 128
bpg = 128
# iteration
cuda.synchronize()
for t in tqdm(range(T),'Simulation'):
x_randt = cuda.to_device((sigmadt*np.sqrt((T-t)/T) * self.rng.normal(0, 1, size=self.xe.shape)).astype(np.float32))#*self.cell_size[self.ecid,np.newaxis]
dynamics2d_gpu[bpg, tpb](d_xe, d_xe_F, d_ecid, d_alpha, gamma, x_randt, rm_intra, rm_inter, dt)
cuda.synchronize()
# var:t-1, var_F:t
d_xe[:, :] = d_xe_F # update xe to t
cuda.synchronize()
# if self.t % vis_interval ==0:
# print(self.t)
self.t += 1
# close
cuda.synchronize()
self.xe = d_xe.copy_to_host()
self.update_xc()
self.alphashape_info['computed'] = False # marks alpha shapes need to be updated
def save_sim(self) -> None:
"""
Save simulation
"""
filename = f'{self.sim_name}_{self.t}'
if os.path.exists(filename+'.pkl'):
warnings.warn(f"File '{filename}' already exists.")
filename = filename + '_temp'
filename = filename+'.pkl'
with open(filename, 'wb') as f:
data = {
'xe': self.xe,
'ceidn': self.ceidn,
'ecid': self.ecid,
'param': self.param,
'scale': self.scale,
'deltax': self.deltax
}
if self.alphashape_info['computed']:
data['alpha_radius']= self.alpha_radius
pickle.dump(data, f)
print(f"saved as {filename}")
def load_sim(self, sim_name: str , t: float, path: str = '.', rename: bool = True) -> None:
'''
Restore a simulation from `{path}/{sim_name}_{t}.pkl`
Parameters
------
sim_name : str
Name of simulation
t : float
Time point
path : str, default: '.'
Path to simulation data
rename : bool, default: True
If True, rename the `sem` to `sim_name`
'''
filename = f'{path}/{sim_name}_{t}.pkl'
with open(filename, 'rb') as f:
print(f'load sim data from {filename}')
data = pickle.load(f)
self.xe = data['xe']
self.ceidn = data['ceidn']
self.ecid = data['ecid']
self.scale = data['scale']
self.deltax = data['deltax']
if 'param' in data:
print('.param loaded')
self.param = data['param']
if 'alpha_radius' in data:
print('.alpha_radius loaded')
self.compute_alphashape(alpha=data['alpha_radius'])
self.t = t
self.update_xc()
if rename:
print(f'Simulation renamed as {sim_name}')
self.sim_name = sim_name
@cuda.jit
def dynamics2d_gpu(xe, xe_F, ecid, alpha, gamma, x_randt, rm_intra, rm_inter, dt):
"""
Simulation function
"""
start = cuda.grid(1)
stride = cuda.gridsize(1)
ne = xe.shape[0]
for i in range(start, ne, stride):
cid = ecid[i]
for j in range(ne):
if j == i :
continue
deltax = xe[i, 0]-xe[j, 0]
deltay = xe[i, 1]-xe[j, 1]
if abs(deltax) < 30 and abs(deltay) < 30:
r = sqrt(deltax**2 + deltay**2)
if ecid[j] == cid:
dV = max(2*d_potential_LJ_gpu(r, rm_intra, 1.5)+ gamma*r, -10.0 )
else:
dV = max(alpha[cid,ecid[j]]*d_potential_LJ_gpu(r, rm_inter, 1.5), -10.0)
xe_F[i, 0] += -dt * dV * deltax
xe_F[i, 1] += -dt * dV * deltay
xe_F[i, 0] += x_randt[i, 0]
xe_F[i, 1] += x_randt[i, 1]
@cuda.jit(device=True)
def d_potential_LJ_gpu(r, rm, epsilon):
rs6 = (rm/r)**6
return epsilon*r**-2*(rs6-rs6*rs6)
class cellshape_GT(CellBase):
"""Cell shape representation for experimental data visualization"""
def __init__(self,
xe: np.ndarray,
ecid: np.ndarray,
ceidn: np.ndarray,
xc: Optional[np.ndarray] = None,
ctype: Optional[np.ndarray] = None,
color_list: Optional[np.ndarray] = None,
adata: Optional[AnnData] = None,
spatial_key: str = 'spatial',
cluster_key: Optional[str] = None):
self.nc = ceidn.shape[0]-1
super().__init__(adata, xc, ctype, cluster_key, spatial_key, color_list)
# Visualization-specific properties
self.xe = xe
self.ecid = ecid
self.ceidn = ceidn
self.dim = xe.shape[1]
self.ne_per_cell: np.ndarray = ceidn[1:]-ceidn[:-1]
if xc is None and adata is None:
print('compute xc from xe')
self.update_xc()
def __repr__(self):
return f'Cell Number: {self.nc}\nElement Number: {self.xe.shape[0]}\nDim: {self.dim}'
[docs]
def cell_shape_modeling(adata: AnnData,
cluster_key: str,
ne: int = 20,
rd_ratio: float = 2.5,
spatial_key: str = 'spatial',
pca_key: str = 'X_pca',
seed: int = 1
) -> SEM:
"""
Perform cell shape modeling based on subcellular element method
Parameters
----------
adata : Anndata
AnnData object
cluster_key : str
Key in `adata.obs` that contains cell type annotations
ne : int, default 20
number of elements per cells
rd_ratio : float, default 2.5
Cell radius-distance ratio
- rd_ratio>2: cell radius < cell distance/2, tissue with gaps
- rd_ratio=2: cell radius = cell distance/2, no gaps (confluent tissue)
- rd_ratio<2: cell radius > cell distance/2, overcrowded
spatial_key : str, default 'spatial'
Key in `adata.obsm` that contains spatial coordinates
pca_key : str
Key in `adata.obsm` that contains PCA embeddings.
If not in `adata.obsm`, `scanpy.pp.pca(adata)` will be computed.
seed : int, default 1
random seed
Returns
-------
SEM
SEM object containing cell shapes and cell-cell contains information
Set the field in adata
`.obsp['contacts']` (csr_matrix) for cell-cell contacts
Examples
--------
>>> sem = cr.tl.cell_shape_modeling(adata,cluster_key = 'cell_type')
"""
rm = 2
rd_ratio = 2.5
ne = 20
param = {"rm_intra":rm,"rm_inter":rm*1.2,"dt":0.04,'sigma':1,'alpha':(8,0.5),'gamma':0.001}
sem = SEM(
ne, rm, rd_ratio, adata = adata,
cluster_key=cluster_key,
spatial_key = spatial_key,
embedding_key = pca_key,
seed=seed
)
sem.sim_gpu(param,T=2000)
sem.compute_contact()
sem.compute_alphashape()
return sem