Source code for pygot.tools.analysis.cell_fate

from pygot.tools.traj import velocity_graph, diffusion_graph
from pygot.preprocessing import mutual_nearest_neighbors
import scanpy as sc
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
import statsmodels.distributions.empirical_distribution as edf
from scipy.interpolate import interp1d
from cellrank._utils._linear_solver import _solve_lin_system
from datetime import datetime




def current():
    now = datetime.now()
    formatted_time = now.strftime("%Y-%m-%d %H:%M:%S")
    return formatted_time


[docs] class CellFate: """Cell fate prediction based on the markov chain. .. math:: r = \\frac{1}{|C_k|}\sum_{i \in C_k}{R_{\\dot, i}} \\\\ P = (I - Q)^{-1} r \\\\ where :math:`R` represents non-target cells to target cells transition matrix, :math:`C_k` is the set of target cell type k. Besides, :math:`Q` is the transition matrix from the non-target cells to the non-target cells, and :math:`I` is the identity matrix. The final solution is :math:`P`, which represents the absorbing probabilities for the target cell type. Example: ---------- :: cf = pygot.tl.analysis.CellFate() cf.fit(adata, embedding_key='X_pca', velocity_key='velocity_pca', cell_type_key='clusters', target_cell_types=['Beta', 'Alpha', 'Delta', 'Epsilon']) adata.obs[adata.obsm['descendant'].columns] = adata.obsm['descendant'] sc.pl.umap(adata, color=adata.obsm['descendant'].columns, ncols=2) """
[docs] def __init__(self): pass
[docs] def fit( self, adata, embedding_key, velocity_key, cell_type_key, target_cell_types=None, target_cell_idx=None, n_neighbors=30, mutual=True, sde=True, D=1. ): """ fit the cell fate prediction model and export the result into adata.obsm['descendant'] and adata.obsm['ancestor'] Arguments: ---------- adata: :class:`anndata.AnnData` AnnData object embedding_key: `str` The key of the embedding in adata.obsm velocity_key: `str` The key of the velocity in adata.obsm cell_type_key: `str` The key of the cell type in adata.obs target_cell_types: `list` (default: None) The list of target cell types target_cell_idx: `list` (default: None) The list of target cell indices n_neighbors: `int` (default: 30) The number of neighbors for the nearest neighbors graph mutual: `bool` (default: True) Whether to use mutual nearest neighbors graph. Might isolate some cells if set to True sde: `bool` (default: True) Whether to use inner product kernel or cosine kernel D: `float` (default: 1.) The diffusion factor. Larger D means larger diffusion. """ assert (target_cell_types is not None) or (target_cell_idx is not None); 'Must offer target_cell_types or target_cell_idx' adata.obs['transition'] = 0 if target_cell_types is not None: adata.obs.loc[(adata.obs[cell_type_key].isin(target_cell_types)), 'transition'] = 1 model = TimeSeriesRoadmap(adata, embedding_key=embedding_key, velocity_key=velocity_key, time_key='transition', sde=sde, D=D) model.compute_state_coupling(cell_type_key=cell_type_key, n_neighbors=n_neighbors, mutual=mutual) model.export_result() self.model = model
[docs] def get_cluster_transition_map( self, pvalue=1e-3, max_cutoff=0.45 ): """ Get the cluster transition map based on the cell fate prediction model. Arguments: ---------- pvalue: `float` (default: 1e-3) The pvalue cutoff for the cluster transition map max_cutoff: `float` (default: 0.45) The maximum cutoff for the cluster transition map """ transition_list = self.model.filter_state_coupling(pvalue=pvalue, max_cutoff=max_cutoff) return transition_list[0]
[docs] class TimeSeriesRoadmap: """Developmental tree inference based on the velocity graph. Example: ---------- :: embedding_key = 'X_pca' velocity_key = 'velocity_pca' time_key = 'stage_numeric' cell_type_key = 'clusters' roadmap = pygot.tl.analysis.TimeSeriesRoadmap(adata, embedding_key, velocity_key, time_key) roadmap.fit(cell_type_key='clusters', n_neighbors=30) filtered_state_coupling_list = roadmap.filter_state_coupling(pvalue=0.001) #permutation test to fileter cell type coupling """
[docs] def __init__(self, adata, embedding_key, velocity_key, time_key, sde=False, D=1.): self.adata = adata self.embedding_key = embedding_key self.velocity_key = velocity_key self.time_key = time_key self.ts = np.sort(np.unique(adata.obs[time_key])) self.state_map = {t:{} for t in self.ts[:-1]} self.sde = sde self.D = D
[docs] def compute_state_coupling( self, cell_type_key='cell_type', n_neighbors=None, permutation_iter_n=100, mutual=True, ): ad = self.adata self.cell_type_key = cell_type_key print(current(), '\t Compute transition roadmap among', self.ts) ad.obs['idx'] = range(len(ad)) for i in range(len(self.ts) - 1): start = self.ts[i] end = self.ts[i+1] print(current(), '\t Compute transition between {} and {}'.format(start, end)) x0_obs = ad.obs.loc[ad.obs[self.time_key] == start] x1_obs = ad.obs.loc[ad.obs[self.time_key] == end] idx = pd.concat([x0_obs['idx'], x1_obs['idx']]) embedding = ad.obsm[self.embedding_key][idx.tolist()] embedding_v = ad.obsm[self.velocity_key][idx.tolist()] x0x1_ad = sc.AnnData(obs=ad.obs.loc[idx.index]) x0x1_ad.obsm[self.embedding_key] = embedding x0x1_ad.obsm[self.velocity_key] = embedding_v x0x1_ad = ad[np.concatenate([x0_obs.index, x1_obs.index])].copy() fwd, bwd, fbwd, null, descendant, ancestor = time_series_transition_map( x0x1_ad, self.embedding_key, self.velocity_key, self.time_key, start, end, norm=0, n_neighbors=n_neighbors, cell_type_key=cell_type_key, permutation_iter_n=permutation_iter_n, mutual=mutual, sde=self.sde, D=self.D ) self.state_map[start]['fwd'] = fwd self.state_map[start]['bwd'] = bwd self.state_map[start]['fbwd'] = fbwd self.state_map[start]['null'] = null self.state_map[start]['null_iedf'] = fit_null_distribution(null) self.state_map[start]['descendant'] = descendant self.state_map[start]['ancestor'] = ancestor
[docs] def filter_state_coupling( self, pvalue=0.001, max_cutoff=0.45 ): filtered_fbwd_list = [] for key in self.state_map.keys(): if len(self.state_map[key]['fbwd']) > 1: cutoff = min(max_cutoff, self.state_map[key]['null_iedf'](1-pvalue)) else: cutoff = max_cutoff self.state_map[key]['cutoff'] = cutoff filtered_fbwd_list.append((self.state_map[key]['fbwd'] > cutoff) * self.state_map[key]['fbwd']) self.state_map[key]['filtered_fbwd'] = filtered_fbwd_list[-1] return filtered_fbwd_list
[docs] def export_result( self, ): print("Export result into adata.obsm['descendant'] and adata.obsm['ancestor']") descendant_col = np.unique(np.concatenate([self.state_map[key]['descendant'].columns for key in self.state_map.keys()])) ancestor_col = np.unique(np.concatenate([self.state_map[key]['ancestor'].columns for key in self.state_map.keys()])) self.adata.obsm['descendant'] = pd.DataFrame(np.zeros(shape=(len(self.adata.obs), len(descendant_col))), index=self.adata.obs.index, columns=descendant_col) self.adata.obsm['ancestor'] = pd.DataFrame(np.zeros(shape=(len(self.adata.obs), len(ancestor_col))), index=self.adata.obs.index, columns=ancestor_col) for key in self.state_map.keys(): m = self.state_map[key]['descendant'] self.adata.obsm['descendant'].loc[m.index, m.columns] = m m = self.state_map[key]['ancestor'] self.adata.obsm['ancestor'].loc[m.index, m.columns] = m start_cells = self.adata.obs.loc[self.adata.obs[self.time_key] == self.ts[0]].index end_cells = self.adata.obs.loc[self.adata.obs[self.time_key] == self.ts[-1]].index for cell_type in descendant_col: self.adata.obsm['descendant'].loc[end_cells, cell_type] = np.array(self.adata.obs.loc[end_cells][self.cell_type_key] == cell_type).astype(float) for cell_type in ancestor_col: self.adata.obsm['ancestor'].loc[start_cells, cell_type] = np.array(self.adata.obs.loc[start_cells][self.cell_type_key] == cell_type).astype(float)
def fit_null_distribution(sample): sample = sample[~np.isnan(sample)] sample_edf = edf.ECDF(sample) slope_changes = sorted(set(sample)) sample_edf_values_at_slope_changes = [ sample_edf(item) for item in slope_changes] inverted_edf = interp1d(sample_edf_values_at_slope_changes, slope_changes) return inverted_edf def split_list(lst, sizes): result = [] start = 0 for size in sizes: result.append(lst[start:start+size]) start += size return result def time_series_transition_map( x0x1_ad, embedding_key, velocity_key, time_key, current_stage, next_stage, cell_type_key = 'cell_type', n_neighbors=None, norm=0, permutation_iter_n = 100, mutual=True, sde = False, D=1., ): print(current(), '\t Compute velocity graph') # Add an index column x0x1_ad.obs['idx'] = range(len(x0x1_ad)) if n_neighbors is None: n_neighbors = min(50, max(15, int(len(x0x1_ad) * 0.0025))) print('{} to {} | Number of neighbors: {}'.format(current_stage, next_stage, n_neighbors)) # Compute neighbors based on the embedding if len(x0x1_ad) < 8192: #scanpy exact nn cutoff #symetric graph #sc.pp.neighbors(x0x1_ad, n_neighbors=n_neighbors, use_rep=embedding_key) #graph = x0x1_ad.obsp['connectivities'] graph = mutual_nearest_neighbors(x0x1_ad, n_neighbors=n_neighbors, use_rep=embedding_key, mutual=False, sym=True) else: #mutual nearest neighbors graph if mutual: graph = mutual_nearest_neighbors(x0x1_ad, n_neighbors=n_neighbors, use_rep=embedding_key, mutual=True, sym=False) else: graph = mutual_nearest_neighbors(x0x1_ad, n_neighbors=n_neighbors, use_rep=embedding_key, mutual=False, sym=True) # Get indices for the current and next stage cells x0_idx = x0x1_ad.obs.loc[x0x1_ad.obs[time_key] == current_stage, 'idx'].to_numpy() x1_idx = x0x1_ad.obs.loc[x0x1_ad.obs[time_key] == next_stage, 'idx'].to_numpy() if sde == False: # Compute the velocity graph velocity_graph(x0x1_ad, embedding_key, velocity_key, graph=graph, split_negative=True) P_fwd = x0x1_ad.uns['velocity_graph'] P_bwd = -x0x1_ad.uns['velocity_graph_neg'] else: P_fwd = diffusion_graph(X=x0x1_ad.obsm[embedding_key], V=x0x1_ad.obsm[velocity_key], graph=graph, D=D) P_bwd = diffusion_graph(X=x0x1_ad.obsm[embedding_key], V=-x0x1_ad.obsm[velocity_key], graph=graph, D=D) print(current(), '\t Convert into markov chain') P_fwd /= P_fwd.sum(axis=1) P_fwd = csr_matrix(P_fwd) P_bwd /= P_bwd.sum(axis=1) P_bwd = csr_matrix(P_bwd) x0x1_markov = P_fwd[x0_idx][:, x1_idx] x1x0_markov = P_bwd[x1_idx][:, x0_idx] Q_bwd = P_bwd[x1_idx][:, x1_idx] Q_fwd = P_fwd[x0_idx][:, x0_idx] fixed_fwd = np.array(Q_fwd.sum(axis=1) == 0).flatten() fixed_bwd = np.array(Q_bwd.sum(axis=1) == 0).flatten() x0_obs = x0x1_ad.obs.iloc[x0_idx] x1_obs = x0x1_ad.obs.iloc[x1_idx] x0_cell_list = x0_obs[cell_type_key].unique() x0_cell_idx_list = [np.where(x0_obs[cell_type_key] == c)[0] for c in x0_cell_list] x0_cell_num_list = np.array([ len(n) for n in x0_cell_idx_list]) x1_cell_list = x1_obs[cell_type_key].unique() x1_cell_idx_list = [np.where(x1_obs[cell_type_key] == c)[0] for c in x1_cell_list] x1_cell_num_list = np.array([ len(n) for n in x1_cell_idx_list]) s_fwd = np.array([x0x1_markov[:, f].sum(axis=1) for f in x1_cell_idx_list])[:,:,0].T s_bwd = np.array([x1x0_markov[:, f].sum(axis=1) for f in x0_cell_idx_list])[:,:,0].T descendant = s_fwd / x1_cell_num_list ancestor = s_bwd / x0_cell_num_list descendant *= (np.array((1-Q_fwd.sum(axis=1))).flatten() / (descendant.sum(axis=1) + 1e-6))[:,None] ancestor *= (np.array((1-Q_bwd.sum(axis=1))).flatten() / (ancestor.sum(axis=1) + 1e-6))[:,None] print(current(), '\t Solve abosorbing probabilities') IQR_fwd = _solve_lin_system(Q_fwd[~fixed_fwd,:][:,~fixed_fwd], csr_matrix(descendant[~fixed_fwd,:]), use_eye=True, show_progress_bar=False) IQR_bwd = _solve_lin_system(Q_bwd[~fixed_bwd,:][:,~fixed_bwd], csr_matrix(ancestor[~fixed_bwd,:]), use_eye=True, show_progress_bar=False) descendant[~fixed_fwd,:] = IQR_fwd ancestor[~fixed_bwd,:] = IQR_bwd ancestor[np.isnan(ancestor)] = 0. descendant[np.isnan(descendant)] = 0. def compute_state_coupling(norm=0,): # Compute the forward state coupling matrix fwd = np.zeros((len(x0_cell_list), len(x1_cell_list))) for i in range(len(x0_cell_list)): fwd[i, :] = descendant[x0_cell_idx_list[i], :].sum(axis=0) fwd = fwd / fwd.sum(axis=1)[:, None] # Normalize by rows fwd[np.isnan(fwd)] = 0. fwd += 1e-3 # Compute the backward state coupling matrix bwd = np.zeros((len(x0_cell_list), len(x1_cell_list))) for j in range(len(x1_cell_list)): bwd[:, j] = ancestor[x1_cell_idx_list[j], :].sum(axis=0) bwd = bwd / bwd.sum(axis=0) # Normalize by columns bwd[np.isnan(bwd)] = 0. bwd += 1e-3 # Combine the forward and backward matrices state_coupling = fwd * bwd state_coupling = state_coupling / (state_coupling.sum(axis=0) if norm == 0 else state_coupling.sum(axis=1)[:, None]) return fwd, bwd, state_coupling # Make sure cell type is treated as a string x0x1_ad.obs[cell_type_key] = x0x1_ad.obs[cell_type_key].astype(str) # Calculate the forward and backward state couplings state_coupling_fwd, state_coupling_bwd, state_coupling = compute_state_coupling(norm) state_coupling = pd.DataFrame(state_coupling, index=x0_cell_list, columns=x1_cell_list) state_coupling_fwd = pd.DataFrame(state_coupling_fwd, index=x0_cell_list, columns=x1_cell_list) state_coupling_bwd = pd.DataFrame(state_coupling_bwd, index=x0_cell_list, columns=x1_cell_list) bwd_cc = ancestor / ancestor.sum(axis=1)[:,None] fwd_cc = descendant / descendant.sum(axis=1)[:,None] bwd_cc = pd.DataFrame(bwd_cc, columns=state_coupling.index, index=x1_obs.index) fwd_cc = pd.DataFrame(fwd_cc, columns=state_coupling.columns, index=x0_obs.index) # Perform permutation testing permutation_list = [] N, M = len(x0_obs), len(x1_obs) print(current(), '\t Generate NULL distribution') for k in range(permutation_iter_n): x0_cell_idx_list = split_list(np.random.permutation(N), x0_cell_num_list) x1_cell_idx_list = split_list(np.random.permutation(M), x1_cell_num_list) _, _, permu_state_coupling = compute_state_coupling(norm) permutation_list.append(permu_state_coupling.flatten()) return state_coupling_fwd, state_coupling_bwd, state_coupling, np.concatenate(permutation_list, axis=0), \ fwd_cc, bwd_cc