diff --git a/Singlecell_multi_omics/.keep b/Singlecell_multi_omics/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/.keep b/Singlecell_multi_omics/src/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/core/.keep b/Singlecell_multi_omics/src/core/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/core/__init__.py b/Singlecell_multi_omics/src/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87d16c6e6035f9afaa63f1390701b1acb7489544 --- /dev/null +++ b/Singlecell_multi_omics/src/core/__init__.py @@ -0,0 +1,17 @@ +from .posterior import Posterior +from .trainer import Trainer +from .inference import UnsupervisedTrainer, AdapterTrainer +from .annotation import ( + ClassifierTrainer, +) +from .multi_inference import MultiPosterior, MultiTrainer + +__all__ = [ + "Trainer", + "Posterior", + "UnsupervisedTrainer", + "AdapterTrainer", + "ClassifierTrainer", + "MultiPosterior", + "MultiTrainer" +] diff --git a/Singlecell_multi_omics/src/core/annotation.py b/Singlecell_multi_omics/src/core/annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..dddb7fc9d3f09556efc230ac65e0757501ff3e54 --- /dev/null +++ b/Singlecell_multi_omics/src/core/annotation.py @@ -0,0 +1,317 @@ +from collections import namedtuple + +import numpy as np +import logging + +from sklearn import neighbors +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import GridSearchCV +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC + +import torch +from torch.nn import functional as F + +from scMVP.inference import Posterior +from scMVP.inference import Trainer +from scMVP.inference.inference import UnsupervisedTrainer +from scMVP.inference.posterior import unsupervised_clustering_accuracy + +logger = logging.getLogger(__name__) + + +class AnnotationPosterior(Posterior): + def __init__(self, *args, model_zl=False, **kwargs): + super().__init__(*args, **kwargs) + self.model_zl = model_zl + + def accuracy(self): + model, cls = ( + (self.sampling_model, self.model) + if hasattr(self, "sampling_model") + else (self.model, None) + ) + acc = compute_accuracy(model, self, classifier=cls, model_zl=self.model_zl) + logger.debug("Acc: %.4f" % (acc)) + return acc + + accuracy.mode = "max" + + @torch.no_grad() + def hierarchical_accuracy(self): + all_y, all_y_pred = self.compute_predictions() + acc = np.mean(all_y == all_y_pred) + + all_y_groups = np.array([self.model.labels_groups[y] for y in all_y]) + all_y_pred_groups = np.array([self.model.labels_groups[y] for y in all_y_pred]) + h_acc = np.mean(all_y_groups == all_y_pred_groups) + + logger.debug("Hierarchical Acc : %.4f\n" % h_acc) + return acc + + accuracy.mode = "max" + + @torch.no_grad() + def compute_predictions(self, soft=False): + """ + :return: the true labels and the predicted labels + :rtype: 2-tuple of :py:class:`numpy.int32` + """ + model, cls = ( + (self.sampling_model, self.model) + if hasattr(self, "sampling_model") + else (self.model, None) + ) + return compute_predictions( + model, self, classifier=cls, soft=soft, model_zl=self.model_zl + ) + + @torch.no_grad() + def unsupervised_classification_accuracy(self): + all_y, all_y_pred = self.compute_predictions() + uca = unsupervised_clustering_accuracy(all_y, all_y_pred)[0] + logger.debug("UCA : %.4f" % (uca)) + return uca + + unsupervised_classification_accuracy.mode = "max" + + @torch.no_grad() + def nn_latentspace(self, posterior): + data_train, _, labels_train = self.get_latent() + data_test, _, labels_test = posterior.get_latent() + nn = KNeighborsClassifier() + nn.fit(data_train, labels_train) + score = nn.score(data_test, labels_test) + return score + + +class ClassifierTrainer(Trainer): + r"""The ClassifierInference class for training a classifier either on the raw data or on top of the latent + space of another model (VAE, VAEC, SCANVI). + + Args: + :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` + :gene_dataset: A gene_dataset instance like ``CortexDataset()`` + :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples + to use Default: ``0.8``. + :test_size: The test size, either a float between 0 and 1 or and integer for the number of test samples + to use Default: ``None``. + :sampling_model: Model with z_encoder with which to first transform data. + :sampling_zl: Transform data with sampling_model z_encoder and l_encoder and concat. + :\**kwargs: Other keywords arguments from the general Trainer class. + + + Examples: + >>> gene_dataset = CortexDataset() + >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, + ... n_labels=gene_dataset.n_labels) + + >>> classifier = Classifier(vae.n_latent, n_labels=cortex_dataset.n_labels) + >>> trainer = ClassifierTrainer(classifier, gene_dataset, sampling_model=vae, train_size=0.5) + >>> trainer.train(n_epochs=20, lr=1e-3) + >>> trainer.test_set.accuracy() + """ + + def __init__( + self, + *args, + train_size=0.8, + test_size=None, + sampling_model=None, + sampling_zl=False, + use_cuda=True, + **kwargs + ): + self.sampling_model = sampling_model + self.sampling_zl = sampling_zl + super().__init__(*args, use_cuda=use_cuda, **kwargs) + self.train_set, self.test_set, self.validation_set = self.train_test_validation( + self.model, + self.gene_dataset, + train_size=train_size, + test_size=test_size, + type_class=AnnotationPosterior, + ) + self.train_set.to_monitor = ["accuracy"] + self.test_set.to_monitor = ["accuracy"] + self.validation_set.to_monitor = ["accuracy"] + self.train_set.model_zl = sampling_zl + self.test_set.model_zl = sampling_zl + self.validation_set.model_zl = sampling_zl + + @property + def posteriors_loop(self): + return ["train_set"] + + def __setattr__(self, key, value): + if key in ["train_set", "test_set"]: + value.sampling_model = self.sampling_model + super().__setattr__(key, value) + + def loss(self, tensors_labelled): + x, _, _, _, labels_train = tensors_labelled + if self.sampling_model: + if hasattr(self.sampling_model, "classify"): + return F.cross_entropy( + self.sampling_model.classify(x), labels_train.view(-1) + ) + else: + if self.sampling_model.log_variational: + x = torch.log(1 + x) + if self.sampling_zl: + x_z = self.sampling_model.z_encoder(x)[0] + x_l = self.sampling_model.l_encoder(x)[0] + x = torch.cat((x_z, x_l), dim=-1) + else: + x = self.sampling_model.z_encoder(x)[0] + return F.cross_entropy(self.model(x), labels_train.view(-1)) + + @torch.no_grad() + def compute_predictions(self, soft=False): + """ + :return: the true labels and the predicted labels + :rtype: 2-tuple of :py:class:`numpy.int32` + """ + model, cls = ( + (self.sampling_model, self.model) + if hasattr(self, "sampling_model") + else (self.model, None) + ) + full_set = self.create_posterior(type_class=AnnotationPosterior) + return compute_predictions( + model, full_set, classifier=cls, soft=soft, model_zl=self.sampling_zl + ) + + +@torch.no_grad() +def compute_predictions( + model, data_loader, classifier=None, soft=False, model_zl=False +): + all_y_pred = [] + all_y = [] + + for i_batch, tensors in enumerate(data_loader): + sample_batch, _, _, _, labels = tensors + all_y += [labels.view(-1).cpu()] + + if hasattr(model, "classify"): + y_pred = model.classify(sample_batch) + elif classifier is not None: + # Then we use the specified classifier + if model is not None: + if model.log_variational: + sample_batch = torch.log(1 + sample_batch) + if model_zl: + sample_z = model.z_encoder(sample_batch)[0] + sample_l = model.l_encoder(sample_batch)[0] + sample_batch = torch.cat((sample_z, sample_l), dim=-1) + else: + sample_batch, _, _ = model.z_encoder(sample_batch) + y_pred = classifier(sample_batch) + else: # The model is the raw classifier + y_pred = model(sample_batch) + + if not soft: + y_pred = y_pred.argmax(dim=-1) + + all_y_pred += [y_pred.cpu()] + + all_y_pred = np.array(torch.cat(all_y_pred)) + all_y = np.array(torch.cat(all_y)) + + return all_y, all_y_pred + + +@torch.no_grad() +def compute_accuracy(vae, data_loader, classifier=None, model_zl=False): + all_y, all_y_pred = compute_predictions( + vae, data_loader, classifier=classifier, model_zl=model_zl + ) + return np.mean(all_y == all_y_pred) + + +Accuracy = namedtuple( + "Accuracy", ["unweighted", "weighted", "worst", "accuracy_classes"] +) + + +@torch.no_grad() +def compute_accuracy_tuple(y, y_pred): + y = y.ravel() + n_labels = len(np.unique(y)) + classes_probabilities = [] + accuracy_classes = [] + for cl in range(n_labels): + idx = y == cl + classes_probabilities += [np.mean(idx)] + accuracy_classes += [ + np.mean((y[idx] == y_pred[idx])) if classes_probabilities[-1] else 0 + ] + # This is also referred to as the "recall": p = n_true_positive / (n_false_negative + n_true_positive) + # ( We could also compute the "precision": p = n_true_positive / (n_false_positive + n_true_positive) ) + accuracy_named_tuple = Accuracy( + unweighted=np.dot(accuracy_classes, classes_probabilities), + weighted=np.mean(accuracy_classes), + worst=np.min(accuracy_classes), + accuracy_classes=accuracy_classes, + ) + return accuracy_named_tuple + + +@torch.no_grad() +def compute_accuracy_nn(data_train, labels_train, data_test, labels_test, k=5): + clf = neighbors.KNeighborsClassifier(k, weights="distance") + return compute_accuracy_classifier( + clf, data_train, labels_train, data_test, labels_test + ) + + +@torch.no_grad() +def compute_accuracy_classifier(clf, data_train, labels_train, data_test, labels_test): + clf.fit(data_train, labels_train) + # Predicting the labels + y_pred_test = clf.predict(data_test) + y_pred_train = clf.predict(data_train) + + return ( + ( + compute_accuracy_tuple(labels_train, y_pred_train), + compute_accuracy_tuple(labels_test, y_pred_test), + ), + y_pred_test, + ) + + +@torch.no_grad() +def compute_accuracy_svc( + data_train, + labels_train, + data_test, + labels_test, + param_grid=None, + verbose=0, + max_iter=-1, +): + if param_grid is None: + param_grid = [ + {"C": [1, 10, 100, 1000], "kernel": ["linear"]}, + {"C": [1, 10, 100, 1000], "gamma": [0.001, 0.0001], "kernel": ["rbf"]}, + ] + svc = SVC(max_iter=max_iter) + clf = GridSearchCV(svc, param_grid, verbose=verbose) + return compute_accuracy_classifier( + clf, data_train, labels_train, data_test, labels_test + ) + + +@torch.no_grad() +def compute_accuracy_rf( + data_train, labels_train, data_test, labels_test, param_grid=None, verbose=0 +): + if param_grid is None: + param_grid = {"max_depth": np.arange(3, 10), "n_estimators": [10, 50, 100, 200]} + rf = RandomForestClassifier(max_depth=2, random_state=0) + clf = GridSearchCV(rf, param_grid, verbose=verbose) + return compute_accuracy_classifier( + clf, data_train, labels_train, data_test, labels_test + ) diff --git a/Singlecell_multi_omics/src/core/autotune.py b/Singlecell_multi_omics/src/core/autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..33c8922a27cf126f91209657c75ae74ec934ff2a --- /dev/null +++ b/Singlecell_multi_omics/src/core/autotune.py @@ -0,0 +1,1301 @@ +import datetime +import logging +import multiprocessing +import os +import pickle +import threading +import time +from collections import defaultdict +from functools import partial, wraps +from logging.handlers import QueueListener, QueueHandler +from queue import Empty +from subprocess import Popen +from typing import Any, Callable, Dict, List, TextIO, Type, Union + +import numpy as np +import pymongo +import torch +import tqdm +from hyperopt import fmin, tpe, Trials, hp, STATUS_OK, STATUS_FAIL +from hyperopt.mongoexp import ( + as_mongo_str, + MongoJobs, + MongoTrials, + MongoWorker, + ReserveTimeout, +) + +from scMVP._settings import autotune_formatter +from scMVP.dataset import DownloadableDataset, GeneExpressionDataset +from scMVP.models import VAE +from . import Trainer, UnsupervisedTrainer + +# spawning is required for processes relying on cuda, and for windows +multiprocessing.set_start_method("spawn", force=True) + +# instantiate logger, handler and formatter +# logger_all is used to send *all* autotune logs to a logfile +logger_all = logging.getLogger(__name__ + ".all") +logger_all.setLevel(logging.DEBUG) +logger = logging.getLogger(__name__) +# instantiate hyperopt and autotune file handlers as global variables for clean up +fh_autotune = None +fh_hyperopt = None + + +class FminTimeoutError(Exception): + """Thrown if fmin process hasn't finished in the allotted + time after all workers have died. + """ + + +class DispatchHandler(logging.Handler): + """A simple dispatcher for logging events. + + It dispatches events to loggers based on the name in the received record, + which then get dispatched, by the logging system, to the handlers, configured for those loggers. + """ + + def emit(self, record: logging.LogRecord): + record_logger = logging.getLogger(record.name) + if record.levelno >= record_logger.level: + record_logger.handle(record) + + +class StoppableThread(threading.Thread): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stop_event = threading.Event() + + def stop(self): + self.stop_event.set() + + +# register running process and open files to terminate/close at exit +started_processes: List[Union[multiprocessing.Process, Popen, QueueListener]] = [] +started_threads: List[StoppableThread] = [] +started_queues: List[multiprocessing.Queue] = [] +open_files: List[TextIO] = [] + + +# cleanup helpers +def _cleanup_processes_files(): + """Cleanup function, starts with latest processes/files. + + Terminates processes, sets stop events to stop threads, closes open files. + """ + logger_all.info("Cleaning up") + logger_all.debug("Cleaning up: closing files.") + for f in open_files[::-1]: + if not f.closed: + f.close() + logger_all.debug("Cleaning up: closing queues.") + for q in started_queues: + q.close() + logger_all.debug("Cleaning up: setting cleanup_event and joining threads.") + for t in started_threads[::-1]: + if t.is_alive(): + logger_all.debug("Closing Thread {}.".format(t.name)) + t.stop_event.set() + t.join() + else: + logger_all.debug("Thread {} already done.".format(t.name)) + logger_all.debug("Cleaning up: terminating processes.") + for p in started_processes[::-1]: + if isinstance(p, Popen): + if p.poll() is not None: + logger_all.debug("Terminating mongod process.") + p.terminate() + p.wait() + else: + logger_all.debug("mongodd process already done.") + if isinstance(p, multiprocessing.Process): + if p.is_alive(): + logger_all.debug("Terminating Process {}.".format(p.name)) + p.terminate() + else: + logger_all.debug("Process {} already done.".format(p.name)) + if isinstance(p, QueueListener): + if p._thread is not None and not p.queue._closed: + p.stop() + + +def _cleanup_logger(): + """Removes added handlers.""" + logger_all.debug("Cleaning up: removing added logging handler.") + hp_logger = logging.getLogger("hyperopt") + for handler in hp_logger.handlers: + if handler == fh_hyperopt: + logger_all.debug("Cleaning up: removing hyperopt FileHandler.") + hp_logger.removeHandler(fh_hyperopt) + break + for handler in logger_all.handlers: + if handler == fh_autotune: + logger_all.debug("Cleaning up: removing autotune FileHandler.") + logger_all.removeHandler(fh_autotune) + + +def _cleanup_decorator(func: Callable): + """Decorates top-level calls in order to launch cleanup when an Exception is caught.""" + + @wraps(func) + def decorated(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger_all.exception( + "Caught {exception} in {func}, starting cleanup.".format( + exception=e.args, func=func.__name__ + ), + exc_info=True, + ) + _cleanup_processes_files() + _cleanup_logger() + raise + + return decorated + + +def _error_logger_decorator(func: Callable): + """Decorates top-level calls in order to launch cleanup when an Exception is caught.""" + + @wraps(func) + def decorated(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger_all.exception( + "Caught {exception} in {func}, starting cleanup.".format( + exception=e.args, func=func.__name__ + ), + exc_info=True, + ) + raise + + return decorated + + +def configure_asynchronous_logging(logging_queue: multiprocessing.Queue): + """Helper for asynchronous logging - Writes all logs to a queue.""" + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + queue_handler = QueueHandler(logging_queue) + queue_handler.setLevel(logging.DEBUG) + root_logger.addHandler(queue_handler) + logger_all.debug("Asynchronous logging has been set.") + + +def _asynchronous_logging_method_decorator(func: Callable): + """Decorates top-level calls in order to launch cleanup when an Exception is caught.""" + + @wraps(func) + def decorated(self, *args, **kwargs): + configure_asynchronous_logging(self.logging_queue) + return func(self, *args, **kwargs) + + return decorated + + +@_cleanup_decorator +def auto_tune_scvi_model( + exp_key: str, + gene_dataset: GeneExpressionDataset = None, + delayed_populating: bool = False, + custom_objective_hyperopt: Callable = None, + objective_kwargs: Dict[str, Any] = None, + model_class: VAE = VAE, + trainer_class: Trainer = UnsupervisedTrainer, + metric_name: str = None, + metric_kwargs: Dict[str, Any] = None, + posterior_name: str = "test_set", + model_specific_kwargs: dict = None, + trainer_specific_kwargs: dict = None, + train_func_specific_kwargs: dict = None, + space: dict = None, + max_evals: int = 100, + train_best: bool = True, + pickle_result: bool = True, + save_path: str = ".", + use_batches: bool = False, + parallel: bool = True, + n_cpu_workers: int = None, + gpu_ids: List[int] = None, + n_workers_per_gpu: int = 1, + reserve_timeout: float = 180.0, + fmin_timeout: float = 300.0, + fmin_timer: float = None, + mongo_port: str = "1234", + mongo_host: str = "localhost", + db_name: str = "scMVP_db", + multiple_hosts: bool = False, +) -> (Type[Trainer], Trials): + """Perform automatic hyperparameter optimization of an scMVP model + and return best model and hyperopt Trials object. + + ``Trials`` object contains hyperparameter space and loss history for each trial. + We provide a default hyperparameter search space (see source code), + but we recommend the user to build a custom one for each application. + Convention: fixed parameters (no default) have precedence over tunable parameters (default). + Note that the verbosity of this function has to be set using the logging module. + In particular, for the parallel case, only a progress bar is shown if the + logging level is equal or higher to ``logging.WARNING``. + + :param exp_key: Name of the experiment in MongoDb. + If already exists in db, ``hyperopt`` will run a number of trainings equal to + the difference between current and previous ``max_evals``. + :param gene_dataset: scVI gene expression dataset. + :param delayed_populating: Switch for the delayed populating mechanism + of scvi.dataset.dataset.DownloadableDataset. Useful for large datasets + which have to be instantiated inside the workers. + :param custom_objective_hyperopt: A custom objective function respecting the ``hyperopt`` format. + Roughly, it needs to return the quantity to optimize for, either directly + or in a ``dict`` under the "loss" key. + See https://github.com/hyperopt/hyperopt/wiki for a more detailed explanation. + By default, we provide an objective function which can be parametrized + through the various arguments of this function (``gene_dataset``, ``model_class``, etc.) + :param objective_kwargs: Dictionary containing the fixed keyword arguments ` + to the custom `objective_hyperopt. + :param model_class: scVI model class (e.g ``VAE``, ``VAEC``, ``SCANVI``) + :param trainer_class: ``Trainer`` sub-class (e.g ``UnsupervisedTrainer``) + :param metric_name: Name of the metric to optimize for. If `None` defaults to ``marginal_ll`` + :param metric_kwargs: keyword arguments for the metric method. + If `metric_name` is None, defaults to {"n_mc_samples": 100}. + :param posterior_name: Name of the posterior distribution to compute the metric with. + :param model_specific_kwargs: ``dict`` of fixed parameters which will be passed to the model. + :param trainer_specific_kwargs: ``dict`` of fixed parameters which will be passed to the trainer. + :param train_func_specific_kwargs: dict of fixed parameters which will be passed to the train method. + :param space: dict containing up to three sub-dicts with keys "model_tunable_kwargs", + "trainer_tunable_kwargs" or "train_func_tunable_kwargs". + Each of those dict contains ``hyperopt`` defined parameter spaces (e.g. ``hp.choice(..)``) + which will be passed to the corresponding object : model, trainer or train method + when performing hyper-optimization. Default: mutable, see source code. + :param max_evals: Maximum number of evaluations of the objective. + :param train_best: If ``True``, train best model and return it. + :param pickle_result: If ``True``, pickle ``Trials`` and ``Trainer`` objects using ``save_path``. + :param save_path: Path where to save best model, trainer, trials and mongo files. + :param use_batches: If ``False``, pass ``n_batch=0`` to model else pass ``gene_dataset.n_batches``. + :param parallel: If ``True``, use ``MongoTrials`` object to run trainings in parallel. + :param n_cpu_workers: Number of cpu workers to launch. If None, and no GPUs are found, + defaults to ``os.cpucount() - 1``. Else, defaults to 0. + :param gpu_ids: Ids of the GPUs to use. If None defaults to all GPUs found by ``torch``. + Note that considered gpu ids are int from 0 to ``torch.cuda.device_count()``. + :param n_workers_per_gpu: Number of workers to launch per gpu found by ``torch``. + :param reserve_timeout: Amount of time, in seconds, a worker tries to reserve a job for + before throwing a ``ReserveTimeout`` Exception. + :param fmin_timeout: Amount of time, in seconds, fmin_process has to terminate + after all workers have died - before throwing a ``FminTimeoutError``. + If ``multiple_hosts`` is set to ``True``, this is set to ``None`` to prevent timing out. + :param fmin_timer: Global amount of time allowed for fmin_process. + If not None, the minimization procedure will be stopped after ``fmin_timer`` seconds. + Used only if ``parallel`` is set to ``True``. + :param mongo_port: Port to the Mongo db. + :param mongo_host: Hostname used with ``mongo_port`` to indicate the prefix of the mongodb address. + The prefix of the address passed onto the workers and ``MongoTrials`` object + is ``'{mongo_host}:{mongo_port}'``. + :param db_name: Name to use when creating the Mongo database. Suffix of the Mongo address. + :param multiple_hosts: If ``True``, user is considered to have workers launched on several machines. + Therefore, setting this to ``True`` disables the ``fmin_timeout`` behaviour. + :return: ``Trainer`` object for the best model and ``(Mongo)Trials`` object containing logs for the different runs. + + Examples: + >>> from scMVP.dataset import CortexDataset + >>> gene_dataset = CortexDataset() + >>> best_trainer, trials = auto_tune_scvi_model("cortex", gene_dataset) + """ + global fh_autotune + + # add file handler + fh_autotune = logging.handlers.RotatingFileHandler( + os.path.join(save_path, "scvi_autotune_logfile.txt") + ) + fh_autotune.setFormatter(autotune_formatter) + fh_autotune.setLevel(logging.DEBUG) + logger_all.addHandler(fh_autotune) + + if delayed_populating and not isinstance(gene_dataset, DownloadableDataset): + raise ValueError( + "The delayed_population mechanism requires an " + "instance of scvi.dataset.dataset.DownloadableDataset." + ) + + if fmin_timer and train_best: + logger_all.warning( + "fmin_timer and train_best are both set to True. " + "This means that runtime will exceed fmin_timer " + "by at least the time it takes to complete a full training." + ) + + logger_all.info("Starting experiment: {exp_key}".format(exp_key=exp_key)) + + # default search space + if space is None: + logger_all.debug("Using default parameter search space.") + space = { + "model_tunable_kwargs": { + "n_latent": 5 + hp.randint("n_latent", 11), # [5, 15] + "n_hidden": hp.choice("n_hidden", [64, 128, 256]), + "n_layers": 1 + hp.randint("n_layers", 5), + "dropout_rate": hp.choice("dropout_rate", [0.1, 0.3, 0.5, 0.7]), + "reconstruction_loss": hp.choice("reconstruction_loss", ["zinb", "nb"]), + }, + "train_func_tunable_kwargs": { + "lr": hp.choice("lr", [0.01, 0.005, 0.001, 0.0005, 0.0001]) + }, + } + + # default metric + if metric_name is None: + metric_name = "marginal_ll" + metric_kwargs = {"n_mc_samples": 100} + + # build a partial objective function restricted to the search space + if custom_objective_hyperopt is None: + # default specific kwargs + model_specific_kwargs = model_specific_kwargs if model_specific_kwargs else {} + trainer_specific_kwargs = ( + trainer_specific_kwargs if trainer_specific_kwargs else {} + ) + train_func_specific_kwargs = ( + train_func_specific_kwargs if train_func_specific_kwargs else {} + ) + + # default early stopping + if "early_stopping_kwargs" not in trainer_specific_kwargs: + logger_all.debug("Adding default early stopping behaviour.") + early_stopping_kwargs = { + "early_stopping_metric": "elbo", + "save_best_state_metric": "elbo", + "patience": 50, + "threshold": 0, + "reduce_lr_on_plateau": True, + "lr_patience": 25, + "lr_factor": 0.2, + } + trainer_specific_kwargs["early_stopping_kwargs"] = early_stopping_kwargs + # add elbo to metrics to monitor + metrics_to_monitor = trainer_specific_kwargs.get("metrics_to_monitor", []) + metrics_to_monitor.append("elbo") + trainer_specific_kwargs["metrics_to_monitor"] = metrics_to_monitor + + logger_all.info( + "Fixed parameters: \n" + "model: \n" + + str(model_specific_kwargs) + + "\n" + + "trainer: \n" + + str(trainer_specific_kwargs) + + "\n" + + "train method: \n" + + str(train_func_specific_kwargs) + ) + objective_hyperopt = partial( + _objective_function, + **{ + "gene_dataset": gene_dataset, + "delayed_populating": delayed_populating, + "model_class": model_class, + "trainer_class": trainer_class, + "metric_name": metric_name, + "metric_kwargs": metric_kwargs, + "posterior_name": posterior_name, + "model_specific_kwargs": model_specific_kwargs, + "trainer_specific_kwargs": trainer_specific_kwargs, + "train_func_specific_kwargs": train_func_specific_kwargs, + "use_batches": use_batches, + }, + ) + else: + logger_all.info("Using custom objective function.") + objective_hyperopt = partial(custom_objective_hyperopt, **objective_kwargs) + + if parallel: + logger_all.info("Starting parallel hyperoptimization") + trials = _auto_tune_parallel( + objective_hyperopt=objective_hyperopt, + exp_key=exp_key, + space=space, + max_evals=max_evals, + save_path=save_path, + n_cpu_workers=n_cpu_workers, + gpu_ids=gpu_ids, + n_workers_per_gpu=n_workers_per_gpu, + reserve_timeout=reserve_timeout, + fmin_timeout=fmin_timeout, + fmin_timer=fmin_timer, + mongo_port=mongo_port, + mongo_host=mongo_host, + db_name=db_name, + multiple_hosts=multiple_hosts, + ) + + else: + logger_all.info("Starting sequential hyperoptimization") + trials = Trials() + + # run hyperoptimization + _ = fmin( + fn=objective_hyperopt, + space=space, + algo=tpe.suggest, + max_evals=max_evals, + trials=trials, + ) + + # return best model, trained + if train_best: + logger_all.debug("Training best model with full training set") + best_space = trials.best_trial["result"]["space"] + best_trainer = objective_hyperopt(best_space, is_best_training=True) + + if pickle_result: + if train_best: + logger_all.debug("Pickling best model and trainer") + # pickle trainer and save model (overkill?) + with open( + os.path.join(save_path, "best_trainer_{key}".format(key=exp_key)), "wb" + ) as f: + pickle.dump(best_trainer, f) + torch.save( + best_trainer.model.state_dict(), + os.path.join(save_path, "best_model_{key}".format(key=exp_key)), + ) + # remove object containing thread.lock (otherwise pickle.dump throws) + logger_all.debug("Pickling Trials object") + if hasattr(trials, "handle"): + del trials.handle + with open( + os.path.join(save_path, "trials_{key}".format(key=exp_key)), "wb" + ) as f: + pickle.dump(trials, f) + + # remove added logging handlers + _cleanup_logger() + + if train_best: + return best_trainer, trials + else: + return trials + + +def _auto_tune_parallel( + objective_hyperopt: Callable, + exp_key: str, + space: dict = None, + max_evals: int = 100, + save_path: str = ".", + n_cpu_workers: int = None, + gpu_ids: List[int] = None, + n_workers_per_gpu: int = 1, + reserve_timeout: float = 180.0, + fmin_timeout: float = 300.0, + fmin_timer: float = None, + mongo_port: str = "1234", + mongo_host: str = "localhost", + db_name: str = "scvi_db", + multiple_hosts: bool = False, +) -> MongoTrials: + """Parallel version of the hyperoptimization procedure. + Called by ``auto_tune_scvi_model`` when ``parallel=True``. + Specifically, first the MongoDb service is launched in its own forked process. + Then, the call to the minimization process is made in its own forked process. + Then, the call ``worker_launcher`` is made in its own Thread. + After that, the program waits for either the minimization + process to finish or for the workers to all timeout. + When one of these conditions is verified the program kills the waiter for the other + and tries to dequeue the results from the minimization process. + At that point, if ``multiple_hosts`` is set to True, the program waits indefinitely + for the minimization process to put the results in the queue. + If not, the minimisation process has ``fmin_timeout`` seconds to finish. + This mechanism ensures that the program does not hang if, for any reason, + the workers die before completing all the jobs. + Note that logs to the ``hyperopt`` package are automatically stored in ``./hyperopt_logfile.txt``. + Note that the progress bar is automatically disabled if the logging level + for ``scvi.inference.autotune`` is lower than logging.WARNING. + + :param objective_hyperopt: Callable, the objective function to minimize. + :param exp_key: Name of the experiment in MongoDb. + :param space: ``dict`` containing up to three sub-dicts with keys "model_tunable_kwargs", + "trainer_tunable_kwargs" or "train_func_tunable_kwargs". + Each of those dict contains ``hyperopt`` defined parameter spaces (e.g. ``hp.choice(..)``) + which will be passed to the corresponding object : model, trainer or train method + when performing hyperoptimization. Default: mutable, see source code. + :param max_evals: Maximum number of evaluations of the objective. + :param save_path: Path where to save best model, trainer, trials and mongo files. + :param n_cpu_workers: Number of cpu workers to launch. If None, and no GPUs are found, + defaults to ``os.cpucount() - 1``. Else, defaults to 0. + :param gpu_ids: Ids of the GPUs to use. If None defaults to all GPUs found by ``torch``. + Note that considered gpu ids are int from ``0`` to ``torch.cuda.device_count()``. + :param n_workers_per_gpu: Number of workers ton launch per gpu found by ``torch``. + :param reserve_timeout: Amount of time, in seconds, a worker tries to reserve a job for + before throwing a ``ReserveTimeout`` Exception. + :param fmin_timeout: Amount of time, in seconds, ``fmin_process`` has to terminate + after all workers have died - before throwing a ``FminTimeoutError``. + If ``multiple_hosts`` is set to ``True``, this is set to None to disable the timeout behaviour. + :param fmin_timer: Global amount of time allowed for fmin_process. + If not None, the minimization procedure will be stopped after ``fmin_timer`` seconds. + Used only if ``parallel`` is set to ``True``. + :param mongo_port: Port to the mongo db. + :param mongo_host: Hostname used with mongo_port to indicate the prefix of the mongodb address. + The prefix of the address passed onto the workers and MongoTrials object is ``'{mongo_host}:{mongo_port}'``. + :param db_name: Name to use when creating the Mongo database. Suffix of the mongo address. + :param multiple_hosts: If ``True``, user is considered to have workers launched on several machines. + Therefore, setting this to ``True`` disables the ``fmin_timeout`` behaviour. + :return: ``MongoTrials`` object containing the results of the program. + """ + global started_processes + global started_threads + global started_queues + global fh_hyperopt + + # prepare parallel logging + logging_queue = multiprocessing.Queue() + started_queues.append(logging_queue) + listener = QueueListener(logging_queue, DispatchHandler()) + listener.start() + started_processes.append(listener) + + # run mongod bash script + mongo_path = os.path.join(save_path, "mongo") + if not os.path.exists(mongo_path): + os.makedirs(mongo_path) + mongo_logfile = open(os.path.join(mongo_path, "mongo_logfile.txt"), "w") + open_files.append(mongo_logfile) + logger_all.debug( + "Starting MongoDb process, logs redirected to " + "{name}.".format(name=mongo_logfile.name) + ) + mongod_process = Popen( + [ + "mongod", + "--quiet", + "--dbpath={path}".format(path=mongo_path), + "--port={port}".format(port=mongo_port), + ], + stdout=mongo_logfile, + ) + # let mongo server start and check it did + time.sleep(5) + client = pymongo.MongoClient( + mongo_host + ":" + mongo_port, serverSelectionTimeoutMS=100 + ) + try: + client.server_info() + client.close() + except pymongo.mongo_client.ServerSelectionTimeoutError: + logger_all.error("Failed to connect to mongo agent.") + mongo_logfile.close() + mongo_logfile = open(os.path.join(mongo_path, "mongo_logfile.txt"), "r") + logger_all.error( + "Logs for the mongod subprocess: \n" + "".join(mongo_logfile.readlines()) + ) + raise + + mongo_url = os.path.join(mongo_host + ":" + mongo_port, db_name) + started_processes.append(mongod_process) + + # log hyperopt only to file + hp_logger = logging.getLogger("hyperopt") + hp_logger.propagate = False + fh_hyperopt = logging.handlers.RotatingFileHandler( + os.path.join(save_path, "hyperopt_logfile.txt") + ) + fh_hyperopt.setFormatter(autotune_formatter) + hp_logger.addHandler(fh_hyperopt) + + # start fmin launcher thread + logger_all.debug("Starting minimization procedure") + queue = multiprocessing.Queue() + started_queues.append(queue) + fmin_launcher_thread = FminLauncherThread( + logging_queue=logging_queue, + queue=queue, + objective_hyperopt=objective_hyperopt, + exp_key=exp_key, + space=space, + algo=tpe.suggest, + max_evals=max_evals, + fmin_timer=fmin_timer, + mongo_url=mongo_url, + ) + fmin_launcher_thread.start() + started_threads.append(fmin_launcher_thread) + + # start worker launcher + logger_all.debug("Starting worker launcher") + worker_launcher_thread = WorkerLauncherThread( + logging_queue=logging_queue, + exp_key=exp_key, + n_cpu_workers=n_cpu_workers, + gpu_ids=gpu_ids, + n_workers_per_gpu=n_workers_per_gpu, + reserve_timeout=reserve_timeout, + workdir=mongo_path, + mongo_url=mongo_url, + multiple_hosts=multiple_hosts, + max_evals=max_evals, + ) + worker_launcher_thread.start() + started_threads.append(worker_launcher_thread) + + # wait for one to finish + while worker_launcher_thread.is_alive() and fmin_launcher_thread.is_alive(): + time.sleep(5) + + if not fmin_launcher_thread.is_alive(): + logger_all.debug("Setting worker launcher stop event.") + worker_launcher_thread.stop_event.set() + try: + if multiple_hosts: + # if using multiple_hosts, there could still be workers -> disable fmin timeout + fmin_timeout = None + logger_all.debug( + "multiple_hosts set to True, fmin will block until all trials have been completed." + ) + else: + logger_all.debug( + "multiple_hosts set to false, Fmin has {time} seconds to finish".format( + time=fmin_timeout + ) + ) + trials = queue.get(timeout=fmin_timeout) + queue.close() + except Empty: + logger_all.error( + "Queue still empty {fmin_timeout} seconds after all workers have died." + "\n".format(fmin_timeout=fmin_timeout) + "Terminating minimization process." + ) + raise FminTimeoutError( + "Queue still empty {fmin_timeout} seconds after all workers " + "have died. Check that you have used a new exp_key or allowed " + "a higher max_evals".format(fmin_timeout=fmin_timeout) + ) + + # sanity: wait for fmin, terminate workers and wait for launcher + fmin_launcher_thread.join() + worker_launcher_thread.join() + logger_all.info( + "Finished minimization procedure for experiment {exp_key}.".format( + exp_key=exp_key + ) + ) + logger_all.debug("Terminating mongod process.") + mongod_process.terminate() + # wait for process to actually terminate, avoid issues with unreleased mongod.lock + mongod_process.wait() + mongo_logfile.close() + mongo_logfile = open(os.path.join(mongo_path, "mongo_logfile.txt"), "r") + logger_all.info( + "Logs for the mongod subprocess: \n" + "".join(mongo_logfile.readlines()) + ) + logger_all.debug("Stopping asynchronous logging listener.") + listener.stop() + logging_queue.close() + + # cleanup queues, processes, threads, files and logger + _cleanup_processes_files() + _cleanup_logger() + + return trials + + +class FminLauncherThread(StoppableThread): + """Starts the process which ultimately call the minimzation procedure. + + Is encapsulated in a ``threading.Thread`` to allow for the ``fmin_timer`` mechanism. + + :param logging_queue: Queue to send logs to main process using a ``QueueHandler``. + Here to be passed on to `FminProcess`. + :param queue: Queue to put trials in. Here to be passed on to `FminProcess`. + :param objective_hyperopt: Callable, the objective function to minimize + :param exp_key: Name of the experiment in MongoDb. + :param space: ``dict`` containing up to three sub-dicts with keys "model_tunable_kwargs", + "trainer_tunable_kwargs" or "train_func_tunable_kwargs". + Each of those dict contains ``hyperopt`` defined parameter spaces (e.g. ``hp.choice(..)``) + which will be passed to the corresponding object : model, trainer or train method + when performing hyperoptimization. Default: mutable, see source code. + :param algo: Bayesian optimization algorithm from ``hyperopt`` to use. + :param max_evals: Maximum number of evaluations of the objective. + :param fmin_timer: Global amount of time allowed for fmin_process. + If not None, the minimization procedure will be stopped after ``fmin_timer`` seconds. + Used only if ``parallel`` is set to ``True``. + :param mongo_url: String of the form mongo_host:mongo_port/db_name. + """ + + def __init__( + self, + logging_queue: multiprocessing.Queue, + queue: multiprocessing.Queue, + objective_hyperopt: Callable, + exp_key: str, + space: dict, + algo: Callable = tpe.suggest, + max_evals: int = 100, + fmin_timer: float = None, + mongo_url: str = "localhost:1234/scvi_db", + ): + super().__init__(name="Fmin Launcher") + self.logging_queue = logging_queue + self.queue = queue + self.objective_hyperopt = objective_hyperopt + self.exp_key = exp_key + self.space = space + self.algo = algo + self.max_evals = max_evals + self.fmin_timer = fmin_timer + self.mongo_url = mongo_url + + @_error_logger_decorator + def run(self): + """Launches a ``hyperopt`` minimization procedure.""" + # call fmin in a process to enable termination + fmin_process = FminProcess( + logging_queue=self.logging_queue, + queue=self.queue, + objective_hyperopt=self.objective_hyperopt, + space=self.space, + mongo_url=self.mongo_url, + exp_key=self.exp_key, + algo=self.algo, + max_evals=self.max_evals, + ) + logger_all.debug("Starting FminProcess.") + fmin_process.start() + started_processes.append(fmin_process) + if self.fmin_timer is not None: + logger_all.info( + "Timer set, fmin will run for at most {timer}.".format( + timer=self.fmin_timer + ) + ) + start_time = time.monotonic() + run_time = 0 + while ( + run_time < self.fmin_timer + and fmin_process.is_alive() + and not self.stop_event.is_set() + ): + time.sleep(10) + run_time = time.monotonic() - start_time + if self.stop_event.is_set(): + logger_all.debug("Stop event set.") + elif run_time > self.fmin_timer and fmin_process.is_alive(): + logger_all.debug( + "Timer ran out. Terminating FminProcess and putting current Trials in queue." + ) + fmin_process.terminate() + # queue.put uses pickle so remove attribute containing thread.lock + trials = MongoTrials( + as_mongo_str(os.path.join(self.mongo_url, "jobs")), + exp_key=self.exp_key, + ) + if hasattr(trials, "handle"): + logger_all.debug("Deleting Trial handle for pickling.") + del trials.handle + logger_all.debug("Putting Trials in Queue.") + self.queue.put(trials) + else: + logger_all.debug("fmin finished.") + else: + logger_all.debug("No timer, waiting for fmin...") + while fmin_process.is_alive() and not self.stop_event.is_set(): + time.sleep(10) + logger_all.debug("fmin finished.") + + +class FminProcess(multiprocessing.Process): + """Call ``hyperopt``'s fmin. + + Is encapsulated in a ``multiprocessing.Process`` in order to + allow for termination in case cleanup is required. + + :param logging_queue: Queue to send logs to main process using a ``QueueHandler``. + :param queue: Queue to put trials in. + :param objective_hyperopt: Callable, the objective function to minimize + :param space: ``dict`` containing up to three sub-dicts with keys "model_tunable_kwargs", + "trainer_tunable_kwargs" or "train_func_tunable_kwargs". + Each of those dict contains ``hyperopt`` defined parameter spaces (e.g. ``hp.choice(..)``) + which will be passed to the corresponding object : model, trainer or train method + when performing hyperoptimization. Default: mutable, see source code. + :param exp_key: Name of the experiment in MongoDb. + :param mongo_url: String of the form mongo_host:mongo_port/db_name + :param algo: Bayesian optimization algorithm from ``hyperopt`` to use. + :param max_evals: Maximum number of evaluations of the objective. + :param show_progressbar: Whether or not to show the ``hyperopt`` progress bar. + """ + + def __init__( + self, + logging_queue: multiprocessing.Queue, + queue: multiprocessing.Queue, + objective_hyperopt: Callable, + space: dict, + exp_key: str, + mongo_url: str = "localhost:1234/scvi_db", + algo: Callable = tpe.suggest, + max_evals: int = 100, + show_progressbar: bool = False, + ): + super().__init__(name="Fmin") + self.logging_queue = logging_queue + self.queue = queue + self.objective_hyperopt = objective_hyperopt + self.space = space + self.mongo_url = mongo_url + self.exp_key = exp_key + self.algo = algo + self.max_evals = max_evals + self.show_progressbar = show_progressbar + + @_asynchronous_logging_method_decorator + @_error_logger_decorator + def run(self): + logger_all.debug("Instantiating MongoTrials object.") + trials = MongoTrials( + as_mongo_str(os.path.join(self.mongo_url, "jobs")), exp_key=self.exp_key + ) + logger_all.debug("Calling fmin.") + fmin( + fn=self.objective_hyperopt, + space=self.space, + algo=self.algo, + max_evals=self.max_evals, + trials=trials, + show_progressbar=self.show_progressbar, + ) + # queue.put uses pickle so remove attribute containing thread.lock + if hasattr(trials, "handle"): + logger_all.debug("fmin returned. Deleting Trial handle for pickling.") + del trials.handle + logger_all.debug("Putting Trials in Queue.") + self.queue.put(trials) + + +class WorkerLauncherThread(StoppableThread): + """Launches the local workers which are going to run the jobs required by the minimization process. + Terminates when the worker_watchdog call finishes. + Specifically, first ``n_gpu_workers`` are launched per GPU in ``gpu_ids`` in their own spawned process. + Then, ``n_cpu_workers`` CPU workers are launched, also in their own spawned process. + The use of spawned processes (each have their own python interpreter) is mandatory for compatiblity with CUDA. + See https://pytorch.org/docs/stable/notes/multiprocessing.html for more information. + + :param logging_queue: Queue to send logs to main process using a ``QueueHandler``. + Here to be passed on to the `HyperoptWorker` processes. + :param exp_key: This key is used by hyperopt as a suffix to the part of the MongoDb + which corresponds to the current experiment. In particular, it has to be passed to ``MongoWorker``. + :param n_cpu_workers: Number of cpu workers to launch. If None, and no GPUs are found, + defaults to ``os.cpu_count() - 1``. Else, defaults to 0. + :param gpu_ids: Ids of the GPUs to use. If None defaults to all GPUs found by ``torch``. + Note that considered gpu ids are int from ``0`` to ``torch.cuda.device_count()``. + :param n_workers_per_gpu: Number of workers ton launch per gpu found by ``torch``. + :param reserve_timeout: Amount of time, in seconds, a worker tries to reserve a job for + before throwing a ``ReserveTimeout`` Exception. + :param workdir: Directory where the workers + :param mongo_url: Address to the running MongoDb service. + :param multiple_hosts: ``True`` if launching workers form multiple hosts. + :param max_evals: Maximum number of evaluations of the objective. + Useful for instantiating a progress bar. + """ + + def __init__( + self, + logging_queue: multiprocessing.Queue, + exp_key: str, + n_cpu_workers: int = None, + gpu_ids: List[int] = None, + n_workers_per_gpu: int = 1, + reserve_timeout: float = 30.0, + workdir: str = ".", + mongo_url: str = "localhost:1234/scvi_db", + multiple_hosts: bool = False, + max_evals: int = 100, + ): + super().__init__(name="Worker Launcher") + self.logging_queue = logging_queue + self.exp_key = exp_key + self.n_cpu_workers = n_cpu_workers + self.gpu_ids = gpu_ids + self.n_workers_per_gpu = n_workers_per_gpu + self.reserve_timeout = reserve_timeout + self.workdir = workdir + self.mongo_url = mongo_url + self.multiple_hosts = multiple_hosts + self.max_evals = max_evals + + @_error_logger_decorator + def run(self): + global started_processes + + if self.gpu_ids is None: + n_gpus = torch.cuda.device_count() + logger_all.debug( + "gpu_ids is None, defaulting to all {n_gpus} GPUs found by torch.".format( + n_gpus=n_gpus + ) + ) + self.gpu_ids = list(range(n_gpus)) + if n_gpus and self.n_cpu_workers is None: + self.n_cpu_workers = 0 + logger_all.debug( + "Some GPU.s found and n_cpu_wokers is None, defaulting to n_cpu_workers = 0" + ) + if not n_gpus and self.n_cpu_workers is None: + self.n_cpu_workers = os.cpu_count() - 1 + logger_all.debug( + "No GPUs found and n_cpu_wokers is None, defaulting to n_cpu_workers = " + "{n_cpu_workers} (os.cpu_count() - 1)".format( + n_cpu_workers=self.n_cpu_workers + ) + ) + if ( + self.gpu_ids is None + and (self.n_cpu_workers == 0 or self.n_cpu_workers is None) + and not self.multiple_hosts + ): + raise ValueError("No hardware (cpu/gpu) selected/found.") + + # log progress with queue and progress_listener + pbar = None + if not self.multiple_hosts and logger.level >= logging.WARNING: + pbar = tqdm.tqdm(total=self.max_evals) + progress_queue = multiprocessing.Queue() + started_queues.append(progress_queue) + prog_listener = ProgressListener(progress_queue=progress_queue, pbar=pbar) + prog_listener.start() + started_threads.append(prog_listener) + + running_workers = [] + # launch gpu workers + logger_all.info( + "Starting {n_workers_per_gpu} worker.s for each of the {n_gpus} gpu.s set for use/" + "found.".format( + n_workers_per_gpu=self.n_workers_per_gpu, n_gpus=len(self.gpu_ids) + ) + ) + for gpu_id in self.gpu_ids: + for sub_id in range(self.n_workers_per_gpu): + worker = HyperoptWorker( + progress_queue=progress_queue, + logging_queue=self.logging_queue, + exp_key=self.exp_key, + workdir=self.workdir, + gpu=True, + hw_id=str(gpu_id), + reserve_timeout=self.reserve_timeout, + mongo_url=self.mongo_url, + name="Worker GPU " + str(gpu_id) + ":" + str(sub_id), + ) + worker.start() + running_workers.append(worker) + + # launch cpu workers + logger_all.info( + "Starting {n_cpu_workers} cpu worker.s".format( + n_cpu_workers=self.n_cpu_workers + ) + ) + for cpu_id in range(self.n_cpu_workers): + worker = HyperoptWorker( + progress_queue=progress_queue, + logging_queue=self.logging_queue, + exp_key=self.exp_key, + workdir=self.workdir, + gpu=False, + hw_id=str(cpu_id), + reserve_timeout=self.reserve_timeout, + mongo_url=self.mongo_url, + name="Worker CPU " + str(cpu_id), + ) + worker.start() + running_workers.append(worker) + started_processes.extend(running_workers) + + # wait or return if all workers have died + while not self.stop_event.is_set(): + n_alive = 0 + for worker in running_workers: + n_alive += 1 if worker.is_alive() else n_alive + if n_alive == 0: + logger_all.debug( + "All workers have died, check stdout/stderr for error tracebacks." + ) + break + logger_all.debug( + "Worker watchdog finished, terminating workers and stopping listener." + ) + for worker in running_workers: + if worker.is_alive(): + worker.terminate() + prog_listener.stop_event.set() + prog_listener.join() + + +class ProgressListener(StoppableThread): + """Listens to workers when they finish a job and logs progress. + + Workers put in the progress_queue when they finish a job + and when they do this function sends a log to the progress logger. + """ + + def __init__(self, progress_queue: multiprocessing.Queue, pbar: tqdm.tqdm = None): + super().__init__(name="Progress Listener") + self.progress_queue = progress_queue + self.pbar = pbar + + @_error_logger_decorator + def run(self): + logger_all.debug("Listener listening...") + + i = 0 + while not self.stop_event.is_set(): + # get job done signal + try: + self.progress_queue.get(block=False) + i += 1 + logger_all.info("{i} job.s done".format(i=i)) + # update progress bar through ProgressHandler + if self.pbar is not None: + self.pbar.update() + except Empty: + pass + time.sleep(5) + if self.pbar is not None: + self.pbar.close() + self.progress_queue.close() + + +class HyperoptWorker(multiprocessing.Process): + """Launches a ``hyperopt`` ``MongoWorker`` which runs jobs until ``ReserveTimeout`` is raised. + + :param progress_queue: Queue in which to put None when a job is done. + :param logging_queue: Queue to send logs to main process using a ``QueueHandler``. + :param exp_key: This key is used by hyperopt as a suffix to the part of the MongoDb + which corresponds to the current experiment. In particular, it has to be passed to ``MongoWorker``. + :param workdir: + :param gpu: If ``True`` means a GPU is to be used. + :param hw_id: Id of the GPU to use. set via env variable ``CUDA_VISIBLE_DEVICES``. + :param poll_interval: Time to wait between attempts to reserve a job. + :param reserve_timeout: Amount of time, in seconds, a worker tries to reserve a job for + before throwing a ``ReserveTimeout`` Exception. + :param mongo_url: Address to the running MongoDb service. + """ + + def __init__( + self, + name: str, + progress_queue: multiprocessing.Queue, + logging_queue: multiprocessing.Queue, + exp_key: str, + workdir: str = ".", + gpu: bool = True, + hw_id: str = None, + poll_interval: float = 1.0, + reserve_timeout: float = 30.0, + mongo_url: str = "localhost:1234/scvi_db", + ): + super().__init__(name=name) + self.progress_queue = progress_queue + self.logging_queue = logging_queue + self.exp_key = exp_key + self.workdir = workdir + self.gpu = gpu + self.hw_id = hw_id + self.poll_interval = poll_interval + self.reserve_timeout = reserve_timeout + self.mongo_url = mongo_url + + @_asynchronous_logging_method_decorator + @_error_logger_decorator + def run(self): + logger_all.debug("Worker working...") + + os.environ["CUDA_VISIBLE_DEVICES"] = self.hw_id if self.gpu else str() + + mjobs = MongoJobs.new_from_connection_str( + os.path.join(as_mongo_str(self.mongo_url), "jobs") + ) + mworker = MongoWorker( + mjobs, float(self.poll_interval), workdir=self.workdir, exp_key=self.exp_key + ) + + while True: + try: + mworker.run_one(reserve_timeout=float(self.reserve_timeout)) + self.progress_queue.put(None) + except ReserveTimeout: + logger_all.debug( + "Caught ReserveTimeout. " + "Exiting after failing to reserve job for {time} seconds.".format( + time=self.reserve_timeout + ) + ) + break + + +@_error_logger_decorator +def _objective_function( + space: dict, + gene_dataset: GeneExpressionDataset, + delayed_populating: bool = False, + model_class: Type[VAE] = VAE, + trainer_class: Type[Trainer] = UnsupervisedTrainer, + metric_name: str = None, + metric_kwargs: Dict[str, Any] = None, + posterior_name: str = "test_set", + model_specific_kwargs: dict = None, + trainer_specific_kwargs: dict = None, + train_func_specific_kwargs: dict = None, + use_batches: bool = False, + is_best_training: bool = False, +) -> Union[Dict[str, Any], Trainer]: + """Objective function for automatic hyperparameter optimization. + Train a scVI model and return the best value of the early-stopping metric (e.g, log-likelihood). + Convention: fixed parameters (no default) have precedence over tunable parameters (default). + + :param space: dict containing up to three sub-dicts with keys "model_tunable_kwargs", + "trainer_tunable_kwargs" or "train_func_tunable_kwargs". + Each of those dict contains hyperopt defined parameter spaces (e.g. ``hp.choice(..)``) + which will be passed to the corresponding object : model, trainer or train method + when performing hyperoptimization. + :param gene_dataset: scVI gene dataset + :param model_class: scVI model class (e.g ``VAE``, ``VAEC``, ``SCANVI``) + :param trainer_class: Trainer class (e.g ``UnsupervisedTrainer``) + :param metric_name: Name of the metric to optimize for. If `None` defaults to "marginal_ll" + :param metric_kwargs: keyword arguments for the metric method. + If `metric_name` is None, defaults to {"n_mc_samples": 100}. + :param posterior_name: Name of the posterior distribution to compute the metric with. + :param model_specific_kwargs: ``dict`` of fixed parameters which will be passed to the model. + :param model_specific_kwargs: dict of fixed parameters which will be passed to the model. + :param trainer_specific_kwargs: dict of fixed parameters which will be passed to the trainer. + :param train_func_specific_kwargs: dict of fixed parameters which will be passed to the train method. + :param use_batches: If False, pass n_batch=0 to model else pass gene_dataset.n_batches + :param is_best_training: True if training the model with the best hyperparameters + :return: best value of the early stopping metric, and best model if is_best_training + """ + # handle mutable defaults + metric_kwargs = metric_kwargs if metric_kwargs is not None else {} + + if delayed_populating and isinstance(gene_dataset, DownloadableDataset): + gene_dataset.populate() + + start_time = time.monotonic() + # hyperopt params + space = defaultdict(dict, space) + model_tunable_kwargs = space["model_tunable_kwargs"] + trainer_tunable_kwargs = space["trainer_tunable_kwargs"] + train_func_tunable_kwargs = space["train_func_tunable_kwargs"] + + # use_cuda default + if "use_cuda" not in trainer_specific_kwargs: + trainer_specific_kwargs["use_cuda"] = bool(torch.cuda.device_count()) + if "n_epochs" not in {**train_func_specific_kwargs, **train_func_tunable_kwargs}: + train_func_specific_kwargs["n_epochs"] = 1000 + + # add hardcoded parameters + # disable scVI progbar + trainer_specific_kwargs["show_progbar"] = False + if is_best_training: + trainer_specific_kwargs["train_size"] = 1.0 + # no monitoring, will crash otherwise + trainer_specific_kwargs["frequency"] = None + trainer_specific_kwargs["early_stopping_kwargs"] = {} + else: + # evaluate at each epoch + trainer_specific_kwargs["frequency"] = 1 + + # merge params with fixed param precedence + model_tunable_kwargs.update(model_specific_kwargs) + trainer_tunable_kwargs.update(trainer_specific_kwargs) + train_func_tunable_kwargs.update(train_func_specific_kwargs) + + if not is_best_training: + logger_all.info( + "Parameters being tested: \n" + "model: \n" + + str(model_tunable_kwargs) + + "\n" + + "trainer: \n" + + str(trainer_tunable_kwargs) + + "\n" + + "train method: \n" + + str(train_func_tunable_kwargs) + ) + + # define model + logger_all.debug("Instantiating model") + model = model_class( + n_input=gene_dataset.nb_genes, + n_batch=gene_dataset.n_batches * use_batches, + **model_tunable_kwargs, + ) + + # define trainer + logger_all.debug("Instantiating trainer") + trainer = trainer_class(model, gene_dataset, **trainer_tunable_kwargs) + + # train model + logger_all.debug("Starting training") + trainer.train(**train_func_tunable_kwargs) + logger_all.debug("Finished training") + elapsed_time = time.monotonic() - start_time + # if training the best model, return model else return criterion + if is_best_training: + return trainer + else: + # select metric from early stopping kwargs if possible + metric = None + save_best_state_metric = None + early_stopping_kwargs = trainer_specific_kwargs.get( + "early_stopping_kwargs", None + ) + if early_stopping_kwargs is not None: + metric = early_stopping_kwargs.get("early_stopping_metric", None) + save_best_state_metric = early_stopping_kwargs.get( + "save_best_state_metric", None + ) + + # store run results + if metric is not None: + early_stopping_loss_is_best = True + best_epoch = trainer.best_epoch + # add actual number of epochs to be used when training best model + space["train_func_tunable_kwargs"]["n_epochs"] = best_epoch + early_stopping_loss = trainer.early_stopping.best_performance + metric += "_" + trainer.early_stopping.on + # default to elbo + else: + early_stopping_loss_is_best = False + metric = "elbo_test_set" + early_stopping_loss = trainer.history[metric][-1] + best_epoch = len(trainer.history[metric]) + + # load best state + if save_best_state_metric is not None: + model.load_state_dict(trainer.best_state_dict) + + # compute optimized metric + loss = getattr(getattr(trainer, posterior_name), metric_name)(**metric_kwargs) + + logger_all.debug( + "Training of {n_epochs} epochs finished in {time} with loss = {loss}".format( + n_epochs=len(trainer.history[metric]), + time=str(datetime.timedelta(seconds=elapsed_time)), + loss=loss, + ) + ) + + # check status + status = STATUS_OK + if np.isnan(loss): + status = STATUS_FAIL + + return { + "loss": loss, + "early_stopping_loss": early_stopping_loss, + "early_stopping_loss_is_best": early_stopping_loss_is_best, + "best_epoch": best_epoch, + "elapsed_time": elapsed_time, + "status": status, + "history": trainer.history, + "space": space, + "worker_name": multiprocessing.current_process().name, + } diff --git a/Singlecell_multi_omics/src/core/data/README.md b/Singlecell_multi_omics/src/core/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ad30b7fe12d3cca62e0bf5f9c44ad22f998a5344 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/README.md @@ -0,0 +1,54 @@ +# Available datasets + +The full descriptions of the datasets and the studies of origin can be found in the manuscript. Here we provide the links to access the processed datasets. + +## Pretraining data + +To download data from CellXGene and build for pretraining, go to the folder [cellxgene](cellxgene) and follow the instructions. + +## Datasets for cell type annotation + +- Multiple Sclerosis (M.S.) dataset: [link](https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v?usp=sharing) + +- Myeloid (Mye.) dataset: [link](https://drive.google.com/drive/folders/1VbpApQufZq8efFGakW3y8QDDpY9MBoDS?usp=drive_link) + +- hPancreas dataset: [link](https://drive.google.com/drive/folders/1s9XjcSiPC-FYV3VeHrEa7SeZetrthQVV?usp=drive_link) + +## Datasets for multi-batch integration + +- PBMC 10K: [link](https://docs.scvi-tools.org/en/stable/api/reference/scvi.data.pbmc_dataset.html) + +- Perirhinal Cortex dataset: [link](https://drive.google.com/file/d/1rDAxDtvWx1GpJaNhlKBi71f8-psUNppE/view?usp=drive_link) + +- COVID-19 dataset: [link](https://drive.google.com/file/d/1eD9LbxNJ35YUde3VtdVcjkwm-f4iyJ6x/view?usp=drive_link) + +## Datasets for multi-omics integration + +- BMMC dataset: [link](https://drive.google.com/drive/folders/1VRsVugg6vgCq8GG0gGajYsyXfrEtP0jK?usp=sharing) + +- 10x Multiome PBMC dataset: [link](https://drive.google.com/drive/folders/163J4Qi7R-awuLiHnWCh-eJD7RPMnb_yK?usp=sharing) + +## Datasets for perturbation prediction + +- Adamson dataset: [link](https://dataverse.harvard.edu/api/access/datafile/6154417) + +- Norman dataset: [link](https://dataverse.harvard.edu/api/access/datafile/6154020) + +## Datasets for the GRN analysis + +- Immune Human dataset [link](https://figshare.com/ndownloader/files/25717328) + +## Datasets for zero-shot integration + +- Lung-Kim dataset: [link](https://drive.google.com/file/d/1z_0vWYMhRuRiD1EyhuFtY9ReIR0msWaL/view?usp=sharing) + +- COVID-19 dataset: [link](https://drive.google.com/file/d/1eD9LbxNJ35YUde3VtdVcjkwm-f4iyJ6x/view?usp=drive_link) + +- Multiple Sclerosis (M.S.) dataset: [link](https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v?usp=sharing) + + +## Datasets for zero-shot integration + +- COVID-19 dataset(splitted) : [link](https://drive.google.com/drive/folders/1jSPoPunGQOmd71vDsK0FS7UvmDhGdhQS?usp=sharing) + +- Lung-Kim dataset(splitted): [link](https://drive.google.com/drive/folders/1gbfO7VqxCOkfzgHAih6hO88zFv6pd8wO?usp=sharing) diff --git a/Singlecell_multi_omics/src/core/data/__init__.py b/Singlecell_multi_omics/src/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/README.md b/Singlecell_multi_omics/src/core/data/cellxgene/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1c9ab6e48b3ce32abb3c868d20af5b9d38f0e60d --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/README.md @@ -0,0 +1,58 @@ +# Build Training Cell Corpus from Cellxgene Census + +- This documentation describes the procedure for building the pre-training cell corpus from the cellxgene census. +- Please note that this script is designed to run on a cluster with the SLURM workload manager for parallelization. +- You may need to modify the scripts to run on your own system. +- Internet access is required for querying the cellxgene census dataset. +- The scripts referred to in this document are located in the `/data/cellxgene` directory. + +## General Workflow for Cell Corpus Construction +- The general workflow is: + + 0. (Optional) Configure the query list and query conditions. + 1. Build the cell index files based on query + 2. Download the dataset in `h5ad` chunks + 3. Transform the `h5ad` into `scb` (single-cell bank for high-performance IO) + +## (Optional) Configure the Query List and Query Conditions +- If you wish to customize your pre-training dataset, you may modify the `data_config.py` file and `query_list.txt` file. +- In the `data_config.py` file, + - `MAJOR_TISSUE_LIST` refers to the general organ system defined in the cellxgene census; it defines the resolution we used to store the cells. + - `VERSION` refers to the version of the cellxgene census; we used the version `2023-05-08` for our experiments. You may change it to the latest/LTS version. Check out the [cellxgene census release plan](https://chanzuckerberg.github.io/cellxgene-census/cellxgene_census_docsite_data_release_info.html) for more information. + - As we only use normal cells for pre-training, we filter the dataset by the `DISEASE` column in the cellxgene census. + - For the `pan-cancer` model, we filter the dataset by the `DISEASE` column in the cellxgene census. The filtered cancer list is defined in the `cancer_list.txt` file. You may modify it according to your own needs. + +## Build the Cell Index Files Based on Query + +- We first query cells from the cellxgene census and filter the cells according to our needs. + - `INDEX_PATH` is the path to the cell index file (to be generated), cell index is the SOMA id (unique index in cellxgene census) for each cell in the cellxgene census. + - `QUERY_PATH` is the path to the query file; each line in the query file is a general organ system defined in the cellxgene census. +- Replace them in the following command and run it to generate the cell index file: + +```{bash} +INDEX_PATH="path/to/index" +QUERY_PATH="path/to/query" + +./build_soma_idx.sh $INDEX_PATH $QUERY_PATH +``` + +## Download the Dataset in Chunks +- We download the dataset in chunks; each chunk contains a maximum of 200000 cells, and the chunk size can be modified by changing the `MAX_PARTITION_SIZE` in the `download_partition.sh` file. +- Before running the script, you need to modify the `DATA_PATH`, `QUERY_PATH` and `INDEX_PATH` in the `array_download_partition.sh` file. + - Keep the `INDEX_PATH` and `QUERY_PATH` consistent with the previous step. + - `DATA_PATH` is the path to the directory to store the downloaded dataset. The resulting dataset will be stored in the `h5ad` format. +- Submit it to download the dataset (each compute node will need internet access): +```{bash} +sbatch array_download_partition.sh +``` + +## Build the `scb` Files +- We preprocess the dataset and then transform the `h5ad` into `scb` (single-cell bank for high-performance I/O). +- Before running the script, you need to modify the `DATA_PATH`, `OUTPUT_PATH`, `QUERY_PATH`, and `VOCAB_PATH` in the `array_build_scb.sh` file. + - Keep the `DATA_PATH` and `QUERY_PATH` consistent with the previous step. + - `OUTPUT_PATH` is the path to store the `scb` files. + - `VOCAB_PATH` is the path to the vocabulary file, which is used to map the gene id to token id. +- Then simply submit the job to the cluster by: +```{bash} +sbatch array_build_scb.sh +``` \ No newline at end of file diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/array_build_scb.sh b/Singlecell_multi_omics/src/core/data/cellxgene/array_build_scb.sh new file mode 100644 index 0000000000000000000000000000000000000000..55c722b780faabd7d9afcd89f132ae372b543454 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/array_build_scb.sh @@ -0,0 +1,31 @@ +#!/bin/bash +#SBATCH --time=8:00:00 +#SBATCH --cpus-per-task=8 +#SBATCH --array=1-9 +#SBATCH --mem=96G +#SBATCH --qos=nopreemption +#SBATCH -p cpu + + +QUERY_PATH="path/to/query.txt" + + +query_name=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $QUERY_PATH) + +DATA_PATH="path/to/data/${query_name}" +OUTPUT_PATH="path/to/output/${query_name}" +VOCAB_PATH="path/to/vocab" + +echo "processing ${query_name}" +N=200000 + + +mkdir -p $OUTPUT_PATH + +echo "downloading to ${OUTPUT_PATH}" + +python build_large_scale_data.py \ + --input-dir ${DATA_PATH} \ + --output-dir ${OUTPUT_PATH} \ + --vocab-file ${VOCAB_PATH} \ + --N ${N} diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/array_download_partition.sh b/Singlecell_multi_omics/src/core/data/cellxgene/array_download_partition.sh new file mode 100644 index 0000000000000000000000000000000000000000..09c40f187730e2c67fd6304a890011fcff7825f7 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/array_download_partition.sh @@ -0,0 +1,21 @@ +#!/bin/bash +#SBATCH --time=8:00:00 +#SBATCH --cpus-per-task=4 +#SBATCH --array=1-9 +#SBATCH --mem=48G +#SBATCH --qos=nopreemption +#SBATCH -p cpu + + + +INDEX_PATH="path/to/index" +QUERY_PATH="path/to/query" +DATA_PATH="path/to/data" + +cd $DATA_PATH + +query_name=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $QUERY_PATH) + +echo "downloading ${query_name}" + +./download_partition.sh ${query_name} ${INDEX_PATH} ${DATA_PATH} \ No newline at end of file diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/array_process_allcounts.sh b/Singlecell_multi_omics/src/core/data/cellxgene/array_process_allcounts.sh new file mode 100644 index 0000000000000000000000000000000000000000..508a4992c950a8d19d68916ff40222f215a620b8 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/array_process_allcounts.sh @@ -0,0 +1,53 @@ +#!/bin/sh +#SBATCH --time=8:00:00 +#SBATCH --cpus-per-task=8 +#SBATCH --array=1,2,4,5,7,8 +#SBATCH --mem=48G +#SBATCH --qos=nopreemption +#SBATCH -p cpu + +QUERY_PATH="query_list.txt" + +query_name=$(sed -n "${SLURM_ARRAY_TASK_ID}p" $QUERY_PATH) + +echo "processing ${query_name}" + +DATASET="/scratch/ssd004/datasets/cellxgene/scb_strict/${query_name}/all_counts" +VOCAB_PATH="/scratch/ssd004/datasets/cellxgene/scFormer/scformer/tokenizer/default_census_vocab.json" +bash ~/.bashrc + +NPROC=$SLURM_GPUS_ON_NODE +JOB_NAME="cellxgene_census_${QUERY_NAME}" +LOG_INTERVAL=2000 +VALID_SIZE_OR_RATIO=0.03 +MAX_LENGTH=1200 +per_proc_batch_size=32 +LAYERS=12 +MODEL_SCALE=8 +SAVE_DIR="/scratch/ssd004/datasets/cellxgene/profile_tmp" +# others, pancreas, lung, kidney, heart, blood +alias python_=~/.cache/pypoetry/virtualenvs/scformer-9yG_XnDJ-py3.9/bin/python +python_ -c "import torch; print(torch.version.cuda)" +python_ process_allcounts.py \ + --data-source $DATASET \ + --save-dir ${SAVE_DIR}/${JOB_NAME}-$(date +%b%d-%H-%M-%Y) \ + --vocab-path ${VOCAB_PATH} \ + --valid-size-or-ratio $VALID_SIZE_OR_RATIO \ + --max-seq-len $MAX_LENGTH \ + --batch-size $per_proc_batch_size \ + --eval-batch-size $(($per_proc_batch_size * 2)) \ + --nlayers $LAYERS \ + --nheads 8 \ + --embsize $((MODEL_SCALE * 64)) \ + --d-hid $((MODEL_SCALE * 64)) \ + --grad-accu-steps $((128 / $per_proc_batch_size)) \ + --epochs 2 \ + --lr 0.0001 \ + --warmup-ratio-or-step 10000 \ + --log-interval $LOG_INTERVAL \ + --save-interval $(($LOG_INTERVAL * 3)) \ + --trunc-by-sample \ + --no-cls \ + --no-cce \ + --fp16 | + awk '{ print strftime("[%Y-%m-%d %H:%M:%S]"), $0; fflush(); }' diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/build_large_scale_data.py b/Singlecell_multi_omics/src/core/data/cellxgene/build_large_scale_data.py new file mode 100644 index 0000000000000000000000000000000000000000..78f219d8842f72ffbaa0c583797fb63a97c3b381 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/build_large_scale_data.py @@ -0,0 +1,224 @@ +# build large-scale data in scBank format from a group of AnnData objects +# %% +import gc +import json +from pathlib import Path +import argparse +import shutil +import traceback +from typing import Dict, List, Optional +import warnings +import numpy as np +import os + +import scanpy as sc + +import sys + +sys.path.insert(0, "../../") +import scgpt as scg +from scgpt import scbank + +# %% +parser = argparse.ArgumentParser( + description="Build large-scale data in scBank format from a group of AnnData objects" +) +parser.add_argument( + "--input-dir", + type=str, + required=True, + help="Directory containing AnnData objects", +) +parser.add_argument( + "--output-dir", + type=str, + default="./data.scb", + help="Directory to save scBank data, by default will make a directory named " + "data.scb in the current directory", +) +parser.add_argument( + "--include-files", + type=str, + nargs="*", + help="Space separated file names to include, default to all files in input_dir", +) +parser.add_argument( + "--metainfo", + type=str, + default=None, + help="Json file containing meta information for each dataset, default to None.", +) + +# vocabulary +parser.add_argument( + "--vocab-file", + type=str, + default=None, + help="File containing the gene vocabulary, default to None. If None, will " + "use the default gene vocabulary from scGPT, which use HGNC gene symbols.", +) + +parser.add_argument( + "--N", + type=int, + default=10000, + help="Hyperparam for filtering genes, default to 10000.", +) + + +# if scg.utils.isnotebook(): +# args = parser.parse_args( +# [ +# "--input-dir", +# "./datasets/", +# "--output-dir", +# "./databanks/", +# "--include-files", +# "f72958f5-7f42-4ebb-98da-445b0c6de516.h5ad", +# "--metainfo", +# "./metainfo.json", +# "--vocab-file", +# "../../scgpt/tokenizer/default_cellxgene_vocab.json", +# ] +# ) +# else: +args = parser.parse_args() + +"""command line example +python build_large_scale_data.py \ + --input-dir ./datasets/ \ + --output-dir ./databanks/ \ + --metainfo ./metainfo.json \ + --vocab-file ../../scgpt/tokenizer/default_cellxgene_vocab.json +""" + +# %% +print(args) + +input_dir = Path(args.input_dir) +output_dir = Path(args.output_dir) +files = [f for f in input_dir.glob("*.h5ad")] +print(f"Found {len(files)} files in {input_dir}") +if args.include_files is not None: + files = [f for f in files if f.name in args.include_files] +if args.metainfo is not None: + metainfo = json.load(open(args.metainfo)) + files = [f for f in files if f.stem in metainfo] + include_obs = { + f.stem: {"disease": metainfo[f.stem]["include_disease"]} + for f in files + if "include_disease" in metainfo[f.stem] + } + +if args.vocab_file is None: + vocab = scg.tokenizer.get_default_gene_vocab() +else: + vocab = scg.tokenizer.GeneVocab.from_file(args.vocab_file) + +# %% [markdown] +# # preprocessing data + + +def preprocess( + adata: sc.AnnData, + main_table_key: str = "counts", + include_obs: Optional[Dict[str, List[str]]] = None, + N=10000, +) -> sc.AnnData: + """ + Preprocess the data for scBank. This function will modify the AnnData object in place. + + Args: + adata: AnnData object to preprocess + main_table_key: key in adata.layers to store the main table + include_obs: dict of column names and values to include in the main table + + Returns: + The preprocessed AnnData object + """ + if include_obs is not None: + # include only cells that have the specified values in the specified columns + for col, values in include_obs.items(): + adata = adata[adata.obs[col].isin(values)] + + # filter genes + sc.pp.filter_genes(adata, min_counts=(3 / 10000) * N) + + # TODO: add binning in sparse matrix and save in separate datatable + # preprocessor = Preprocessor( + # use_key="X", # the key in adata.layers to use as raw data + # filter_gene_by_counts=False, # step 1 + # filter_cell_by_counts=False, # step 2 + # normalize_total=False, # 3. whether to normalize the raw data and to what sum + # log1p=False, # 4. whether to log1p the normalized data + # binning=51, # 6. whether to bin the raw data and to what number of bins + # result_binned_key="X_binned", # the key in adata.layers to store the binned data + # ) + # preprocessor(adata) + + adata.layers[main_table_key] = adata.X.copy() # preserve counts + # sc.pp.normalize_total(adata, target_sum=1e4) + # sc.pp.log1p(adata) + # adata.raw = adata # freeze the state in `.raw` + + # apply a hard clip to the data for now + print( + f"original mean and max of counts: {adata.layers[main_table_key].mean():.2f}, " + f"{adata.layers[main_table_key].max():.2f}" + ) + # if isinstance(adata.layers[main_table_key], np.ndarray): + # adata.layers[main_table_key] = adata.layers[main_table_key].clip(0, 30) + # else: # assume it is a sparse matrix + # adata.layers[main_table_key].data = adata.layers[main_table_key].data.clip(0, 30) + + return adata + + +# %% +main_table_key = "counts" +token_col = "feature_name" +for f in files: + try: + adata = sc.read(f, cache=True) + adata = preprocess(adata, main_table_key, N=args.N) + print(f"read {adata.shape} valid data from {f.name}") + + # TODO: CHECK AND EXPAND VOCABULARY IF NEEDED + # NOTE: do not simply expand, need to check whether to use the same style of gene names + + # BUILD SCBANK DATA + db = scbank.DataBank.from_anndata( + adata, + vocab=vocab, + to=output_dir / f"{f.stem}.scb", + main_table_key=main_table_key, + token_col=token_col, + immediate_save=False, + ) + db.meta_info.on_disk_format = "parquet" + # sync all to disk + db.sync() + + # clean up + del adata + del db + gc.collect() + except Exception as e: + traceback.print_exc() + warnings.warn(f"failed to process {f.name}: {e}") + shutil.rmtree(output_dir / f"{f.stem}.scb", ignore_errors=True) + +# or run scbank.DataBank.batch_from_anndata(files, to=args.output_dir) +# %% +# test loading from disk +# db = scbank.DataBank.from_path(args.output_dir) + +# %% run this to copy all parquet datatables to a single directory +target_dir = output_dir / f"all_{main_table_key}" +target_dir.mkdir(exist_ok=True) +for f in files: + output_parquet_dt = ( + output_dir / f"{f.stem}.scb" / f"{main_table_key}.datatable.parquet" + ) + if output_parquet_dt.exists(): + os.symlink(output_parquet_dt, target_dir / f"{f.stem}.datatable.parquet") diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.py b/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.py new file mode 100644 index 0000000000000000000000000000000000000000..e2add15bbab6abc4214c0302207a4c1abaa9b88a --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.py @@ -0,0 +1,71 @@ +### This script is used to retrieve cell soma ids from cellxgene census + +import cellxgene_census +from data_config import VALUE_FILTER, VERSION +from typing import List +import os +import argparse + +parser = argparse.ArgumentParser( + description='Build soma index list based on query') + + +parser.add_argument("--query-name", + type=str, + required=True, + help="query name to build the index", +) + +parser.add_argument("--output-dir", + type=str, + required=True, + help="Directory to store the output idx file", +) + +args = parser.parse_args() +# print(args) + + +def retrieve_soma_idx(query_name) -> List[str]: + """ + This function is used to retrieve cell soma ids from cellxgene census based on the query name + """ + + with cellxgene_census.open_soma(census_version=VERSION) as census: + cell_metadata = census["census_data"]["homo_sapiens"].obs.read( + value_filter = VALUE_FILTER[query_name], + column_names = ["soma_joinid"] + ) + cell_metadata = cell_metadata.concat() + cell_metadata = cell_metadata.to_pandas() + return cell_metadata["soma_joinid"].to_list() + +def convert2file(idx_list: List[str], query_name: str, output_dir: str) -> None: + """ + This function is used to convert the retrieved idx_list to file by query_name + """ + + # set up the dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + file_path = os.path.join(output_dir, f"{query_name}.idx") + + # write to the file + with open(file_path, 'w') as fp: + for item in idx_list: + fp.write("%s\n" % item) + +def build_soma_idx(query_name, output_dir) -> None: + """ + This function is used to build the soma idx for cells under query_name + """ + idx_list = retrieve_soma_idx(query_name) + convert2file(idx_list, query_name, output_dir) + + +# if __name__ == "__main__": +# build_soma_idx("heart") + +build_soma_idx(args.query_name, args.output_dir) + + diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.sh b/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6486f94236c7388fa38163f329df778fdff3cad --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/build_soma_idx.sh @@ -0,0 +1,10 @@ +#!/bin/sh +# output directory for the index +OUTPUT_DIR=$1 +QUERY_LIST=$2 + +while read QUERY; do + echo "building index for ${QUERY}" + python3 ./build_soma_idx.py --query-name ${QUERY} --output-dir ${OUTPUT_DIR} +done < ${QUERY_LIST} + diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/cancer_list.txt b/Singlecell_multi_omics/src/core/data/cellxgene/cancer_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..727adc1ca308f6938b95b8c14e0644979d673bf3 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/cancer_list.txt @@ -0,0 +1,21 @@ +malignant ovarian serous tumor +glioblastoma +lung adenocarcinoma +squamous cell lung carcinoma +small cell lung carcinoma +non-small cell lung carcinoma +B-cell non-Hodgkin lymphoma +follicular lymphoma +gastric cancer +blastoma +pilocytic astrocytoma +acute myeloid leukemia +tubular adenoma +clear cell renal carcinoma +adenocarcinoma +tubulovillous adenoma +colorectal cancer +Wilms tumor +acute promyelocytic leukemia +neuroendocrine carcinoma +chromophobe renal cell carcinoma diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/data_config.py b/Singlecell_multi_omics/src/core/data/cellxgene/data_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ae46283995b2d3d9168192c7b84c7aaa6842e6de --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/data_config.py @@ -0,0 +1,33 @@ + + + +MAJOR_TISSUE_LIST = ["heart", "blood", "brain", "lung", "kidney", "intestine", "pancreas"] +VERSION = "2023-05-08" + +CANCER_LIST_PATH = "./cancer_list.txt" +with open(CANCER_LIST_PATH) as f: + CANCER_LIST = [line.rstrip('\n') for line in f] + +# build the value filter dict for each tissue +VALUE_FILTER = { + tissue : f"suspension_type != 'na' and disease == 'normal' and tissue_general == '{tissue}'" for tissue in MAJOR_TISSUE_LIST +} +# build the value filter dict for cells related with other tissues +# since tileDB does not support `not in ` operator, we will just use `!=` to filter out the other tissues +VALUE_FILTER["others"] = f"suspension_type != 'na' and disease == 'normal'" +for tissue in MAJOR_TISSUE_LIST: + VALUE_FILTER["others"] = f"{VALUE_FILTER['others']} and (tissue_general != '{tissue}')" + +VALUE_FILTER['pan-cancer'] = f"suspension_type != 'na'" +cancer_condition = "" +for disease in CANCER_LIST: + if cancer_condition == "": + cancer_condition = f"(disease == '{disease}')" + else: + cancer_condition = f"{cancer_condition} or (disease == '{disease}')" +VALUE_FILTER['pan-cancer'] = f"(suspension_type != 'na') and ({cancer_condition})" + +if __name__ == "__main__": + # print(VALUE_FILTER["others"]) + # print(MAJOR_TISSUE_LIST) + print(VALUE_FILTER['pan-cancer']) diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.py b/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.py new file mode 100644 index 0000000000000000000000000000000000000000..32df29b636c5ae6f5088b75297ce162835e48c23 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.py @@ -0,0 +1,103 @@ +import cellxgene_census +import pandas as pd +import numpy as np +from data_config import VERSION +from typing import List +import os +import argparse + + +parser = argparse.ArgumentParser( + description='Download a given partition cell of the query in h5ad') + +parser.add_argument("--query-name", + type=str, + required=True, + help="query name to build the index", +) + +parser.add_argument("--partition-idx", + type=int, + required=True, + help="partition index to download", +) +parser.add_argument("--output-dir", + type=str, + required=True, + help="Directory to store the output h4ad file", +) + +parser.add_argument("--index-dir", + type=str, + required=True, + help="Directory to find the index file", +) + +parser.add_argument("--max-partition-size", + type=int, + required=True, + help="The max partition size for each partition(chunk)", +) + + +args = parser.parse_args() + +# print(args) + + + + +def define_partition(partition_idx, id_list, partition_size) -> List[str]: + """ + This function is used to define the partition for each job + + partition_idx is the partition index, which is an integer, and 0 <= partition_idx <= len(id_list) // MAX_PARTITION_SIZE + """ + i = partition_idx * partition_size + return id_list[i:i + partition_size] + + +def load2list(query_name, soma_id_dir) -> List[int]: + """ + This function is used to load the idx list from file + """ + file_path = os.path.join(soma_id_dir, f"{query_name}.idx") + with open(file_path, 'r') as fp: + idx_list = fp.readlines() + idx_list = [int(x.strip()) for x in idx_list] + return idx_list + +def download_partition(partition_idx, query_name, output_dir, index_dir, partition_size): + """ + This function is used to download the partition_idx partition of the query_name + """ + # define id partition + id_list = load2list(query_name, index_dir) + id_partition = define_partition(partition_idx, id_list, partition_size) + with cellxgene_census.open_soma(census_version=VERSION) as census: + adata = cellxgene_census.get_anndata(census, + organism="Homo sapiens", + obs_coords=id_partition, + ) + # prepare the query dir if not exist + query_dir = os.path.join(output_dir, query_name) + if not os.path.exists(query_dir): + os.makedirs(query_dir) + query_adata_path = os.path.join(query_dir, f"partition_{partition_idx}.h5ad") + adata.write_h5ad(query_adata_path) + return query_adata_path + +def del_partition(partition_idx, query_name, output_dir, index_dir, partition_size): + query_dir = os.path.join(output_dir, query_name) + query_adata_path = os.path.join(query_dir, f"partition_{partition_idx}.h5ad") + os.remove(query_adata_path) + + +if __name__ == "__main__": + + download_partition(partition_idx=args.partition_idx, + query_name=args.query_name, + output_dir=args.output_dir, + index_dir=args.index_dir, + partition_size=args.max_partition_size + ) \ No newline at end of file diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.sh b/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.sh new file mode 100644 index 0000000000000000000000000000000000000000..344da6f9502dcee2fd99c4b4ef721e7a5e3457e8 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/download_partition.sh @@ -0,0 +1,22 @@ +#!/bin/sh +QUERY=$1 +INDEX_DIR=$2 +OUTPUT_DIR=$3 + +MAX_PARTITION_SIZE=200000 + +total_num=`wc -l ${INDEX_DIR}/${QUERY}.idx | awk '{ print $1 }'` +total_partition=$(($total_num / $MAX_PARTITION_SIZE)) +# echo $total_num +# echo $total_partition" + +for i in $(seq 0 $total_partition) +do + echo "downloading partition ${i}/${total_partition} for ${QUERY}" + python3 ./download_partition.py \ + --query-name ${QUERY} \ + --index-dir ${INDEX_DIR} \ + --output-dir ${OUTPUT_DIR} \ + --partition-idx ${i} \ + --max-partition-size ${MAX_PARTITION_SIZE} +done \ No newline at end of file diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/expand_gene_list.py b/Singlecell_multi_omics/src/core/data/cellxgene/expand_gene_list.py new file mode 100644 index 0000000000000000000000000000000000000000..86b0c7b68c78f59b9ff6870b3f5c9a26aa34f630 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/expand_gene_list.py @@ -0,0 +1,41 @@ +import cellxgene_census +import json + +VERSION = "2023-05-08" + +with cellxgene_census.open_soma(census_version=VERSION) as census: + meta_data = ( + census["census_data"]["homo_sapiens"] + .ms["RNA"] + .var.read( + column_names=[ + "feature_name", + ], + ) + ) + new_gene_list = meta_data.concat().to_pandas()["feature_name"].to_list() + # print(gene_name) + +with open("../../scgpt/tokenizer/default_cellxgene_vocab.json", "r") as f: + old_gene_dict = json.load(f) + +print("old gene list length:", len(old_gene_dict)) + +expanded_dict = old_gene_dict.copy() + +# count the genes in old but not in new: +# for gene, num in old_gene_dict.items(): +# if gene not in new_gene_list: +# print(f"diff at {gene}") + +starting_num = max(old_gene_dict.values()) + 1 +for new_gene in new_gene_list: + if new_gene not in old_gene_dict.keys(): + expanded_dict[new_gene] = starting_num + starting_num += 1 +print("new gene dict length:", len(expanded_dict)) + +dump_path = "../../scgpt/tokenizer/default_census_vocab.json" + +with open(dump_path, "w") as f: + json.dump(expanded_dict, f, indent=2) diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/metainfo.json b/Singlecell_multi_omics/src/core/data/cellxgene/metainfo.json new file mode 100644 index 0000000000000000000000000000000000000000..186fa7f3409a1a426874fe7cef3a15c909bfd1b1 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/metainfo.json @@ -0,0 +1,244 @@ +{ + "f72958f5-7f42-4ebb-98da-445b0c6de516": { + "name": "Azimuth meta-analysis of 10 datasets of healthy and diseased human lung", + "url": "https://cellxgene.cziscience.com/e/f72958f5-7f42-4ebb-98da-445b0c6de516.cxg/", + "include_disease": ["normal", "COVID-19"] + }, + "d0c12af4-c0e4-4c7b-873a-70752b449689": { + "name": "Stromal cells (all non-immune cells)", + "url": "https://cellxgene.cziscience.com/e/d0c12af4-c0e4-4c7b-873a-70752b449689.cxg/" + }, + "804d9a85-665b-45c6-b204-13457fdcc7ac": { + "name": "Megakaryocyte/erythroid cells", + "url": "https://cellxgene.cziscience.com/e/804d9a85-665b-45c6-b204-13457fdcc7ac.cxg/" + }, + "6a30bf44-c490-41ac-965b-0bb58432b10a": { + "name": "HSC/progenitor cells", + "url": "https://cellxgene.cziscience.com/e/6a30bf44-c490-41ac-965b-0bb58432b10a.cxg/" + }, + "2aa1c93c-4ef3-4e9a-98e7-0bd37933953c": { + "name": "Myeloid cells", + "url": "https://cellxgene.cziscience.com/e/2aa1c93c-4ef3-4e9a-98e7-0bd37933953c.cxg/" + }, + "3affa268-8a74-460a-ac9b-a984c0832469": { + "name": "Lymphoid cells", + "url": "https://cellxgene.cziscience.com/e/3affa268-8a74-460a-ac9b-a984c0832469.cxg/" + }, + "fd072bc3-2dfb-46f8-b4e3-467cb3223182": { + "name": "Full dataset of single-cell RNA-seq profiles from 9 developmental tissues across gestation (4-17 pcw)", + "url": "https://cellxgene.cziscience.com/e/fd072bc3-2dfb-46f8-b4e3-467cb3223182.cxg/" + }, + "48101fa2-1a63-4514-b892-53ea1d3a8657": { + "name": "HSC/immune cells (all hematopoietic-derived cells)", + "url": "https://cellxgene.cziscience.com/e/48101fa2-1a63-4514-b892-53ea1d3a8657.cxg/" + }, + "aa633105-e8e5-4dcc-b72d-6da6c191b3e9": { + "name": "NK/T cells", + "url": "https://cellxgene.cziscience.com/e/aa633105-e8e5-4dcc-b72d-6da6c191b3e9.cxg/" + }, + "1b9d8702-5af8-4142-85ed-020eb06ec4f6": { + "name": "Global", + "url": "https://cellxgene.cziscience.com/e/1b9d8702-5af8-4142-85ed-020eb06ec4f6.cxg/" + }, + "71be997d-ff75-41b9-8a9f-1288c865f921": { + "name": "B cell compartment", + "url": "https://cellxgene.cziscience.com/e/71be997d-ff75-41b9-8a9f-1288c865f921.cxg/" + }, + "fe52003e-1460-4a65-a213-2bb1a508332f": { + "name": "Myeloid compartment", + "url": "https://cellxgene.cziscience.com/e/fe52003e-1460-4a65-a213-2bb1a508332f.cxg/" + }, + "53d208b0-2cfd-4366-9866-c3c6114081bc": { + "name": "Tabula Sapiens - All Cells", + "url": "https://cellxgene.cziscience.com/e/53d208b0-2cfd-4366-9866-c3c6114081bc.cxg/" + }, + "c5d88abe-f23a-45fa-a534-788985e93dad": { + "name": "Tabula Sapiens - Immune", + "url": "https://cellxgene.cziscience.com/e/c5d88abe-f23a-45fa-a534-788985e93dad.cxg/" + }, + "ae29ebd0-1973-40a4-a6af-d15a5f77a80f": { + "name": "T & innate lymphoid cells", + "url": "https://cellxgene.cziscience.com/e/ae29ebd0-1973-40a4-a6af-d15a5f77a80f.cxg/" + }, + "218acb0f-9f2f-4f76-b90b-15a4b7c7f629": { + "name": "multiplexed scRNA-seq of 1.2 million PBMCs from adult lupus samples", + "url": "https://cellxgene.cziscience.com/e/218acb0f-9f2f-4f76-b90b-15a4b7c7f629.cxg/", + "include_disease": ["normal"] + }, + "3faad104-2ab8-4434-816d-474d8d2641db": { + "name": "Single-cell eQTL mapping identifies cell type specific genetic control of autoimmune disease", + "url": "https://cellxgene.cziscience.com/e/3faad104-2ab8-4434-816d-474d8d2641db.cxg/" + }, + "ebc2e1ff-c8f9-466a-acf4-9d291afaf8b3": { + "name": "COMBAT project: single cell gene expression data from COVID-19, sepsis and flu patient PBMCs", + "url": "https://cellxgene.cziscience.com/e/ebc2e1ff-c8f9-466a-acf4-9d291afaf8b3.cxg/" + }, + "2a498ace-872a-4935-984b-1afa70fd9886": { + "name": "PBMC", + "url": "https://cellxgene.cziscience.com/e/2a498ace-872a-4935-984b-1afa70fd9886.cxg/" + }, + "c05fb583-eb2f-4e3a-8e74-f9bd6414e418": { + "name": "healthy young bone marrow donor", + "url": "https://cellxgene.cziscience.com/e/c05fb583-eb2f-4e3a-8e74-f9bd6414e418.cxg/" + }, + "cd4c96bb-ad66-4e83-ba9e-a7df8790eb12": { + "name": "3 healthy young and 3 healthy old bone marrow donors (Reference sample)", + "url": "https://cellxgene.cziscience.com/e/cd4c96bb-ad66-4e83-ba9e-a7df8790eb12.cxg/" + }, + "d3566d6a-a455-4a15-980f-45eb29114cab": { + "name": "blood and bone marrow from a healthy young donor", + "url": "https://cellxgene.cziscience.com/e/d3566d6a-a455-4a15-980f-45eb29114cab.cxg/" + }, + "1a2e3350-28a8-4f49-b33c-5b67ceb001f6": { + "name": "Fetal Bone Marrow (10x)", + "url": "https://cellxgene.cziscience.com/e/1a2e3350-28a8-4f49-b33c-5b67ceb001f6.cxg/" + }, + "343ff97c-85df-494b-8400-beb937618611": { + "name": "Human Fetal Bone Marrow (CITE-seq)", + "url": "https://cellxgene.cziscience.com/e/343ff97c-85df-494b-8400-beb937618611.cxg/" + }, + "471647b3-04fe-4c76-8372-3264feb950e8": { + "name": "CD34+ Fetal Bone Marrow, Fetal Liver, Cord Blood (CITE-seq)", + "url": "https://cellxgene.cziscience.com/e/471647b3-04fe-4c76-8372-3264feb950e8.cxg/" + }, + "4c4cd77c-8fee-4836-9145-16562a8782fe": { + "name": "Individual Single-Cell RNA-seq PBMC Data from Lee et al.", + "url": "https://cellxgene.cziscience.com/e/4c4cd77c-8fee-4836-9145-16562a8782fe.cxg/" + }, + "db0752b9-f20e-40b8-8997-992f3ae0bb2e": { + "name": "Classical Monocyte sub_clusters of COVID-19 Immune Altas: Integration of 5 public COVID-19 PBMC single-cell datasets", + "url": "https://cellxgene.cziscience.com/e/db0752b9-f20e-40b8-8997-992f3ae0bb2e.cxg/" + }, + "e763ed0d-0e5a-4b8e-9514-6da3d9e47956": { + "name": "Platelet sub_clusters of COVID-19 Immune Altas: Integration of 5 public COVID-19 PBMC single-cell datasets", + "url": "https://cellxgene.cziscience.com/e/e763ed0d-0e5a-4b8e-9514-6da3d9e47956.cxg/" + }, + "59b69042-47c2-47fd-ad03-d21beb99818f": { + "name": "Individual Single-Cell RNA-seq PBMC Data from Arunachalam et al.", + "url": "https://cellxgene.cziscience.com/e/59b69042-47c2-47fd-ad03-d21beb99818f.cxg/" + }, + "d9b4bc69-ed90-4f5f-99b2-61b0681ba436": { + "name": "B Cell/Plasmablast Sub_clusters of COVID-19 Immune Altas: Integration of 5 public COVID-19 PBMC single-cell datasets", + "url": "https://cellxgene.cziscience.com/e/d9b4bc69-ed90-4f5f-99b2-61b0681ba436.cxg/" + }, + "96a3f64b-0ee9-40d8-91e9-813ce38261c9": { + "name": "COVID-19 Immune Altas: Integration of 5 public COVID-19 PBMC single-cell datasets", + "url": "https://cellxgene.cziscience.com/e/96a3f64b-0ee9-40d8-91e9-813ce38261c9.cxg/" + }, + "bc2a7b3d-f04e-477e-96c9-9d5367d5425c": { + "name": "T Cell and NK Cell Subtypes of COVID-19 Immune Altas: Integration of 5 public COVID-19 PBMC single-cell datasets", + "url": "https://cellxgene.cziscience.com/e/bc2a7b3d-f04e-477e-96c9-9d5367d5425c.cxg/" + }, + "055ca631-6ffb-40de-815e-b931e10718c0": { + "name": "Individual Single-Cell RNA-seq PBMC Data from Wilk et al.", + "url": "https://cellxgene.cziscience.com/e/055ca631-6ffb-40de-815e-b931e10718c0.cxg/" + }, + "ae5341b8-60fb-4fac-86db-86e49ee66287": { + "name": "Individual Single-Cell RNA-seq PBMC Data from Guo et al.", + "url": "https://cellxgene.cziscience.com/e/ae5341b8-60fb-4fac-86db-86e49ee66287.cxg/" + }, + "5e717147-0f75-4de1-8bd2-6fda01b8d75f": { + "name": "Individual Single-Cell RNA-seq PBMC Data from Schulte-Schrepping et al.", + "url": "https://cellxgene.cziscience.com/e/5e717147-0f75-4de1-8bd2-6fda01b8d75f.cxg/" + }, + "01ad3cd7-3929-4654-84c0-6db05bd5fd59": { + "name": "Type I interferon autoantibodies are associated with systemic immune alterations in patients with COVID-19", + "url": "https://cellxgene.cziscience.com/e/01ad3cd7-3929-4654-84c0-6db05bd5fd59.cxg/" + }, + "ed5d841d-6346-47d4-ab2f-7119ad7e3a35": { + "name": "nygc multimodal pbmc", + "url": "https://cellxgene.cziscience.com/e/ed5d841d-6346-47d4-ab2f-7119ad7e3a35.cxg/" + }, + "c7775e88-49bf-4ba2-a03b-93f00447c958": { + "name": "Single-cell multi-omics analysis of the immune response in COVID-19", + "url": "https://cellxgene.cziscience.com/e/c7775e88-49bf-4ba2-a03b-93f00447c958.cxg/" + }, + "30cd5311-6c09-46c9-94f1-71fe4b91813c": { + "name": "Time-resolved Systems Immunology Reveals a Late Juncture Linked to Fatal COVID-19: Innate Cells", + "url": "https://cellxgene.cziscience.com/e/30cd5311-6c09-46c9-94f1-71fe4b91813c.cxg/" + }, + "c874f155-9bf9-4928-b821-f52c876b3e48": { + "name": "49 years old male - Fresh PBMCs (1 day post-intubation)", + "url": "https://cellxgene.cziscience.com/e/c874f155-9bf9-4928-b821-f52c876b3e48.cxg/" + }, + "8a554710-08bc-4005-87cd-da9675bdc2e7": { + "name": "82 years old female - Fresh PBMCs (1 day post-intubation)", + "url": "https://cellxgene.cziscience.com/e/8a554710-08bc-4005-87cd-da9675bdc2e7.cxg/" + }, + "881fe679-c6e0-45a3-9427-c4e81be6921f": { + "name": "66 years old female - Fresh PBMCs (2 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/881fe679-c6e0-45a3-9427-c4e81be6921f.cxg/" + }, + "eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569": { + "name": "49 years old male - Fresh PBMCs (3 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569.cxg/" + }, + "ed9e9f96-4f08-49d2-bef5-b2c29adf3edc": { + "name": "66 years old female - Fresh PBMCs (4 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/ed9e9f96-4f08-49d2-bef5-b2c29adf3edc.cxg/" + }, + "01c93cf6-b695-4e30-a26e-121ae8b16a9e": { + "name": "66 years old female - Fresh PBMCs (7 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/01c93cf6-b695-4e30-a26e-121ae8b16a9e.cxg/" + }, + "db59611b-42de-4035-93aa-1ed39f38b467": { + "name": "49 years old male - Fresh PBMCs (2 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/db59611b-42de-4035-93aa-1ed39f38b467.cxg/" + }, + "ea786a06-5855-48b7-80d7-0313a21a2044": { + "name": "66 years old female - Fresh PBMCs (3 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/ea786a06-5855-48b7-80d7-0313a21a2044.cxg/" + }, + "84230ea4-998d-4aa8-8456-81dd54ce23af": { + "name": "74 years old female - Fresh PBMCs (3 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/84230ea4-998d-4aa8-8456-81dd54ce23af.cxg/" + }, + "50eb1e23-b8d4-4f76-a184-44e5541fa05a": { + "name": "74 years old female - Fresh PBMCs (8 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/50eb1e23-b8d4-4f76-a184-44e5541fa05a.cxg/" + }, + "79ef1959-a6b4-4cac-82ca-30feaec48df1": { + "name": "74 years old female - Fresh PBMCs (7 days post-intubation)", + "url": "https://cellxgene.cziscience.com/e/79ef1959-a6b4-4cac-82ca-30feaec48df1.cxg/" + }, + "9dbab10c-118d-496b-966a-67f1763a6b7d": { + "name": "Large-scale single-cell analysis reveals critical immune characteristics of COVID-19 patients", + "url": "https://cellxgene.cziscience.com/e/9dbab10c-118d-496b-966a-67f1763a6b7d.cxg/" + }, + "krasnow_lab_human_lung_cell_atlas_10x-1-remixed": { + "name": "Krasnow Lab Human Lung Cell Atlas, 10X", + "url": "https://cellxgene.cziscience.com/e/krasnow_lab_human_lung_cell_atlas_10x-1-remixed.cxg/" + }, + "krasnow_lab_human_lung_cell_atlas_smartseq2-2-remixed": { + "name": "Krasnow Lab Human Lung Cell Atlas, Smart-seq2", + "url": "https://cellxgene.cziscience.com/e/krasnow_lab_human_lung_cell_atlas_smartseq2-2-remixed.cxg/" + }, + "c2a461b1-0c15-4047-9fcb-1f966fe55100": { + "name": "Autoimmunity PBMCs", + "url": "https://cellxgene.cziscience.com/e/c2a461b1-0c15-4047-9fcb-1f966fe55100.cxg/" + }, + "fa8605cf-f27e-44af-ac2a-476bee4410d3": { + "name": "PBMCs", + "url": "https://cellxgene.cziscience.com/e/fa8605cf-f27e-44af-ac2a-476bee4410d3.cxg/" + }, + "Single_cell_atlas_of_peripheral_immune_response_to_SARS_CoV_2_infection": { + "name": "Single-cell atlas of peripheral immune response to SARS-CoV-2 infection", + "url": "https://cellxgene.cziscience.com/e/Single_cell_atlas_of_peripheral_immune_response_to_SARS_CoV_2_infection.cxg/" + }, + "human_cell_landscape": { + "name": "Construction of a human cell landscape at single-cell level", + "url": "https://cellxgene.cziscience.com/e/human_cell_landscape.cxg/" + }, + "01209dce-3575-4bed-b1df-129f57fbc031": { + "name": "Single-cell transcriptomics of human T cells reveals tissue and activation signatures in health and disease", + "url": "https://cellxgene.cziscience.com/e/01209dce-3575-4bed-b1df-129f57fbc031.cxg/" + }, + "5bc42b88-bb76-4954-927b-8bb7369adc64": { + "name": "Pregnant Uterus (All)", + "url": "https://cellxgene.cziscience.com/e/5bc42b88-bb76-4954-927b-8bb7369adc64.cxg/" + }, + "de2c780c-1747-40bd-9ccf-9588ec186cee": { + "name": "Immunophenotyping of COVID-19 and influenza highlights the role of type I interferons in development of severe COVID-19", + "url": "https://cellxgene.cziscience.com/e/de2c780c-1747-40bd-9ccf-9588ec186cee.cxg/" + } +} diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/process_allcounts.py b/Singlecell_multi_omics/src/core/data/cellxgene/process_allcounts.py new file mode 100644 index 0000000000000000000000000000000000000000..09943395d5cfd7c20a892c0b00677a3dee239f56 --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/process_allcounts.py @@ -0,0 +1,315 @@ +import argparse +from pathlib import Path +from scgpt.tokenizer import GeneVocab, random_mask_value +import sys +from datasets import Dataset, load_dataset +import os + +sys.path.insert(0, "../") + +parser = argparse.ArgumentParser() +parser.add_argument( + "-d", + "--data-source", + type=str, + required=True, + help='The name of the data source (currently support "scvi" datasets), or the ' + "path to the data file.", +) +parser.add_argument( + "-s", + "--save-dir", + type=str, + required=True, + help="The directory to save the trained model and the results.", +) +parser.add_argument( + "--load-model", + type=str, + default=None, + help="The directory containing the model and configs to load and continue training.", +) + +# settings for data +parser.add_argument( + "--n-hvg", + type=int, + default=None, + help="The number of highly variable genes. If set to 0, will use all genes. " + "Default is None, which will determine the n_hvg automatically.", +) +parser.add_argument( + "--valid-size-or-ratio", + type=float, + default=0.1, + help="The ratio or size of the validation set size if split the dataset. " + "If value is between 0 and 1, will be parsed as the ratio. If value is " + "greater than 1 and be an integer, will be parsed as the size. If value " + "is 0, will not split the dataset.", +) + +parser.add_argument( + "--grad-accu-steps", + type=int, + default=1, + help="The number of gradient accumulation steps. Default is 1.", +) + +# settings for tokenizer +parser.add_argument( + "--pad-token", + type=str, + default="", + help="The token to use for padding. Default is .", +) +parser.add_argument( + "--input-style", + type=str, + choices=["normed_raw", "log1p", "binned"], + default="binned", + help="The style of the input data. Default is binned.", +) +parser.add_argument( + "--input-emb-style", + type=str, + choices=["category", "continuous", "scaling"], + default="continuous", + help="The style of the input embedding. Default is continuous.", +) +parser.add_argument( + "--n-bins", + type=int, + default=51, + help="The number of bins to use for the binned input style. Default is 51.", +) +parser.add_argument( + "--max-seq-len", + type=int, + default=1536, + help="The maximum length of the sequence. Default is 1000. The actual used " + "max length would be the minimum of this value and the length of the longest " + "sequence in the data.", +) +# omit the args for MLM and MVC, will always use them by default +parser.add_argument( + "--training-tasks", # choices of "mlm", "gen", "both" + type=str, + default="both", + choices=["pcpt", "gen", "both"], + help="The tasks to use for training. pcpt: perception training with maked token " + "learning. gen: generation. Default is both.", +) +parser.add_argument( + "--mask-ratio", + type=float, + default=0.40, + help="The ratio of masked values in the training data. Default is 0.40. This" + "value will be ignored if --training-tasks is set to gen or both.", +) +parser.add_argument( + "--trunc-by-sample", + action="store_true", + help="Whether to truncate the input by sampling rather than cutting off if " + "sequence length > max_seq_length. Default is False.", +) +parser.add_argument( + "--vocab-path", + type=str, + help="Path to the vocabulary file.", +) +# settings for training +parser.add_argument( + "--local-rank", + type=int, + default=-1, + help="The local rank of the process for using the torch.distributed.launch " + "utility. Will be -1 if not running in distributed model.", +) +parser.add_argument( + "--batch-size", + type=int, + default=32, + help="The batch size for training. Default is 32.", +) +parser.add_argument( + "--eval-batch-size", + type=int, + default=32, + help="The batch size for evaluation. Default is 32.", +) +parser.add_argument( + "--epochs", + type=int, + default=10, + help="The number of epochs for training.", +) +parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="The learning rate for training. Default is 1e-3.", +) +parser.add_argument( + "--scheduler-interval", + type=int, + default=100, + help="The interval iterations for updating the learning rate. Default is 100. " + "This will only be used when warmup-ratio is 0.", +) +parser.add_argument( + "--scheduler-factor", + type=float, + default=0.99, + help="The factor for updating the learning rate. Default is 0.99. " + "This will only be used when warmup-ratio is 0.", +) +parser.add_argument( + "--warmup-ratio-or-step", + type=float, + default=0.1, + help="The ratio of warmup steps out of the total training steps. Default is 0.1. " + "If warmup-ratio is above 0, will use a cosine scheduler with warmup. If " + "the value is above 1, will use it as the number of warmup steps.", +) +parser.add_argument( + "--no-cls", + action="store_true", + help="Whether to deactivate the classification loss. Default is False.", +) +parser.add_argument( + "--no-cce", + action="store_true", + help="Whether to deactivate the contrastive cell embedding objective. " + "Default is False.", +) +parser.add_argument( + "--fp16", + action="store_true", + help="Whether to train in automatic mixed precision. Default is False.", +) +parser.add_argument( + "--fast-transformer", + type=bool, + default=True, + help="Whether to use the fast transformer. Default is True.", +) + +# settings for model +parser.add_argument( + "--nlayers", + type=int, + default=4, + help="The number of layers for the transformer. Default is 4.", +) +parser.add_argument( + "--nheads", + type=int, + default=4, + help="The number of heads for the transformer. Default is 4.", +) +parser.add_argument( + "--embsize", + type=int, + default=64, + help="The embedding size for the transformer. Default is 64.", +) +parser.add_argument( + "--d-hid", + type=int, + default=64, + help="dimension of the feedforward network model in the transformer. " + "Default is 64.", +) +parser.add_argument( + "--dropout", + type=float, + default=0.2, + help="The dropout rate. Default is 0.2.", +) +parser.add_argument( + "--n-layers-cls", + type=int, + default=3, + help="The number of layers for the classification network, including the " + "output layer. Default is 3.", +) + +# settings for logging +parser.add_argument( + "--log-interval", + type=int, + default=100, + help="The interval for logging. Default is 100.", +) +parser.add_argument( + "--save-interval", + type=int, + default=1000, + help="The interval for saving the model. Default is 1000.", +) + +args = parser.parse_args() +# args.pad_value = -2 + +if args.input_style == "binned": + if args.input_emb_style == "scaling": + raise ValueError("input_emb_style `scaling` is not supported for binned input.") +elif args.input_style == "log1p" or args.input_style == "normed_raw": + if args.input_emb_style == "category": + raise ValueError( + "input_emb_style `category` is not supported for log1p or normed_raw input." + ) + +if args.input_emb_style == "category": + args.mask_value = args.n_bins + 1 + args.pad_value = args.n_bins # for padding gene expr values + n_input_bins = args.n_bins + 2 +else: + args.mask_value = -1 + args.pad_value = -2 + n_input_bins = args.n_bins + + +def _map_append_cls(dataset: Dataset) -> Dataset: + dataset = dataset.map( + lambda example: { + "genes": [vocab[""]] + example["genes"], + "expressions": [args.pad_value] + example["expressions"], + }, + # batched=True, # not using since then the map func needs to loop + num_proc=len(os.sched_getaffinity(0)), + ) + + return dataset + + +special_tokens = [args.pad_token, "", ""] + +parquet_files = [str(f) for f in Path(args.data_source).glob("*.parquet")] +cache_dir = Path(args.data_source).parent / "cache" +vocab = GeneVocab.from_file(Path(args.vocab_path)) +for s in special_tokens: + if s not in vocab: + vocab.append_token(s) + + +# load or make the dataset w/ appended at the beginning +cls_prefix_datatable = Path(args.data_source) / "cls_prefix_data.parquet" +if not cls_prefix_datatable.exists(): + print("preparing cls prefix dataset") + raw_dataset = load_dataset( + "parquet", + data_files=parquet_files, + split="train", + cache_dir=str(cache_dir), + ) + raw_dataset = _map_append_cls(raw_dataset) + raw_dataset.to_parquet(str(cls_prefix_datatable)) +raw_dataset = load_dataset( + "parquet", + data_files=str(cls_prefix_datatable), + split="train", + cache_dir=str(cache_dir), +) + +# others, pancreas, lung, kidney, heart, blood diff --git a/Singlecell_multi_omics/src/core/data/cellxgene/query_list.txt b/Singlecell_multi_omics/src/core/data/cellxgene/query_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..19232a3ef7c8fb8458c16698b0bf2da052a7ed4d --- /dev/null +++ b/Singlecell_multi_omics/src/core/data/cellxgene/query_list.txt @@ -0,0 +1,9 @@ +heart +blood +brain +lung +kidney +intestine +pancreas +others +pan-cancer diff --git a/Singlecell_multi_omics/src/core/inference.py b/Singlecell_multi_omics/src/core/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..48ef587c0e423f914a4e2025a030bff6c4448085 --- /dev/null +++ b/Singlecell_multi_omics/src/core/inference.py @@ -0,0 +1,131 @@ +import copy + +import matplotlib.pyplot as plt +import torch + +from scMVP.inference import Trainer + +plt.switch_backend("agg") + + +class UnsupervisedTrainer(Trainer): + r"""The VariationalInference class for the unsupervised training of an autoencoder. + + Args: + :model: A model instance from class ``VAE``, ``VAEC``, + :gene_dataset: A gene_dataset instance like ``snareDataset()`` + :train_size: The train size, either a float between 0 and 1 or an integer for the number of training samples + to use Default: ``0.8``. + :test_size: The test size, either a float between 0 and 1 or an integer for the number of training samples + to use Default: ``None``, which is equivalent to data not in the train set. If ``train_size`` and ``test_size`` + do not add to 1 or the length of the dataset then the remaining samples are added to a ``validation_set``. + :n_epochs_kl_warmup: Number of epochs for linear warmup of KL(q(z|x)||p(z)) term. After `n_epochs_kl_warmup`, + the training objective is the ELBO. This might be used to prevent inactivity of latent units, and/or to + improve clustering of latent space, as a long warmup turns the model into something more of an autoencoder. + :normalize_loss: A boolean determining whether the loss is divided by the total number of samples used for + training. In particular, when the global KL divergence is equal to 0 and the division is performed, the loss + for a minibatchis is equal to the average of reconstruction losses and KL divergences on the minibatch. + Default: ``None``, which is equivalent to setting False when the model is an instance from class + ``AutoZIVAE`` and True otherwise. + :\*\*kwargs: Other keywords arguments from the general Trainer class. + + Examples: + >>> gene_dataset = snareDataset() + >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, + ... n_labels=gene_dataset.n_labels) + + >>> infer = VariationalInference(gene_dataset, vae, train_size=0.5) + >>> infer.train(n_epochs=20, lr=1e-3) + """ + default_metrics_to_monitor = ["elbo"] + + def __init__( + self, + model, + gene_dataset, + train_size=0.8, + test_size=None, + n_epochs_kl_warmup=400, + normalize_loss=None, + **kwargs + ): + super().__init__(model, gene_dataset, **kwargs) + self.n_epochs_kl_warmup = n_epochs_kl_warmup + + self.normalize_loss = ( + not ( + hasattr(self.model, "reconstruction_loss") + and self.model.reconstruction_loss == "autozinb" + ) + if normalize_loss is None + else normalize_loss + ) + + # Total size of the dataset used for training + # (e.g. training set in this class but testing set in AdapterTrainer). + # It used to rescale minibatch losses (cf. eq. (8) in Kingma et al., Auto-Encoding Variational Bayes, iCLR 2013) + self.n_samples = 1.0 + + if type(self) is UnsupervisedTrainer: + ( + self.train_set, + self.test_set, + self.validation_set, + ) = self.train_test_validation(model, gene_dataset, train_size, test_size) + self.train_set.to_monitor = ["elbo"] + self.test_set.to_monitor = ["elbo"] + self.validation_set.to_monitor = ["elbo"] + self.n_samples = len(self.train_set.indices) + + @property + def posteriors_loop(self): + return ["train_set"] + + def loss(self, tensors): + sample_batch, local_l_mean, local_l_var, batch_index, y = tensors + #reconst_loss, kl_divergence_local, kl_divergence_global = self.model( + # sample_batch, local_l_mean, local_l_var, batch_index, y + #) + reconst_loss, kl_divergence_local, kl_divergence_global = self.model( + sample_batch, local_l_mean, local_l_var, batch_index, batch_index + ) + loss = ( + self.n_samples + * torch.mean(reconst_loss + self.kl_weight * kl_divergence_local) + + kl_divergence_global + ) + if self.normalize_loss: + loss = loss / self.n_samples + return loss + + def on_epoch_begin(self): + if self.n_epochs_kl_warmup is not None: + self.kl_weight = min(1, self.epoch / self.n_epochs_kl_warmup) + else: + self.kl_weight = 1.0 + + +class AdapterTrainer(UnsupervisedTrainer): + def __init__(self, model, gene_dataset, posterior_test, frequency=5): + super().__init__(model, gene_dataset, frequency=frequency) + self.test_set = posterior_test + self.test_set.to_monitor = ["elbo"] + self.params = list(self.model.z_encoder.parameters()) + list( + self.model.l_encoder.parameters() + ) + self.z_encoder_state = copy.deepcopy(model.z_encoder.state_dict()) + self.l_encoder_state = copy.deepcopy(model.l_encoder.state_dict()) + self.n_scale = len(self.test_set.indices) + + @property + def posteriors_loop(self): + return ["test_set"] + + def train(self, n_path=10, n_epochs=50, **kwargs): + for i in range(n_path): + # Re-initialize to create new path + self.model.z_encoder.load_state_dict(self.z_encoder_state) + self.model.l_encoder.load_state_dict(self.l_encoder_state) + super().train(n_epochs, params=self.params, **kwargs) + + return min(self.history["elbo_test_set"]) diff --git a/Singlecell_multi_omics/src/core/multi_inference.py b/Singlecell_multi_omics/src/core/multi_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..0b341b5609fd4cda8ba2a83a97862526ecb2ca1e --- /dev/null +++ b/Singlecell_multi_omics/src/core/multi_inference.py @@ -0,0 +1,538 @@ +from typing import Optional +import logging +import torch +from torch.distributions import Poisson, Gamma, Bernoulli, Normal +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch import logsumexp +import torch.distributions as distributions +import numpy as np + +from scMVP.inference import Posterior +from . import UnsupervisedTrainer + +from scMVP.dataset import GeneExpressionDataset +from scMVP.models import multi_vae_attention +from sklearn.utils.linear_assignment_ import linear_assignment + +logger = logging.getLogger(__name__) + + + + + +class MultiPosterior(Posterior): + r"""The functional data unit for Multivae. A `MultiPosterior` instance is instantiated with a model and + a gene_dataset, and as well as additional arguments that for Pytorch's `DataLoader`. A subset of indices + can be specified, for purposes such as splitting the data into train/test/validation. Each trainer instance of the `Trainer` class can therefore have multiple + `MultiPosterior` instances to train a model. A `MultiPosterior` instance also comes with many methods or + utilities for its corresponding data. + + + :param model: A model instance from class ``Multivae`` + :param gene_dataset: A gene_dataset instance like ``ATACDataset()`` with attribute ``ATAC_expression`` + :param shuffle: Specifies if a `RandomSampler` or a `SequentialSampler` should be used + :param indices: Specifies how the data should be split with regards to train/test or labelled/unlabelled + :param use_cuda: Default: ``True`` + :param data_loader_kwarg: Keyword arguments to passed into the `DataLoader` + + Examples: + + Let us instantiate a `trainer`, with a gene_dataset and a model + + >>> gene_dataset = CbmcDataset() + >>> totalvi = TOTALVI(gene_dataset.nb_genes, len(gene_dataset.protein_names), + ... n_batch=gene_dataset.n_batches * False, n_labels=gene_dataset.n_labels, use_cuda=True) + >>> trainer = TotalTrainer(vae, gene_dataset) + >>> trainer.train(n_epochs=400) + """ + + def __init__( + self, + model: multi_vae_attention, + gene_dataset: GeneExpressionDataset, + shuffle: bool = False, + indices: Optional[np.ndarray] = None, + use_cuda: bool = True, + data_loader_kwargs=dict(), + ): + + super().__init__( + model, + gene_dataset, + shuffle=shuffle, + indices=indices, + use_cuda=use_cuda, + data_loader_kwargs=data_loader_kwargs, + ) + # Add atac tensor as another tensor to be loaded + self.data_loader_kwargs.update( + { + "collate_fn": gene_dataset.collate_fn_builder( + {"atac_expression": np.float32}# debug cell index + ) + } + ) + + self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs) + + def corrupted(self): + return self.update( + { + "collate_fn": self.gene_dataset.collate_fn_builder( + {"atac_expression": np.float32}, corrupted=True + ) + } + ) + + def uncorrupted(self): + return self.update( + { + "collate_fn": self.gene_dataset.collate_fn_builder( + {"atac_expression": np.float32} + ) + } + ) + + @torch.no_grad() + def elbo(self): + elbo = self.compute_elbo(self.model) + logger.debug("ELBO : %.4f" % elbo) + return elbo + elbo.mode = "min" + + @torch.no_grad() + def reconstruction_error(self): + reconstruction_error = self.compute_reconstruction_error(self.model, self) + logger.debug("Reconstruction Error : %.4f" % reconstruction_error) + return reconstruction_error + + reconstruction_error.mode = "min" + + @torch.no_grad() + def marginal_ll(self, n_mc_samples=1000): + + ll = self.compute_marginal_log_likelihood(self.model, self, n_mc_samples) + logger.debug("True LL : %.4f" % ll) + return ll + + def compute_elbo(self, vae:multi_vae_attention, **kwargs): + """ Computes the ELBO. + + The ELBO is the reconstruction error + the KL divergences + between the variational distributions and the priors. + It differs from the marginal log likelihood. + Specifically, it is a lower bound on the marginal log likelihood + plus a term that is constant with respect to the variational distribution. + It still gives good insights on the modeling of the data, and is fast to compute. + """ + # Iterate once over the posterior and compute the elbo + elbo = 0 + for i_batch, tensors in enumerate(self): + ( + sample_batch_X, + local_l_mean, + local_l_var, + batch_index, + label, + sample_batch_Y, + ) = tensors + + reconst_loss, kl_divergence_local, kl_divergence_global = vae( + sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label + ) + elbo += torch.sum(reconst_loss + kl_divergence_local).item() + n_samples = len(self.indices) + elbo += kl_divergence_global + return elbo / n_samples + + def compute_reconstruction_error(self, vae:multi_vae_attention, **kwargs): + r""" Computes log p(x/z), which is the reconstruction error . + Differs from the marginal log likelihood, but still gives good + insights on the modeling of the data, and is fast to compute + + This is really a helper function to self.ll, self.ll_protein, etc. + """ + # Iterate once over the posterior and computes the total log_likelihood + log_lkl = 0 + for i_batch, tensors in enumerate(self): + sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors[ + :5 + ] # general fish case + + # Distribution parameters + outputs = vae.inference(sample_batch, batch_index, labels, **kwargs) + p_rna_r = outputs["p_rna_r"] + p_rna_rate = outputs["p_rna_rate"] + p_rna_dropout = outputs["p_rna_dropout"] + p_atac_mean = outputs["p_atac_mean"] + p_atac_r = outputs["p_atac_r"] + p_atac_dropout = outputs["p_atac_dropout"] + + # Reconstruction loss + reconst_rna_loss = vae.get_reconstruction_loss( + sample_batch, + p_rna_rate, + p_rna_r, + p_rna_dropout, +# bernoulli_params=bernoulli_params, + **kwargs + ) + reconst_atac_loss = vae.get_reconstruction_atac_loss( + sample_batch, + p_atac_mean, + p_atac_r, + p_atac_dropout, + **kwargs + ) + + log_lkl += torch.sum(reconst_rna_loss).item() + log_lkl += torch.sum(reconst_atac_loss).item() + n_samples = len(self.indices) + return log_lkl / n_samples + + def compute_marginal_log_likelihood(self, vae:multi_vae_attention , n_mc_samples): + """ Computes a biased estimator for log p(x), which is the marginal log likelihood. + + Despite its bias, the estimator still converges to the real value + of log p(x) when n_samples_mc (for Monte Carlo) goes to infinity + (a fairly high value like 100 should be enough) + Due to the Monte Carlo sampling, this method is not as computationally efficient + as computing only the reconstruction loss + """ + # Uses MC sampling to compute a tighter lower bound on log p(x) + + log_lkl = 0 + for i_batch, tensors in enumerate(self): + sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors + to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) + + for i in range(n_mc_samples): + # Distribution parameters and sampled variables + outputs = vae.inference(sample_batch, batch_index, labels) + p_rna_r = outputs["p_rna_r"] + p_rna_rate = outputs["p_rna_rate"] + p_rna_dropout = outputs["p_rna_dropout"] + qz_m = outputs["qz_m"] + qz_v = outputs["qz_v"] + z = outputs["z"] + p_atac_mean = outputs["p_atac_mean"] + p_atac_r = outputs["p_atac_r"] + p_atac_dropout = outputs["p_atac_dropout"] + mu_c = outputs["mu_c"] + var_c = outputs["var_c"] + gamma = outputs["gamma"] + mu_c_max = outputs["mu_c_max"], + var_c_max = outputs["var_c_max"], + z_c_max = outputs["z_c_max"], + + # Reconstruction Loss + reconst_rna_loss = vae.get_reconstruction_loss( + sample_batch, + p_rna_r, + p_rna_rate, + p_rna_dropout, + ) + reconst_atac_loss = vae.get_reconstruction_atac_loss( + sample_batch, + p_atac_r, + p_atac_mean, + p_atac_dropout, + ) + + # Log-probabilities + #p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) + p_z = 0.0 + for prob, mu, var in mu_c, var_c, gamma: + p_z += prob*Normal(mu, var.sqrt()).log_prob(z).sum(dim=-1) + + p_x_zl = -reconst_rna_loss - reconst_atac_loss + q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) + #q_z_max = Normal(mu_c_max, var_c_max.sqrt()).log_prob(z_c_max).sum(dim=-1) + + to_sum[:, i] = p_z + p_x_zl - q_z_x #- q_z_max + + batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) + log_lkl += torch.sum(batch_log_lkl).item() + + n_samples = len(self.indices) + # The minus sign is there because we actually look at the negative log likelihood + return -log_lkl / n_samples + + @torch.no_grad() + def get_latent(self, sample=False): + """ + Output posterior z mean or sample, batch index, and label + :param sample: z mean or z sample + :return: three np.ndarrays, latent, batch_indices, labels + """ + latent = [] + latent_rna = []; + latent_atac = []; + batch_indices = [] + labels = [] + cluster_gamma = [] + cluster_index = [] + for tensors in self: + sample_batch_rna, local_l_mean, local_l_var, batch_index, label, sample_batch_atac = tensors + give_mean = not sample + latent_temp = self.model.sample_from_posterior_z( + [sample_batch_rna, sample_batch_atac], y=label, give_mean=give_mean + ) + latent += [ + latent_temp[0][0].cpu() + ] + latent_rna += [ + latent_temp[1][0].cpu() + ] + latent_atac += [ + latent_temp[2].cpu() + ] + gamma, mu_c, var_c, pi = self.model.get_gamma(latent_temp[0][0]) + cluster_gamma += [gamma.cpu()] + cluster_index += [torch.argmax(gamma.cpu(),dim=1)] + batch_indices += [batch_index.cpu()] + labels += [label.cpu()] + return ( + np.array(torch.cat(latent)), + np.array(torch.cat(latent_rna)), + np.array(torch.cat(latent_atac)), + np.array(torch.cat(cluster_gamma)), + np.array(torch.cat(cluster_index)), + np.array(torch.cat(batch_indices)), + np.array(torch.cat(labels)).ravel(), + ) + + @torch.no_grad() + def generate( + self, + n_samples: int = 100, + genes: Optional[np.ndarray] = None, + batch_size: int = 256, + #batch_size: int = 128, + ) : + """ + Create observation samples from the Posterior Predictive distribution + + :param n_samples: Number of required samples for each cell + :param genes: Indices of genes of interest + :param batch_size: Desired Batch size to generate data + + :return: Tuple (x_new, x_old) + Where x_old has shape (n_cells, n_genes) + Where x_new has shape (n_cells, n_genes, n_samples) + """ + assert self.model.reconstruction_loss in ["zinb", "zip"] + zero_inflated = "zinb" + + rna_old = [] + rna_new = [] + atac_old = [] + atac_new = [] + for tensors in self.update({"batch_size": batch_size}): + sample_batch, _, _, batch_index, labels = tensors + outputs = self.model.inference( + sample_batch, batch_index=batch_index, y=labels, n_samples=n_samples + ) + p_rna_r = outputs["p_rna_r"] + p_rna_rate = outputs["p_rna_rate"] + p_rna_dropout = outputs["p_rna_dropout"] + p_atac_mean = outputs["p_atac_mean"] + p_atac_dropout = outputs["p_atac_dropout"] + + # Generating rna-seq data + p = p_rna_rate / (p_rna_rate + p_rna_r) + r = p_rna_r + # Important remark: Gamma is parametrized by the rate = 1/scale! + l_train_rna = distributions.Gamma(concentration=r, rate=(1 - p) / p).sample() + # Clamping as distributions objects can have buggy behaviors when + # their parameters are too high + l_train_rna = torch.clamp(l_train_rna, max=1e8) + gene_expressions = distributions.Poisson( + l_train_rna + ).sample() # Shape : (n_samples, n_cells_batch, n_genes) + + #Generating atac-seq data + l_train_atac = torch.clamp(p_atac_mean, max=1e2) + atac_expressions = distributions.Poisson( + l_train_atac + ).sample() + + # zero-inflate + if zero_inflated: + p_zero_rna = (1.0 + torch.exp(-p_rna_dropout)).pow(-1) + random_prob_rna = torch.rand_like(p_zero_rna) + gene_expressions[random_prob_rna <= p_zero_rna] = 0 + + p_zero_atac = (1.0 + torch.exp(-p_atac_dropout)).pow(-1) + random_prob_atac = torch.rand_like(p_zero_atac) + atac_expressions[random_prob_atac <= p_zero_atac] = 0 + + gene_expressions = gene_expressions.permute( + [1, 2, 0] + ) # Shape : (n_cells_batch, n_genes, n_samples) + atac_expressions = atac_expressions.permute( + [1, 2, 0] + ) + + rna_old.append(sample_batch[0].cpu()) + rna_new.append(gene_expressions.cpu()) + atac_old.append(sample_batch[1].cpu()) + atac_new.append(atac_expressions.cpu()) + + rna_old = torch.cat(rna_old) # Shape (n_cells, n_genes) + rna_new = torch.cat(rna_new) # Shape (n_cells, n_genes, n_samples) + if genes is not None: + gene_ids = self.gene_dataset.genes_to_index(genes) + rna_new = rna_new[:, gene_ids, :] + rna_old = rna_old[:, gene_ids] + return rna_new.numpy(), rna_old.numpy(), atac_new.numpy(), rna_old.numpy() + + @torch.no_grad() + def imputation(self, n_samples: int = 1): + """ Gene imputation + """ + imputed_rna_list = [] + imputed_atac_list = [] + label_list = [] # for the annotated data + atac_list = [] + for tensors in self: + x_rna, local_l_mean, local_l_var, batch_index, label, x_atac = tensors + p_rna_rate, p_atac_rate = self.model.get_sample_rate( + x=[x_rna,x_atac], batch_index=batch_index, y=label, n_samples=n_samples, local_l_mean = local_l_mean, local_l_var = local_l_var + ) + imputed_rna_list += [np.array(p_rna_rate.cpu())] + imputed_atac_list += [np.array(p_atac_rate.cpu())] + label_list += [np.array(label.cpu())] # only for annotated data + atac_list += [np.array(x_atac.cpu())] # for the bins without call peak + imputed_rna_list = np.concatenate(imputed_rna_list) + imputed_atac_list = np.concatenate(imputed_atac_list) + label_list = np.concatenate(label_list) # only for annotated data + atac_list = np.concatenate(atac_list)# for the bins without call peak + return imputed_rna_list.squeeze(), imputed_atac_list.squeeze(), label_list.squeeze(), atac_list + + @torch.no_grad() + def get_sample_scale(self): + p_rna_scales = [] + p_atac_scales = [] + for tensors in self: + x_rna, _, _, batch_index, labels, x_atac = tensors + p_rna_scales += [ + np.array( + ( + self.model.get_sample_scale( + x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1 + )[0] + ) + ) + ] + p_atac_scales += [ + np.array( + ( + self.model.get_sample_scale( + x=[x_rna,x_atac], batch_index=batch_index, y=labels, n_samples=1 + )[1] + ) + ) + ] + return np.concatenate(p_rna_scales), np.concatenate(p_atac_scales) + + def cluster_acc(Y_pred, Y): + assert Y_pred.size == Y.size + D = max(Y_pred.max(), Y.max()) + 1 + w = np.zeros((D, D), dtype=np.int64) + for i in range(Y_pred.size): + w[Y_pred[i], Y[i]] += 1 + ind = linear_assignment(w.max() - w) + return sum([w[i, j] for i, j in ind]) * 1.0 / Y_pred.size, ind + + def get_clustering(self): + latent, latent_rna, latent_atac, cluster_gamma, batch_indices, labels = self.get_latent() + cluster_accuarcy, ind = self.cluster_acc(np.argmax(cluster_gamma,axis=1),labels) + print('cell dataset multi-vae - clustering accuracy: %.2f%%' % (cluster_accuarcy * 100)) + return cluster_accuarcy, ind + +class MultiTrainer(UnsupervisedTrainer): + r"""The VariationalInference class for the unsupervised training of an autoencoder. + + Args: + :model: A model instance from class ``TOTALVI`` + :gene_dataset: A gene_dataset instance like ``CbmcDataset()`` with attribute ``protein_expression`` + :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples + to use Default: ``0.93``. + :test_size: The test size, either a float between 0 and 1 or and integer for the number of training samples + to use Default: ``0.02``. Note that if train and test do not add to 1 the remainder is placed in a validation set + :\*\*kwargs: Other keywords arguments from the general Trainer class. + """ + default_metrics_to_monitor = ["elbo"] + + def __init__( + self, + model, + dataset, + train_size=0.90, + test_size=0.05, + pro_recons_weight=1.0, + n_epochs_back_kl_warmup=50, #200, init + n_epochs_kl_warmup=200, + **kwargs + ): + self.n_genes = dataset.nb_genes + self.n_proteins = model.n_input_atac + + self.pro_recons_weight = pro_recons_weight + self.n_epochs_back_kl_warmup = n_epochs_back_kl_warmup + super().__init__( + model, dataset, n_epochs_kl_warmup=n_epochs_kl_warmup, **kwargs + ) + if type(self) is MultiTrainer: + ( + self.train_set, + self.test_set, + self.validation_set, + ) = self.train_test_validation( + model, dataset, train_size, test_size, type_class=MultiPosterior + ) + self.train_set.to_monitor = [] + self.test_set.to_monitor = ["elbo"] + self.validation_set.to_monitor = ["elbo"] + + def loss(self, tensors): + ( + sample_batch_X, + local_l_mean, + local_l_var, + batch_index, + label, + sample_batch_Y, + ) = tensors + + #reconst_loss, kl_divergence_local, kl_divergence_global = self.model( + # sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, label + #) + reconst_loss, kl_divergence_local, kl_divergence_global = self.model( + sample_batch_X, sample_batch_Y, local_l_mean, local_l_var, batch_index, batch_index + ) + loss = ( + self.n_samples + * torch.mean(reconst_loss + self.back_warmup_weight * kl_divergence_local) + + kl_divergence_global + ) + print( + "reconst_loss = %f,kl_divergence_local = %f,kl_weight = %f,loss = %f" % + (torch.mean(reconst_loss), torch.mean(kl_divergence_local), self.back_warmup_weight, loss) + ) + # self.KL_divergence = kl_divergence_global + if self.normalize_loss: + loss = loss / self.n_samples + return loss + + + def on_epoch_begin(self): + super().on_epoch_begin() + if self.n_epochs_back_kl_warmup is not None: + #self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup) + self.back_warmup_weight = min(1, self.epoch + self.n_epochs_back_kl_warmup / self.n_epochs_back_kl_warmup) + else: + self.back_warmup_weight = 1.0 + diff --git a/Singlecell_multi_omics/src/core/posterior.py b/Singlecell_multi_omics/src/core/posterior.py new file mode 100644 index 0000000000000000000000000000000000000000..8514d7ae33779641b8c0617ad43788ac137539ae --- /dev/null +++ b/Singlecell_multi_omics/src/core/posterior.py @@ -0,0 +1,1223 @@ +import copy +import os +import logging + +from typing import List, Optional, Union, Tuple + +import numpy as np +import pandas as pd +import scipy +import torch +import torch.distributions as distributions + +from matplotlib import pyplot as plt +from scipy.stats import kde, entropy +from sklearn.cluster import KMeans +from sklearn.manifold import TSNE +from sklearn.metrics import adjusted_rand_score as ARI +from sklearn.metrics import normalized_mutual_info_score as NMI +from sklearn.metrics import silhouette_score +from sklearn.mixture import GaussianMixture as GMM +from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor +from sklearn.utils.linear_assignment_ import linear_assignment +from torch.utils.data import DataLoader +from torch.utils.data.sampler import ( + SequentialSampler, + SubsetRandomSampler, + RandomSampler, +) + +from scMVP.dataset import GeneExpressionDataset +from scMVP.models.log_likelihood import ( + compute_elbo, + compute_reconstruction_error, + compute_marginal_log_likelihood_scvi, + compute_marginal_log_likelihood_autozi, +) + +logger = logging.getLogger(__name__) + + +class SequentialSubsetSampler(SubsetRandomSampler): + def __iter__(self): + return iter(self.indices) + + +class Posterior: + r"""The functional data unit. A `Posterior` instance is instantiated with a model and a gene_dataset, and + as well as additional arguments that for Pytorch's `DataLoader`. A subset of indices can be specified, for + purposes such as splitting the data into train/test or labelled/unlabelled (for semi-supervised learning). + Each trainer instance of the `Trainer` class can therefore have multiple `Posterior` instances to train a model. + A `Posterior` instance also comes with many methods or utilities for its corresponding data. + + + :param model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` + :param gene_dataset: A gene_dataset instance like ``CortexDataset()`` + :param shuffle: Specifies if a `RandomSampler` or a `SequentialSampler` should be used + :param indices: Specifies how the data should be split with regards to train/test or labelled/unlabelled + :param use_cuda: Default: ``True`` + :param data_loader_kwarg: Keyword arguments to passed into the `DataLoader` + + Examples: + + Let us instantiate a `trainer`, with a gene_dataset and a model + + >>> gene_dataset = CortexDataset() + >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, + ... n_labels=gene_dataset.n_labels, use_cuda=True) + >>> trainer = UnsupervisedTrainer(vae, gene_dataset) + >>> trainer.train(n_epochs=50) + + A `UnsupervisedTrainer` instance has two `Posterior` attributes: `train_set` and `test_set` + For this subset of the original gene_dataset instance, we can examine the differential expression, + log_likelihood, entropy batch mixing, ... or display the TSNE of the data in the latent space through the + scVI model + + >>> trainer.train_set.differential_expression_stats() + >>> trainer.train_set.reconstruction_error() + >>> trainer.train_set.entropy_batch_mixing() + >>> trainer.train_set.show_t_sne(n_samples=1000, color_by="labels") + + """ + + def __init__( + self, + model, + gene_dataset: GeneExpressionDataset, + shuffle=False, + indices=None, + use_cuda=True, + data_loader_kwargs=dict(), + ): + """ + + When added to annotation, has a private name attribute + """ + self.model = model + self.gene_dataset = gene_dataset + self.to_monitor = [] + self.use_cuda = use_cuda + + if indices is not None and shuffle: + raise ValueError("indices is mutually exclusive with shuffle") + if indices is None: + if shuffle: + sampler = RandomSampler(gene_dataset) + else: + sampler = SequentialSampler(gene_dataset) + else: + if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): + indices = np.where(indices)[0].ravel() + sampler = SubsetRandomSampler(indices) + self.data_loader_kwargs = copy.copy(data_loader_kwargs) + self.data_loader_kwargs.update( + {"collate_fn": gene_dataset.collate_fn_builder(), "sampler": sampler} + ) + self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs) + + def accuracy(self): + pass + + accuracy.mode = "max" + + @property + def indices(self): + if hasattr(self.data_loader.sampler, "indices"): + return self.data_loader.sampler.indices + else: + return np.arange(len(self.gene_dataset)) + + @property + def nb_cells(self): + if hasattr(self.data_loader.sampler, "indices"): + return len(self.data_loader.sampler.indices) + else: + return self.gene_dataset.nb_cells + + def __iter__(self): + return map(self.to_cuda, iter(self.data_loader)) + + def to_cuda(self, tensors): + return [t.cuda() if self.use_cuda else t for t in tensors] + + def update(self, data_loader_kwargs): + posterior = copy.copy(self) + posterior.data_loader_kwargs = copy.copy(self.data_loader_kwargs) + posterior.data_loader_kwargs.update(data_loader_kwargs) + posterior.data_loader = DataLoader( + self.gene_dataset, **posterior.data_loader_kwargs + ) + return posterior + + #def sequential(self, batch_size=128): + def sequential(self, batch_size=64): + return self.update( + { + "batch_size": batch_size, + "sampler": SequentialSubsetSampler(indices=self.indices), + } + ) + + def corrupted(self): + return self.update( + {"collate_fn": self.gene_dataset.collate_fn_builder(corrupted=True)} + ) + + def uncorrupted(self): + return self.update({"collate_fn": self.gene_dataset.collate_fn_builder()}) + + @torch.no_grad() + def elbo(self): + elbo = compute_elbo(self.model, self) + logger.debug("ELBO : %.4f" % elbo) + return elbo + + elbo.mode = "min" + + @torch.no_grad() + def reconstruction_error(self): + reconstruction_error = compute_reconstruction_error(self.model, self) + logger.debug("Reconstruction Error : %.4f" % reconstruction_error) + return reconstruction_error + + reconstruction_error.mode = "min" + + @torch.no_grad() + def marginal_ll(self, n_mc_samples=1000): + if ( + hasattr(self.model, "reconstruction_loss") + and self.model.reconstruction_loss == "autozinb" + ): + ll = compute_marginal_log_likelihood_autozi(self.model, self, n_mc_samples) + else: + ll = compute_marginal_log_likelihood_scvi(self.model, self, n_mc_samples) + logger.debug("True LL : %.4f" % ll) + return ll + + @torch.no_grad() + def get_latent(self, sample=False): + """ + Output posterior z mean or sample, batch index, and label + :param sample: z mean or z sample + :return: three np.ndarrays, latent, batch_indices, labels + """ + latent = [] + batch_indices = [] + labels = [] + for tensors in self: + sample_batch, local_l_mean, local_l_var, batch_index, label = tensors + give_mean = not sample + latent += [ + self.model.sample_from_posterior_z( + sample_batch, give_mean=give_mean + ).cpu() + ] + batch_indices += [batch_index.cpu()] + labels += [label.cpu()] + return ( + np.array(torch.cat(latent)), + np.array(torch.cat(batch_indices)), + np.array(torch.cat(labels)).ravel(), + ) + + @torch.no_grad() + def entropy_batch_mixing(self, **kwargs): + if self.gene_dataset.n_batches == 2: + latent, batch_indices, labels = self.get_latent() + be_score = entropy_batch_mixing(latent, batch_indices, **kwargs) + logger.debug("Entropy batch mixing : {}".format(be_score)) + return be_score + + entropy_batch_mixing.mode = "max" + + @torch.no_grad() + def differential_expression_stats(self, M_sampling=100): + """ + Output average over statistics in a symmetric way (a against b), forget the sets if permutation is True + + :param M_sampling: number of samples + :return: Tuple px_scales, all_labels where (i) px_scales: scales of shape (M_sampling, n_genes) + (ii) all_labels: labels of shape (M_sampling, ) + """ + px_scales = [] + all_labels = [] + batch_size = max( + self.data_loader_kwargs["batch_size"] // M_sampling, 2 + ) # Reduce batch_size on GPU + if len(self.gene_dataset) % batch_size == 1: + batch_size += 1 + for tensors in self.update({"batch_size": batch_size}): + sample_batch, _, _, batch_index, labels = tensors + px_scales += [ + np.array( + ( + self.model.get_sample_scale( + sample_batch, + batch_index=batch_index, + y=labels, + n_samples=M_sampling, + ) + ).cpu() + ) + ] + + # Align the sampling + if M_sampling > 1: + px_scales[-1] = (px_scales[-1].transpose((1, 0, 2))).reshape( + -1, px_scales[-1].shape[-1] + ) + all_labels += [np.array((labels.repeat(1, M_sampling).view(-1, 1)).cpu())] + + px_scales = np.concatenate(px_scales) + all_labels = np.concatenate(all_labels).ravel() # this will be used as boolean + + return px_scales, all_labels + + @torch.no_grad() + def sample_scale_from_batch(self, n_samples, batchid=None, selection=None): + px_scales = [] + if selection is None: + raise ValueError("selections should be a list of cell subsets indices") + else: + if selection.dtype is np.dtype("bool"): + selection = np.asarray(np.where(selection)[0].ravel()) + old_loader = self.data_loader + for i in batchid: + idx = np.random.choice( + np.arange(len(self.gene_dataset))[selection], n_samples + ) + sampler = SubsetRandomSampler(idx) + self.data_loader_kwargs.update({"sampler": sampler}) + self.data_loader = DataLoader(self.gene_dataset, **self.data_loader_kwargs) + px_scales.append(self.get_harmonized_scale(i)) + self.data_loader = old_loader + px_scales = np.concatenate(px_scales) + return px_scales + + @torch.no_grad() + def differential_expression_score( + self, + idx1: Union[List[bool], np.ndarray], + idx2: Union[List[bool], np.ndarray], + batchid1: Optional[Union[List[int], np.ndarray]] = None, + batchid2: Optional[Union[List[int], np.ndarray]] = None, + genes: Optional[Union[List[str], np.ndarray]] = None, + n_samples: int = None, + sample_pairs: bool = True, + M_permutation: int = None, + all_stats: bool = True, + ): + """ + Computes gene specific Bayes factors using masks idx1 and idx2 + + To that purpose we sample the Posterior in the following way: + 1. The posterior is sampled n_samples times for each subpopulation + 2. For computation efficiency (posterior sampling is quite expensive), instead of + comparing element-wise the obtained samples, we can permute posterior samples. + Remember that computing the Bayes Factor requires sampling + q(z_A | x_A) and q(z_B | x_B) + + :param idx1: bool array masking subpopulation cells 1. Should be True where cell is + from associated population + :param idx2: bool array masking subpopulation cells 2. Should be True where cell is + from associated population + :param batchid1: List of batch ids for which you want to perform DE Analysis for + subpopulation 1. By default, all ids are taken into account + :param batchid2: List of batch ids for which you want to perform DE Analysis for + subpopulation 2. By default, all ids are taken into account + :param genes: list Names of genes for which Bayes factors will be computed + :param n_samples: Number of times the posterior will be sampled for each pop + :param sample_pairs: Activates step 2 described above. + Simply formulated, pairs obtained from posterior sampling (when calling + `sample_scale_from_batch`) will be randomly permuted so that the number of + pairs used to compute Bayes Factors becomes M_permutation. + :param M_permutation: Number of times we will "mix" posterior samples in step 2. + Only makes sense when sample_pairs=True + :param all_stats: If False returns Bayes factors alone + else, returns not only Bayes Factor of population 1 vs population 2 but other metrics as + well, mostly used for sanity checks, such as (i) Bayes Factors of 2 vs 1 and (ii) + Bayes factors obtained when shuffled indices (iii) Gene expression statistics (mean, scale ...) + :return: + """ + + n_samples = 5000 if n_samples is None else n_samples + M_permutation = 10000 if M_permutation is None else M_permutation + if batchid1 is None: + batchid1 = np.arange(self.gene_dataset.n_batches) + if batchid2 is None: + batchid2 = np.arange(self.gene_dataset.n_batches) + px_scale1 = self.sample_scale_from_batch( + selection=idx1, batchid=batchid1, n_samples=n_samples + ) + px_scale2 = self.sample_scale_from_batch( + selection=idx2, batchid=batchid2, n_samples=n_samples + ) + px_scale_mean1 = px_scale1.mean(axis=0) + px_scale_mean2 = px_scale2.mean(axis=0) + px_scale = np.concatenate((px_scale1, px_scale2), axis=0) + all_labels = np.concatenate( + (np.repeat(0, len(px_scale1)), np.repeat(1, len(px_scale2))), axis=0 + ) + if genes is not None: + px_scale = px_scale[:, self.gene_dataset.genes_to_index(genes)] + bayes1 = get_bayes_factors( + px_scale, + all_labels, + cell_idx=0, + M_permutation=M_permutation, + permutation=False, + sample_pairs=sample_pairs, + ) + if all_stats is True: + bayes1_permuted = get_bayes_factors( + px_scale, + all_labels, + cell_idx=0, + M_permutation=M_permutation, + permutation=True, + sample_pairs=sample_pairs, + ) + bayes2 = get_bayes_factors( + px_scale, + all_labels, + cell_idx=1, + M_permutation=M_permutation, + permutation=False, + sample_pairs=sample_pairs, + ) + bayes2_permuted = get_bayes_factors( + px_scale, + all_labels, + cell_idx=1, + M_permutation=M_permutation, + permutation=True, + sample_pairs=sample_pairs, + ) + ( + mean1, + mean2, + nonz1, + nonz2, + norm_mean1, + norm_mean2, + ) = self.gene_dataset.raw_counts_properties(idx1, idx2) + res = pd.DataFrame( + [ + bayes1, + bayes1_permuted, + bayes2, + bayes2_permuted, + mean1, + mean2, + nonz1, + nonz2, + norm_mean1, + norm_mean2, + px_scale_mean1, + px_scale_mean2, + ], + index=[ + "bayes1", + "bayes1_permuted", + "bayes2", + "bayes2_permuted", + "mean1", + "mean2", + "nonz1", + "nonz2", + "norm_mean1", + "norm_mean2", + "scale1", + "scale2", + ], + columns=self.gene_dataset.gene_names, + ).T + res = res.sort_values(by=["bayes1"], ascending=False) + return res + else: + return bayes1 + + @torch.no_grad() + def one_vs_all_degenes( + self, + subset: Optional[Union[List[bool], np.ndarray]] = None, + cell_labels: Optional[Union[List, np.ndarray]] = None, + min_cells: int = 10, + n_samples: int = None, + sample_pairs: bool = False, + M_permutation: int = None, + output_file: bool = False, + save_dir: str = "./", + filename="one2all", + ): + """ + Performs one population vs all others Differential Expression Analysis + given labels or using cell types, for each type of population + + :param subset: None Or bool array masking subset of cells you are interested in + (True when you want to select cell). In that case, it should have same length than `gene_dataset` + :param cell_labels: optional: Labels of cells + :param min_cells: Ceil number of cells used to compute Bayes Factors + :param n_samples: Number of times the posterior will be sampled for each pop + :param sample_pairs: Activates pair random permutations. + Simply formulated, pairs obtained from posterior sampling (when calling + `sample_scale_from_batch`) will be randomly permuted so that the number of + pairs used to compute Bayes Factors becomes M_permutation. + :param M_permutation: Number of times we will "mix" posterior samples in step 2. + Only makes sense when sample_pairs=True + :param output_file: Bool: save file? + :param save_dir: + :param filename: + :return: Tuple (de_res, de_cluster) (i) de_res is a list of length nb_clusters + (based on provided labels or on hardcoded cell types) (ii) de_res[i] contains Bayes Factors + for population number i vs all the rest (iii) de_cluster returns the associated names of clusters. + Are contained in this results only clusters for which we have at least `min_cells` + elements to compute predicted Bayes Factors + """ + if cell_labels is not None: + if len(cell_labels) != len(self.gene_dataset): + raise ValueError( + " the length of cell_labels have to be " + "the same as the number of cells" + ) + if (cell_labels is None) and not hasattr(self.gene_dataset, "cell_types"): + raise ValueError( + "If gene_dataset is not annotated with labels and cell types," + " then must provide cell_labels" + ) + # Input cell_labels take precedence over cell type label annotation in dataset + elif cell_labels is not None: + cluster_id = np.unique(cell_labels[cell_labels >= 0]) + # Can make cell_labels < 0 to filter out cells when computing DE + else: + cluster_id = self.gene_dataset.cell_types + cell_labels = self.gene_dataset.labels.ravel() + de_res = [] + de_cluster = [] + for i, x in enumerate(cluster_id): + if subset is None: + idx1 = cell_labels == i + idx2 = cell_labels != i + else: + idx1 = (cell_labels == i) * subset + idx2 = (cell_labels != i) * subset + if np.sum(idx1) > min_cells and np.sum(idx2) > min_cells: + de_cluster.append(x) + res = self.differential_expression_score( + idx1=idx1, + idx2=idx2, + M_permutation=M_permutation, + n_samples=n_samples, + sample_pairs=sample_pairs, + ) + res["clusters"] = np.repeat(x, len(res.index)) + de_res.append(res) + if output_file: # store as an excel spreadsheet + writer = pd.ExcelWriter( + save_dir + "differential_expression.%s.xlsx" % filename, + engine="xlsxwriter", + ) + for i, x in enumerate(de_cluster): + de_res[i].to_excel(writer, sheet_name=str(x)) + writer.close() + return de_res, de_cluster + + def within_cluster_degenes( + self, + cell_labels: Optional[Union[List, np.ndarray]] = None, + min_cells: int = 10, + states: Union[List[bool], np.ndarray] = [], + batch1: Optional[Union[List[int], np.ndarray]] = None, + batch2: Optional[Union[List[int], np.ndarray]] = None, + subset: Optional[Union[List[bool], np.ndarray]] = None, + n_samples: int = None, + sample_pairs: bool = False, + M_permutation: int = None, + output_file: bool = False, + save_dir: str = "./", + filename: str = "within_cluster", + ): + """ + Performs Differential Expression within clusters for different cell states + + :param cell_labels: optional: Labels of cells + :param min_cells: Ceil number of cells used to compute Bayes Factors + :param states: States of the cells. + :param batch1: List of batch ids for which you want to perform DE Analysis for + subpopulation 1. By default, all ids are taken into account + :param batch2: List of batch ids for which you want to perform DE Analysis for + subpopulation 2. By default, all ids are taken into account + :param subset: MASK: Subset of cells you are interested in. + :param n_samples: Number of times the posterior will be sampled for each pop + :param sample_pairs: Activates pair random permutations. + Simply formulated, pairs obtained from posterior sampling (when calling + `sample_scale_from_batch`) will be randomly permuted so that the number of + pairs used to compute Bayes Factors becomes M_permutation. + :param M_permutation: Number of times we will "mix" posterior samples in step 2. + Only makes sense when sample_pairs=True + :param output_file: Bool: save file? + :param save_dir: + :param filename: + :return: Tuple (de_res, de_cluster) (i) de_res is a list of length nb_clusters + (based on provided labels or on hardcoded cell types) (ii) de_res[i] contains Bayes Factors + for population number i vs all the rest (iii) de_cluster returns the associated names of clusters. + Are contained in this results only clusters for which we have at least `min_cells` + elements to compute predicted Bayes Factors + """ + if len(self.gene_dataset) != len(states): + raise ValueError( + " the length of states have to be the same as the number of cells" + ) + if cell_labels is not None: + if len(cell_labels) != len(self.gene_dataset): + raise ValueError( + " the length of cell_labels have to be " + "the same as the number of cells" + ) + if (cell_labels is None) and not hasattr(self.gene_dataset, "cell_types"): + raise ValueError( + "If gene_dataset is not annotated with labels and cell types," + " then must provide cell_labels" + ) + # Input cell_labels take precedence over cell type label annotation in dataset + elif cell_labels is not None: + cluster_id = np.unique(cell_labels[cell_labels >= 0]) + # Can make cell_labels < 0 to filter out cells when computing DE + else: + cluster_id = self.gene_dataset.cell_types + cell_labels = self.gene_dataset.labels.ravel() + de_res = [] + de_cluster = [] + states = np.asarray([1 if x else 0 for x in states]) + nstates = np.asarray([0 if x else 1 for x in states]) + for i, x in enumerate(cluster_id): + if subset is None: + idx1 = (cell_labels == i) * states + idx2 = (cell_labels == i) * nstates + else: + idx1 = (cell_labels == i) * subset * states + idx2 = (cell_labels == i) * subset * nstates + if np.sum(idx1) > min_cells and np.sum(idx2) > min_cells: + de_cluster.append(x) + res = self.differential_expression_score( + idx1=idx1, + idx2=idx2, + batchid1=batch1, + batchid2=batch2, + M_permutation=M_permutation, + n_samples=n_samples, + sample_pairs=sample_pairs, + ) + res["clusters"] = np.repeat(x, len(res.index)) + de_res.append(res) + if output_file: # store as an excel spreadsheet + writer = pd.ExcelWriter( + save_dir + "differential_expression.%s.xlsx" % filename, + engine="xlsxwriter", + ) + for i, x in enumerate(de_cluster): + de_res[i].to_excel(writer, sheet_name=str(x)) + writer.close() + return de_res, de_cluster + + @torch.no_grad() + def imputation(self, n_samples=1): + imputed_list = [] + for tensors in self: + sample_batch, _, _, batch_index, labels = tensors + px_rate = self.model.get_sample_rate( + sample_batch, batch_index=batch_index, y=labels, n_samples=n_samples + ) + imputed_list += [np.array(px_rate.cpu())] + imputed_list = np.concatenate(imputed_list) + return imputed_list.squeeze() + + @torch.no_grad() + def generate( + self, + n_samples: int = 100, + genes: Union[list, np.ndarray] = None, + batch_size: int = 64, + #batch_size: int = 128, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Create observation samples from the Posterior Predictive distribution + + :param n_samples: Number of required samples for each cell + :param genes: Indices of genes of interest + :param batch_size: Desired Batch size to generate data + + :return: Tuple (x_new, x_old) + Where x_old has shape (n_cells, n_genes) + Where x_new has shape (n_cells, n_genes, n_samples) + """ + assert self.model.reconstruction_loss in ["zinb", "nb"] + zero_inflated = self.model.reconstruction_loss == "zinb" + x_old = [] + x_new = [] + for tensors in self.update({"batch_size": batch_size}): + sample_batch, _, _, batch_index, labels = tensors + outputs = self.model.inference( + sample_batch, batch_index=batch_index, y=labels, n_samples=n_samples + ) + px_r = outputs["px_r"] + px_rate = outputs["px_rate"] + px_dropout = outputs["px_dropout"] + + p = px_rate / (px_rate + px_r) + r = px_r + # Important remark: Gamma is parametrized by the rate = 1/scale! + l_train = distributions.Gamma(concentration=r, rate=(1 - p) / p).sample() + # Clamping as distributions objects can have buggy behaviors when + # their parameters are too high + l_train = torch.clamp(l_train, max=1e8) + gene_expressions = distributions.Poisson( + l_train + ).sample() # Shape : (n_samples, n_cells_batch, n_genes) + if zero_inflated: + p_zero = (1.0 + torch.exp(-px_dropout)).pow(-1) + random_prob = torch.rand_like(p_zero) + gene_expressions[random_prob <= p_zero] = 0 + + gene_expressions = gene_expressions.permute( + [1, 2, 0] + ) # Shape : (n_cells_batch, n_genes, n_samples) + + x_old.append(sample_batch.cpu()) + x_new.append(gene_expressions.cpu()) + + x_old = torch.cat(x_old) # Shape (n_cells, n_genes) + x_new = torch.cat(x_new) # Shape (n_cells, n_genes, n_samples) + if genes is not None: + gene_ids = self.gene_dataset.genes_to_index(genes) + x_new = x_new[:, gene_ids, :] + x_old = x_old[:, gene_ids] + return x_new.numpy(), x_old.numpy() + + @torch.no_grad() + def generate_parameters(self): + dropout_list = [] + mean_list = [] + dispersion_list = [] + for tensors in self.sequential(1000): + sample_batch, _, _, batch_index, labels = tensors + + outputs = self.model.inference( + sample_batch, batch_index=batch_index, y=labels, n_samples=1 + ) + px_r = outputs["px_r"] + px_rate = outputs["px_rate"] + px_dropout = outputs["px_dropout"] + + dispersion_list += [ + np.repeat(np.array(px_r.cpu())[np.newaxis, :], px_rate.size(0), axis=0) + ] + mean_list += [np.array(px_rate.cpu())] + dropout_list += [np.array(px_dropout.cpu())] + + return ( + np.concatenate(dropout_list), + np.concatenate(mean_list), + np.concatenate(dispersion_list), + ) + + @torch.no_grad() + def get_stats(self): + libraries = [] + #for tensors in self.sequential(batch_size=128): + for tensors in self.sequential(batch_size=64): + x, local_l_mean, local_l_var, batch_index, y = tensors + library = self.model.inference(x, batch_index, y)["library"] + libraries += [np.array(library.cpu())] + libraries = np.concatenate(libraries) + return libraries.ravel() + + @torch.no_grad() + def get_harmonized_scale(self, fixed_batch): + px_scales = [] + fixed_batch = float(fixed_batch) + for tensors in self: + sample_batch, local_l_mean, local_l_var, batch_index, label = tensors + px_scales += [self.model.scale_from_z(sample_batch, fixed_batch).cpu()] + return np.concatenate(px_scales) + + @torch.no_grad() + def get_sample_scale(self): + px_scales = [] + for tensors in self: + sample_batch, _, _, batch_index, labels = tensors + px_scales += [ + np.array( + ( + self.model.get_sample_scale( + sample_batch, batch_index=batch_index, y=labels, n_samples=1 + ) + ).cpu() + ) + ] + return np.concatenate(px_scales) + + @torch.no_grad() + def imputation_list(self, n_samples=1): + original_list = [] + imputed_list = [] + batch_size = 10000 # self.data_loader_kwargs["batch_size"] // n_samples + for tensors, corrupted_tensors in zip( + self.uncorrupted().sequential(batch_size=batch_size), + self.corrupted().sequential(batch_size=batch_size), + ): + batch = tensors[0] + actual_batch_size = batch.size(0) + dropout_batch, _, _, batch_index, labels = corrupted_tensors + px_rate = self.model.get_sample_rate( + dropout_batch, batch_index=batch_index, y=labels, n_samples=n_samples + ) + + indices_dropout = torch.nonzero(batch - dropout_batch) + if indices_dropout.size() != torch.Size([0]): + i = indices_dropout[:, 0] + j = indices_dropout[:, 1] + + batch = batch.unsqueeze(0).expand( + (n_samples, batch.size(0), batch.size(1)) + ) + original = np.array(batch[:, i, j].view(-1).cpu()) + imputed = np.array(px_rate[..., i, j].view(-1).cpu()) + + cells_index = np.tile(np.array(i.cpu()), n_samples) + + original_list += [ + original[cells_index == i] for i in range(actual_batch_size) + ] + imputed_list += [ + imputed[cells_index == i] for i in range(actual_batch_size) + ] + else: + original_list = np.array([]) + imputed_list = np.array([]) + return original_list, imputed_list + + @torch.no_grad() + def imputation_score(self, original_list=None, imputed_list=None, n_samples=1): + if original_list is None or imputed_list is None: + original_list, imputed_list = self.imputation_list(n_samples=n_samples) + + original_list_concat = np.concatenate(original_list) + imputed_list_concat = np.concatenate(imputed_list) + are_lists_empty = (len(original_list_concat) == 0) and ( + len(imputed_list_concat) == 0 + ) + if are_lists_empty: + logger.info( + "No difference between corrupted dataset and uncorrupted dataset" + ) + return 0.0 + else: + return np.median(np.abs(original_list_concat - imputed_list_concat)) + + @torch.no_grad() + def imputation_benchmark( + self, n_samples=8, show_plot=True, title_plot="imputation", save_path="" + ): + original_list, imputed_list = self.imputation_list(n_samples=n_samples) + # Median of medians for all distances + median_score = self.imputation_score( + original_list=original_list, imputed_list=imputed_list + ) + + # Mean of medians for each cell + imputation_cells = [] + for original, imputed in zip(original_list, imputed_list): + has_imputation = len(original) and len(imputed) + imputation_cells += [ + np.median(np.abs(original - imputed)) if has_imputation else 0 + ] + mean_score = np.mean(imputation_cells) + + logger.debug( + "\nMedian of Median: %.4f\nMean of Median for each cell: %.4f" + % (median_score, mean_score) + ) + + plot_imputation( + np.concatenate(original_list), + np.concatenate(imputed_list), + show_plot=show_plot, + title=os.path.join(save_path, title_plot), + ) + return original_list, imputed_list + + @torch.no_grad() + def knn_purity(self): + latent, _, labels = self.get_latent() + score = knn_purity(latent, labels) + logger.debug("KNN purity score : {}".format(score)) + return score + + knn_purity.mode = "max" + + @torch.no_grad() + def clustering_scores(self, prediction_algorithm="knn"): + if self.gene_dataset.n_labels > 1: + latent, _, labels = self.get_latent() + if prediction_algorithm == "knn": + labels_pred = KMeans( + self.gene_dataset.n_labels, n_init=200 + ).fit_predict( + latent + ) # n_jobs>1 ? + elif prediction_algorithm == "gmm": + gmm = GMM(self.gene_dataset.n_labels) + gmm.fit(latent) + labels_pred = gmm.predict(latent) + + asw_score = silhouette_score(latent, labels) + nmi_score = NMI(labels, labels_pred) + ari_score = ARI(labels, labels_pred) + uca_score = unsupervised_clustering_accuracy(labels, labels_pred)[0] + logger.debug( + "Clustering Scores:\nSilhouette: %.4f\nNMI: %.4f\nARI: %.4f\nUCA: %.4f" + % (asw_score, nmi_score, ari_score, uca_score) + ) + return asw_score, nmi_score, ari_score, uca_score + + @torch.no_grad() + def nn_overlap_score(self, **kwargs): + """ + Quantify how much the similarity between cells in the mRNA latent space resembles their similarity at the + protein level. Compute the overlap fold enrichment between the protein and mRNA-based cell 100-nearest neighbor + graph and the Spearman correlation of the adjacency matrices. + """ + if hasattr(self.gene_dataset, "protein_expression_clr"): + latent, _, _ = self.sequential().get_latent() + protein_data = self.gene_dataset.protein_expression_clr[self.indices] + spearman_correlation, fold_enrichment = nn_overlap( + latent, protein_data, **kwargs + ) + logger.debug( + "Overlap Scores:\nSpearman Correlation: %.4f\nFold Enrichment: %.4f" + % (spearman_correlation, fold_enrichment) + ) + return spearman_correlation, fold_enrichment + + @torch.no_grad() + def show_t_sne( + self, + n_samples=1000, + color_by="", + save_name="", + latent=None, + batch_indices=None, + labels=None, + n_batch=None, + ): + # If no latent representation is given + if latent is None: + latent, batch_indices, labels = self.get_latent(sample=True) + latent, idx_t_sne = self.apply_t_sne(latent, n_samples) + batch_indices = batch_indices[idx_t_sne].ravel() + labels = labels[idx_t_sne].ravel() + if not color_by: + plt.figure(figsize=(10, 10)) + plt.scatter(latent[:, 0], latent[:, 1]) + if color_by == "scalar": + plt.figure(figsize=(10, 10)) + plt.scatter(latent[:, 0], latent[:, 1], c=labels.ravel()) + else: + if n_batch is None: + n_batch = self.gene_dataset.n_batches + if color_by == "batches" or color_by == "labels": + indices = ( + batch_indices.ravel() if color_by == "batches" else labels.ravel() + ) + n = n_batch if color_by == "batches" else self.gene_dataset.n_labels + if self.gene_dataset.cell_types is not None and color_by == "labels": + plt_labels = self.gene_dataset.cell_types + else: + plt_labels = [str(i) for i in range(len(np.unique(indices)))] + plt.figure(figsize=(10, 10)) + for i, label in zip(range(n), plt_labels): + plt.scatter( + latent[indices == i, 0], latent[indices == i, 1], label=label + ) + plt.legend() + elif color_by == "batches and labels": + fig, axes = plt.subplots(1, 2, figsize=(14, 7)) + batch_indices = batch_indices.ravel() + for i in range(n_batch): + axes[0].scatter( + latent[batch_indices == i, 0], + latent[batch_indices == i, 1], + label=str(i), + ) + axes[0].set_title("batch coloring") + axes[0].axis("off") + axes[0].legend() + + indices = labels.ravel() + if hasattr(self.gene_dataset, "cell_types"): + plt_labels = self.gene_dataset.cell_types + else: + plt_labels = [str(i) for i in range(len(np.unique(indices)))] + for i, cell_type in zip(range(self.gene_dataset.n_labels), plt_labels): + axes[1].scatter( + latent[indices == i, 0], + latent[indices == i, 1], + label=cell_type, + ) + axes[1].set_title("label coloring") + axes[1].axis("off") + axes[1].legend() + plt.axis("off") + plt.tight_layout() + if save_name: + plt.savefig(save_name) + + @staticmethod + def apply_t_sne(latent, n_samples=1000): + idx_t_sne = ( + np.random.permutation(len(latent))[:n_samples] + if n_samples + else np.arange(len(latent)) + ) + if latent.shape[1] != 2: + latent = TSNE().fit_transform(latent[idx_t_sne]) + return latent, idx_t_sne + + def raw_data(self): + """ + Returns raw data for classification + """ + return ( + self.gene_dataset.X[self.indices], + self.gene_dataset.labels[self.indices].ravel(), + ) + + +def entropy_from_indices(indices): + return entropy(np.array(np.unique(indices, return_counts=True)[1].astype(np.int32))) + + +def entropy_batch_mixing( + latent_space, batches, n_neighbors=50, n_pools=50, n_samples_per_pool=100 +): + def entropy(hist_data): + n_batches = len(np.unique(hist_data)) + if n_batches > 2: + raise ValueError("Should be only two clusters for this metric") + frequency = np.mean(hist_data == 1) + if frequency == 0 or frequency == 1: + return 0 + return -frequency * np.log(frequency) - (1 - frequency) * np.log(1 - frequency) + + n_neighbors = min(n_neighbors, len(latent_space) - 1) + nne = NearestNeighbors(n_neighbors=1 + n_neighbors, n_jobs=8) + nne.fit(latent_space) + kmatrix = nne.kneighbors_graph(latent_space) - scipy.sparse.identity( + latent_space.shape[0] + ) + + score = 0 + for t in range(n_pools): + indices = np.random.choice( + np.arange(latent_space.shape[0]), size=n_samples_per_pool + ) + score += np.mean( + [ + entropy( + batches[ + kmatrix[indices].nonzero()[1][ + kmatrix[indices].nonzero()[0] == i + ] + ] + ) + for i in range(n_samples_per_pool) + ] + ) + return score / float(n_pools) + + +def get_bayes_factors( + px_scale: Union[List[float], np.ndarray], + all_labels: Union[List, np.ndarray], + cell_idx: Union[int, str], + other_cell_idx: Optional[Union[int, str]] = None, + genes_idx: Union[List[int], np.ndarray] = None, + M_permutation: int = 10000, + permutation: bool = False, + sample_pairs: bool = True, +): + """ + Returns an array of bayes factor for all genes + + :param px_scale: The gene frequency array for all cells (might contain multiple samples per cells) + :param all_labels: The labels array for the corresponding cell types + :param cell_idx: The first cell type population to consider. Either a string or an idx + :param other_cell_idx: (optional) The second cell type population to consider. Either a string or an idx + :param genes_idx: Indices of genes for which DE Analysis applies + :param sample_pairs: Activates subsampling. + Simply formulated, pairs obtained from posterior sampling (when calling + `sample_scale_from_batch`) will be randomly permuted so that the number of + pairs used to compute Bayes Factors becomes M_permutation. + :param M_permutation: Number of times we will "mix" posterior samples in step 2. + Only makes sense when sample_pairs=True + :param permutation: Whether or not to permute. Normal behavior is False. + Setting permutation=True basically shuffles cell_idx and other_cell_idx so that we + estimate Bayes Factors of random populations of the union of cell_idx and other_cell_idx. + :return: + """ + idx = all_labels == cell_idx + idx_other = ( + (all_labels == other_cell_idx) + if other_cell_idx is not None + else (all_labels != cell_idx) + ) + if genes_idx is not None: + px_scale = px_scale[:, genes_idx] + sample_rate_a = px_scale[idx].reshape(-1, px_scale.shape[1]) + sample_rate_b = px_scale[idx_other].reshape(-1, px_scale.shape[1]) + + # agregate dataset + samples = np.vstack((sample_rate_a, sample_rate_b)) + + if sample_pairs is True: + # prepare the pairs for sampling + list_1 = list(np.arange(sample_rate_a.shape[0])) + list_2 = list(sample_rate_a.shape[0] + np.arange(sample_rate_b.shape[0])) + if not permutation: + # case1: no permutation, sample from A and then from B + u, v = ( + np.random.choice(list_1, size=M_permutation), + np.random.choice(list_2, size=M_permutation), + ) + else: + # case2: permutation, sample from A+B twice + u, v = ( + np.random.choice(list_1 + list_2, size=M_permutation), + np.random.choice(list_1 + list_2, size=M_permutation), + ) + + # then constitutes the pairs + first_set = samples[u] + second_set = samples[v] + else: + first_set = sample_rate_a + second_set = sample_rate_b + res = np.mean(first_set >= second_set, 0) + res = np.log(res + 1e-8) - np.log(1 - res + 1e-8) + return res + + +def plot_imputation(original, imputed, show_plot=True, title="Imputation"): + y = imputed + x = original + + ymax = 10 + mask = x < ymax + x = x[mask] + y = y[mask] + + mask = y < ymax + x = x[mask] + y = y[mask] + + l_minimum = np.minimum(x.shape[0], y.shape[0]) + + x = x[:l_minimum] + y = y[:l_minimum] + + data = np.vstack([x, y]) + + plt.figure(figsize=(5, 5)) + + axes = plt.gca() + axes.set_xlim([0, ymax]) + axes.set_ylim([0, ymax]) + + nbins = 50 + + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + k = kde.gaussian_kde(data) + xi, yi = np.mgrid[0 : ymax : nbins * 1j, 0 : ymax : nbins * 1j] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + + plt.title(title, fontsize=12) + plt.ylabel("Imputed counts") + plt.xlabel("Original counts") + + plt.pcolormesh(yi, xi, zi.reshape(xi.shape), cmap="Reds") + + a, _, _, _ = np.linalg.lstsq(y[:, np.newaxis], x, rcond=-1) + linspace = np.linspace(0, ymax) + plt.plot(linspace, a * linspace, color="black") + + plt.plot(linspace, linspace, color="black", linestyle=":") + if show_plot: + plt.show() + plt.savefig(title + ".png") + + +def nn_overlap(X1, X2, k=100): + """ + Compute the overlap between the k-nearest neighbor graph of X1 and X2 using Spearman correlation of the + adjacency matrices. + """ + assert len(X1) == len(X2) + n_samples = len(X1) + k = min(k, n_samples - 1) + nne = NearestNeighbors(n_neighbors=k + 1) # "n_jobs=8 + nne.fit(X1) + kmatrix_1 = nne.kneighbors_graph(X1) - scipy.sparse.identity(n_samples) + nne.fit(X2) + kmatrix_2 = nne.kneighbors_graph(X2) - scipy.sparse.identity(n_samples) + + # 1 - spearman correlation from knn graphs + spearman_correlation = scipy.stats.spearmanr( + kmatrix_1.A.flatten(), kmatrix_2.A.flatten() + )[0] + # 2 - fold enrichment + set_1 = set(np.where(kmatrix_1.A.flatten() == 1)[0]) + set_2 = set(np.where(kmatrix_2.A.flatten() == 1)[0]) + fold_enrichment = ( + len(set_1.intersection(set_2)) + * n_samples ** 2 + / (float(len(set_1)) * len(set_2)) + ) + return spearman_correlation, fold_enrichment + + +def unsupervised_clustering_accuracy(y, y_pred): + """ + Unsupervised Clustering Accuracy + """ + assert len(y_pred) == len(y) + u = np.unique(np.concatenate((y, y_pred))) + n_clusters = len(u) + mapping = dict(zip(u, range(n_clusters))) + reward_matrix = np.zeros((n_clusters, n_clusters), dtype=np.int64) + for y_pred_, y_ in zip(y_pred, y): + if y_ in mapping: + reward_matrix[mapping[y_pred_], mapping[y_]] += 1 + cost_matrix = reward_matrix.max() - reward_matrix + ind = linear_assignment(cost_matrix) + return sum([reward_matrix[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind + + +def knn_purity(latent, label, n_neighbors=30): + nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(latent) + indices = nbrs.kneighbors(latent, return_distance=False)[:, 1:] + neighbors_labels = np.vectorize(lambda i: label[i])(indices) + + # pre cell purity scores + scores = ((neighbors_labels - label.reshape(-1, 1)) == 0).mean(axis=1) + res = [ + np.mean(scores[label == i]) for i in np.unique(label) + ] # per cell-type purity + + return np.mean(res) + + +def proximity_imputation(real_latent1, normed_gene_exp_1, real_latent2, k=4): + knn = KNeighborsRegressor(k, weights="distance") + y = knn.fit(real_latent1, normed_gene_exp_1).predict(real_latent2) + return y diff --git a/Singlecell_multi_omics/src/core/trainer.py b/Singlecell_multi_omics/src/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5860a1a8cf1513c8eb1b7e1fa1df7bdddf6ae21f --- /dev/null +++ b/Singlecell_multi_omics/src/core/trainer.py @@ -0,0 +1,457 @@ +import logging +import sys +import time + +from abc import abstractmethod +from collections import defaultdict, OrderedDict +from itertools import cycle + +import numpy as np +import torch + +from sklearn.model_selection._split import _validate_shuffle_split +from torch.utils.data.sampler import SubsetRandomSampler +from tqdm import trange + +from scMVP.inference.posterior import Posterior + +logger = logging.getLogger(__name__) + + +class Trainer: + r"""The abstract Trainer class for training a PyTorch model and monitoring its statistics. It should be + inherited at least with a .loss() function to be optimized in the training loop. + + Args: + :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI`` + :gene_dataset: A gene_dataset instance like ``CortexDataset()`` + :use_cuda: Default: ``True``. + :metrics_to_monitor: A list of the metrics to monitor. If not specified, will use the + ``default_metrics_to_monitor`` as specified in each . Default: ``None``. + :benchmark: if True, prevents statistics computation in the training. Default: ``False``. + :frequency: The frequency at which to keep track of statistics. Default: ``None``. + :early_stopping_metric: The statistics on which to perform early stopping. Default: ``None``. + :save_best_state_metric: The statistics on which we keep the network weights achieving the best store, and + restore them at the end of training. Default: ``None``. + :on: The data_loader name reference for the ``early_stopping_metric`` and ``save_best_state_metric``, that + should be specified if any of them is. Default: ``None``. + :show_progbar: If False, disables progress bar. + :seed: Random seed for train/test/validate split + """ + default_metrics_to_monitor = [] + + def __init__( + self, + model, + gene_dataset, + use_cuda=True, + metrics_to_monitor=None, + benchmark=False, + frequency=None, + weight_decay=1e-6, + early_stopping_kwargs=None, + data_loader_kwargs=None, + show_progbar=True, + seed=0, + ): + # handle mutable defaults + early_stopping_kwargs = ( + early_stopping_kwargs if early_stopping_kwargs else dict() + ) + data_loader_kwargs = data_loader_kwargs if data_loader_kwargs else dict() + + self.model = model + self.gene_dataset = gene_dataset + self._posteriors = OrderedDict() + self.seed = seed + + #self.data_loader_kwargs = {"batch_size": 256, "pin_memory": use_cuda} # 128 for batchsize in init + self.data_loader_kwargs = {"batch_size": 64, "pin_memory": use_cuda} # 128 for batchsize in init + self.data_loader_kwargs.update(data_loader_kwargs) + + self.weight_decay = weight_decay + self.benchmark = benchmark + self.epoch = -1 # epoch = self.epoch + 1 in compute metrics + self.training_time = 0 + # self.KL_divergence = -1 + # self.KL_divergence_max = 10000 + + if metrics_to_monitor is not None: + self.metrics_to_monitor = set(metrics_to_monitor) + else: + self.metrics_to_monitor = set(self.default_metrics_to_monitor) + + self.early_stopping = EarlyStopping(**early_stopping_kwargs) + + if self.early_stopping.early_stopping_metric: + self.metrics_to_monitor.add(self.early_stopping.early_stopping_metric) + + self.use_cuda = use_cuda and torch.cuda.is_available() + if self.use_cuda: + self.model.cuda() + + self.frequency = frequency if not benchmark else None + + self.history = defaultdict(list) + + self.best_state_dict = self.model.state_dict() + self.best_epoch = self.epoch + + self.show_progbar = show_progbar + + @torch.no_grad() + def compute_metrics(self): + begin = time.time() + epoch = self.epoch + 1 + if self.frequency and ( + epoch == 0 or epoch == self.n_epochs or (epoch % self.frequency == 0) + ): + with torch.set_grad_enabled(False): + self.model.eval() + logger.debug("\nEPOCH [%d/%d]: " % (epoch, self.n_epochs)) + + for name, posterior in self._posteriors.items(): + message = " ".join([s.capitalize() for s in name.split("_")[-2:]]) + if posterior.nb_cells < 5: + logging.debug( + message + " is too small to track metrics (<5 samples)" + ) + continue + if hasattr(posterior, "to_monitor"): + for metric in posterior.to_monitor: + if metric not in self.metrics_to_monitor: + logger.debug(message) + result = getattr(posterior, metric)() + self.history[metric + "_" + name] += [result] + for metric in self.metrics_to_monitor: + result = getattr(posterior, metric)() + self.history[metric + "_" + name] += [result] + self.model.train() + self.compute_metrics_time += time.time() - begin + + def train(self, n_epochs=20, lr=1e-3, eps=0.01, params=None): + begin = time.time() + self.model.train() + + if params is None: + params = filter(lambda p: p.requires_grad, self.model.parameters()) + + optimizer = self.optimizer = torch.optim.Adam( + params, lr=lr, eps=eps, weight_decay=self.weight_decay + ) + aa = self.model.parameters() + + self.compute_metrics_time = 0 + self.n_epochs = n_epochs + flag = True + + with trange( + n_epochs, desc="training", file=sys.stdout, disable=not self.show_progbar + ) as pbar: + # We have to use tqdm this way so it works in Jupyter notebook. + # See https://stackoverflow.com/questions/42212810/tqdm-in-jupyter-notebook + for self.epoch in pbar: + self.on_epoch_begin() + pbar.update(1) + for tensors_list in self.data_loaders_loop(): + if tensors_list[0][0].shape[0] < 3: + continue + loss = self.loss(*tensors_list) + print(loss) + # if self.KL_divergence > self.KL_divergence_max: + # break + #if self.epoch == 15 and flag: + # flag = False + # optimizer.add_param_group({'params': self.model.get_params()}) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if not self.on_epoch_end(): + break + + if self.early_stopping.save_best_state_metric is not None: + self.model.load_state_dict(self.best_state_dict) + self.compute_metrics() + + self.model.eval() + self.training_time += (time.time() - begin) - self.compute_metrics_time + if self.frequency: + logger.debug( + "\nTraining time: %i s. / %i epochs" + % (int(self.training_time), self.n_epochs) + ) + self.compute_metrics() + + + def on_epoch_begin(self): + pass + + def on_epoch_end(self): + self.compute_metrics() + on = self.early_stopping.on + early_stopping_metric = self.early_stopping.early_stopping_metric + save_best_state_metric = self.early_stopping.save_best_state_metric + if save_best_state_metric is not None and on is not None: + if self.early_stopping.update_state( + self.history[save_best_state_metric + "_" + on][-1] + ): + self.best_state_dict = self.model.state_dict() + self.best_epoch = self.epoch + + continue_training = True + if early_stopping_metric is not None and on is not None: + continue_training, reduce_lr = self.early_stopping.update( + self.history[early_stopping_metric + "_" + on][-1] + ) + if reduce_lr: + logger.info("Reducing LR.") + for param_group in self.optimizer.param_groups: + param_group["lr"] *= self.early_stopping.lr_factor + + # if self.KL_divergence > self.KL_divergence_max: + # continue_training = False + return continue_training + + @property + @abstractmethod + def posteriors_loop(self): + pass + + def data_loaders_loop(self): + """returns an zipped iterable corresponding to loss signature""" + data_loaders_loop = [self._posteriors[name] for name in self.posteriors_loop] + return zip( + data_loaders_loop[0], + *[cycle(data_loader) for data_loader in data_loaders_loop[1:]] + ) + + def register_posterior(self, name, value): + name = name.strip("_") + self._posteriors[name] = value + + def corrupt_posteriors( + self, rate=0.1, corruption="uniform", update_corruption=True + ): + if not hasattr(self.gene_dataset, "corrupted") and update_corruption: + self.gene_dataset.corrupt(rate=rate, corruption=corruption) + for name, posterior in self._posteriors.items(): + self.register_posterior(name, posterior.corrupted()) + + def uncorrupt_posteriors(self): + for name_, posterior in self._posteriors.items(): + self.register_posterior(name_, posterior.uncorrupted()) + + def __getattr__(self, name): + if "_posteriors" in self.__dict__: + _posteriors = self.__dict__["_posteriors"] + if name.strip("_") in _posteriors: + return _posteriors[name.strip("_")] + return object.__getattribute__(self, name) + + def __delattr__(self, name): + if name.strip("_") in self._posteriors: + del self._posteriors[name.strip("_")] + else: + object.__delattr__(self, name) + + def __setattr__(self, name, value): + if isinstance(value, Posterior): + name = name.strip("_") + self.register_posterior(name, value) + else: + object.__setattr__(self, name, value) + + def train_test_validation( + self, + model=None, + gene_dataset=None, + train_size=0.1, + test_size=None, + type_class=Posterior, + ): + """Creates posteriors ``train_set``, ``test_set``, ``validation_set``. + If ``train_size + test_size < 1`` then ``validation_set`` is non-empty. + + :param train_size: float, int, or None (default is 0.1) + :param test_size: float, int, or None (default is None) + """ + model = self.model if model is None and hasattr(self, "model") else model + gene_dataset = ( + self.gene_dataset + if gene_dataset is None and hasattr(self, "model") + else gene_dataset + ) + n = len(gene_dataset) + try: + n_train, n_test = _validate_shuffle_split(n, test_size, train_size) + except ValueError: + if train_size != 1.0: + raise ValueError( + "Choice of train_size={} and test_size={} not understood".format( + train_size, test_size + ) + ) + n_train, n_test = n, 0 + random_state = np.random.RandomState(seed=self.seed) + permutation = random_state.permutation(n) + indices_test = permutation[:n_test] + indices_train = permutation[n_test : (n_test + n_train)] + indices_validation = permutation[(n_test + n_train) :] + + return ( + self.create_posterior( + model, gene_dataset, indices=indices_train, type_class=type_class + ), + self.create_posterior( + model, gene_dataset, indices=indices_test, type_class=type_class + ), + self.create_posterior( + model, gene_dataset, indices=indices_validation, type_class=type_class + ), + ) + + def create_posterior( + self, + model=None, + gene_dataset=None, + shuffle=False, + indices=None, + type_class=Posterior, + ): + model = self.model if model is None and hasattr(self, "model") else model + gene_dataset = ( + self.gene_dataset + if gene_dataset is None and hasattr(self, "model") + else gene_dataset + ) + return type_class( + model, + gene_dataset, + shuffle=shuffle, + indices=indices, + use_cuda=self.use_cuda, + data_loader_kwargs=self.data_loader_kwargs, + ) + + +class SequentialSubsetSampler(SubsetRandomSampler): + def __init__(self, indices): + self.indices = np.sort(indices) + + def __iter__(self): + return iter(self.indices) + + +class EarlyStopping: + def __init__( + self, + early_stopping_metric: str = None, + save_best_state_metric: str = None, + on: str = "test_set", + patience: int = 15, + threshold: int = 3, + benchmark: bool = False, + reduce_lr_on_plateau: bool = False, + lr_patience: int = 10, + lr_factor: float = 0.5, + posterior_class=Posterior, + ): + self.benchmark = benchmark + self.patience = patience + self.threshold = threshold + self.epoch = 0 + self.wait = 0 + self.wait_lr = 0 + self.mode = ( + getattr(posterior_class, early_stopping_metric).mode + if early_stopping_metric is not None + else None + ) + # We set the best to + inf because we're dealing with a loss we want to minimize + self.current_performance = np.inf + self.best_performance = np.inf + self.best_performance_state = np.inf + # If we want to maximize, we start at - inf + if self.mode == "max": + self.best_performance *= -1 + self.current_performance *= -1 + self.mode_save_state = ( + getattr(Posterior, save_best_state_metric).mode + if save_best_state_metric is not None + else None + ) + if self.mode_save_state == "max": + self.best_performance_state *= -1 + + self.early_stopping_metric = early_stopping_metric + self.save_best_state_metric = save_best_state_metric + self.on = on + self.reduce_lr_on_plateau = reduce_lr_on_plateau + self.lr_patience = lr_patience + self.lr_factor = lr_factor + + def update(self, scalar): + self.epoch += 1 + if self.benchmark: + continue_training = True + reduce_lr = False + elif self.wait >= self.patience: + continue_training = False + reduce_lr = False + else: + # Check if we should reduce the learning rate + if not self.reduce_lr_on_plateau: + reduce_lr = False + elif self.wait_lr >= self.lr_patience: + reduce_lr = True + self.wait_lr = 0 + else: + reduce_lr = False + # Shift + self.current_performance = scalar + + # Compute improvement + if self.mode == "max": + improvement = self.current_performance - self.best_performance + elif self.mode == "min": + improvement = self.best_performance - self.current_performance + else: + raise NotImplementedError("Unknown optimization mode") + + # updating best performance + if improvement > 0: + self.best_performance = self.current_performance + + if improvement < self.threshold: + self.wait += 1 + self.wait_lr += 1 + else: + self.wait = 0 + self.wait_lr = 0 + + continue_training = True + if not continue_training: + # FIXME: log total number of epochs run + logger.info( + "\nStopping early: no improvement of more than " + + str(self.threshold) + + " nats in " + + str(self.patience) + + " epochs" + ) + logger.info( + "If the early stopping criterion is too strong, " + "please instantiate it with different parameters in the train method." + ) + return continue_training, reduce_lr + + def update_state(self, scalar): + improved = ( + self.mode_save_state == "max" and scalar - self.best_performance_state > 0 + ) or ( + self.mode_save_state == "min" and self.best_performance_state - scalar > 0 + ) + if improved: + self.best_performance_state = scalar + return improved diff --git a/Singlecell_multi_omics/src/scGPT/.keep b/Singlecell_multi_omics/src/scGPT/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/scGPT/model/.keep b/Singlecell_multi_omics/src/scGPT/model/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Singlecell_multi_omics/src/scGPT/model/__init__.py b/Singlecell_multi_omics/src/scGPT/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4411db4709c1c7fc43b8a484603d3a3c68894b27 --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/__init__.py @@ -0,0 +1,11 @@ +from .model import ( + TransformerModel, + FlashTransformerEncoderLayer, + GeneEncoder, + AdversarialDiscriminator, + MVCDecoder, +) +from .generation_model import * +from .multiomic_model import MultiOmicTransformerModel +from .dsbn import * +from .grad_reverse import * diff --git a/Singlecell_multi_omics/src/scGPT/model/dsbn.py b/Singlecell_multi_omics/src/scGPT/model/dsbn.py new file mode 100644 index 0000000000000000000000000000000000000000..324b128e6c0dc23cf4d6bebc3819109df9901c5f --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/dsbn.py @@ -0,0 +1,82 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + +# The code is modified from https://github.com/wgchang/DSBN/blob/master/model/dsbn.py +class _DomainSpecificBatchNorm(nn.Module): + _version = 2 + + def __init__( + self, + num_features: int, + num_domains: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + super(_DomainSpecificBatchNorm, self).__init__() + self._cur_domain = None + self.num_domains = num_domains + self.bns = nn.ModuleList( + [ + self.bn_handle(num_features, eps, momentum, affine, track_running_stats) + for _ in range(num_domains) + ] + ) + + @property + def bn_handle(self) -> nn.Module: + raise NotImplementedError + + @property + def cur_domain(self) -> Optional[int]: + return self._cur_domain + + @cur_domain.setter + def cur_domain(self, domain_label: int): + self._cur_domain = domain_label + + def reset_running_stats(self): + for bn in self.bns: + bn.reset_running_stats() + + def reset_parameters(self): + for bn in self.bns: + bn.reset_parameters() + + def _check_input_dim(self, input: torch.Tensor): + raise NotImplementedError + + def forward(self, x: torch.Tensor, domain_label: int) -> torch.Tensor: + self._check_input_dim(x) + if domain_label >= self.num_domains: + raise ValueError( + f"Domain label {domain_label} exceeds the number of domains {self.num_domains}" + ) + bn = self.bns[domain_label] + self.cur_domain = domain_label + return bn(x) + + +class DomainSpecificBatchNorm1d(_DomainSpecificBatchNorm): + @property + def bn_handle(self) -> nn.Module: + return nn.BatchNorm1d + + def _check_input_dim(self, input: torch.Tensor): + if input.dim() > 3: + raise ValueError( + "expected at most 3D input (got {}D input)".format(input.dim()) + ) + + +class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm): + @property + def bn_handle(self) -> nn.Module: + return nn.BatchNorm2d + + def _check_input_dim(self, input: torch.Tensor): + if input.dim() != 4: + raise ValueError("expected 4D input (got {}D input)".format(input.dim())) diff --git a/Singlecell_multi_omics/src/scGPT/model/generation_model.py b/Singlecell_multi_omics/src/scGPT/model/generation_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7c65d9aebf34ea42bde10f6aeffc53305b9dfb4d --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/generation_model.py @@ -0,0 +1,549 @@ +import os +import math +from typing import Mapping, Optional, Tuple, Any, Union + +import torch +from torch import nn, Tensor +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from torch.distributions import Bernoulli +from torch.utils.data import dataset +from tqdm import trange + +from .model import ( + ExprDecoder, + MVCDecoder, + ContinuousValueEncoder, + FastTransformerEncoderWrapper, + FlashTransformerEncoderLayer, +) +from ..utils import map_raw_id_to_vocab_id +from .. import logger + + +class TransformerGenerator(nn.Module): + def __init__( + self, + ntoken: int, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + nlayers_cls: int, + n_cls: int, + vocab: Any, + dropout: float = 0.5, + pad_token: str = "", + pad_value: int = 0, + pert_pad_id: int = 2, + do_mvc: bool = False, + domain_spec_batchnorm: Union[bool, str] = False, + n_input_bins: Optional[int] = 0, + cell_emb_style: str = "cls", + mvc_decoder_style: str = "inner product", + decoder_activation: Optional[str] = None, + decoder_adaptive_bias: bool = False, + ecs_threshold: float = 0.3, + explicit_zero_prob: bool = False, + use_fast_transformer: bool = False, + fast_transformer_backend: str = "flash", + pre_norm: bool = False, + ): + super().__init__() + self.model_type = "Transformer" + self.d_model = d_model + self.pad_token_id = vocab[pad_token] + self.pad_value = pad_value + self.pert_pad_id = pert_pad_id + self.ecs_threshold = ecs_threshold + self.domain_spec_batchnorm = domain_spec_batchnorm + self.n_input_bins = n_input_bins + self.cell_emb_style = cell_emb_style + self.explicit_zero_prob = explicit_zero_prob + self.norm_scheme = "pre" if pre_norm else "post" + if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: + raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") + if use_fast_transformer: + try: + from flash_attn.flash_attention import FlashMHA + except ImportError: + import warnings + + warnings.warn( + "flash-attn is not installed, using pytorch transformer instead. " + "Set use_fast_transformer=False to avoid this warning. " + "Installing flash-attn is highly recommended." + ) + use_fast_transformer = False + self.use_fast_transformer = use_fast_transformer + + self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) + self.value_encoder = ContinuousValueEncoder(d_model, dropout) + self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id) + + # print("Using simple batchnorm instead of domain specific batchnorm") + # self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) + + if use_fast_transformer: + if fast_transformer_backend == "linear": + self.transformer_encoder = FastTransformerEncoderWrapper( + d_model, nhead, d_hid, nlayers, dropout + ) + elif fast_transformer_backend == "flash": + encoder_layers = FlashTransformerEncoderLayer( + d_model, + nhead, + d_hid, + dropout, + batch_first=True, + norm_scheme=self.norm_scheme, + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + else: + encoder_layers = TransformerEncoderLayer( + d_model, nhead, d_hid, dropout, batch_first=True + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + # self.decoder = nn.Linear(d_model, 1) + self.decoder = AffineExprDecoder( + d_model, + explicit_zero_prob=explicit_zero_prob, + activation=decoder_activation, + adaptive_bias=decoder_adaptive_bias, + ) + self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) + if do_mvc: + self.mvc_decoder = MVCDecoder( + d_model, + arch_style=mvc_decoder_style, + explicit_zero_prob=explicit_zero_prob, + ) + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + self.encoder.embedding.weight.data.uniform_(-initrange, initrange) + + def _encode( + self, + src: Tensor, + values: Tensor, + input_pert_flags, + src_key_padding_mask: Tensor, + ) -> Tensor: + src = self.encoder(src) # (batch, seq_len, embsize) + self.cur_gene_token_embs = src + values = self.value_encoder(values) # (batch, seq_len, embsize) + perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) + total_embs = src + values + perts + + # total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + output = self.transformer_encoder( + total_embs, src_key_padding_mask=src_key_padding_mask + ) + return output # (batch, seq_len, embsize) + + def _get_cell_emb_from_layer( + self, layer_output: Tensor, weights: Tensor = None + ) -> Tensor: + """ + Args: + layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) + weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used + when :attr:`self.cell_emb_style` is "w-pool". + + Returns: + :obj:`Tensor`: shape (batch, embsize) + """ + if self.cell_emb_style == "cls": + cell_emb = layer_output[:, 0, :] # (batch, embsize) + elif self.cell_emb_style == "avg-pool": + cell_emb = torch.mean(layer_output, dim=1) + elif self.cell_emb_style == "w-pool": + if weights is None: + raise ValueError("weights is required when cell_emb_style is w-pool") + if weights.dim() != 2: + raise ValueError("weights should be 2D") + cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) + cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) + + return cell_emb + + def forward( + self, + src: Tensor, + values: Tensor, + input_pert_flags: Tensor, + src_key_padding_mask: Tensor, + CLS: bool = False, + CCE: bool = False, + MVC: bool = False, + ECS: bool = False, + do_sample: bool = False, + ) -> Mapping[str, Tensor]: + """ + Args: + src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] + values (:obj:`Tensor`): token values, shape [batch_size, seq_len] + src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, + seq_len] + CLS (:obj:`bool`): if True, return the celltype classification objective + (CLS) output + CCE (:obj:`bool`): if True, return the contrastive cell embedding objective + (CCE) output + MVC (:obj:`bool`): if True, return the masked value prediction for cell + embedding MVC output + ECS (:obj:`bool`): if True, return the elastic cell similarity objective + (ECS) output. + + Returns: + dict of output Tensors. + """ + if self.explicit_zero_prob and not do_sample and not self.training: + do_sample = True + logger.warning("Auto set do_sample to True when model is in eval mode.") + + # binning input gene values + if self.n_input_bins > 0: + from ..preprocess import binning + + processed_values = torch.stack( + [binning(row, n_bins=self.n_input_bins) for row in values], dim=0 + ).to(values.device) + else: + processed_values = values + + transformer_output = self._encode( + src, processed_values, input_pert_flags, src_key_padding_mask + ) + output = {} + mlm_output = self.decoder(transformer_output, values) + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) + output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] + else: + output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mlm_zero_probs"] = mlm_output["zero_probs"] + + cell_emb = self._get_cell_emb_from_layer(transformer_output, values) + if CLS: + output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) + if MVC: + mvc_output = self.mvc_decoder( + cell_emb, + self.cur_gene_token_embs, + ) # (batch, seq_len) + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) + output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] + else: + output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mvc_zero_probs"] = mvc_output["zero_probs"] + if ECS: + # Here using customized cosine similarity instead of F.cosine_similarity + # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 + # normalize the embedding + cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) + cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) + + # mask out diagnal elements + mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) + cos_sim = cos_sim.masked_fill(mask, 0.0) + # only optimize positive similarities + cos_sim = F.relu(cos_sim) + + output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) + + return output + + def encode_batch( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_size: int, + output_to_cpu: bool = True, + ) -> Tensor: + """ + Args: + src: Tensor, shape [N, seq_len] + values: Tensor, shape [N, seq_len] + src_key_padding_mask: Tensor, shape [N, seq_len] + + Returns: + output Tensor of shape [N, seq_len, embsize] + """ + outputs = [] + N = src.size(0) + device = next(self.parameters()).device + for i in trange(0, N, batch_size): + output = self._encode( + src[i : i + batch_size].to(device), + values[i : i + batch_size].to(device), + src_key_padding_mask[i : i + batch_size].to(device), + ) + if output_to_cpu: + output = output.cpu() + outputs.append(output) + return torch.cat(outputs, dim=0) + + def pred_perturb( + self, + batch_data, + include_zero_gene="batch-wise", + gene_ids=None, + amp=True, + ) -> Tensor: + """ + Args: + batch_data: a dictionary of input data with keys. + + Returns: + output Tensor of shape [N, seq_len] + """ + self.eval() + device = next(self.parameters()).device + batch_data.to(device) + batch_size = len(batch_data.pert) + x: torch.Tensor = batch_data.x + ori_gene_values = x[:, 0].view(batch_size, -1) # (batch_size, n_genes) + pert_flags = x[:, 1].long().view(batch_size, -1) + + if include_zero_gene in ["all", "batch-wise"]: + assert gene_ids is not None + if include_zero_gene == "all": + input_gene_ids = torch.arange(ori_gene_values.size(1), device=device) + else: # batch-wise + input_gene_ids = ( + ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0] + ) + input_values = ori_gene_values[:, input_gene_ids] + input_pert_flags = pert_flags[:, input_gene_ids] + + mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) + mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) + + src_key_padding_mask = torch.zeros_like( + input_values, dtype=torch.bool, device=device + ) + with torch.cuda.amp.autocast(enabled=amp): + output_dict = self( + mapped_input_gene_ids, + input_values, + input_pert_flags, + src_key_padding_mask=src_key_padding_mask, + CLS=False, + CCE=False, + MVC=False, + ECS=False, + do_sample=True, + ) + output_values = output_dict["mlm_output"].float() + pred_gene_values = torch.zeros_like(ori_gene_values) + pred_gene_values[:, input_gene_ids] = output_values + return pred_gene_values + + +def generate_square_subsequent_mask(sz: int) -> Tensor: + """Generates an upper-triangular matrix of -inf, with zeros on diag.""" + return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) + + +class GeneEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class AffineExprDecoder(nn.Module): + def __init__( + self, + d_model: int, + explicit_zero_prob: bool = False, + activation: Optional[str] = None, + tanh_coeff: bool = False, + adaptive_bias: bool = False, + ): + """ + Predict the expression value of each gene in an affine like form of Ax + b. + This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b. + + Args: + d_model: The embedding dimension. + explicit_zero_prob: If True, predict the probability of each gene being + zero. + activation: The activation function for the coefficient A and bias b. + tanh_coeff: If True, use tanh activation for the coefficient A. + adaptive_bias: If True, use a learnable bias for the bias b. + """ + super().__init__() + self.explicit_zero_prob = explicit_zero_prob + self.tanh_coeff = tanh_coeff + self.adaptive_bias = adaptive_bias + self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) + self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) + + self.activation = activation + if activation is not None: + assert hasattr(nn, activation), f"Unknown activation: {activation}" + self.activation = getattr(nn, activation)() + + def forward(self, x: Tensor, values: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, seq_len, embsize] + values: Tensor, shape [batch_size, seq_len] + + Returns: + output Tensor of shape [batch_size, seq_len] + """ + coeff = self.coeff_decoder(x) + bias = self.bias_decoder(x) + + if self.activation is not None: + coeff["pred"] = self.activation(coeff["pred"]) + bias["pred"] = self.activation(bias["pred"]) + + # if self.tanh_coeff: + # coeff["pred"] = 1 + torch.tanh(coeff["pred"]) + + if self.adaptive_bias: + # bias["pred"] = bias["pred"] * values.mean(dim=1, keepdim=True) + non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum( + dim=1, keepdim=True + ) + bias["pred"] = bias["pred"] * non_zero_value_mean + + if self.explicit_zero_prob: + return { + "pred": coeff["pred"] * values + bias["pred"], + "zero_probs": coeff["zero_probs"], + } + + return dict(pred=coeff["pred"] * values + bias["pred"]) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + zero_out_idx: Optional[int] = None, + ): + """ + Generic token embedding module. + + Args: + num_embeddings: The number of tokens. + embedding_dim: The embedding dimension. + padding_idx: The index of the padding token. + zero_out_idx: Indicate if any idx embedding should be zero vector. + """ + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + self.zero_out_idx = zero_out_idx + if zero_out_idx is not None: + self._fill_idx_with_zero(zero_out_idx) + zero_vector = self(zero_out_idx) + assert torch.all(zero_vector == 0.0) + assert not zero_vector.requires_grad + + def _fill_idx_with_zero(self, idx) -> None: + with torch.no_grad(): + self.embedding.weight[idx].fill_(0) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + x = x + self.pe[: x.size(0)] + return self.dropout(x) + + +class Similarity(nn.Module): + """ + Dot product or cosine similarity + """ + + def __init__(self, temp): + super().__init__() + self.temp = temp + self.cos = nn.CosineSimilarity(dim=-1) + + def forward(self, x, y): + return self.cos(x, y) / self.temp + + +class ClsDecoder(nn.Module): + """ + Decoder for classification task. + """ + + def __init__( + self, + d_model: int, + n_cls: int, + nlayers: int = 3, + activation: callable = nn.ReLU, + ): + super().__init__() + # module list + self._decoder = nn.ModuleList() + for i in range(nlayers - 1): + self._decoder.append(nn.Linear(d_model, d_model)) + self._decoder.append(activation()) + self._decoder.append(nn.LayerNorm(d_model)) + self.out_layer = nn.Linear(d_model, n_cls) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, embsize] + """ + for layer in self._decoder: + x = layer(x) + return self.out_layer(x) diff --git a/Singlecell_multi_omics/src/scGPT/model/grad_reverse.py b/Singlecell_multi_omics/src/scGPT/model/grad_reverse.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc7e6127717d1682253fb6d278d2f389ba974c0 --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/grad_reverse.py @@ -0,0 +1,17 @@ +import torch +from torch.autograd import Function + + +class GradReverse(Function): + @staticmethod + def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor: + ctx.lambd = lambd + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + return grad_output.neg() * ctx.lambd, None + + +def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor: + return GradReverse.apply(x, lambd) diff --git a/Singlecell_multi_omics/src/scGPT/model/model.py b/Singlecell_multi_omics/src/scGPT/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0639bee3ca200966291dfed65b557c648c86ee --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/model.py @@ -0,0 +1,1045 @@ +import gc +import math +from typing import Dict, Mapping, Optional, Tuple, Any, Union + +import torch +import numpy as np +from torch import nn, Tensor +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from torch.distributions import Bernoulli +from tqdm import trange + +try: + from flash_attn.flash_attention import FlashMHA + + flash_attn_available = True +except ImportError: + import warnings + + warnings.warn("flash_attn is not installed") + flash_attn_available = False + +from .dsbn import DomainSpecificBatchNorm1d +from .grad_reverse import grad_reverse + + +class TransformerModel(nn.Module): + def __init__( + self, + ntoken: int, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + nlayers_cls: int = 3, + n_cls: int = 1, + vocab: Any = None, + dropout: float = 0.5, + pad_token: str = "", + pad_value: int = 0, + do_mvc: bool = False, + do_dab: bool = False, + use_batch_labels: bool = False, + num_batch_labels: Optional[int] = None, + domain_spec_batchnorm: Union[bool, str] = False, + input_emb_style: str = "continuous", + n_input_bins: Optional[int] = None, + cell_emb_style: str = "cls", + mvc_decoder_style: str = "inner product", + ecs_threshold: float = 0.3, + explicit_zero_prob: bool = False, + use_fast_transformer: bool = False, + fast_transformer_backend: str = "flash", + pre_norm: bool = False, + ): + super().__init__() + self.model_type = "Transformer" + self.d_model = d_model + self.do_dab = do_dab + self.ecs_threshold = ecs_threshold + self.use_batch_labels = use_batch_labels + self.domain_spec_batchnorm = domain_spec_batchnorm + self.input_emb_style = input_emb_style + self.cell_emb_style = cell_emb_style + self.explicit_zero_prob = explicit_zero_prob + self.norm_scheme = "pre" if pre_norm else "post" + if self.input_emb_style not in ["category", "continuous", "scaling"]: + raise ValueError( + f"input_emb_style should be one of category, continuous, scaling, " + f"got {input_emb_style}" + ) + if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: + raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") + if use_fast_transformer: + if not flash_attn_available: + warnings.warn( + "flash-attn is not installed, using pytorch transformer instead. " + "Set use_fast_transformer=False to avoid this warning. " + "Installing flash-attn is highly recommended." + ) + use_fast_transformer = False + self.use_fast_transformer = use_fast_transformer + + # TODO: add dropout in the GeneEncoder + self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) + + # Value Encoder, NOTE: the scaling style is also handled in _encode method + if input_emb_style == "continuous": + self.value_encoder = ContinuousValueEncoder(d_model, dropout) + elif input_emb_style == "category": + assert n_input_bins > 0 + self.value_encoder = CategoryValueEncoder( + n_input_bins, d_model, padding_idx=pad_value + ) + else: + self.value_encoder = nn.Identity() # nn.Softmax(dim=1) + # TODO: consider row-wise normalization or softmax + # TODO: Correct handle the mask_value when using scaling + + # Batch Encoder + if use_batch_labels: + self.batch_encoder = BatchLabelEncoder(num_batch_labels, d_model) + + if domain_spec_batchnorm is True or domain_spec_batchnorm == "dsbn": + use_affine = True if domain_spec_batchnorm == "do_affine" else False + print(f"Use domain specific batchnorm with affine={use_affine}") + self.dsbn = DomainSpecificBatchNorm1d( + d_model, num_batch_labels, eps=6.1e-5, affine=use_affine + ) + elif domain_spec_batchnorm == "batchnorm": + print("Using simple batchnorm instead of domain specific batchnorm") + self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) + + if use_fast_transformer: + if fast_transformer_backend == "linear": + self.transformer_encoder = FastTransformerEncoderWrapper( + d_model, nhead, d_hid, nlayers, dropout + ) + elif fast_transformer_backend == "flash": + encoder_layers = FlashTransformerEncoderLayer( + d_model, + nhead, + d_hid, + dropout, + batch_first=True, + norm_scheme=self.norm_scheme, + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + else: + encoder_layers = TransformerEncoderLayer( + d_model, nhead, d_hid, dropout, batch_first=True + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + self.decoder = ExprDecoder( + d_model, + explicit_zero_prob=explicit_zero_prob, + use_batch_labels=use_batch_labels, + ) + self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) + if do_mvc: + self.mvc_decoder = MVCDecoder( + d_model, + arch_style=mvc_decoder_style, + explicit_zero_prob=explicit_zero_prob, + use_batch_labels=use_batch_labels, + ) + + if do_dab: + self.grad_reverse_discriminator = AdversarialDiscriminator( + d_model, + n_cls=num_batch_labels, + reverse_grad=True, + ) + + self.sim = Similarity(temp=0.5) # TODO: auto set temp + self.creterion_cce = nn.CrossEntropyLoss() + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + # TODO: check if this initialization is helpful and shall we apply to all? + self.encoder.embedding.weight.data.uniform_(-initrange, initrange) + + def _encode( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_labels: Optional[Tensor] = None, # (batch,) + ) -> Tensor: + self._check_batch_labels(batch_labels) + + src = self.encoder(src) # (batch, seq_len, embsize) + self.cur_gene_token_embs = src + + values = self.value_encoder(values) # (batch, seq_len, embsize) + if self.input_emb_style == "scaling": + values = values.unsqueeze(2) + total_embs = src * values + else: + total_embs = src + values + + if getattr(self, "dsbn", None) is not None: + batch_label = int(batch_labels[0].item()) + total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( + 0, 2, 1 + ) # the batch norm always works on dim 1 + elif getattr(self, "bn", None) is not None: + total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + + output = self.transformer_encoder( + total_embs, src_key_padding_mask=src_key_padding_mask + ) + return output # (batch, seq_len, embsize) + + def _get_cell_emb_from_layer( + self, layer_output: Tensor, weights: Tensor = None + ) -> Tensor: + """ + Args: + layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) + weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used + when :attr:`self.cell_emb_style` is "w-pool". + + Returns: + :obj:`Tensor`: shape (batch, embsize) + """ + if self.cell_emb_style == "cls": + cell_emb = layer_output[:, 0, :] # (batch, embsize) + elif self.cell_emb_style == "avg-pool": + cell_emb = torch.mean(layer_output, dim=1) + elif self.cell_emb_style == "w-pool": + if weights is None: + raise ValueError("weights is required when cell_emb_style is w-pool") + if weights.dim() != 2: + raise ValueError("weights should be 2D") + cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) + cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) + + return cell_emb + + def _check_batch_labels(self, batch_labels: Tensor) -> None: + if self.use_batch_labels or self.domain_spec_batchnorm: + assert batch_labels is not None + elif batch_labels is not None: + raise ValueError( + "batch_labels should only be provided when `self.use_batch_labels`" + " or `self.domain_spec_batchnorm` is True" + ) + + def generate( + self, + cell_emb: Tensor, + src: Tensor, + values: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + gen_iters: int = 1, + batch_labels: Optional[Tensor] = None, # (batch,) + ) -> Tensor: + """ + Args: + cell_emb(:obj:`Tensor`): shape (batch, embsize) + src(:obj:`Tensor`): shape (batch, seq_len) + values(:obj:`Tensor`): shape (batch, seq_len), optional + src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional + gen_iters(:obj:`int`): number of generation iterations + batch_labels(:obj:`Tensor`): shape (batch,), optional + """ + # TODO: should have a tag indicate the generation mode + # TODO: if gen_iters > 1, should have a tag indicate the current iteration + try: + self._check_batch_labels(batch_labels) + except: + import warnings + + warnings.warn( + "batch_labels is required but not provided, using zeros instead" + ) + batch_labels = torch.zeros( + cell_emb.shape[0], dtype=torch.long, device=cell_emb.device + ) + + src = self.encoder(src) # (batch, seq_len, embsize) + + if values is not None: + values = self.value_encoder(values) # (batch, seq_len, embsize) + if self.input_emb_style == "scaling": + values = values.unsqueeze(2) + total_embs = src * values + else: + total_embs = src + values + else: + total_embs = src + + if getattr(self, "dsbn", None) is not None: + batch_label = int(batch_labels[0].item()) + total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( + 0, 2, 1 + ) # the batch norm always works on dim 1 + elif getattr(self, "bn", None) is not None: + total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + + total_embs[:, 0, :] = cell_emb + + if src_key_padding_mask is None: + src_key_padding_mask = torch.zeros( + total_embs.shape[:2], dtype=torch.bool, device=total_embs.device + ) + transformer_output = self.transformer_encoder( + total_embs, src_key_padding_mask=src_key_padding_mask + ) + + if self.use_batch_labels: + batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) + mlm_output = self.decoder( + transformer_output + if not self.use_batch_labels + else torch.cat( + [ + transformer_output, + batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), + ], + dim=2, + ), + # else transformer_output + batch_emb.unsqueeze(1), + ) + output = mlm_output["pred"] # (batch, seq_len) + + return output # (batch, seq_len) + + def forward( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_labels: Optional[Tensor] = None, + CLS: bool = False, + CCE: bool = False, + MVC: bool = False, + ECS: bool = False, + do_sample: bool = False, + ) -> Mapping[str, Tensor]: + """ + Args: + src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] + values (:obj:`Tensor`): token values, shape [batch_size, seq_len] + src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, + seq_len] + batch_labels (:obj:`Tensor`): batch labels, shape [batch_size] + CLS (:obj:`bool`): if True, return the celltype classification objective + (CLS) output + CCE (:obj:`bool`): if True, return the contrastive cell embedding objective + (CCE) output + MVC (:obj:`bool`): if True, return the masked value prediction for cell + embedding MVC output + ECS (:obj:`bool`): if True, return the elastic cell similarity objective + (ECS) output. + + Returns: + dict of output Tensors. + """ + transformer_output = self._encode( + src, values, src_key_padding_mask, batch_labels + ) + if self.use_batch_labels: + batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) + + output = {} + mlm_output = self.decoder( + transformer_output + if not self.use_batch_labels + else torch.cat( + [ + transformer_output, + batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), + ], + dim=2, + ), + # else transformer_output + batch_emb.unsqueeze(1), + ) + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) + output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] + else: + output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mlm_zero_probs"] = mlm_output["zero_probs"] + + cell_emb = self._get_cell_emb_from_layer(transformer_output, values) + output["cell_emb"] = cell_emb + + if CLS: + output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) + if CCE: + cell1 = cell_emb + transformer_output2 = self._encode( + src, values, src_key_padding_mask, batch_labels + ) + cell2 = self._get_cell_emb_from_layer(transformer_output2) + + # Gather embeddings from all devices if distributed training + if dist.is_initialized() and self.training: + cls1_list = [ + torch.zeros_like(cell1) for _ in range(dist.get_world_size()) + ] + cls2_list = [ + torch.zeros_like(cell2) for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list=cls1_list, tensor=cell1.contiguous()) + dist.all_gather(tensor_list=cls2_list, tensor=cell2.contiguous()) + + # NOTE: all_gather results have no gradients, so replace the item + # of the current rank with the original tensor to keep gradients. + # See https://github.com/princeton-nlp/SimCSE/blob/main/simcse/models.py#L186 + cls1_list[dist.get_rank()] = cell1 + cls2_list[dist.get_rank()] = cell2 + + cell1 = torch.cat(cls1_list, dim=0) + cell2 = torch.cat(cls2_list, dim=0) + # TODO: should detach the second run cls2? Can have a try + cos_sim = self.sim(cell1.unsqueeze(1), cell2.unsqueeze(0)) # (batch, batch) + labels = torch.arange(cos_sim.size(0)).long().to(cell1.device) + output["loss_cce"] = self.creterion_cce(cos_sim, labels) + if MVC: + mvc_output = self.mvc_decoder( + cell_emb + if not self.use_batch_labels + else torch.cat([cell_emb, batch_emb], dim=1), + # else cell_emb + batch_emb, + self.cur_gene_token_embs, + ) + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) + output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] + else: + output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mvc_zero_probs"] = mvc_output["zero_probs"] + if ECS: + # Here using customized cosine similarity instead of F.cosine_similarity + # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 + # normalize the embedding + cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) + cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) + + # mask out diagnal elements + mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) + cos_sim = cos_sim.masked_fill(mask, 0.0) + # only optimize positive similarities + cos_sim = F.relu(cos_sim) + + output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) + + if self.do_dab: + output["dab_output"] = self.grad_reverse_discriminator(cell_emb) + + return output + + def encode_batch( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_size: int, + batch_labels: Optional[Tensor] = None, + output_to_cpu: bool = True, + time_step: Optional[int] = None, + return_np: bool = False, + ) -> Tensor: + """ + Args: + src (Tensor): shape [N, seq_len] + values (Tensor): shape [N, seq_len] + src_key_padding_mask (Tensor): shape [N, seq_len] + batch_size (int): batch size for encoding + batch_labels (Tensor): shape [N, n_batch_labels] + output_to_cpu (bool): whether to move the output to cpu + time_step (int): the time step index in the transformer output to return. + The time step is along the second dimenstion. If None, return all. + return_np (bool): whether to return numpy array + + Returns: + output Tensor of shape [N, seq_len, embsize] + """ + N = src.size(0) + device = next(self.parameters()).device + + # initialize the output tensor + array_func = np.zeros if return_np else torch.zeros + float32_ = np.float32 if return_np else torch.float32 + shape = ( + (N, self.d_model) + if time_step is not None + else (N, src.size(1), self.d_model) + ) + outputs = array_func(shape, dtype=float32_) + + for i in trange(0, N, batch_size): + raw_output = self._encode( + src[i : i + batch_size].to(device), + values[i : i + batch_size].to(device), + src_key_padding_mask[i : i + batch_size].to(device), + batch_labels[i : i + batch_size].to(device) + if batch_labels is not None + else None, + ) + output = raw_output.detach() + if output_to_cpu: + output = output.cpu() + if return_np: + output = output.numpy() + if time_step is not None: + output = output[:, time_step, :] + outputs[i : i + batch_size] = output + + return outputs + + +def generate_square_subsequent_mask(sz: int) -> Tensor: + """Generates an upper-triangular matrix of -inf, with zeros on diag.""" + return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) + + +class FastTransformerEncoderWrapper(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + dropout: float = 0.5, + ): + super().__init__() + self.fast_transformer_encoder = self.build_fast_transformer_encoder( + d_model, nhead, d_hid, nlayers, dropout + ) + + @staticmethod + def build_fast_transformer_encoder( + d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float + ) -> nn.Module: + from fast_transformers.builders import TransformerEncoderBuilder + + if d_model % nhead != 0: + raise ValueError( + f"d_model must be divisible by nhead, " + f"got d_model={d_model} and nhead={nhead}" + ) + builder = TransformerEncoderBuilder.from_kwargs( + n_layers=nlayers, + n_heads=nhead, + query_dimensions=d_model // nhead, + value_dimensions=d_model // nhead, + feed_forward_dimensions=d_hid, + attention_type="linear", + attention_dropout=dropout, + dropout=dropout, + activation="gelu", + ) + assert builder.attention_type == "linear" + return builder.get() + + @staticmethod + def build_length_mask( + src: Tensor, + src_key_padding_mask: torch.BoolTensor, + ) -> "LengthMask": + from fast_transformers.masking import LengthMask + + seq_len = src.shape[1] + num_paddings = src_key_padding_mask.sum(dim=1) + actual_seq_len = seq_len - num_paddings # (N,) + length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device) + + if src_key_padding_mask[length_mask.bool_matrix].sum() != 0: + raise ValueError( + "Found padding tokens in the middle of the sequence. " + "src_key_padding_mask and length_mask are not compatible." + ) + return length_mask + + def forward( + self, + src: Tensor, + src_key_padding_mask: torch.BoolTensor, + ) -> Tensor: + """ + Args: + src: Tensor, shape [N, seq_len, embsize] + src_key_padding_mask: Tensor, shape [N, seq_len] + + Returns: + output Tensor of shape [N, seq_len, embsize] + """ + if src_key_padding_mask.shape != src.shape[:2]: + raise ValueError( + f"src_key_padding_mask shape {src_key_padding_mask.shape} " + f"does not match first two dims of src shape {src.shape[:2]}" + ) + + if src_key_padding_mask.dtype != torch.bool: + raise ValueError( + f"src_key_padding_mask needs to be of type torch.bool, " + f"got {src_key_padding_mask.dtype}" + ) + + length_mask = self.build_length_mask(src, src_key_padding_mask) + output = self.fast_transformer_encoder(src, length_mask=length_mask) + return output + + +class FlashTransformerEncoderLayer(nn.Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + The class is modified from torch.nn.TransformerEncoderLayer to support the + FlashAttention. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False``. + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + + Alternatively, when ``batch_first`` is ``True``: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + """ + __constants__ = ["batch_first"] + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + layer_norm_eps=1e-5, + batch_first=True, + device=None, + dtype=None, + norm_scheme="post", # "pre" or "post" + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = FlashMHA( + embed_dim=d_model, + num_heads=nhead, + batch_first=batch_first, + attention_dropout=dropout, + **factory_kwargs, + ) + # Version compatibility workaround + if not hasattr(self.self_attn, "batch_first"): + self.self_attn.batch_first = batch_first + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = self._get_activation_fn(activation) + self.norm_scheme = norm_scheme + if self.norm_scheme not in ["pre", "post"]: + raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") + + @staticmethod + def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super().__setstate__(state) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if src_mask is not None: + raise ValueError("FlashTransformerEncoderLayer does not support src_mask") + + if not src_key_padding_mask.any().item(): + # no padding tokens in src + src_key_padding_mask_ = None + else: + if src_key_padding_mask.dtype != torch.bool: + src_key_padding_mask = src_key_padding_mask.bool() + # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite + src_key_padding_mask_ = ~src_key_padding_mask + + if self.norm_scheme == "pre": + src = self.norm1(src) + src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] + src = src + self.dropout1(src2) + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + else: + src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + +class GeneEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + x = x + self.pe[: x.size(0)] + return self.dropout(x) + + +class ContinuousValueEncoder(nn.Module): + """ + Encode real number values to a vector using neural nets projection. + """ + + def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + self.linear1 = nn.Linear(1, d_model) + self.activation = nn.ReLU() + self.linear2 = nn.Linear(d_model, d_model) + self.norm = nn.LayerNorm(d_model) + self.max_value = max_value + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, seq_len] + """ + # TODO: test using actual embedding layer if input is categorical + # expand last dimension + x = x.unsqueeze(-1) + # clip x to [-inf, max_value] + x = torch.clamp(x, max=self.max_value) + x = self.activation(self.linear1(x)) + x = self.linear2(x) + x = self.norm(x) + return self.dropout(x) + + +class CategoryValueEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x.long() + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class BatchLabelEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, embsize) + x = self.enc_norm(x) + return x + + +class Similarity(nn.Module): + """ + Dot product or cosine similarity + """ + + def __init__(self, temp): + super().__init__() + self.temp = temp + self.cos = nn.CosineSimilarity(dim=-1) + + def forward(self, x, y): + return self.cos(x, y) / self.temp + + +class ExprDecoder(nn.Module): + def __init__( + self, + d_model: int, + explicit_zero_prob: bool = False, + use_batch_labels: bool = False, + ): + super().__init__() + d_in = d_model * 2 if use_batch_labels else d_model + self.fc = nn.Sequential( + nn.Linear(d_in, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + self.explicit_zero_prob = explicit_zero_prob + if explicit_zero_prob: + self.zero_logit = nn.Sequential( + nn.Linear(d_in, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """x is the output of the transformer, (batch, seq_len, d_model)""" + pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) + + if not self.explicit_zero_prob: + return dict(pred=pred_value) + zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) + zero_probs = torch.sigmoid(zero_logits) + return dict(pred=pred_value, zero_probs=zero_probs) + # TODO: note that the return currently is only for training. Since decoder + # is not used in the test setting for the integration task, the eval/inference + # logic is not implemented yet. However, remember to implement it when + # the decoder is used in any test setting. The inference logic will need + # to sample from the bernoulli distribution with the zero_probs. + + +class ClsDecoder(nn.Module): + """ + Decoder for classification task. + """ + + def __init__( + self, + d_model: int, + n_cls: int, + nlayers: int = 3, + activation: callable = nn.ReLU, + ): + super().__init__() + # module list + self._decoder = nn.ModuleList() + for i in range(nlayers - 1): + self._decoder.append(nn.Linear(d_model, d_model)) + self._decoder.append(activation()) + self._decoder.append(nn.LayerNorm(d_model)) + self.out_layer = nn.Linear(d_model, n_cls) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, embsize] + """ + for layer in self._decoder: + x = layer(x) + return self.out_layer(x) + + +class MVCDecoder(nn.Module): + """ + Decoder for the masked value prediction for cell embeddings. + """ + + def __init__( + self, + d_model: int, + arch_style: str = "inner product", + query_activation: nn.Module = nn.Sigmoid, + hidden_activation: nn.Module = nn.PReLU, + explicit_zero_prob: bool = False, + use_batch_labels: bool = False, + ) -> None: + """ + Args: + d_model (:obj:`int`): dimension of the gene embedding. + arch_style (:obj:`str`): architecture style of the decoder, choice from + 1. "inner product" or 2. "concat query" or 3. "sum query". + query_activation (:obj:`nn.Module`): activation function for the query + vectors. + hidden_activation (:obj:`nn.Module`): activation function for the hidden + layers. + """ + super().__init__() + d_in = d_model * 2 if use_batch_labels else d_model + if arch_style in ["inner product", "inner product, detach"]: + self.gene2query = nn.Linear(d_model, d_model) + self.query_activation = query_activation() + self.W = nn.Linear(d_model, d_in, bias=False) + if explicit_zero_prob: # by default, gene-wise prob rate + self.W_zero_logit = nn.Linear(d_model, d_in) + elif arch_style == "concat query": + self.gene2query = nn.Linear(d_model, 64) + self.query_activation = query_activation() + self.fc1 = nn.Linear(d_model + 64, 64) + self.hidden_activation = hidden_activation() + self.fc2 = nn.Linear(64, 1) + elif arch_style == "sum query": + self.gene2query = nn.Linear(d_model, d_model) + self.query_activation = query_activation() + self.fc1 = nn.Linear(d_model, 64) + self.hidden_activation = hidden_activation() + self.fc2 = nn.Linear(64, 1) + else: + raise ValueError(f"Unknown arch_style: {arch_style}") + + self.arch_style = arch_style + self.do_detach = arch_style.endswith("detach") + self.explicit_zero_prob = explicit_zero_prob + + def forward( + self, cell_emb: Tensor, gene_embs: Tensor + ) -> Union[Tensor, Dict[str, Tensor]]: + """ + Args: + cell_emb: Tensor, shape (batch, embsize=d_model) + gene_embs: Tensor, shape (batch, seq_len, embsize=d_model) + """ + gene_embs = gene_embs.detach() if self.do_detach else gene_embs + if self.arch_style in ["inner product", "inner product, detach"]: + query_vecs = self.query_activation(self.gene2query(gene_embs)) + cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1) + # the pred gene expr values, # (batch, seq_len) + pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) + if not self.explicit_zero_prob: + return dict(pred=pred_value) + # zero logits need to based on the cell_emb, because of input exprs + zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2) + zero_probs = torch.sigmoid(zero_logits) + return dict(pred=pred_value, zero_probs=zero_probs) + elif self.arch_style == "concat query": + query_vecs = self.query_activation(self.gene2query(gene_embs)) + # expand cell_emb to (batch, seq_len, embsize) + cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1) + + h = self.hidden_activation( + self.fc1(torch.cat([cell_emb, query_vecs], dim=2)) + ) + if self.explicit_zero_prob: + raise NotImplementedError + return self.fc2(h).squeeze(2) # (batch, seq_len) + elif self.arch_style == "sum query": + query_vecs = self.query_activation(self.gene2query(gene_embs)) + cell_emb = cell_emb.unsqueeze(1) + + h = self.hidden_activation(self.fc1(cell_emb + query_vecs)) + if self.explicit_zero_prob: + raise NotImplementedError + return self.fc2(h).squeeze(2) # (batch, seq_len) + + +class AdversarialDiscriminator(nn.Module): + """ + Discriminator for the adversarial training for batch correction. + """ + + def __init__( + self, + d_model: int, + n_cls: int, + nlayers: int = 3, + activation: callable = nn.LeakyReLU, + reverse_grad: bool = False, + ): + super().__init__() + # module list + self._decoder = nn.ModuleList() + for i in range(nlayers - 1): + self._decoder.append(nn.Linear(d_model, d_model)) + self._decoder.append(activation()) + self._decoder.append(nn.LayerNorm(d_model)) + self.out_layer = nn.Linear(d_model, n_cls) + self.reverse_grad = reverse_grad + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, embsize] + """ + if self.reverse_grad: + x = grad_reverse(x, lambd=1.0) + for layer in self._decoder: + x = layer(x) + return self.out_layer(x) diff --git a/Singlecell_multi_omics/src/scGPT/model/multiomic_model.py b/Singlecell_multi_omics/src/scGPT/model/multiomic_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5128fac74e1604df8ae75251aab4ff86a2bb93ed --- /dev/null +++ b/Singlecell_multi_omics/src/scGPT/model/multiomic_model.py @@ -0,0 +1,1079 @@ +import gc +import math +from typing import Dict, Mapping, Optional, Tuple, Any, Union + +import torch +import numpy as np +from torch import nn, Tensor +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from torch.distributions import Bernoulli +from tqdm import trange + +try: + from flash_attn.flash_attention import FlashMHA +except ImportError: + import warnings + + warnings.warn("flash_attn is not installed") + +from .dsbn import DomainSpecificBatchNorm1d +from .grad_reverse import grad_reverse + + +class MultiOmicTransformerModel(nn.Module): + def __init__( + self, + ntoken: int, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + nlayers_cls: int = 3, + n_cls: int = 1, + vocab: Any = None, + dropout: float = 0.5, + pad_token: str = "", + pad_value: int = 0, + do_mvc: bool = False, + do_dab: bool = False, + use_batch_labels: bool = False, + num_batch_labels: Optional[int] = None, + domain_spec_batchnorm: Union[bool, str] = False, + input_emb_style: str = "continuous", + n_input_bins: Optional[int] = None, + cell_emb_style: str = "cls", + mvc_decoder_style: str = "inner product", + ecs_threshold: float = 0.3, + explicit_zero_prob: bool = False, + use_fast_transformer: bool = False, + fast_transformer_backend: str = "flash", + pre_norm: bool = False, + use_mod: bool = False, + ntokens_mod: Optional[int] = None, + vocab_mod: Optional[Any] = None, + ): + super().__init__() + self.model_type = "Transformer" + self.d_model = d_model + self.do_dab = do_dab + self.ecs_threshold = ecs_threshold + self.use_batch_labels = use_batch_labels + self.domain_spec_batchnorm = domain_spec_batchnorm + self.input_emb_style = input_emb_style + self.cell_emb_style = cell_emb_style + self.explicit_zero_prob = explicit_zero_prob + self.norm_scheme = "pre" if pre_norm else "post" + self.use_mod = use_mod + + if self.input_emb_style not in ["category", "continuous", "scaling"]: + raise ValueError( + f"input_emb_style should be one of category, continuous, scaling, " + f"got {input_emb_style}" + ) + if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: + raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") + + # TODO: add dropout in the GeneEncoder + self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) + + # Value Encoder, NOTE: the scaling style is also handled in _encode method + if input_emb_style == "continuous": + self.value_encoder = ContinuousValueEncoder(d_model, dropout) + elif input_emb_style == "category": + assert n_input_bins > 0 + self.value_encoder = CategoryValueEncoder( + n_input_bins, d_model, padding_idx=pad_value + ) + else: + self.value_encoder = nn.Identity() # nn.Softmax(dim=1) + # TODO: consider row-wise normalization or softmax + # TODO: Correct handle the mask_value when using scaling + + # Batch Encoder + if use_batch_labels: + self.batch_encoder = BatchLabelEncoder(num_batch_labels, d_model) + + if use_mod: + self.mod_encoder = BatchLabelEncoder( + ntokens_mod, d_model, padding_idx=vocab_mod[pad_token] + ) + + if domain_spec_batchnorm is True or domain_spec_batchnorm == "dsbn": + use_affine = True if domain_spec_batchnorm == "do_affine" else False + print(f"Use domain specific batchnorm with affine={use_affine}") + self.dsbn = DomainSpecificBatchNorm1d( + d_model, num_batch_labels, eps=6.1e-5, affine=use_affine + ) + elif domain_spec_batchnorm == "batchnorm": + print("Using simple batchnorm instead of domain specific batchnorm") + self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) + + if use_fast_transformer: + if fast_transformer_backend == "linear": + self.transformer_encoder = FastTransformerEncoderWrapper( + d_model, nhead, d_hid, nlayers, dropout + ) + elif fast_transformer_backend == "flash": + encoder_layers = FlashTransformerEncoderLayer( + d_model, + nhead, + d_hid, + dropout, + batch_first=True, + norm_scheme=self.norm_scheme, + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + else: + encoder_layers = TransformerEncoderLayer( + d_model, nhead, d_hid, dropout, batch_first=True + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + self.decoder = ExprDecoder( + d_model, + explicit_zero_prob=explicit_zero_prob, + use_batch_labels=use_batch_labels, + use_mod=use_mod, + ) + self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) + if do_mvc: + self.mvc_decoder = MVCDecoder( + d_model, + arch_style=mvc_decoder_style, + explicit_zero_prob=explicit_zero_prob, + use_batch_labels=use_batch_labels, + use_mod=use_mod, + ) + + if do_dab: + self.grad_reverse_discriminator = AdversarialDiscriminator( + d_model, + n_cls=num_batch_labels, + reverse_grad=True, + ) + + self.sim = Similarity(temp=0.5) # TODO: auto set temp + self.creterion_cce = nn.CrossEntropyLoss() + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + # TODO: check if this initialization is helpful and shall we apply to all? + self.encoder.embedding.weight.data.uniform_(-initrange, initrange) + + def _encode( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_labels: Optional[Tensor] = None, # (batch,) + ) -> Tensor: + self._check_batch_labels(batch_labels) + + src = self.encoder(src) # (batch, seq_len, embsize) + self.cur_gene_token_embs = src + + values = self.value_encoder(values) # (batch, seq_len, embsize) + if self.input_emb_style == "scaling": + values = values.unsqueeze(2) + total_embs = src * values + else: + total_embs = src + values + + if getattr(self, "dsbn", None) is not None: + batch_label = int(batch_labels[0].item()) + total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( + 0, 2, 1 + ) # the batch norm always works on dim 1 + elif getattr(self, "bn", None) is not None: + total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + + output = self.transformer_encoder( + total_embs, src_key_padding_mask=src_key_padding_mask + ) + return output # (batch, seq_len, embsize) + + def _get_cell_emb_from_layer( + self, layer_output: Tensor, weights: Tensor = None + ) -> Tensor: + """ + Args: + layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) + weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used + when :attr:`self.cell_emb_style` is "w-pool". + + Returns: + :obj:`Tensor`: shape (batch, embsize) + """ + if self.cell_emb_style == "cls": + cell_emb = layer_output[:, 0, :] # (batch, embsize) + elif self.cell_emb_style == "avg-pool": + cell_emb = torch.mean(layer_output, dim=1) + elif self.cell_emb_style == "w-pool": + if weights is None: + raise ValueError("weights is required when cell_emb_style is w-pool") + if weights.dim() != 2: + raise ValueError("weights should be 2D") + cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) + cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) + + return cell_emb + + def _check_batch_labels(self, batch_labels: Tensor) -> None: + if self.use_batch_labels or self.domain_spec_batchnorm: + assert batch_labels is not None + elif batch_labels is not None: + raise ValueError( + "batch_labels should only be provided when `self.use_batch_labels`" + " or `self.domain_spec_batchnorm` is True" + ) + + def generate( + self, + cell_emb: Tensor, + src: Tensor, + values: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + gen_iters: int = 1, + batch_labels: Optional[Tensor] = None, # (batch,) + ) -> Tensor: + """ + Args: + cell_emb(:obj:`Tensor`): shape (batch, embsize) + src(:obj:`Tensor`): shape (batch, seq_len) + values(:obj:`Tensor`): shape (batch, seq_len), optional + src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional + gen_iters(:obj:`int`): number of generation iterations + batch_labels(:obj:`Tensor`): shape (batch,), optional + """ + # TODO: should have a tag indicate the generation mode + # TODO: if gen_iters > 1, should have a tag indicate the current iteration + try: + self._check_batch_labels(batch_labels) + except: + import warnings + + warnings.warn( + "batch_labels is required but not provided, using zeros instead" + ) + batch_labels = torch.zeros( + cell_emb.shape[0], dtype=torch.long, device=cell_emb.device + ) + + src = self.encoder(src) # (batch, seq_len, embsize) + + if values is not None: + values = self.value_encoder(values) # (batch, seq_len, embsize) + if self.input_emb_style == "scaling": + values = values.unsqueeze(2) + total_embs = src * values + else: + total_embs = src + values + else: + total_embs = src + + if getattr(self, "dsbn", None) is not None: + batch_label = int(batch_labels[0].item()) + total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( + 0, 2, 1 + ) # the batch norm always works on dim 1 + elif getattr(self, "bn", None) is not None: + total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) + + total_embs[:, 0, :] = cell_emb + + if src_key_padding_mask is None: + src_key_padding_mask = torch.zeros( + total_embs.shape[:2], dtype=torch.bool, device=total_embs.device + ) + transformer_output = self.transformer_encoder( + total_embs, src_key_padding_mask=src_key_padding_mask + ) + + if self.use_batch_labels: + batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) + mlm_output = self.decoder( + transformer_output + if not self.use_batch_labels + else torch.cat( + [ + transformer_output, + batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), + ], + dim=2, + ), + # else transformer_output + batch_emb.unsqueeze(1), + ) + output = mlm_output["pred"] # (batch, seq_len) + + return output # (batch, seq_len) + + def forward( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_labels: Optional[Tensor] = None, + CLS: bool = False, + CCE: bool = False, + MVC: bool = False, + ECS: bool = False, + do_sample: bool = False, + mod_types: Optional[Tensor] = None, + ) -> Mapping[str, Tensor]: + """ + Args: + src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] + values (:obj:`Tensor`): token values, shape [batch_size, seq_len] + src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, + seq_len] + batch_labels (:obj:`Tensor`): batch labels, shape [batch_size] + CLS (:obj:`bool`): if True, return the celltype classification objective + (CLS) output + CCE (:obj:`bool`): if True, return the contrastive cell embedding objective + (CCE) output + MVC (:obj:`bool`): if True, return the masked value prediction for cell + embedding MVC output + ECS (:obj:`bool`): if True, return the elastic cell similarity objective + (ECS) output. + do_sample (:obj:`bool`): if True, sample from the output distribution + and apply to the output. + mod_types (:obj:`Tensor`): shape [batch_size, seq_len], optional, only + used when `self.use_mod` is True. The token types for the tokens. + + Returns: + dict of output Tensors. + """ + transformer_output = self._encode( + src, values, src_key_padding_mask, batch_labels + ) + if self.use_batch_labels: + batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) + + if self.use_mod: + mod_emb = self.mod_encoder(mod_types) + + output = {} + + if self.use_batch_labels and self.use_mod: + cat_0 = ( + batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1) + + mod_emb + ) + elif self.use_batch_labels and not self.use_mod: + cat_0 = batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1) + elif self.use_mod and not self.use_batch_labels: + cat_0 = mod_emb + else: + cat_0 = None + + mlm_output = self.decoder( + transformer_output + if cat_0 is None + else torch.cat( + [transformer_output, cat_0], + dim=2, + ), + ) + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) + output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] + else: + output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mlm_zero_probs"] = mlm_output["zero_probs"] + + cell_emb = self._get_cell_emb_from_layer(transformer_output, values) + output["cell_emb"] = cell_emb + + if CLS: + output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) + if CCE: + cell1 = cell_emb + transformer_output2 = self._encode( + src, values, src_key_padding_mask, batch_labels + ) + cell2 = self._get_cell_emb_from_layer(transformer_output2) + + # Gather embeddings from all devices if distributed training + if dist.is_initialized() and self.training: + cls1_list = [ + torch.zeros_like(cell1) for _ in range(dist.get_world_size()) + ] + cls2_list = [ + torch.zeros_like(cell2) for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list=cls1_list, tensor=cell1.contiguous()) + dist.all_gather(tensor_list=cls2_list, tensor=cell2.contiguous()) + + # NOTE: all_gather results have no gradients, so replace the item + # of the current rank with the original tensor to keep gradients. + # See https://github.com/princeton-nlp/SimCSE/blob/main/simcse/models.py#L186 + cls1_list[dist.get_rank()] = cell1 + cls2_list[dist.get_rank()] = cell2 + + cell1 = torch.cat(cls1_list, dim=0) + cell2 = torch.cat(cls2_list, dim=0) + # TODO: should detach the second run cls2? Can have a try + cos_sim = self.sim(cell1.unsqueeze(1), cell2.unsqueeze(0)) # (batch, batch) + labels = torch.arange(cos_sim.size(0)).long().to(cell1.device) + output["loss_cce"] = self.creterion_cce(cos_sim, labels) + + if MVC: + if self.use_batch_labels and self.use_mod: + cat_1 = batch_emb + self._get_cell_emb_from_layer(mod_emb) + cat_2 = ( + batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1) + + mod_emb + ) + elif self.use_batch_labels and not self.use_mod: + cat_1 = batch_emb + cat_2 = batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1) + elif self.use_mod and not self.use_batch_labels: + cat_1 = self._get_cell_emb_from_layer(mod_emb) + cat_2 = mod_emb + else: + cat_1 = None + cat_2 = None + + mvc_output = self.mvc_decoder( + cell_emb if cat_1 is None else torch.cat([cell_emb, cat_1], dim=1), + self.cur_gene_token_embs + if cat_2 is None + else torch.cat([self.cur_gene_token_embs, cat_2], dim=2), + ) + + if self.explicit_zero_prob and do_sample: + bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) + output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] + else: + output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) + if self.explicit_zero_prob: + output["mvc_zero_probs"] = mvc_output["zero_probs"] + if ECS: + # Here using customized cosine similarity instead of F.cosine_similarity + # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 + # normalize the embedding + cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) + cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) + + # mask out diagnal elements + mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) + cos_sim = cos_sim.masked_fill(mask, 0.0) + # only optimize positive similarities + cos_sim = F.relu(cos_sim) + + output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) + + if self.do_dab: + output["dab_output"] = self.grad_reverse_discriminator(cell_emb) + + return output + + def encode_batch( + self, + src: Tensor, + values: Tensor, + src_key_padding_mask: Tensor, + batch_size: int, + batch_labels: Optional[Tensor] = None, + output_to_cpu: bool = True, + time_step: Optional[int] = None, + return_np: bool = False, + ) -> Tensor: + """ + Args: + src (Tensor): shape [N, seq_len] + values (Tensor): shape [N, seq_len] + src_key_padding_mask (Tensor): shape [N, seq_len] + batch_size (int): batch size for encoding + batch_labels (Tensor): shape [N, n_batch_labels] + output_to_cpu (bool): whether to move the output to cpu + time_step (int): the time step index in the transformer output to return. + The time step is along the second dimenstion. If None, return all. + return_np (bool): whether to return numpy array + + Returns: + output Tensor of shape [N, seq_len, embsize] + """ + N = src.size(0) + device = next(self.parameters()).device + + # initialize the output tensor + array_func = np.zeros if return_np else torch.zeros + float32_ = np.float32 if return_np else torch.float32 + shape = ( + (N, self.d_model) + if time_step is not None + else (N, src.size(1), self.d_model) + ) + outputs = array_func(shape, dtype=float32_) + + for i in trange(0, N, batch_size): + raw_output = self._encode( + src[i : i + batch_size].to(device), + values[i : i + batch_size].to(device), + src_key_padding_mask[i : i + batch_size].to(device), + batch_labels[i : i + batch_size].to(device) + if batch_labels is not None + else None, + ) + output = raw_output.detach() + if output_to_cpu: + output = output.cpu() + if return_np: + output = output.numpy() + if time_step is not None: + output = output[:, time_step, :] + outputs[i : i + batch_size] = output + + return outputs + + +def generate_square_subsequent_mask(sz: int) -> Tensor: + """Generates an upper-triangular matrix of -inf, with zeros on diag.""" + return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) + + +class FastTransformerEncoderWrapper(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + d_hid: int, + nlayers: int, + dropout: float = 0.5, + ): + super().__init__() + self.fast_transformer_encoder = self.build_fast_transformer_encoder( + d_model, nhead, d_hid, nlayers, dropout + ) + + @staticmethod + def build_fast_transformer_encoder( + d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float + ) -> nn.Module: + from fast_transformers.builders import TransformerEncoderBuilder + + if d_model % nhead != 0: + raise ValueError( + f"d_model must be divisible by nhead, " + f"got d_model={d_model} and nhead={nhead}" + ) + builder = TransformerEncoderBuilder.from_kwargs( + n_layers=nlayers, + n_heads=nhead, + query_dimensions=d_model // nhead, + value_dimensions=d_model // nhead, + feed_forward_dimensions=d_hid, + attention_type="linear", + attention_dropout=dropout, + dropout=dropout, + activation="gelu", + ) + assert builder.attention_type == "linear" + return builder.get() + + @staticmethod + def build_length_mask( + src: Tensor, + src_key_padding_mask: torch.BoolTensor, + ) -> "LengthMask": + from fast_transformers.masking import LengthMask + + seq_len = src.shape[1] + num_paddings = src_key_padding_mask.sum(dim=1) + actual_seq_len = seq_len - num_paddings # (N,) + length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device) + + if src_key_padding_mask[length_mask.bool_matrix].sum() != 0: + raise ValueError( + "Found padding tokens in the middle of the sequence. " + "src_key_padding_mask and length_mask are not compatible." + ) + return length_mask + + def forward( + self, + src: Tensor, + src_key_padding_mask: torch.BoolTensor, + ) -> Tensor: + """ + Args: + src: Tensor, shape [N, seq_len, embsize] + src_key_padding_mask: Tensor, shape [N, seq_len] + + Returns: + output Tensor of shape [N, seq_len, embsize] + """ + if src_key_padding_mask.shape != src.shape[:2]: + raise ValueError( + f"src_key_padding_mask shape {src_key_padding_mask.shape} " + f"does not match first two dims of src shape {src.shape[:2]}" + ) + + if src_key_padding_mask.dtype != torch.bool: + raise ValueError( + f"src_key_padding_mask needs to be of type torch.bool, " + f"got {src_key_padding_mask.dtype}" + ) + + length_mask = self.build_length_mask(src, src_key_padding_mask) + output = self.fast_transformer_encoder(src, length_mask=length_mask) + return output + + +class FlashTransformerEncoderLayer(nn.Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + The class is modified from torch.nn.TransformerEncoderLayer to support the + FlashAttention. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False``. + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + + Alternatively, when ``batch_first`` is ``True``: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + """ + __constants__ = ["batch_first"] + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + layer_norm_eps=1e-5, + batch_first=True, + device=None, + dtype=None, + norm_scheme="post", # "pre" or "post" + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.self_attn = FlashMHA( + embed_dim=d_model, + num_heads=nhead, + batch_first=batch_first, + attention_dropout=dropout, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = self._get_activation_fn(activation) + self.norm_scheme = norm_scheme + if self.norm_scheme not in ["pre", "post"]: + raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") + + @staticmethod + def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super().__setstate__(state) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if src_mask is not None: + raise ValueError("FlashTransformerEncoderLayer does not support src_mask") + + if not src_key_padding_mask.any().item(): + # no padding tokens in src + src_key_padding_mask_ = None + else: + if src_key_padding_mask.dtype != torch.bool: + src_key_padding_mask = src_key_padding_mask.bool() + # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite + src_key_padding_mask_ = ~src_key_padding_mask + + if self.norm_scheme == "pre": + src = self.norm1(src) + src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] + src = src + self.dropout1(src2) + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + else: + src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + +class GeneEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + x = x + self.pe[: x.size(0)] + return self.dropout(x) + + +class ContinuousValueEncoder(nn.Module): + """ + Encode real number values to a vector using neural nets projection. + """ + + def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + self.linear1 = nn.Linear(1, d_model) + self.activation = nn.ReLU() + self.linear2 = nn.Linear(d_model, d_model) + self.norm = nn.LayerNorm(d_model) + self.max_value = max_value + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, seq_len] + """ + # TODO: test using actual embedding layer if input is categorical + # expand last dimension + x = x.unsqueeze(-1) + # clip x to [-inf, max_value] + x = torch.clamp(x, max=self.max_value) + x = self.activation(self.linear1(x)) + x = self.linear2(x) + x = self.norm(x) + return self.dropout(x) + + +class CategoryValueEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x.long() + x = self.embedding(x) # (batch, seq_len, embsize) + x = self.enc_norm(x) + return x + + +class BatchLabelEncoder(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + ): + super().__init__() + self.embedding = nn.Embedding( + num_embeddings, embedding_dim, padding_idx=padding_idx + ) + self.enc_norm = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self.embedding(x) # (batch, embsize) + x = self.enc_norm(x) + return x + + +class Similarity(nn.Module): + """ + Dot product or cosine similarity + """ + + def __init__(self, temp): + super().__init__() + self.temp = temp + self.cos = nn.CosineSimilarity(dim=-1) + + def forward(self, x, y): + return self.cos(x, y) / self.temp + + +class ExprDecoder(nn.Module): + def __init__( + self, + d_model: int, + explicit_zero_prob: bool = False, + use_batch_labels: bool = False, + use_mod: bool = False, + ): + super().__init__() + d_in = d_model * 2 if use_batch_labels or use_mod else d_model + self.fc = nn.Sequential( + nn.Linear(d_in, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + self.explicit_zero_prob = explicit_zero_prob + if explicit_zero_prob: + self.zero_logit = nn.Sequential( + nn.Linear(d_in, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """x is the output of the transformer, (batch, seq_len, d_model)""" + pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) + + if not self.explicit_zero_prob: + return dict(pred=pred_value) + zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) + zero_probs = torch.sigmoid(zero_logits) + return dict(pred=pred_value, zero_probs=zero_probs) + # TODO: note that the return currently is only for training. Since decoder + # is not used in the test setting for the integration task, the eval/inference + # logic is not implemented yet. However, remember to implement it when + # the decoder is used in any test setting. The inference logic will need + # to sample from the bernoulli distribution with the zero_probs. + + +class ClsDecoder(nn.Module): + """ + Decoder for classification task. + """ + + def __init__( + self, + d_model: int, + n_cls: int, + nlayers: int = 3, + activation: callable = nn.ReLU, + ): + super().__init__() + # module list + self._decoder = nn.ModuleList() + for i in range(nlayers - 1): + self._decoder.append(nn.Linear(d_model, d_model)) + self._decoder.append(activation()) + self._decoder.append(nn.LayerNorm(d_model)) + self.out_layer = nn.Linear(d_model, n_cls) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, embsize] + """ + for layer in self._decoder: + x = layer(x) + return self.out_layer(x) + + +class MVCDecoder(nn.Module): + """ + Decoder for the masked value prediction for cell embeddings. + """ + + def __init__( + self, + d_model: int, + arch_style: str = "inner product", + query_activation: nn.Module = nn.Sigmoid, + hidden_activation: nn.Module = nn.PReLU, + explicit_zero_prob: bool = False, + use_batch_labels: bool = False, + use_mod: bool = False, + ) -> None: + """ + Args: + d_model (:obj:`int`): dimension of the gene embedding. + arch_style (:obj:`str`): architecture style of the decoder, choice from + 1. "inner product" or 2. "concat query" or 3. "sum query". + query_activation (:obj:`nn.Module`): activation function for the query + vectors. + hidden_activation (:obj:`nn.Module`): activation function for the hidden + layers. + """ + super().__init__() + d_in = d_model * 2 if use_batch_labels or use_mod else d_model + d_model = d_model * 2 if use_batch_labels or use_mod else d_model + if arch_style in ["inner product", "inner product, detach"]: + self.gene2query = nn.Linear(d_model, d_model) + self.query_activation = query_activation() + self.W = nn.Linear(d_model, d_in, bias=False) + if explicit_zero_prob: # by default, gene-wise prob rate + self.W_zero_logit = nn.Linear(d_model, d_in) + elif arch_style == "concat query": + self.gene2query = nn.Linear(d_model, 64) + self.query_activation = query_activation() + self.fc1 = nn.Linear(d_model + 64, 64) + self.hidden_activation = hidden_activation() + self.fc2 = nn.Linear(64, 1) + elif arch_style == "sum query": + self.gene2query = nn.Linear(d_model, d_model) + self.query_activation = query_activation() + self.fc1 = nn.Linear(d_model, 64) + self.hidden_activation = hidden_activation() + self.fc2 = nn.Linear(64, 1) + else: + raise ValueError(f"Unknown arch_style: {arch_style}") + + self.arch_style = arch_style + self.do_detach = arch_style.endswith("detach") + self.explicit_zero_prob = explicit_zero_prob + + def forward( + self, cell_emb: Tensor, gene_embs: Tensor + ) -> Union[Tensor, Dict[str, Tensor]]: + """ + Args: + cell_emb: Tensor, shape (batch, embsize=d_model) + gene_embs: Tensor, shape (batch, seq_len, embsize=d_model) + """ + gene_embs = gene_embs.detach() if self.do_detach else gene_embs + if self.arch_style in ["inner product", "inner product, detach"]: + query_vecs = self.query_activation(self.gene2query(gene_embs)) + cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1) + # the pred gene expr values, # (batch, seq_len) + pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) + if not self.explicit_zero_prob: + return dict(pred=pred_value) + # zero logits need to based on the cell_emb, because of input exprs + zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2) + zero_probs = torch.sigmoid(zero_logits) + return dict(pred=pred_value, zero_probs=zero_probs) + elif self.arch_style == "concat query": + query_vecs = self.query_activation(self.gene2query(gene_embs)) + # expand cell_emb to (batch, seq_len, embsize) + cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1) + + h = self.hidden_activation( + self.fc1(torch.cat([cell_emb, query_vecs], dim=2)) + ) + if self.explicit_zero_prob: + raise NotImplementedError + return self.fc2(h).squeeze(2) # (batch, seq_len) + elif self.arch_style == "sum query": + query_vecs = self.query_activation(self.gene2query(gene_embs)) + cell_emb = cell_emb.unsqueeze(1) + + h = self.hidden_activation(self.fc1(cell_emb + query_vecs)) + if self.explicit_zero_prob: + raise NotImplementedError + return self.fc2(h).squeeze(2) # (batch, seq_len) + + +class AdversarialDiscriminator(nn.Module): + """ + Discriminator for the adversarial training for batch correction. + """ + + def __init__( + self, + d_model: int, + n_cls: int, + nlayers: int = 3, + activation: callable = nn.LeakyReLU, + reverse_grad: bool = False, + ): + super().__init__() + # module list + self._decoder = nn.ModuleList() + for i in range(nlayers - 1): + self._decoder.append(nn.Linear(d_model, d_model)) + self._decoder.append(activation()) + self._decoder.append(nn.LayerNorm(d_model)) + self.out_layer = nn.Linear(d_model, n_cls) + self.reverse_grad = reverse_grad + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor, shape [batch_size, embsize] + """ + if self.reverse_grad: + x = grad_reverse(x, lambd=1.0) + for layer in self._decoder: + x = layer(x) + return self.out_layer(x)