import ot as pot
import torch
from sklearn.neighbors import KNeighborsRegressor
from cellrank.kernels import CytoTRACEKernel
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings
import networkx as nx
from sklearn.neighbors import NearestNeighbors
from scipy.sparse.csgraph import dijkstra
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import minmax_scale
from copy import deepcopy
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import pygot.external.palantir as palantir
def scale(x):
return (x - np.nanmin(x)) / (np.nanmax(x) - np.nanmin(x))
def highlight_extrema(adata, basis='umap', **kwargs):
extrema_names = adata.uns['extrema']
fig,ax = plt.subplots(1,1)
sc.pl.embedding(adata, show=False, ax=ax, basis=basis, **kwargs)
ax.scatter(adata[extrema_names].obsm['X_'+basis][:,0], adata[extrema_names].obsm['X_'+basis][:,1], label='extrema')
plt.legend()
def calcu_w(D, sdv=None):
if sdv is None:
sdv = np.std(np.ravel(D)) * 1.06 * len(np.ravel(D)) ** (-1 / 5)
W = np.exp(-0.5 * np.power((D / sdv), 2))
return W, sdv
def _max_min_sampling(data, num_waypoints, seed=None):
"""Function for max min sampling of waypoints
:param data: Data matrix along which to sample the waypoints,
usually diffusion components
:param num_waypoints: Number of waypoints to sample
:param seed: Random number generator seed to find initial guess.
:return: pandas Series reprenting the sampled waypoints
"""
waypoint_set = list()
no_iterations = int((num_waypoints) / data.shape[1])
if seed is not None:
np.random.seed(seed)
# Sample along each component
N = data.shape[0]
for ind in range(data.shape[1]):
# Data vector
vec = np.ravel(data[:,ind])
# Random initialzlation
iter_set = [
np.random.randint(N),
]
# Distances along the component
dists = np.zeros([N, no_iterations])
dists[:, 0] = abs(vec - data[iter_set, ind])
for k in range(1, no_iterations):
# Minimum distances across the current set
min_dists = dists[:, 0:k].min(axis=1)
# Point with the maximum of the minimum distances is the new waypoint
new_wp = np.where(min_dists == min_dists.max())[0][0]
iter_set.append(new_wp)
# Update distances
dists[:, k] = abs(vec - data[new_wp, ind])
# Update global set
waypoint_set = waypoint_set + iter_set
# Unique waypoints
waypoints = np.unique(waypoint_set)
return waypoints
def fast_palantir(dist_matrix, start_cell, waypoints, waypoints_D, waypoints_W, sdv, max_iterations = 25):
waypoints_idx = waypoints != start_cell
start_D = dist_matrix[[start_cell]]
D = np.concatenate([start_D, waypoints_D[waypoints_idx]])
start_W, _ = calcu_w(start_D, sdv)
W = np.concatenate([start_W, waypoints_W[waypoints_idx]])
norm_c = W.sum(axis=0)
zero_idx = norm_c == 0
norm_c[zero_idx] = 1.
W = W / norm_c
W[:, zero_idx] = 1. / W.shape[0]
# Initalize pseudotime to start cell distances
pseudotime = D[0]
converged = False
# Iteratively update perspective and determine pseudotime
iteration = 1
while not converged and iteration < max_iterations:
P = deepcopy(D)
# Perspective matrix by alinging to start distances
for i,wp in enumerate(waypoints[waypoints_idx]):
# Position of waypoints relative to start
idx_val = pseudotime[wp]
# Convert all cells before starting point to the negative
before_indices = np.where(pseudotime < idx_val)[0]
P[i+1, before_indices] = -D[i+1, before_indices]
# Align to start
P[i+1, :] = P[i+1, :] + idx_val
# Weighted pseudotime
new_traj = (P*W).sum(axis=0)
# Check for convergence
corr = pearsonr(pseudotime, new_traj)[0]
if corr > 0.9999:
converged = True
# If not converged, continue iteration
pseudotime = new_traj
iteration += 1
pseudotime -= np.min(pseudotime)
pseudotime /= np.max(pseudotime)
return pseudotime
def diffmap_extrema(adata, diffmap_key='X_diffmap', ):
extrema = []
eigenvectors = adata.obsm[diffmap_key]
for dcomp in range(eigenvectors.shape[1]):
ec = eigenvectors[:, dcomp].argmax()
extrema.append(ec)
ec = eigenvectors[:, dcomp].argmin()
extrema.append(ec)
extrema_names = adata.obs.index[extrema]
adata.uns['extrema'] = extrema_names
def init_candidiates(adata, diffmap_key, cell_type_key=None):
diffmap_extrema(adata, diffmap_key=diffmap_key)
if cell_type_key is None:
return
candidates = []
for cell_type in pd.unique(adata.obs[cell_type_key]):
candidates.append(adata.obs.loc[adata.obs[cell_type_key] == cell_type].sample(n=1).index)
adata.uns['extrema'] = np.concatenate([adata.uns['extrema'], np.concatenate(candidates)])
adata.uns['extrema'] = np.unique(adata.uns['extrema'])
def _connect_graph(adj, data, start_cell):
# Create graph and compute distances
graph = nx.Graph(adj)
dists = pd.Series(nx.single_source_dijkstra_path_length(graph, start_cell))
dists = pd.Series(dists.values, index=data.index[dists.index])
# Idenfity unreachable nodes
unreachable_nodes = data.index.difference(dists.index)
if len(unreachable_nodes) > 0:
warnings.warn(
"Some of the cells were unreachable. Consider increasing the k for \n \
nearest neighbor graph construction."
)
# Connect unreachable nodes
while len(unreachable_nodes) > 0:
farthest_reachable = np.where(data.index == dists.idxmax())[0][0]
# Compute distances to unreachable nodes
unreachable_dists = pairwise_distances(
data.iloc[farthest_reachable, :].values.reshape(1, -1),
data.loc[unreachable_nodes, :],
)
unreachable_dists = pd.Series(
np.ravel(unreachable_dists), index=unreachable_nodes
)
# Add edge between farthest reacheable and its nearest unreachable
add_edge = np.where(data.index == unreachable_dists.idxmin())[0][0]
adj[farthest_reachable, add_edge] = unreachable_dists.min()
# Recompute distances to early cell
graph = nx.Graph(adj)
dists = pd.Series(nx.single_source_dijkstra_path_length(graph, start_cell))
dists = pd.Series(dists.values, index=data.index[dists.index])
# Idenfity unreachable nodes
unreachable_nodes = data.index.difference(dists.index)
return adj
def compute_spdist(adata, embedding_key='X_pca', n_neighbors=20, start_cells=None, scale=True):
X = pd.DataFrame(adata.obsm[embedding_key], index=adata.obs.index)
if scale:
X =pd.DataFrame(minmax_scale(X), index=adata.obs.index)
neighbors = NearestNeighbors(n_neighbors=n_neighbors, metric="euclidean").fit(X)
knn_graph = neighbors.kneighbors_graph(X, mode="distance")
if start_cells is not None:
intersection_index = pd.Index(start_cells).intersection(adata.obs.index)
print('Convert into connected graph')
for start_cell in tqdm(intersection_index):
knn_graph = _connect_graph(knn_graph, X, np.where(adata.obs.index == start_cell)[0][0])
dist_matrix = dijkstra(csgraph=knn_graph, directed=False, )
return knn_graph, dist_matrix
def calcu_ot_loss(adata, embedding_key, pseudo_group_key='pseudobin', p=2):
def _ot_loss(M, a=None, b=None):
if a is None:
a = torch.ones(M.shape[0]) / M.shape[0]
if b is None:
b = torch.ones(M.shape[1]) / M.shape[1]
pi = pot.emd(a, b, M)
#pi = pot.sinkhorn_unbalanced(a, b, M, 0., [1., 10.])
return torch.sum(pi * M).item()
loss = 0.
for i in range(int(np.max(adata.obs[pseudo_group_key]))):
M = torch.cdist(torch.tensor(adata[adata.obs.loc[adata.obs[pseudo_group_key] == i].index].obsm[embedding_key]),
torch.tensor(adata[adata.obs.loc[adata.obs[pseudo_group_key] == i+1].index].obsm[embedding_key]), p=p)
loss += _ot_loss(M,)
return loss
def calcu_got_loss(adata, pseudo_group_key='pseudobin', dist_matrix = None):
adata.obs['idx'] = range(len(adata))
def _ot_loss(M, a=None, b=None):
if a is None:
a = torch.ones(M.shape[0]) / M.shape[0]
if b is None:
b = torch.ones(M.shape[1]) / M.shape[1]
pi = pot.emd(a, b, M)
#pi = pot.sinkhorn_unbalanced(a, b, M, 0., [1., 10.])
return torch.sum(pi * M).item()
loss = 0.
for i in range(int(np.max(adata.obs[pseudo_group_key]))):
idx1 = adata.obs.loc[adata.obs[pseudo_group_key] == i].idx.tolist()
idx2 = adata.obs.loc[adata.obs[pseudo_group_key] == i+1].idx.tolist()
M = torch.tensor(dist_matrix[idx1,:])
M = M[:, idx2]
loss += _ot_loss(M,)
return loss
def greedy_search_best_source(adata, embedding_key, kernel='dpt', split_k=None, graph_dist=False, n_neighbors=20, connect_anchor=None, n_waypoints=1200):
assert (kernel=='dpt') | ((kernel != 'dpt') & (graph_dist == True))
time_key = kernel + '_pseudotime'
if split_k is None:
split_k = int(len(adata) / 100)
res = {}
if graph_dist:
knn_graph, dist_matrix = compute_spdist(adata, embedding_key, n_neighbors=n_neighbors, start_cells=connect_anchor)
if kernel == 'palantir':
waypoints = _max_min_sampling(adata.obsm[embedding_key], num_waypoints=n_waypoints, seed=20)
waypoints_D = dist_matrix[waypoints]
waypoints_W, sdv = calcu_w(waypoints_D)
adata.obs['root_loss'] = np.nan
for i in tqdm(range(len(adata))):
filtered_idx = range(len(adata))
if kernel == 'dpt':
adata.uns['iroot'] = i
sc.tl.dpt(adata)
elif kernel == 'sp':
adata.obs[time_key] = dist_matrix[i,:]
elif kernel == 'palantir':
adata.obs[time_key] = fast_palantir(dist_matrix, i, waypoints, waypoints_D, waypoints_W, sdv)
elif kernel == 'euclidean':
adata.obs[time_key] = np.mean((adata.obsm[embedding_key] - adata.obsm[embedding_key][i])**2, axis=1)
if np.sum(np.isinf(adata.obs[time_key])) > 0:
filtered_idx = np.where(~np.isinf(adata.obs[time_key]))[0].tolist()
data = adata[~np.isinf(adata.obs[time_key])].copy()
else:
data = adata
data.obs[time_key] = scale(data.obs[time_key])
if len(data) < 100:
res[adata.obs.index[i]] = np.nan
continue
generate_time_points(data, k=split_k, pseudotime_key=time_key, sigma=0.)
with torch.no_grad():
if graph_dist == False:
loss = calcu_ot_loss(data, embedding_key, pseudo_group_key='pseudobin',)
else:
cost_matrix = dist_matrix[filtered_idx,:]
cost_matrix = cost_matrix[:, filtered_idx]
loss = calcu_got_loss(data, pseudo_group_key='pseudobin', dist_matrix=cost_matrix )
res[adata.obs.index[i]] = loss
res = pd.DataFrame([list(res.keys()), list(res.values())], index=['cell', 'raw_root_loss']).T
res['idx'] = range(len(res))
res = res.sort_values('raw_root_loss')
return res
def generate_time_points(adata, k=4, pseudotime_key = 'dpt_pseudotime', time_key='pseudobin', sigma=.0, ):
adata.obs[time_key] = -1
adata.obs[time_key+'_noise'] = adata.obs[pseudotime_key] + np.random.rand(len(adata)) * sigma
sorted_idx = adata.obs.sort_values(time_key+'_noise').index
bin_idxs = np.array_split(sorted_idx, k)
for i in range(k):
adata.obs.loc[bin_idxs[i], time_key] = i
adata.obs[time_key] = adata.obs[time_key].astype(float)
def smoothe_score(X, y, n_neighbors=5):
idx = ~np.isnan(y)
knn = KNeighborsRegressor(n_neighbors=n_neighbors)
knn.fit(X[idx], y[idx])
y_smoothed = knn.predict(X)
return y_smoothed, knn
[docs]
def determine_source_state(adata, embedding_key, graph_dist=True, n_neighbors=30,
split_m=30, kernel='dpt', n_comps=15, down_sampling=True, n_obs=3000,
cytotrace=True, alpha = 0.1, smooth_k=5,
connect_anchor=False) :
"""Determine souce cell for snapshot data
In most developing biological scenario, source cells will develop into multiple different cells.
By setting cell :math:`r` as start cell, the pseudotime :math:`\hat{t}(x_i)` can be computed, and
the empirical distribution can be divided into :math:`m` portions that :math:`X_1, X_2, ..., X_m`, according to time :math:`\hat{t}(x_i)`.
The transport cost of this time-vary distribution :math:`p_t(x|r)` can be quantified by optimal transport with graphical metrics.
.. math::
W_2^2(r)=\sum_{i=1}^{m-1}\inf_{\pi}\sum_{x \in X_i}\sum_{y \in X_{i+1}}c(x,y | G)\pi(x,y)
where :math:`c(x,y|G)` is the shorest path distance between two cells :math:`x,y` in graph :math:`G`.
According to the energy-saving hypothesis, the defined transport cost of real source cell will be smallest, that
.. math::
{r}^* = arg \min_{r} W_2^2(r)
.. note::
This assumption may *fails* in the case of *linear progression* that souce cell only developing in one direction.
In that case, the transport cost of real source cell and terminate cell will be very close. So this function
will detect linear progression and compute cytotrace score with very low weight (default 0.1) to choose the optimal source cell.
\\
To accelerate the computation, we suggest to down sample the dataset to 3000 cells (default) and use the down sampled data to compute the transport cost.
Arguments:
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
embedding_key: `str`
Name of latent space, in adata.obsm
graph_dist: `bool` (default: True)
Using shorest path distance or euclidean distance
n_neighbors: `int` (default: 30)
Number of neighbors of kNN which is used to compute shortest path distance
split_m: int (default: 30)
Number of split. This number should NOT be too small
kernel: 'dpt' or 'palantir' or 'euclidean' (default: 'dpt')
Pseudotime method, 'dpt' is recommended
n_comps: `int` (default: 15)
Number of diffmap components, which is used for DPT computation
down_sampling: `bool` (default: True)
Down sampling dataset to accelerate computation
n_obs: `int` (default: 3000)
Number of down sampling size
cytotrace: `bool` (default: True)
Use cytotrace to help. Note cytorace is implemented by Cellrank2
alpha: `float` (default: 0.1)
Weight of cytotrace. We do NOT suggest increase the weight
smooth_k: `int` (default: 5)
Number of neighbors which is used to smoothes the final score
time_key: `str` (default: None)
Name of time label, in adata.obs, use if the model input contains time label
Returns
-------
ot_root (.uns): `int`
best source cell index using transport cost only
ot_ct_root (.uns): `int`
best source cell index using both transport cost and cytotrace
root_score (.obs): `np.ndarray`
source cell score (higher score higher probability to be source)
ot_root_score (.obs): `np.ndarray`
source cell score + alpha * cytotrace score (higher score higher probability to be source)
"""
if kernel == 'palantir':
if not ('DM_EigenVectors' in adata.obsm.keys()):
palantir.run_diffusion_maps(adata, n_components=n_comps)
else:
if not ('X_diffmap' in adata.obsm.keys()):
sc.tl.diffmap(adata, n_comps=n_comps)
if down_sampling and len(adata) > n_obs:
print('Down sampling')
sub_adata = adata.copy()
sc.pp.subsample(sub_adata, n_obs=n_obs)
sc.pp.neighbors(sub_adata, use_rep=embedding_key)
else:
sub_adata = adata
if connect_anchor:
if kernel == 'palantir':
init_candidiates(sub_adata, diffmap_key='DM_EigenVectors')
else:
init_candidiates(sub_adata, diffmap_key='X_diffmap')
connect_anchor = sub_adata.uns['extrema']
else:
connect_anchor = None
res = greedy_search_best_source(sub_adata, embedding_key, split_k=split_m, kernel=kernel,
graph_dist=graph_dist, n_neighbors=n_neighbors, connect_anchor=connect_anchor,
)
res['root_loss'], knn = smoothe_score(sub_adata[res['cell']].obsm[embedding_key],
res['raw_root_loss'].to_numpy().astype(float),
smooth_k)
res['root_score'] = 1 - scale(res['root_loss'])
res = res.sort_values('root_score', ascending=False)
print("optimal transport root cell write in adata.uns['ot_root']")
adata.uns['ot_root'] = np.where(adata.obs.index == res['cell'].tolist()[0])[0][0]
adata.obs['root_score'] = 1 - scale(knn.predict(adata.obsm[embedding_key]))
#optianal
if cytotrace:
if (not 'spliced' in adata.layers.keys()) or (not 'unspliced' in adata.layers.keys()):
adata.layers['Ms'] = adata.X
adata.layers['Mu'] = adata.X
CytoTRACEKernel(adata).compute_cytotrace()
adata.obs['ct_root_score'] = (adata.obs['root_score'].to_numpy() + alpha * (1 - adata.obs['ct_pseudotime'].to_numpy())) / (1+alpha)
res['ct_root_score'] = adata[res['cell']].obs['ct_root_score'].tolist()
res = res.sort_values('ct_root_score', ascending=False)
print("optimal transport + cytotrace root cell write in adata.uns['ot_ct_root']")
adata.uns['ot_ct_root'] = np.where(adata.obs.index == res['cell'].tolist()[0])[0][0]
adata.uns['cytotrace_alpha'] = alpha
return res