Source code for pygot.tools.analysis.grn_inference

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.sparse import issparse
from .utils import TF_human, TF_mm
from pygot.evalute import *

class GeneDegradation(nn.Module):
    def __init__(self, output_size, init_beta=1.0, min_beta=0.0, beta_grad=True):
        super(GeneDegradation, self).__init__()
        if beta_grad:
            self.beta = nn.Parameter(init_beta*torch.ones(output_size))
            self.beta.register_hook(self.hinge_hook)
            self.min_beta = torch.tensor(min_beta)
        else:
            self.beta = init_beta*torch.ones(output_size)
        self.relu = nn.ReLU()

    def hinge_hook(self, grad):
        with torch.no_grad():
            self.beta.data = torch.clamp(self.beta, min=self.min_beta)
        return grad
    def forward(self, x):
        return self.relu(self.beta) * x

class GeneRegulatroyModel(nn.Module):
    def __init__(self, tf_num, gene_num, tf_idx, init_jacobian=None, non_negative=True):
        super(GeneRegulatroyModel, self).__init__()
        #G_i,j ~ Gene_j -> Gene_i
        if init_jacobian is None:
            init_jacobian = torch.rand(gene_num, tf_num)
        self.linear = nn.Parameter(init_jacobian)
        
        # Remove self-regulated edge
        if tf_num == gene_num:
            self.linear.register_hook(self.remove_diagonal_hook)
        else:
            self.indices_to_remove = tf_idx
            self.linear.register_hook(self.custom_remove_hook)
        
        self.non_negative = non_negative

    def forward(self, x):
        return (self.linear @ x[:,:,None]).squeeze(-1)
        
    def custom_remove_hook(self, grad):
        with torch.no_grad():
            self.linear[self.indices_to_remove, range(self.linear.shape[1])] = 0
        return grad
        
    def remove_diagonal_hook(self, grad):
        with torch.no_grad():
            self.linear -= torch.diag(torch.diag(self.linear))
        return grad

    def apply_non_negative(self):
        with torch.no_grad():
            self.linear.data = torch.clamp(self.linear.data, min=0)

[docs] class GRNData: """Gene regulatory network data structure This class store the variable of infered grn Variable: ---------- self.G: :class:`np.ndarray`, (n_gene, n_tf) Regulatory strength, self.G[i,j] represent the regulatory strength of gene j to gene i self.beta: :class:`np.ndarray` (n_gene,) Degrade rate of genes self.ranked_edges: :class:`pd.DataFrame` Ranked regulatory relationship self.tf_names: `list` TF names self.gene_names: `list` Gene names self.models: `dict` self.models['G'] is torch model of G, self.models['beta'] is torch model of beta """
[docs] def __init__(self, G_hat:GeneRegulatroyModel, beta_hat: GeneDegradation, tf_names, gene_names): """initial function Arguments: ---------- G_hat: :class:`GeneRegulatroyModel` torch model of G beta_hat: :class:`GeneDegradation` torch model of beta tf_names: `list` TF names gene_names: `list` gene names """ self.G = G_hat.linear.detach().cpu().numpy() self.beta = beta_hat.beta.data.detach().cpu().numpy() self.ranked_edges = get_ranked_edges(self.G, tf_names=tf_names, gene_names=gene_names) self.tf_names = tf_names self.gene_names = gene_names self.models = {'G':G_hat, 'beta':beta_hat}
[docs] def export_grn_into_celloracle(self, oracle): """CellOracle interface Export the fitted GRN into CellOracle for further analysis, such as perturbation Arguments: ---------- oracle: :class:`celloracle.Oracle` celloracle object """ network = pd.DataFrame(self.G.T, index=self.tf_names, columns=self.gene_names) coef_matrix = pd.DataFrame(np.zeros(shape=(len(self.gene_names), len(self.gene_names))), index=self.gene_names.tolist(), columns=self.gene_names.tolist()) coef_matrix.loc[network.index] = network oracle.coef_matrix = coef_matrix oracle.active_regulatory_genes = self.tf_names.tolist() oracle.all_regulatory_genes_in_TFdict = self.tf_names.tolist() print('Finish!')
[docs] class GRN: """Gene regulatory network infered by velocity linear regression Example: ---------- :: grn = GRN() grn_adata = grn.fit(adata, species='human') print(grn_adata.ranked_edges.head()) #print the top regulatory relationship """
[docs] def __init__(self, ): pass
[docs] def fit( self, adata, TF_constrain=True, TF_names=None, species='human', non_negative=True, layer_key=None, n_epoch=10000, lr=0.01, l1_penalty = 0.005, init_beta=1.0, min_beta=1.0, init_jacobian=None, early_stopping=True, batch_size=2048, val_split=0.2, device=None, lineage_key=None, ): """ fit the gene regulatory network Arguments: ---------- adata: :class:`~anndata.AnnData` Annotated data matrix, gene velocity should be stored in adata.layers['velocity'] TF_constrain: `bool` (default: True) Only fit the transcriptional factor(TF) TF_names: `list` (default: None) Names of TF, if None, use default TF names species: 'human' or 'mm' (default: 'human') Default TF names of species non_negative: `bool` (default: True) ONLY fit positive regulatory relationship, which may avoid overfit layer_key: `str` (default: None) Data use as x, if None, use adata.X else should stored in adata.layers n_epoch: `int` (default: 10000) Number of training epochs lr: `float` (default: 0.01) Learning rate l1_penalty: `float` (default: 0.005) l1 weight, control sparsity of grn init_beta: `float` or :class:`GeneDegradation` (default: 1.0) Initial gene degrade rate min_beta: `float` (default: 1.0) Lower bound of degrade rate init_jacobian: `np.ndarray` (default: None) Initial grn early_stopping: `bool` (default: True) Early stopping training batch_size: `int` (default: 2048) Batch size of mini-batch training val_split: `float` (default: 0.2) Validation dataset portion device: :class:`torch.device` (default: None) torch device lineage_key: discard This parameter is discarded Returns ------- grn_data: :class:`GRNData` gene regulatory network """ return infer_GRN( adata, TF_constrain, TF_names, species, lineage_key, layer_key, n_epoch, lr, l1_penalty, init_beta, min_beta, init_jacobian, device, early_stopping, batch_size, val_split, non_negative )
def preprocess_dataset(adata, TF_names, batch_size, layer_key, early_stopping, val_split, device): y = torch.Tensor(adata.layers['scaled_velocity']) adata.var['idx'] = range(len(adata.var)) if layer_key is None: X = torch.Tensor(adata.X) else: X = torch.Tensor(adata.layers[layer_key]) tf_idx = adata.var.loc[TF_names]['idx'].to_numpy() # Split into training and validation sets if early stopping is enabled if early_stopping: dataset_size = X.shape[0] indices = list(range(dataset_size)) split = int(np.floor(val_split * dataset_size)) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] X_train, y_train = X[train_indices].to(device), y[train_indices].to(device) X_val, y_val = X[val_indices].to(device), y[val_indices].to(device) else: X_train, y_train = X.to(device), y.to(device) X_val, y_val = None, None batch_size = min(batch_size, len(X_train)) return X, y, X_train, y_train, X_val, y_val, tf_idx def optimize_global_GRN(adata, TF_names, layer_key=None, beta_grad=True, num_epochs=10000, lr=0.01, l1_penalty = 0.005, init_beta=1.0, min_beta=1.0, init_jacobian=None, device=torch.device('cpu'), early_stopping=False, min_epochs=500, batch_size=32, val_split=0.2, non_negative=True): print('l1_penalty:', l1_penalty, 'min_beta:', min_beta) X, y, X_train, y_train, X_val, y_val, tf_idx = preprocess_dataset(adata, TF_names, batch_size, layer_key, early_stopping, val_split, device) batch_size = min(batch_size, len(X_train)) gene_num = y.shape[1] tf_num = tf_idx.shape[0] G_hat = GeneRegulatroyModel(tf_num, gene_num, tf_idx, init_jacobian, non_negative=non_negative).to(device) if isinstance(init_beta, float): beta_hat = GeneDegradation(gene_num, init_beta, min_beta, beta_grad).to(device) elif isinstance(init_beta, GeneDegradation): beta_hat = init_beta beta_hat.min_beta = beta_hat.min_beta.to(device) beta_hat.beta = beta_hat.beta.to(device) optimizer_G = optim.SGD(G_hat.parameters(), lr=lr) if beta_grad: optimizer_beta = optim.SGD(beta_hat.parameters(), lr=lr) loss_list = [] best_val_loss = float('inf') patience = 10 patience_counter = 0 pbar = tqdm(range(num_epochs)) for epoch in pbar: G_hat.train() beta_hat.train() train_loss = 0 permutation = torch.randperm(X_train.size()[0]) for i in range(0, X_train.size()[0], batch_size): indices = permutation[i:i + batch_size] batch_x, batch_y = X_train[indices], y_train[indices] optimizer_G.zero_grad() if beta_grad: optimizer_beta.zero_grad() outputs = G_hat(batch_x[:,tf_idx]) - beta_hat(batch_x) mse_loss = torch.mean(((outputs - batch_y) ** 2)) #l1_loss = l1_penalty * (torch.norm(G_hat.linear, p=1, dim=0).sum() + torch.norm(G_hat.linear, p=1, dim=1).sum()) l1_loss = l1_penalty * torch.norm(G_hat.linear, p=1) loss = mse_loss + l1_loss loss.backward() optimizer_G.step() if beta_grad: optimizer_beta.step() if G_hat.non_negative: G_hat.apply_non_negative() train_loss += loss.item() train_loss /= (X_train.size()[0] // batch_size) if early_stopping: G_hat.eval() beta_hat.eval() val_loss = 0 with torch.no_grad(): for i in range(0, X_val.size()[0], batch_size): batch_x, batch_y = X_val[i:i + batch_size], y_val[i:i + batch_size] outputs = G_hat(batch_x[:,tf_idx]) - beta_hat(batch_x) mse_loss = torch.mean(((outputs - batch_y) ** 2)) loss = mse_loss + l1_loss val_loss += loss.item() val_loss /= (X_val.size()[0] // batch_size) if batch_size < X_val.size()[0] else 1 if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience and epoch > min_epochs: print(f'Early stopping at epoch {epoch+1}. Best validation loss: {best_val_loss:.5f}') break pbar.set_description(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}') else: pbar.set_description(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}') loss_list.append(train_loss) fit_godness = torch.mean((G_hat(X[:,tf_idx].to(device)) + beta_hat(X.to(device)) - y.to(device))**2, dim=-1).detach().cpu().numpy() return G_hat, beta_hat, fit_godness def get_ranked_edges(jacobian, tf_names, gene_names, cutoff=1e-5): df = pd.DataFrame(jacobian, index=gene_names, columns=tf_names).T stacked = df.stack() values = stacked.to_numpy().flatten() idx = np.argsort(abs(values))[::-1] num_top = np.sum(abs(jacobian) > cutoff) top_idx = idx[:num_top] gene1 = tf_names[top_idx // len(gene_names)] gene2 = gene_names[top_idx % len(gene_names)] result = pd.DataFrame([gene1, gene2, values[top_idx]], index=['Gene1', 'Gene2', 'EdgeWeight']).T result['absEdgeWeight'] = abs(result.EdgeWeight) result = result.sort_values('absEdgeWeight', ascending=False) return result def infer_GRN( adata, TF_constrain=True, TF_names=None, species='human', lineage_key=None, layer_key=None, n_epoch=10000, lr=0.01, l1_penalty = 0.005, init_beta=1.0, min_beta=1.0, init_jacobian=None, device=None, early_stopping=True, batch_size=2048, val_split=0.2, non_negative=True, ): if device is None: device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if not 'velocity' in adata.layers.keys(): raise KeyError('Please compute velocity first and store velocity in adata.layers') if not 'gene_name' in adata.uns.keys(): adata.uns['gene_name'] = adata.var.index if TF_constrain: if TF_names is None: if species == 'human': TF_names = TF_human elif species == 'mm': TF_names = TF_mm else: raise NotImplementedError('Default database do NOT contains TF list of speices{}. Please specify the `TF_names` parameter'.format(species)) else: TF_names = adata.uns['gene_name'] TF_names = pd.Index(TF_names).intersection(pd.Index(adata.uns['gene_name'])) adata.uns['tf_name'] = TF_names print("TF number: {}, {}".format(len(TF_names), TF_names)) if layer_key is None: if issparse(adata.X): adata.X = adata.X.toarray() scale = np.mean(adata.X[adata.X > 0]) / np.mean(abs(adata.layers['velocity'])) else: if issparse(adata.layers[layer_key]): adata.layers[layer_key] = adata.layers[layer_key].toarray() scale = np.mean(adata.layers[layer_key][adata.layers[layer_key] > 0]) / np.mean(abs(adata.layers['velocity'])) print('scale velocity with factor : {}'.format(scale)) adata.layers['scaled_velocity'] = scale * adata.layers['velocity'] if lineage_key is not None: lineages = np.unique(adata.obs[lineage_key]) lineages = lineages[lineages != 'uncertain'] adatas = [adata[adata.obs.loc[adata.obs[lineage_key] == lineages[i]].index] for i in range(len(lineages))] if not isinstance(init_beta, GeneDegradation): print(f"Using whold dataset to estimate degradation..") _, beta_hat, _ = optimize_global_GRN( adata, TF_names=TF_names, layer_key=layer_key, beta_grad=True, num_epochs=n_epoch, lr=lr, l1_penalty=l1_penalty, init_beta=init_beta, min_beta=min_beta, init_jacobian=init_jacobian, device=device, early_stopping=early_stopping, batch_size=batch_size, val_split=val_split, non_negative=non_negative ) else: beta_hat = init_beta grns = {} adata.obs['global_grn_fit_godness'] = np.nan for i in range(len(lineages)): print(f"Training GRN for lineage: {lineages[i]}") G_hat, beta_hat, fit_godness = optimize_global_GRN( adatas[i], TF_names=TF_names, layer_key=layer_key, beta_grad=False, # Do not update beta for each individual GRN num_epochs=n_epoch, lr=lr, l1_penalty=l1_penalty, init_beta=beta_hat, min_beta=min_beta, init_jacobian=init_jacobian, device=device, early_stopping=early_stopping, batch_size=batch_size, val_split=val_split, non_negative=non_negative ) grn = GRNData(G_hat, beta_hat, adata.uns['tf_name'], adata.uns['gene_name']) grns[lineages[i]] = grn adata.obs.loc[adatas[i].obs.index, 'global_grn_fit_godness'] = fit_godness return grns else: G_hat, beta_hat, fit_godness = optimize_global_GRN( adata, TF_names=TF_names, layer_key=layer_key, beta_grad=True, num_epochs=n_epoch, lr=lr, l1_penalty=l1_penalty, init_beta=init_beta, min_beta=min_beta, init_jacobian=init_jacobian, device=device, early_stopping=early_stopping, batch_size=batch_size, val_split=val_split, non_negative=non_negative ) adata.obs['global_grn_fit_godness'] = fit_godness grn = GRNData(G_hat, beta_hat, adata.uns['tf_name'], adata.uns['gene_name']) adata.uns['gene_name'] = np.array(adata.uns['gene_name']) adata.uns['tf_name'] = np.array(adata.uns['tf_name']) return grn