Source code for pygot.plotting.plot_traj

import matplotlib.pyplot as plt
import matplotlib
import scanpy as sc
import numpy as np
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.cluster.hierarchy import fcluster
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from sklearn.preprocessing import minmax_scale
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import minmax_scale
from tqdm import tqdm

def plot_cell_fate_embedding(adata, color, 
                             obsm_key='descendant', basis='umap', **kwargs):
    
    adata.obs[color] = adata.obsm[obsm_key][color]
    sc.pl.embedding(adata, color=color,basis=basis, **kwargs)
    #del adata.obs[color]

def cluster_series(data,  num_clusters = 6, method='average', metric='euclidean', ):
    
    
    cluster_grid = sns.clustermap(data, 
               col_cluster=False, standard_scale=0, cmap='Spectral_r', xticklabels = False, method=method, metric=metric)
    plt.close() 
    row_linkage = cluster_grid.dendrogram_row.linkage

    
    row_clusters = fcluster(row_linkage, num_clusters, criterion='maxclust')
    return row_clusters

def plot_dynamical_genes_clusetermap(
        adata,
        layer=None,
        pseudotime_key='pseudotime', 
        row_clusters=None, 
        n_clusters=4, 
        method='weighted', 
        metric='correlation', 
        gene_font_size=8,
        color=None,
        gene_color=None,
        cmap='Spectral_r',
        highlight_genes = None,
        show_gene=True,
        **kwargs
    ):
    sorted_idx = adata.obs[pseudotime_key].sort_values().index
    # preprocess

    if layer is None:
        data_scaled = minmax_scale(adata[adata.obs.sort_values(pseudotime_key).index].X.toarray()).T
    else:
        data_scaled = minmax_scale(adata[adata.obs.sort_values(pseudotime_key).index].layers[layer].toarray()).T
        
    data = pd.DataFrame(data_scaled, index=adata.var.index)
    
    # pca-kmeans-cluster
    data_pca = PCA(n_components=min(20, data.shape[0])).fit_transform(data_scaled)
    
    row_clusters = KMeans(n_clusters=n_clusters, random_state=42).fit_predict(data_pca)

    # order genes by peak time
    unique_clusters = np.unique(row_clusters)
    
    gene_peak_time = np.sum((data_scaled / data_scaled.sum(axis=1)[:,None]) * np.array(range(data_scaled.shape[1])), axis=1)
    label_order = np.array([np.median(gene_peak_time[np.where(row_clusters == i)[0]]) for i in range(len(unique_clusters))]).argsort()
    
    m = dict(zip(label_order, range(len(unique_clusters))))
    row_clusters = np.array([m[r] for r in row_clusters])

    gene_idx = []
    for i in range(len(unique_clusters)):
        sub_idx = np.where(row_clusters == i)[0]
        sub_idx[gene_peak_time[sub_idx].argsort()]
        gene_idx.append(sub_idx)
    gene_idx = np.concatenate(gene_idx)

    # coloring cells by pseudotime and user-definition
    cmap_t = sns.color_palette("viridis", as_cmap=True) 

    pseudotime_colors = [cmap_t(t) for t in np.linspace(0,1, data.shape[1])]

    if not color is None:
        cell_type_colors = [to_rgba(c) for c in adata.uns['{}_colors'.format(color)][adata[sorted_idx,:].obs[color].cat.codes]]
        col_colors = pd.DataFrame([pseudotime_colors, cell_type_colors], index=['', '']).T
    else:
        col_colors = pseudotime_colors

    # coloring genes by clusters and user-definition
    cmap_g = ListedColormap(sns.color_palette("RdBu_r", len(unique_clusters) ))
    cluster_colors = np.array([cmap_g(cluster) for cluster in row_clusters])

    if not gene_color is None:
        cmap_g = ListedColormap(sns.color_palette("Greys", len(unique_clusters) ))
        values = minmax_scale(adata.var[gene_color].to_numpy())
        g_colors = np.array([cmap_g(v) for v in values])
        
        row_colors = pd.DataFrame([list(g_colors[gene_idx]), list(cluster_colors[gene_idx])], index=[gene_color, 'cluster']).T
        row_colors['idx'] = adata.var.index[gene_idx]
        row_colors = row_colors.set_index('idx')
    else:
        row_colors = cluster_colors[gene_idx]

    
    # plotting
    cluster_grid = sns.clustermap(data.loc[data.index[gene_idx]], 
               col_cluster=False, 
               row_cluster=False,
               cmap=cmap,
               xticklabels = False, 
               yticklabels = True,
               method=method,
               metric=metric, 
               cbar_pos=None,     # 去掉 colorbar
               row_colors=row_colors, 
               col_colors=col_colors,
               **kwargs)

    # show gene setting
    if not show_gene:
        cluster_grid.ax_heatmap.yaxis.set_visible(False)

    else:
        if not highlight_genes is None:
            ax = cluster_grid.ax_heatmap

            yticks = ax.get_yticks()  
            yticklabels = [tick.get_text() for tick in ax.get_yticklabels()]  

            
            new_labels = [
                ""  
                for label in yticklabels
            ]
            ax.set_yticklabels(new_labels, fontsize=10)
            ax.tick_params(axis="y", left=False) 

            for gene in highlight_genes:
                if gene in yticklabels:
                    idx = yticklabels.index(gene)
                    y_pos = yticks[idx]
                    
                    right_bound = data.shape[1] + 0.5  
                    ax.plot([right_bound, right_bound + 10], [y_pos, y_pos], color='black', lw=1.5, clip_on=False)  
                    ax.text(right_bound + 10.5, y_pos, gene, ha="left", va="center", fontsize=gene_font_size)  

            ax.tick_params(axis="both", which="both", length=0)  
        else:
            cluster_grid.ax_heatmap.set_yticklabels(
            cluster_grid.ax_heatmap.get_yticklabels(),
            fontsize=gene_font_size,       
            rotation=0          
        )
    return cluster_grid, pd.DataFrame([row_clusters[gene_idx], adata.var.index[gene_idx]], index=['cluster', 'gene']).T



[docs] def plot_trajectory(adata, traj, basis='pumap', title='', ax=None, embedding_kw=None, **kwargs): marker = matplotlib.markers.MarkerStyle('o', fillstyle='none') embedding_kw = {} if embedding_kw is None else embedding_kw.copy() if ax is None: fig, ax = plt.subplots(1,1, **kwargs) ax.axis('off') sc.pl.embedding(adata, basis=basis, ax=ax, show=False, colorbar_loc=None, title=title, **embedding_kw) for i in range(traj.shape[1]): ax.plot(traj[:,i,0], traj[:,i,1], color='black',alpha=.3, linewidth=1) ax.scatter(traj[0,:,0], traj[0,:,1], color='blue', s=20, label='start') ax.scatter(traj[-1,:,0], traj[-1,:,1], color='red', s=20, label='end', marker=marker)