import torch
import torch.nn as nn
import torch.autograd
from sklearn.metrics import pairwise_distances
from pygot.tools.traj.utils import find_neighbors
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from scipy.optimize import minimize_scalar
from torch.distributions.normal import Normal
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
def dcor_test(adata, pseudotime_key, num_resamples=1):
import dcor
X = adata.X.toarray()
y = adata.obs[pseudotime_key].to_numpy()
res = []
for i in tqdm(range(adata.shape[1])):
res.append(list(dcor.independence.distance_covariance_test(
X[:,i],
y,
num_resamples=num_resamples,
))
)
adata.var[['pvalue', 'statistic']] = np.array(res)
def strings_to_tensor(string_list):
unique_strings = list(set(string_list))
string_to_index = {s: idx for idx, s in enumerate(unique_strings)}
integer_list = [string_to_index[s] for s in string_list]
tensor = torch.tensor(integer_list, dtype=torch.long)
return tensor, string_to_index
def std_bound(x):
upper_bound = np.mean(x)+3*np.std(x)
lower_bound = np.mean(x)-3*np.std(x)
x[x > upper_bound] = upper_bound
x[x < lower_bound] = lower_bound
return x
def normal_sample(mu, logvar, epsilon=1e-6):
std = F.softplus(logvar) + epsilon
dist = Normal(mu, std)
z = dist.rsample()
t = z # Map to (-1, 1) using tanh
return t
def normal_log_likelihood(x, mu, logvar, epsilon=1e-6):
std = F.softplus(logvar) + epsilon
normal_dist = Normal(mu, std)
log_prob_z = normal_dist.log_prob(x)
return log_prob_z
# Kumaraswamy
def log_likelihood(x, a, b):
if x <= 0 or x >= 1:
return -np.inf
return np.log(a) + np.log(b) + (a - 1) * np.log(x) + (b - 1) * np.log(1 - x**a)
# RealNVP implemented by Jakub M. Tomczak
class RealNVP(nn.Module):
def __init__(self, nets, nett, num_flows, prior, D=2, dequantization=True):
super(RealNVP, self).__init__()
self.dequantization = dequantization
self.prior = prior
self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)])
self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)])
self.num_flows = num_flows
self.D = D
def set_prior(self, prior):
self.prior = prior
def pad_to_even(self, x):
if x.shape[1] % 2 != 0:
# Padding one dimension with 0 to make it even
padding = (0, 1) # Pad on the right side along dim=1
x = torch.nn.functional.pad(x, padding, mode='constant', value=0)
return x
def coupling(self, x, index, forward=True):
# x: input, either images (for the first transformation) or outputs from the previous transformation
# index: it determines the index of the transformation
# forward: whether it is a pass from x to y (forward=True), or from y to x (forward=False)
(xa, xb) = torch.chunk(x, 2, 1)
s = self.s[index](xa)
t = self.t[index](xa)
if forward:
#yb = f^{-1}(x)
yb = (xb - t) * torch.exp(-s)
else:
#xb = f(y)
yb = torch.exp(s) * xb + t
return torch.cat((xa, yb), 1), s
def permute(self, x):
return x.flip(1)
def f(self, x):
x = self.pad_to_even(x)
log_det_J, z = x.new_zeros(x.shape[0]), x
for i in range(self.num_flows):
z, s = self.coupling(z, i, forward=True)
z = self.permute(z)
log_det_J = log_det_J - s.sum(dim=1)
return z, log_det_J
def log_prob(self, x):
z, log_det_J = self.f(x)
return self.prior.log_prob(z) + log_det_J
def f_inv(self, z):
x = z
for i in reversed(range(self.num_flows)):
x = self.permute(x)
x, _ = self.coupling(x, i, forward=False)
return x
def forward(self, x, reduction='avg'):
z, log_det_J = self.f(x)
if reduction == 'sum':
return -(self.prior.log_prob(z) + log_det_J).sum()
else:
return -(self.prior.log_prob(z) + log_det_J).mean()
def sample(self, batchSize):
z = self.prior.sample((batchSize, self.D))
z = z[:, 0, :]
x = self.f_inv(z)
return x.view(-1, self.D)
# calcu pearson correlation between x and y, but y is already norm
def torch_pearsonr_fix_y(x, y, dim=1):
x = x - torch.mean(x, dim=dim)[:,None]
#y = y - torch.mean(y, dim=dim)[:,None]
x = x / (torch.std(x, dim=dim) + 1e-9)[:,None]
#y = y / (torch.std(y, dim=dim) + 1e-9)[:,None]
return torch.mean(x * y, dim=dim) # (D,)
# Neural Network for p(x,t)
class DensityModel(nn.Module):
def __init__(self, dim, num_flows =8, M=256):
super(DensityModel, self).__init__()
block_dim = dim // 2
block_dim = block_dim + 1 if dim % 2 != 0 else block_dim
# scale (s) network
nets = lambda: nn.Sequential(nn.Linear(block_dim, M), nn.LeakyReLU(),
nn.Linear(M, M), nn.LeakyReLU(),
nn.Linear(M, block_dim), nn.Tanh())
# translation (t) network
nett = lambda: nn.Sequential(nn.Linear(block_dim, M), nn.LeakyReLU(),
nn.Linear(M, M), nn.LeakyReLU(),
nn.Linear(M, block_dim))
self.dim = dim
# Prior (a.k.a. the base distribution): Gaussian
prior = torch.distributions.MultivariateNormal(torch.zeros(dim), torch.eye(dim))
# Init RealNVP
self.px = RealNVP(nets, nett, num_flows, prior, D=dim, dequantization=False)
self.ptx = nn.Sequential(
nn.Linear(dim + 1 if dim % 2 != 0 else dim , 64),
nn.CELU(),
nn.Linear(64,64),
nn.CELU(),
nn.Linear(64, 2),
)
def to(self, device):
prior = torch.distributions.MultivariateNormal(torch.zeros(self.dim).to(device), torch.eye(self.dim).to(device))
self.px.set_prior(prior)
return super().to(device)
def pearson_loss(self, x_noise, x_neigh, y, reduction='avg', corr_cutoff=0.3):
n_neighbors = x_neigh.shape[1]
expectation_center = self.sample_t_given_x(x_noise).flatten()
expectation_nn = self.sample_t_given_x(x_neigh.reshape(-1, x_neigh.shape[-1]))
delta_t = expectation_nn.reshape(x_noise.shape[0], n_neighbors) - expectation_center[:,None]
corr = torch_pearsonr_fix_y(delta_t, y)
mask = corr > corr_cutoff
corr[mask] = 0.
if sum(mask) == len(corr):
return torch.tensor([0.]), 0.
if reduction == 'avg':
return -corr.sum() / (len(corr) - sum(mask)), sum(mask) / len(corr)
else:
return -corr.sum(), sum(mask)
def sample_t_given_x(self, x):
z, _ = self.px.f(x)
pt_x = self.ptx(z)
pt_x_a, pt_x_b = pt_x[:,0][:,None], pt_x[:,1][:,None]
return normal_sample(pt_x_a, pt_x_b)
def log_prob_t_x(self, x, t, reduction='sum'):
z, _ = self.px.f(x)
pt_x = self.ptx(z)
pt_x_a, pt_x_b = pt_x[:,0][:,None], pt_x[:,1][:,None]
log_pt_x = normal_log_likelihood(t, pt_x_a, pt_x_b).flatten()
if reduction == 'avg':
return log_pt_x.mean()
elif reduction == 'sum':
return log_pt_x.sum()
else:
return log_pt_x
def var_t_given_x(self, x):
z, _ = self.px.f(x)
pt_x = self.ptx(z)
pt_x_a, pt_x_b = pt_x[:,0][:,None], pt_x[:,1][:,None]
return F.softplus(pt_x_b)
def estimate_t(self, x, mode='mle'):
z, _ = self.px.f(x)
pt_x = self.ptx(z)
pt_x_a, pt_x_b = pt_x[:,0][:,None], pt_x[:,1][:,None]
if mode == 'mle':
pt_x_a, pt_x_b = pt_x_a.numpy().flatten(), pt_x_b.numpy().flatten()
return np.array([minimize_scalar(lambda x: -log_likelihood(x, pt_x_a[i], pt_x_b[i]), bounds=(0, 1), method='bounded').x for i in tqdm(range(len(pt_x_a)))])
return pt_x_a
def log_prob_x(self, x, reduction='sum'):
log_px = self.px.log_prob(x)
if reduction == 'avg':
return log_px.mean()
elif reduction == 'sum':
return log_px.sum()
else:
return log_px
def joint_log_prob_xt(self, x, t, reduction='sum'):
z, log_det_J = self.px.f(x)
log_px = self.px.prior.log_prob(z) + log_det_J
pt_x = self.ptx(z)
pt_x_a, pt_x_b = pt_x[:,0][:,None], pt_x[:,1][:,None]
log_pt_x = normal_log_likelihood(t, pt_x_a, pt_x_b).flatten()
if reduction == 'avg':
return (log_px + log_pt_x).mean()
elif reduction == 'sum':
return (log_px + log_pt_x).sum()
else:
return log_px + log_pt_x
def cosine(a, b):
return np.sum(a * b, axis=-1) / (np.linalg.norm(a, axis=-1)*np.linalg.norm(b, axis=-1))
def get_pair_wise_neighbors(X, n_neighbors=30):
"""Compute nearest neighbors
Parameters
----------
X: all cell embedding (n, m)
n_neighbors: number of neighbors
Returns
-------
nn_t_idx: neighbors index (n, n_neighbors)
"""
N_cell = X.shape[0]
dim = X.shape[1]
if N_cell < 3000:
ori_dist = pairwise_distances(X, X)
nn_t_idx = np.argsort(ori_dist, axis=1)[:, 1:n_neighbors]
else:
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(X)
_, nn_t_idx = nbrs.kneighbors(X)
return nn_t_idx
[docs]
class ProbabilityModel:
"""Probability model for pseudotime estimation
Example:
----------
::
#Assume the velocity are already fitted in pca space
embedding_key = 'X_pca'
velocity_key = 'velocity_pca'
# Fit the probability model
pm = pygot.tl.analysis.ProbabilityModel()
history = pm.fit(adata, embedding_key=embedding_key, velocity_key=velocity_key)
# Estimated pseudotime of cells
adata.obs['pseudotime'] = pm.estimate_pseudotime(adata) # pseudotime
adata.obs['var'] = pm.estimate_variance(adata) # variance of time
"""
[docs]
def __init__(self, device=None):
"""Init model
Arguments:
---------
device: :class:`~torch.device`
torch device
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
print('Device:', device)
[docs]
def to(self, device):
self.density_model.to(device)
self.device = device
[docs]
def fit(
self,
adata,
embedding_key,
velocity_key,
n_neighbors=30,
n_iters=500,
mini_batch=True,
batch_size = 512,
):
"""fit model
Arguments:
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
embedding_key: `str` (default: None)
Name of latent space, in adata.obsm.
velocity: `str` (default: None)
Name of latent velocity, in adata.obsm. Use to do variantional inference of conditonal time distribution if it offers.
time_key: `str` (default: None)
Name of time label, in adata.obs. Use as addition information for conditonal time distribution fitting if it offers.
n_neighbors: `int` (default: 30)
Number of neighbors of cell
n_iters: `float` (default: 500)
Number of training iterations
mini_batch: `bool` (default: True)
Use mini-batch training or not
batch_size: `int` (default: 512)
Number of batch size
"""
self.density_model = DensityModel(adata.obsm[embedding_key].shape[1]).to(self.device)
self.x = torch.tensor(adata.obsm[embedding_key], requires_grad=True).float().to(self.device)
if 'distances' in adata.obsp.keys():
self.nn_t_idx = find_neighbors(adata.obsp['distances'], directed=True)
else:
self.nn_t_idx = get_pair_wise_neighbors(adata.obsm[embedding_key], n_neighbors=n_neighbors)
self.v_hat = adata.obsm[embedding_key][self.nn_t_idx.flatten()].reshape(self.nn_t_idx.shape[0], self.nn_t_idx.shape[1], -1) - adata.obsm[embedding_key][:,None, :]
self.velocity_key = velocity_key
self.embedding_key = embedding_key
if len(adata) > 5000 and self.device == 'cpu' and mini_batch == False:
print('Large dataset and cpu device. Suggest to use mini-batch')
with torch.no_grad():
v = adata.obsm[self.velocity_key]
norm_cos_sim = (v[:,None,:] * self.v_hat)
norm_cos_sim = norm_cos_sim.sum(axis=-1) / (np.linalg.norm(v, axis=1) ** 2)[:,None]
norm_cos_sim = torch.tensor(norm_cos_sim).to(self.device)
optimizer = torch.optim.Adamax(self.density_model.parameters(), lr=1e-3)
pbar = tqdm(range(n_iters))
for i in pbar:
if not mini_batch:
batch_idx = list(range(len(self.x)))
sample_x = self.x
else:
batch_idx = np.random.choice(range(len(adata)), size=batch_size, replace=False)
sample_x = self.x[batch_idx]
x_noise = sample_x + torch.randn_like(sample_x) * 0.05
sub_nn_t_idx = self.nn_t_idx[batch_idx]
sample_idx = np.unique(sub_nn_t_idx)
mapper = (np.ones(len(self.x)) * -1).astype(int)
mapper[sample_idx] = range(len(sample_idx))
mapper[sub_nn_t_idx.flatten()]
expectation_center = self.density_model.sample_t_given_x(x_noise).flatten()
expectation_nn = self.density_model.sample_t_given_x(self.x[sample_idx]).flatten()
delta_t = expectation_nn[mapper[sub_nn_t_idx.flatten()]].reshape(sub_nn_t_idx.shape[0], sub_nn_t_idx.shape[1]) - expectation_center[:,None]
taylor_loss = torch.mean((delta_t - norm_cos_sim[batch_idx])**2)
loss = taylor_loss
pbar.set_description("Taylor Loss {:.4f}".format(loss.item()))
optimizer.zero_grad()
if not (loss.grad_fn is None):
loss.backward()
optimizer.step()
del self.x
del self.nn_t_idx
del self.v_hat
[docs]
@torch.no_grad()
def estimate_pseudotime(self, adata, mode='mean'):
"""estimate the pseudotime
Arguments:
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
Returns
-------
:math:`t^*|x`: :class:`~np.ndarray`
pseudotime of cells
"""
expectation = self.density_model.estimate_t(torch.tensor(adata.obsm[self.embedding_key].copy()).float().to(self.device), mode=mode)
if isinstance(expectation, torch.Tensor):
expectation = expectation.detach().cpu().numpy()
return expectation
[docs]
@torch.no_grad()
def estimate_variance(self, adata):
"""estimate the variance of pseudotime
Arguments:
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
Returns
-------
var: :class:`~np.ndarray`
variance of cell time
"""
var = self.density_model.var_t_given_x(torch.tensor(adata.obsm[self.embedding_key].copy()).float().to(self.device)).detach().cpu().numpy()
return var