Source code for pygot.tools.traj.model_training

import torch
from functools import partial
from .v_centric_training import v_centric_training, GraphicalOTVelocitySampler
from .x_centric_training import x_centric_training


class ODEwrapper(torch.nn.Module):
    def __init__(self, func):
        super(ODEwrapper, self).__init__()
        self.func = func
        


    def forward(self, t, x): #NOTE the forward pass when we use torchdiffeq must be forward(self,t,x)
        if self.func.time_varying:
            if len(t.size()) == 0:
                time = t.repeat(x.size()[0],1)
            else:
                time = t
        
            dxdt = self.func(torch.concat([x, time], dim=-1).float())
        else:
            dxdt = self.func(x)
        return dxdt
    
    

class ODEwrapperNoTime(torch.nn.Module):
    def __init__(self, func):
        super(ODEwrapperNoTime, self).__init__()
        self.func = func


    def forward(self, t, x): #NOTE the forward pass when we use torchdiffeq must be forward(self,t,x)
        dxdt = self.func(x)
        return dxdt

class MLP(torch.nn.Module):
    def __init__(self, dim, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), 16, bias=False),
            torch.nn.CELU(),
            torch.nn.Linear(16, 32, bias=False),
            torch.nn.CELU(),
            torch.nn.Linear(32, 16, bias=False),
            torch.nn.CELU(),
            torch.nn.Linear(16, dim, bias=False),
        )

    def forward(self, x):
        return self.net(x) 
    



[docs] def fit_velocity_model( adata, time_key, embedding_key, landmarks=False, device=None, graph_key=None, n_neighbors=50, v_centric_iter_n=1000, v_centric_batch_size=256, add_noise=True, sigma=0.1, lr=5e-3, path='', linear=False, distance_metrics='SP', pretrained_model=None, x_centric=True, x_centric_iter_n=2000, x_centric_batch_size=256, reverse_schema=True, time_varying=True, filtered=True, **kwargs ): """Estimates velocities and fit trajectories in latent space. Arguments: --------- adata: :class:`~anndata.AnnData` Annotated data matrix. time_key: `str` Name of time label in adata.obs embedding_key: `str' Name of latent space to fit, in adata.obsm landmarks: `bool` Use landmarks to approximate graphical paths graph_key: `str` (default: None) Name of graph to fit, in adata.obsm device: :class:`~torch.device` torch device n_neighbors: `int` (default: 50) Neighbors number in kNN v_centric_iter_n: `int` (default: 1000) Iteration number of v-centric training v_centric_batch_size: `int` (default: 256) Batch size of v-centric training. Note: increase will dramatically increase training time due to the complexity of OT add_noise: `bool` (default: True) Assumption of gaussian distribution of velocity sigma: `float` (default: 0.1) The variance of gaussian distribution of velocity distance_metrics: 'SP' or 'L2' (default: 'SP') Distance metrics in optimal transport (shortest path distance or euclidean distance) pretrained_model: :class`~torch.nn.Module` (default: None) Training for pretraiened model path: `str` (default: None) Dir path to store shorest path file linear: `bool` (default: False) Accept linear path or not filtered: `bool` (default: True) Use filtered velocity or not x_centric: `bool` (default: True) Do x-centric training x_centric_iter_n: `int` (default: 2000) Iteration number of x-centric training x_centric_batch_size: `int` (default: 256) Batch size of x-centric training reverse_schema: `bool` (default: True) Simulation trajectory from end to start also (e.g. day7 -> day0), in x-centric training time_varying: `bool` (default: True) The neural network model use time label as input or not lr: `float` (default: 5e-3) Learning rate Returns ------- model: :class`~ODEwrapper` velocity model history: `list` training history """ if device is None: device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') dim = adata.obsm[embedding_key].shape[1] adata.obsm[embedding_key] = adata.obsm[embedding_key][:,:dim] sp_sampler = GraphicalOTVelocitySampler( adata, time_key, graph_key=embedding_key if graph_key==None else graph_key, embedding_key=embedding_key, landmarks=landmarks, device=device, path=path, linear=linear, n_neighbors=n_neighbors, **kwargs ) if pretrained_model == None: model = ODEwrapper(MLP(dim=dim, time_varying=time_varying)).to(device) else: model = pretrained_model optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=0.001) if filtered: sample_fn_path = partial(sp_sampler.filtered_sample_batch_path, sigma=sigma, batch_size=v_centric_batch_size, distance_metrics=distance_metrics, add_noise=add_noise) else: sample_fn_path = partial(sp_sampler.sample_batch_path, sigma=sigma, batch_size=v_centric_batch_size, distance_metrics=distance_metrics, add_noise=add_noise) model, history = v_centric_training(model, optimizer, sample_fn_path, iter_n=v_centric_iter_n, device=device) if x_centric: model, history2 = x_centric_training(adata, time_key, embedding_key, model, reverse_schema=reverse_schema, batch_size=x_centric_batch_size, iter_n=x_centric_iter_n, sp_sampler=sp_sampler, distance_metrics=distance_metrics, device=device) return model, (history, history2) return model, history