diff --git a/dynamo/estimation/csc/velocity.py b/dynamo/estimation/csc/velocity.py index 8588836fa..d1bfc8c8c 100755 --- a/dynamo/estimation/csc/velocity.py +++ b/dynamo/estimation/csc/velocity.py @@ -2,6 +2,7 @@ from multiprocessing.dummy import Pool as ThreadPool from warnings import warn +import numpy as np from scipy.sparse import csr_matrix from tqdm import tqdm @@ -17,6 +18,7 @@ ) from .utils_velocity import * + # from sklearn.cluster import KMeans # from sklearn.neighbors import NearestNeighbors @@ -141,6 +143,8 @@ def vel_u(self, U, repeat=None, update_alpha=True): if self.parameters["beta"].ndim == 1: beta = np.repeat(self.parameters["beta"].reshape((-1, 1)), U.shape[1], axis=1) + elif self.parameters["beta"].shape[1] == U.shape[1]: # to support cell-wise beta + beta = self.parameters["beta"] elif self.parameters["beta"].shape[1] == len(t_uniq) and len(t_uniq) > 1: beta = np.zeros_like(U.shape) for i in range(len(t_uniq)): @@ -404,6 +408,9 @@ def __init__( P=None, US=None, S2=None, + NewCounts=None, + TotalCounts=None, + NewSmoothCSP=None, conn=None, t=None, ind_for_proteins=None, @@ -426,6 +433,9 @@ def __init__( "p": P, "us": US, "s2": S2, + "new_counts": NewCounts, + "total_counts": TotalCounts, + "new_smooth_csp": NewSmoothCSP, } # U: (unlabeled) unspliced; S: (unlabeled) spliced; U / Ul: old and labeled; U, Ul, S, Sl: uu/ul/su/sl if concat_data: self.concatenate_data() @@ -1333,94 +1343,135 @@ def fit( bf, ) elif np.all(self._exist_data("uu", "ul")): - k, k_intercept, k_r2, k_logLL, bs, bf = ( - np.zeros(n_genes), - np.zeros(n_genes), - np.zeros(n_genes), - np.zeros(n_genes), - np.zeros(n_genes), - np.zeros(n_genes), - ) - U = self.data["ul"] - S = self.data["ul"] + self.data["uu"] - US = ( - self.data["us"] - if self.data["us"] is not None - else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T - ) - S2 = ( - self.data["s2"] - if self.data["s2"] is not None - else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T - ) - if cores == 1: - for i in tqdm(range(n_genes), desc="estimating gamma"): + if one_shot_method == "storm-csp": + gamma, gamma_r2, k = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + new_counts = self.data["new_counts"] + total_counts = self.data["total_counts"] + new_smooth_csp = self.data["new_smooth_csp"] + new_smooth = self.data['ul'] + total_smooth = self.data["ul"] + self.data["uu"] + for i in tqdm(range(n_genes), desc="estimating gamma via storm's csp model"): ( + gamma[i], + gamma_r2[i], k[i], - k_intercept[i], - _, - k_r2[i], - _, - k_logLL[i], - bs[i], - bf[i], - ) = self.fit_gamma_stochastic( - self.est_method, - U[i], - S[i], - US[i], - S2[i], + ) = self.fit_gamma_storm_csp( + new_counts[i], + total_counts[i], + new_smooth[i], + total_smooth[i], + t_uniq=t_uniq, perc_left=perc_left, perc_right=perc_right, normalize=True, ) + _, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, new_smooth_csp) + ( + self.parameters["alpha"], + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + ) = ( + alpha, + gamma, + k, + np.zeros(n_genes), + gamma_r2, + ) else: - pool = ThreadPool(cores) - res = pool.starmap( - self.fit_gamma_stochastic, - zip( - itertools.repeat(self.est_method), - U, - S, - US, - S2, - itertools.repeat(perc_left), - itertools.repeat(perc_right), - itertools.repeat(True), - ), + k, k_intercept, k_r2, k_logLL, bs, bf = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), ) - pool.close() - pool.join() - (k, k_intercept, _, k_r2, _, k_logLL, bs, bf) = zip(*res) - (k, k_intercept, k_r2, k_logLL, bs, bf) = ( - np.array(k), - np.array(k_intercept), - np.array(k_r2), - np.array(k_logLL), - np.array(bs), - np.array(bf), + U = self.data["ul"] + S = self.data["ul"] + self.data["uu"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T ) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + k[i], + k_intercept[i], + _, + k_r2[i], + _, + k_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + (k, k_intercept, _, k_r2, _, k_logLL, bs, bf) = zip(*res) + (k, k_intercept, k_r2, k_logLL, bs, bf) = ( + np.array(k), + np.array(k_intercept), + np.array(k_r2), + np.array(k_logLL), + np.array(bs), + np.array(bf), + ) - gamma, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, U) - ( - self.parameters["alpha"], - self.parameters["gamma"], - self.aux_param["gamma_k"], - self.aux_param["gamma_intercept"], - self.aux_param["gamma_r2"], - self.aux_param["gamma_logLL"], - self.aux_param["bs"], - self.aux_param["bf"], - ) = ( - alpha, - gamma, - k, - k_intercept, - k_r2, - k_logLL, - bs, - bf, - ) + gamma, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, U) + ( + self.parameters["alpha"], + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["bs"], + self.aux_param["bf"], + ) = ( + alpha, + gamma, + k, + k_intercept, + k_r2, + k_logLL, + bs, + bf, + ) elif self.extyp.lower() == "mix_std_stm": t_min, t_max = np.min(self.t), np.max(self.t) if np.all(self._exist_data("ul", "uu", "su")): @@ -1468,7 +1519,7 @@ def fit( n_genes = self.data["uu"].shape[0] # self.get_n_genes(data=U) gamma, U = np.zeros(n_genes), np.zeros(n_genes) for i in tqdm( - range(n_genes), desc="solving gamma, alpha" + range(n_genes), desc="solving gamma, alpha" ): # apply sci-fate like approach (can also use one-single time point to estimate gamma) # tmp = self.data['uu'][i, self.t == 0] + self.data['ul'][i, self.t == 0] tmp_ = self.data["uu"][i, self.t == t_max] + self.data["ul"][i, self.t == t_max] @@ -1552,65 +1603,1152 @@ def fit( _, # self.aux_param["delta_logLL"], ) = (delta, delta_intercept, delta_r2, delta_logLL) - def fit_gamma_steady_state(self, u, s, intercept=True, perc_left=None, perc_right=5, normalize=True): - """Estimate gamma using linear regression based on the steady state assumption. - - Arguments - --------- - u: :class:`~numpy.ndarray` or sparse `csr_matrix` - A matrix of unspliced mRNA counts. Dimension: genes x cells. - s: :class:`~numpy.ndarray` or sparse `csr_matrix` - A matrix of spliced mRNA counts. Dimension: genes x cells. - intercept: bool - If using steady state assumption for fitting, then: - True -- the linear regression is performed with an unfixed intercept; - False -- the linear regresssion is performed with a fixed zero intercept. - perc_left: float - The percentage of samples included in the linear regression in the left tail. If set to None, then all the - left samples are excluded. - perc_right: float - The percentage of samples included in the linear regression in the right tail. If set to None, then all the - samples are included. - normalize: bool - Whether to first normalize the data. - - Returns - ------- - k: float - The slope of the linear regression model, which is gamma under the steady state assumption. - b: float - The intercept of the linear regression model. - r2: float - Coefficient of determination or r square for the extreme data points. - r2: float - Coefficient of determination or r square for the extreme data points. - all_r2: float - Coefficient of determination or r square for all data points. - """ - if intercept and perc_left is None: - perc_left = perc_right - u = u.A.flatten() if issparse(u) else u.flatten() - s = s.A.flatten() if issparse(s) else s.flatten() - - mask = find_extreme( - s, - u, - normalize=normalize, - perc_left=perc_left, - perc_right=perc_right, - ) + def fit_protein(self, intercept, perc_left, perc_right, cores): + """Fit the input data to estimate parameters for protein.""" + if np.all(self._exist_data("p", "su")): + ind_for_proteins = self.ind_for_proteins + n_genes = len(ind_for_proteins) if ind_for_proteins is not None else 0 - if self.est_method.lower() == "ols": - k, b, r2, all_r2 = fit_linreg(s, u, mask, intercept) - else: - k, b, r2, all_r2 = fit_linreg_robust(s, u, mask, intercept, self.est_method) + if self.asspt_prot.lower() == "ss" and n_genes > 0: + self.parameters["eta"] = np.ones(n_genes) + (delta, delta_intercept, delta_r2, delta_logLL,) = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) - logLL, all_logLL = ( - calc_norm_loglikelihood(s[mask], u[mask], k), - calc_norm_loglikelihood(s, u, k), - ) + s = ( + self.data["su"][ind_for_proteins] + self.data["sl"][ind_for_proteins] + if self._exist_data("sl") + else self.data["su"][ind_for_proteins] + ) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating delta"): + ( + delta[i], + delta_intercept[i], + _, + delta_r2[i], + _, + delta_logLL[i], + ) = self.fit_gamma_steady_state( + s[i], + self.data["p"][i], + intercept, + perc_left, + perc_right, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_steady_state, + zip( + s, + self.data["p"], + itertools.repeat(intercept), + itertools.repeat(perc_left), + itertools.repeat(perc_right), + ), + ) + pool.close() + pool.join() + (delta, delta_intercept, _, delta_r2, _, delta_logLL) = zip(*res) + (delta, delta_intercept, delta_r2, delta_logLL) = ( + np.array(delta), + np.array(delta_intercept), + np.array(delta_r2), + np.array(delta_logLL), + ) + ( + self.parameters["delta"], + self.aux_param["delta_intercept"], + self.aux_param["delta_r2"], + _, # self.aux_param["delta_logLL"], + ) = (delta, delta_intercept, delta_r2, delta_logLL) - return k, b, r2, all_r2, logLL, all_logLL + def fit_conventional_deterministic( + self, + intercept=False, + perc_left=None, + perc_right=5, + ): + """Fit the input data to estimate parameters for conventional experiment type and steady-state kinetics + experiment type with deterministic model.""" + n_genes = self.get_n_genes() + cores = max(1, int(self.cores)) + if np.all(self._exist_data("uu", "su")): + self.parameters["beta"] = np.ones(n_genes) + gamma, gamma_intercept, gamma_r2, gamma_logLL = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["uu"] if self.data["ul"] is None else self.data["uu"] + self.data["ul"] + S = self.data["su"] if self.data["sl"] is None else self.data["su"] + self.data["sl"] + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + gamma[i], + gamma_intercept[i], + _, + gamma_r2[i], + _, + gamma_logLL[i], + ) = self.fit_gamma_steady_state(U[i], S[i], intercept, perc_left, perc_right) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_steady_state, + zip( + U, + S, + itertools.repeat(intercept), + itertools.repeat(perc_left), + itertools.repeat(perc_right), + ), + ) + pool.close() + pool.join() + ( + gamma, + gamma_intercept, + _, + gamma_r2, + _, + gamma_logLL, + ) = zip(*res) + (gamma, gamma_intercept, gamma_r2, gamma_logLL) = ( + np.array(gamma), + np.array(gamma_intercept), + np.array(gamma_r2), + np.array(gamma_logLL), + ) + ( + self.parameters["gamma"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + ) = (gamma, gamma_intercept, gamma_r2, gamma_logLL) + elif np.all(self._exist_data("uu", "ul")): + self.parameters["beta"] = np.ones(n_genes) + gamma, gamma_intercept, gamma_r2, gamma_logLL = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["ul"] + S = self.data["uu"] + self.data["ul"] + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + gamma[i], + gamma_intercept[i], + _, + gamma_r2[i], + _, + gamma_logLL[i], + ) = self.fit_gamma_steady_state(U[i], S[i], intercept, perc_left, perc_right) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_steady_state, + zip( + U, + S, + itertools.repeat(intercept), + itertools.repeat(perc_left), + itertools.repeat(perc_right), + ), + ) + pool.close() + pool.join() + ( + gamma, + gamma_intercept, + _, + gamma_r2, + _, + gamma_logLL, + ) = zip(*res) + (gamma, gamma_intercept, gamma_r2, gamma_logLL) = ( + np.array(gamma), + np.array(gamma_intercept), + np.array(gamma_r2), + np.array(gamma_logLL), + ) + ( + self.parameters["gamma"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + ) = (gamma, gamma_intercept, gamma_r2, gamma_logLL) + self.fit_protein(intercept=intercept, perc_left=perc_left, perc_right=perc_right, cores=cores) + + def fit_conventional_stochastic( + self, + intercept=False, + perc_left=None, + perc_right=5, + ): + """Fit the input data to estimate parameters for conventional experiment type and steady-state kinetics + experiment type with stochastic model.""" + n_genes = self.get_n_genes() + cores = max(1, int(self.cores)) + if np.all(self._exist_data("uu", "su")): + self.parameters["beta"] = np.ones(n_genes) + gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["uu"] if self.data["ul"] is None else self.data["uu"] + self.data["ul"] + S = self.data["su"] if self.data["sl"] is None else self.data["su"] + self.data["sl"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T + ) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + gamma[i], + gamma_intercept[i], + _, + gamma_r2[i], + _, + gamma_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + ( + gamma, + gamma_intercept, + _, + gamma_r2, + _, + gamma_logLL, + bs, + bf, + ) = zip(*res) + (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( + np.array(gamma), + np.array(gamma_intercept), + np.array(gamma_r2), + np.array(gamma_logLL), + np.array(bs), + np.array(bf), + ) + ( + self.parameters["gamma"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["bs"], + self.aux_param["bf"], + ) = (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf) + elif np.all(self._exist_data("uu", "ul")): + self.parameters["beta"] = np.ones(n_genes) + gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["ul"] + S = self.data["uu"] + self.data["ul"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T + ) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + gamma[i], + gamma_intercept[i], + _, + gamma_r2[i], + _, + gamma_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + ( + gamma, + gamma_intercept, + _, + gamma_r2, + _, + gamma_logLL, + bs, + bf, + ) = zip(*res) + (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf,) = ( + np.array(gamma), + np.array(gamma_intercept), + np.array(gamma_r2), + np.array(gamma_logLL), + np.array(bs), + np.array(bf), + ) + + ( + self.parameters["gamma"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["bs"], + self.aux_param["bf"], + ) = (gamma, gamma_intercept, gamma_r2, gamma_logLL, bs, bf) + + self.fit_protein(intercept=intercept, perc_left=perc_left, perc_right=perc_right, cores=cores) + + def fit_oneshot( + self, + intercept=False, + perc_left=None, + perc_right=5, + clusters=None, + one_shot_method="combined", + ): + """Fit the input data to estimate parameters for one-shot experiment type.""" + n_genes = self.get_n_genes() + cores = max(1, int(self.cores)) + if len(np.unique(self.t)) > 1: + if np.all(self._exist_data("ul", "uu", "su")): + if not self._exist_parameter("beta"): + warn("beta & gamma estimation: only works when there're at least 2 time points.") + uu_m, uu_v, t_uniq = calc_12_mom_labeling(self.data["uu"], self.t) + su_m, su_v, _ = calc_12_mom_labeling(self.data["su"], self.t) + + ( + self.parameters["beta"], + self.parameters["gamma"], + self.aux_param["uu0"], + self.aux_param["su0"], + ) = self.fit_beta_gamma_lsq(t_uniq, uu_m, su_m) + # alpha estimation + ul_m, ul_v, t_uniq = calc_12_mom_labeling(self.data["ul"], self.t) + alpha = np.zeros(n_genes) + # let us only assume one alpha for each gene in all cells + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating alpha"): + # for j in range(len(self.data['ul'][i])): + alpha[i] = fit_alpha_synthesis(t_uniq, ul_m[i], self.parameters["beta"][i]) + else: + pool = ThreadPool(cores) + alpha = pool.starmap( + fit_alpha_synthesis, + zip( + itertools.repeat(t_uniq), + ul_m, + self.parameters["beta"], + ), + ) + pool.close() + pool.join() + alpha = np.array(alpha) + self.parameters["alpha"] = alpha + elif np.all(self._exist_data("ul", "uu")): + n_genes = self.data["uu"].shape[0] # self.get_n_genes(data=U) + u0, gamma = np.zeros(n_genes), np.zeros(n_genes) + uu_m, uu_v, t_uniq = calc_12_mom_labeling(self.data["uu"], self.t) + for i in tqdm(range(n_genes), desc="estimating gamma"): + try: + gamma[i], u0[i] = fit_first_order_deg_lsq(t_uniq, uu_m[i]) + except: + gamma[i], u0[i] = 0, 0 + self.parameters["gamma"], self.aux_param["uu0"] = gamma, u0 + alpha = np.zeros(n_genes) + # let us only assume one alpha for each gene in all cells + ul_m, ul_v, _ = calc_12_mom_labeling(self.data["ul"], self.t) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + # for j in range(len(self.data['ul'][i])): + alpha[i] = fit_alpha_synthesis(t_uniq, ul_m[i], self.parameters["gamma"][i]) + else: + pool = ThreadPool(cores) + alpha = pool.starmap( + fit_alpha_synthesis, + zip( + itertools.repeat(t_uniq), + ul_m, + self.parameters["gamma"], + ), + ) + pool.close() + pool.join() + alpha = np.array(alpha) + self.parameters["alpha"] = alpha + # alpha: one-shot + # 'one_shot' + else: + t_uniq = np.unique(self.t) + if len(t_uniq) > 1: + raise Exception( + "By definition, one-shot experiment should involve only one time point measurement!" + ) + # calculate when having splicing or no splicing + if self.model.lower() == "deterministic": + if np.all(self._exist_data("ul", "uu", "su")): + if self._exist_parameter("beta", "gamma").all(): + self.parameters["alpha"] = self.fit_alpha_oneshot( + self.t, + self.data["ul"], + self.parameters["beta"], + clusters, + ) + else: + beta, gamma, U0, S0 = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + for i in range( + n_genes + ): # can also use the two extreme time points and apply sci-fate like approach. + S, U = ( + self.data["su"][i] + self.data["sl"][i], + self.data["uu"][i] + self.data["ul"][i], + ) + + S0[i], gamma[i] = ( + np.mean(S), + solve_gamma(np.max(self.t), self.data["su"][i], S), + ) + U0[i], beta[i] = ( + np.mean(U), + solve_gamma(np.max(self.t), self.data["uu"][i], U), + ) + ( + self.aux_param["U0"], + self.aux_param["S0"], + self.parameters["beta"], + self.parameters["gamma"], + ) = (U0, S0, beta, gamma) + + ul_m, ul_v, t_uniq = calc_12_mom_labeling(self.data["ul"], self.t) + alpha = np.zeros(n_genes) + # let us only assume one alpha for each gene in all cells + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating alpha"): + # for j in range(len(self.data['ul'][i])): + alpha[i] = fit_alpha_synthesis( + t_uniq, + ul_m[i], + self.parameters["beta"][i], + ) + else: + pool = ThreadPool(cores) + alpha = pool.starmap( + fit_alpha_synthesis, + zip( + itertools.repeat(t_uniq), + ul_m, + self.parameters["beta"], + ), + ) + pool.close() + pool.join() + alpha = np.array(alpha) + self.parameters["alpha"] = alpha + # self.parameters['alpha'] = self.fit_alpha_oneshot(self.t, self.data['ul'], self.parameters['beta'], clusters) + else: + if self._exist_data("ul") and self._exist_parameter("gamma"): + self.parameters["alpha"] = self.fit_alpha_oneshot( + self.t, + self.data["ul"], + self.parameters["gamma"], + clusters, + ) + elif self._exist_data("ul") and self._exist_data("uu"): + if one_shot_method in ["sci-fate", "sci_fate"]: + gamma, total0 = np.zeros(n_genes), np.zeros(n_genes) + for i in tqdm(range(n_genes), desc="estimating gamma"): + total = self.data["uu"][i] + self.data["ul"][i] + total0[i], gamma[i] = ( + np.mean(total), + solve_gamma( + np.max(self.t), + self.data["uu"][i], + total, + ), + ) + (self.aux_param["total0"], self.parameters["gamma"],) = ( + total0, + gamma, + ) + + ul_m, ul_v, t_uniq = calc_12_mom_labeling(self.data["ul"], self.t) + # let us only assume one alpha for each gene in all cells + alpha = np.zeros(n_genes) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating alpha"): + # for j in range(len(self.data['ul'][i])): + alpha[i] = fit_alpha_synthesis( + t_uniq, + ul_m[i], + self.parameters["gamma"][i], + ) # ul_m[i] / t_uniq + else: + pool = ThreadPool(cores) + alpha = pool.starmap( + fit_alpha_synthesis, + zip( + itertools.repeat(t_uniq), + ul_m, + self.parameters["gamma"], + ), + ) + pool.close() + pool.join() + alpha = np.array(alpha) + self.parameters["alpha"] = alpha + # self.parameters['alpha'] = self.fit_alpha_oneshot(self.t, self.data['ul'], self.parameters['gamma'], clusters) + elif one_shot_method == "combined": + self.parameters["alpha"] = ( + csr_matrix(self.data["ul"].shape) + if issparse(self.data["ul"]) + else np.zeros_like(self.data["ul"].shape) + ) + (t_uniq, gamma, gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( + np.unique(self.t), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U, S = ( + self.data["ul"], + self.data["uu"] + self.data["ul"], + ) + + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + gamma_k[i], + gamma_intercept[i], + _, + gamma_r2[i], + _, + gamma_logLL[i], + ) = self.fit_gamma_steady_state(U[i], S[i], False, None, perc_right) + ( + gamma[i], + self.parameters["alpha"][i], + ) = one_shot_gamma_alpha(gamma_k[i], t_uniq, U[i]) + else: + pool = ThreadPool(cores) + res1 = pool.starmap( + self.fit_gamma_steady_state, + zip( + U, + S, + itertools.repeat(False), + itertools.repeat(None), + itertools.repeat(perc_right), + ), + ) + + ( + gamma_k, + gamma_intercept, + _, + gamma_r2, + _, + gamma_logLL, + ) = zip(*res1) + (gamma_k, gamma_intercept, gamma_r2, gamma_logLL,) = ( + np.array(gamma_k), + np.array(gamma_intercept), + np.array(gamma_r2), + np.array(gamma_logLL), + ) + + res2 = pool.starmap( + one_shot_gamma_alpha, + zip(gamma_k, itertools.repeat(t_uniq), U), + ) + + (gamma, alpha) = zip(*res2) + (gamma, self.parameters["alpha"]) = ( + np.array(gamma), + np.array(alpha), + ) + + pool.close() + pool.join() + ( + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["alpha_r2"], + ) = ( + gamma, + gamma_k, + gamma_intercept, + gamma_r2, + gamma_logLL, + gamma_r2, + ) + elif self.model.lower() == "stochastic": + if np.all(self._exist_data("uu", "ul", "su", "sl")): + self.parameters["beta"] = np.ones(n_genes) + k, k_intercept, k_r2, k_logLL, bs, bf = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["uu"] + S = self.data["uu"] + self.data["ul"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T + ) + if cores == 1: + for i in tqdm( + range(n_genes), + desc="estimating beta and alpha for one-shot experiment", + ): + ( + k[i], + k_intercept[i], + _, + k_r2[i], + _, + k_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + ( + k, + k_intercept, + _, + k_r2, + _, + k_logLL, + bs, + bf, + ) = zip(*res) + (k, k_intercept, k_r2, k_logLL, bs, bf) = ( + np.array(k), + np.array(k_intercept), + np.array(k_r2), + np.array(k_logLL), + np.array(bs), + np.array(bf), + ) + beta, alpha0 = one_shot_gamma_alpha_matrix(k, t_uniq, U) + + self.parameters["beta"], self.aux_param["beta_k"] = ( + beta, + k, + ) + + U = self.data["uu"] + self.data["ul"] + S = U + self.data["su"] + self.data["sl"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T + ) + if cores == 1: + for i in tqdm( + range(n_genes), + desc="estimating gamma and alpha for one-shot experiment", + ): + ( + k[i], + k_intercept[i], + _, + k_r2[i], + _, + k_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + (k, k_intercept, _, k_r2, _, k_logLL, bs, bf) = zip(*res) + (k, k_intercept, k_r2, k_logLL, bs, bf) = ( + np.array(k), + np.array(k_intercept), + np.array(k_r2), + np.array(k_logLL), + np.array(bs), + np.array(bf), + ) + + gamma, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, U) + ( + self.parameters["alpha"], + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["bs"], + self.aux_param["bf"], + ) = ( + (alpha + alpha0) / 2, + gamma, + k, + k_intercept, + k_r2, + k_logLL, + bs, + bf, + ) + elif np.all(self._exist_data("uu", "ul")): + if one_shot_method == "storm-csp": + gamma, gamma_r2, k = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + new_counts = self.data["new_counts"] + total_counts = self.data["total_counts"] + new_smooth_csp = self.data["new_smooth_csp"] + new_smooth = self.data['ul'] + total_smooth = self.data["ul"] + self.data["uu"] + for i in tqdm(range(n_genes), desc="estimating gamma via storm's csp model"): + ( + gamma[i], + gamma_r2[i], + k[i], + ) = self.fit_gamma_storm_csp( + new_counts[i], + total_counts[i], + new_smooth[i], + total_smooth[i], + t_uniq=t_uniq, + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + _, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, new_smooth_csp) + ( + self.parameters["alpha"], + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + ) = ( + alpha, + gamma, + k, + np.zeros(n_genes), + gamma_r2, + ) + else: + k, k_intercept, k_r2, k_logLL, bs, bf = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + U = self.data["ul"] + S = self.data["ul"] + self.data["uu"] + US = ( + self.data["us"] + if self.data["us"] is not None + else calc_2nd_moment(U.T, S.T, self.conn, mX=U.T, mY=S.T).T + ) + S2 = ( + self.data["s2"] + if self.data["s2"] is not None + else calc_2nd_moment(S.T, S.T, self.conn, mX=S.T, mY=S.T).T + ) + if cores == 1: + for i in tqdm(range(n_genes), desc="estimating gamma"): + ( + k[i], + k_intercept[i], + _, + k_r2[i], + _, + k_logLL[i], + bs[i], + bf[i], + ) = self.fit_gamma_stochastic( + self.est_method, + U[i], + S[i], + US[i], + S2[i], + perc_left=perc_left, + perc_right=perc_right, + normalize=True, + ) + else: + pool = ThreadPool(cores) + res = pool.starmap( + self.fit_gamma_stochastic, + zip( + itertools.repeat(self.est_method), + U, + S, + US, + S2, + itertools.repeat(perc_left), + itertools.repeat(perc_right), + itertools.repeat(True), + ), + ) + pool.close() + pool.join() + (k, k_intercept, _, k_r2, _, k_logLL, bs, bf) = zip(*res) + (k, k_intercept, k_r2, k_logLL, bs, bf) = ( + np.array(k), + np.array(k_intercept), + np.array(k_r2), + np.array(k_logLL), + np.array(bs), + np.array(bf), + ) + + gamma, alpha = one_shot_gamma_alpha_matrix(k, t_uniq, U) + ( + self.parameters["alpha"], + self.parameters["gamma"], + self.aux_param["gamma_k"], + self.aux_param["gamma_intercept"], + self.aux_param["gamma_r2"], + self.aux_param["gamma_logLL"], + self.aux_param["bs"], + self.aux_param["bf"], + ) = ( + alpha, + gamma, + k, + k_intercept, + k_r2, + k_logLL, + bs, + bf, + ) + + self.fit_protein(intercept=intercept, perc_left=perc_left, perc_right=perc_right, cores=cores) + + def fit_mix_std_stm( + self, + intercept=False, + perc_left=None, + perc_right=5, + ): + """Fit the input data to estimate parameters for mix_std_stm experiment type.""" + n_genes = self.get_n_genes() + cores = max(1, int(self.cores)) + t_min, t_max = np.min(self.t), np.max(self.t) + if np.all(self._exist_data("ul", "uu", "su")): + gamma, beta, total, U = ( + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + np.zeros(n_genes), + ) + for i in tqdm( + range(n_genes), desc="solving gamma/beta" + ): # can also use the two extreme time points and apply sci-fate like approach. + tmp = ( + self.data["uu"][i, self.t == t_max] + + self.data["ul"][i, self.t == t_max] + + self.data["su"][i, self.t == t_max] + + self.data["sl"][i, self.t == t_max] + ) + total[i] = np.mean(tmp) + gamma[i] = solve_gamma( + t_max, + self.data["uu"][i, self.t == t_max] + self.data["su"][i, self.t == t_max], + tmp, + ) + # same for beta + tmp = self.data["uu"][i, self.t == t_max] + self.data["ul"][i, self.t == t_max] + U[i] = np.mean(tmp) + beta[i] = solve_gamma( + np.max(self.t), + self.data["uu"][i, self.t == t_max], + tmp, + ) + + ( + self.parameters["beta"], + self.parameters["gamma"], + self.aux_param["total0"], + self.aux_param["U0"], + ) = (beta, gamma, total, U) + # alpha estimation + self.parameters["alpha"] = self.solve_alpha_mix_std_stm( + self.t, self.data["ul"], self.parameters["beta"] + ) + elif np.all(self._exist_data("ul", "uu")): + n_genes = self.data["uu"].shape[0] # self.get_n_genes(data=U) + gamma, U = np.zeros(n_genes), np.zeros(n_genes) + for i in tqdm( + range(n_genes), desc="solving gamma, alpha" + ): # apply sci-fate like approach (can also use one-single time point to estimate gamma) + # tmp = self.data['uu'][i, self.t == 0] + self.data['ul'][i, self.t == 0] + tmp_ = self.data["uu"][i, self.t == t_max] + self.data["ul"][i, self.t == t_max] + + U[i] = np.mean(tmp_) + # gamma_1 = solve_gamma(np.max(self.t), self.data['uu'][i, self.t == 0], tmp) # steady state + gamma_2 = solve_gamma(t_max, self.data["uu"][i, self.t == t_max], tmp_) # stimulation + # gamma_3 = solve_gamma(np.max(self.t), self.data['uu'][i, self.t == np.max(self.t)], tmp) # sci-fate + gamma[i] = gamma_2 + # print('Steady state, stimulation, sci-fate like gamma values are ', gamma_1, '; ', gamma_2, '; ', gamma_3) + (self.parameters["gamma"], self.aux_param["U0"], self.parameters["beta"],) = ( + gamma, + U, + np.ones(gamma.shape), + ) + # alpha estimation + self.parameters["alpha"] = self.solve_alpha_mix_std_stm( + self.t, self.data["ul"], self.parameters["gamma"] + ) + + self.fit_protein(intercept=intercept, perc_left=perc_left, perc_right=perc_right, cores=cores) + + def fit_gamma_steady_state(self, u, s, intercept=True, perc_left=None, perc_right=5, normalize=True): + """Estimate gamma using linear regression based on the steady state assumption. + + Arguments + --------- + u: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of unspliced mRNA counts. Dimension: genes x cells. + s: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of spliced mRNA counts. Dimension: genes x cells. + intercept: bool + If using steady state assumption for fitting, then: + True -- the linear regression is performed with an unfixed intercept; + False -- the linear regresssion is performed with a fixed zero intercept. + perc_left: float + The percentage of samples included in the linear regression in the left tail. If set to None, then all the + left samples are excluded. + perc_right: float + The percentage of samples included in the linear regression in the right tail. If set to None, then all the + samples are included. + normalize: bool + Whether to first normalize the data. + + Returns + ------- + k: float + The slope of the linear regression model, which is gamma under the steady state assumption. + b: float + The intercept of the linear regression model. + r2: float + Coefficient of determination or r square for the extreme data points. + r2: float + Coefficient of determination or r square for the extreme data points. + all_r2: float + Coefficient of determination or r square for all data points. + """ + if intercept and perc_left is None: + perc_left = perc_right + u = u.A.flatten() if issparse(u) else u.flatten() + s = s.A.flatten() if issparse(s) else s.flatten() + + mask = find_extreme( + s, + u, + normalize=normalize, + perc_left=perc_left, + perc_right=perc_right, + ) + + if self.est_method.lower() == "ols": + k, b, r2, all_r2 = fit_linreg(s, u, mask, intercept) + else: + k, b, r2, all_r2 = fit_linreg_robust(s, u, mask, intercept, self.est_method) + + logLL, all_logLL = ( + calc_norm_loglikelihood(s[mask], u[mask], k), + calc_norm_loglikelihood(s, u, k), + ) + + return k, b, r2, all_r2, logLL, all_logLL + + def fit_gamma_storm_csp( + self, + new_counts, + total_counts, + new_smooth, + total_smooth, + t_uniq, + perc_left=None, + perc_right=50, + normalize=True, + ): + """Estimate gamma using Storm's CSP model based on the steady state assumption. + + Arguments + --------- + new_counts: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of new mRNA raw counts. Dimension: genes x cells. + total_counts: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of total mRNA raw counts. Dimension: genes x cells. + new_smooth: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of new mRNA smoothed data. Dimension: genes x cells. + total_smooth: :class:`~numpy.ndarray` or sparse `csr_matrix` + A matrix of total mRNA smoothed data. Dimension: genes x cells. + t_uniq: : float + The labeling duration of one-shot experiment. + perc_left: float + The percentage of samples included in the linear regression in the left tail. If set to None, then all the left samples are excluded. + perc_right: float + The percentage of samples included in the linear regression in the right tail. If set to None, then all the samples are included. + normalize: bool + Whether to first normalize the + """ + new_counts = new_counts.A.flatten() if issparse(new_counts) else new_counts.flatten() + total_counts = total_counts.A.flatten() if issparse(total_counts) else total_counts.flatten() + new_smooth = new_smooth.A.flatten() if issparse(new_smooth) else new_smooth.flatten() + total_smooth = total_smooth.A.flatten() if issparse(total_smooth) else total_smooth.flatten() + mask = find_extreme(new_smooth, total_smooth, perc_left=perc_left, perc_right=perc_right, normalize=normalize) + gamma = - np.log(1 - np.mean(new_counts[mask]) / np.mean(total_counts[mask])) / t_uniq + gamma_r2 = 1.0 + k = 1 - np.exp(-gamma*t_uniq) + return gamma, gamma_r2, k def fit_gamma_stochastic( self, diff --git a/dynamo/estimation/tsc/storm.py b/dynamo/estimation/tsc/storm.py new file mode 100644 index 000000000..572c24737 --- /dev/null +++ b/dynamo/estimation/tsc/storm.py @@ -0,0 +1,752 @@ +from typing import Tuple, Union, Optional +from anndata import AnnData + +from scipy.sparse import ( + csr_matrix, + issparse, + SparseEfficiencyWarning, +) +import numpy as np +from matplotlib import pyplot as plt +from scipy.optimize import minimize +from tqdm import tqdm +from scipy.special import gammaln +from scipy.optimize import root, fsolve + +from dynamo.tools.utils import find_extreme + + +def mle_cell_specific_poisson_ss( + R: Union[np.ndarray, csr_matrix], + N: Union[np.ndarray, csr_matrix], + time: np.ndarray, + gamma_init: np.ndarray, + cell_total: np.ndarray, + Total_smoothed, + New_smoothed, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Infer parameters based on the cell specific Poisson model using maximum likelihood estimation under the + steady-state assumption + + Args: + R: The number of total mRNA counts for each gene in each cell. shape: (n_var, n_obs). + N: The number of new mRNA counts for each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + gamma_init: The initial value of gamma. shape: (n_var,). + cell_total: The total counts of reads for each cell. shape: (n_obs,). + + Returns: + gamma: The estimated total mRNA degradation rate gamma. shape: (n_var,). + gamma_r2: The R2 of gamma. shape: (n_var,). + gamma_r2_raw: The R2 of gamma without correction. shape: (n_var,). + alpha: The estimated gene specific transcription rate alpha. shape: (n_var,). + + """ + n_var = N.shape[0] + n_obs = N.shape[1] + cell_capture_rate = cell_total / np.median(cell_total) + + # When there is only one labeling duration we can obtain the analytical solution directly but cannot define the + # goodness-of-fit. + if len(np.unique(time)) == 1: + gamma = np.zeros(n_var) + gamma_r2 = np.ones(n_var) # As goodness of fit could not be defined, all were set to 1. + gamma_r2_raw = np.ones(n_var) + alpha = np.zeros(n_var) + for i, r, n, r_smooth, n_smooth in tqdm( + zip(np.arange(n_var), R, N, Total_smoothed, New_smoothed), + "Infer parameters via maximum likelihood estimation based on the CSP model under the steady-state assumption" + ): + n = n.A.flatten() if issparse(n) else n.flatten() + r = r.A.flatten() if issparse(r) else r.flatten() + n_smooth = n_smooth.A.flatten() if issparse(n_smooth) else n_smooth.flatten() + r_smooth = r_smooth.A.flatten() if issparse(r_smooth) else r_smooth.flatten() + t_unique = np.unique(time) + mask = find_extreme(n_smooth, r_smooth, perc_left=None, perc_right=50) + gamma[i] = - np.log(1 - np.mean(n[mask]) / np.mean(r[mask])) / t_unique + alpha[i] = gamma[i]*np.mean(r[mask])/np.mean(cell_capture_rate[mask]) + else: + gamma = np.zeros(n_var) + gamma_r2 = np.zeros(n_var) + gamma_r2_raw = np.zeros(n_var) + alphadivgamma = np.zeros(n_var) + for i, r, n in tqdm( + zip(np.arange(n_var), R, N), + "Infer parameters via maximum likelihood estimation based on the CSP model under the steady-state assumption" + ): + n = n.A.flatten() if issparse(n) else n.flatten() + r = r.A.flatten() if issparse(r) else r.flatten() + + def loss_func_ss(parameters): + # Loss function of cell specific Poisson model under the steady-state assumption + parameter_alpha_div_gamma, parameter_gamma = parameters + mu_new = parameter_alpha_div_gamma * (1 - np.exp(-parameter_gamma * time)) * cell_capture_rate + loss_new = -np.sum(n * np.log(mu_new) - mu_new) + mu_total = parameter_alpha_div_gamma * cell_capture_rate + loss_total = -np.sum(r * np.log(mu_total) - mu_total) + loss = loss_new + loss_total + return loss + + # Initialize and add boundary conditions + alpha_div_gamma_init = np.mean(n) / np.mean(cell_capture_rate * (1 - np.exp(-gamma_init[i] * time))) + b1 = (0, 10 * alpha_div_gamma_init) + b2 = (0, 10 * gamma_init[i]) + bnds = (b1, b2) + parameters_init = np.array([alpha_div_gamma_init, gamma_init[i]]) + + # Solve + res = minimize(loss_func_ss, parameters_init, method='SLSQP', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func_ss, parameters_init, method='Nelder-Mead', tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func_ss, parameters_init, method='COBYLA', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + parameters = res.x + loss = res.fun + success = res.success + alphadivgamma[i], gamma[i] = parameters + + # Calculate deviance R2 as goodness of fit + + def null_loss_func_ss(parameters_null): + # Loss function of null model under the steady-state assumption + parameters_a0_new, parameters_a0_total = parameters_null + mu_new = parameters_a0_new * cell_capture_rate + loss0_new = -np.sum(n * np.log(mu_new) - mu_new) + mu_total = parameters_a0_total * cell_capture_rate + loss0_total = -np.sum(r * np.log(mu_total) - mu_total) + loss0 = loss0_new + loss0_total + return loss0 + + def saturated_loss_func_ss(): + # Loss function of saturated model under the steady-state assumption + loss_saturated_new = -np.sum(n[n > 0] * np.log(n[n > 0]) - n[n > 0]) + loss_saturated_total = -np.sum(r[r > 0] * np.log(r[r > 0]) - r[r > 0]) + loss_saturated = loss_saturated_new + loss_saturated_total + return loss_saturated + + a0_new = np.mean(n) / np.mean(cell_capture_rate) + a0_total = np.mean(r) / np.mean(cell_capture_rate) + loss0 = null_loss_func_ss((a0_new, a0_total)) + + loss_saturated = saturated_loss_func_ss() + null_devanice = 2 * (loss0 - loss_saturated) + devanice = 2 * (loss - loss_saturated) + gamma_r2_raw[i] = 1 - (devanice / (2*n_obs - 2)) / (null_devanice / (2*n_obs - 2)) + + # Top 40% genes were selected by goodness of fit + gamma_r2 = gamma_r2_raw.copy() + number_selected_genes = int(n_var * 0.4) + gamma_r2[gamma < 0.01] = 0 + sort_index = np.argsort(-gamma_r2) + gamma_r2[sort_index[:number_selected_genes]] = 1 + gamma_r2[sort_index[number_selected_genes + 1:]] = 0 + + alpha = alphadivgamma*gamma + + return gamma, gamma_r2, gamma_r2_raw, alpha + + +def mle_cell_specific_poisson( + N: Union[np.ndarray, csr_matrix], + time: np.ndarray, + gamma_init: np.ndarray, + cell_total: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Infer parameters based on cell specific Poisson distributions using maximum likelihood estimation + + Args: + N: The number of new mRNA counts for each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + gamma_init: The initial value of gamma. shape: (n_var,). + cell_total: The total counts of reads for each cell. shape: (n_obs,). + + Returns: + gamma: The estimated total mRNA degradation rate gamma. shape: (n_var,). + gamma_r2: The R2 of gamma. shape: (n_var,). + gamma_r2_raw: The R2 of gamma without correction. shape: (n_var,). + alpha: The estimated gene specific transcription rate alpha. shape: (n_var,). + """ + n_var = N.shape[0] + n_obs = N.shape[1] + gamma = np.zeros(n_var) + gamma_r2 = np.zeros(n_var) + gamma_r2_raw = np.zeros(n_var) + alphadivgamma = np.zeros(n_var) + for i, n in tqdm( + zip(np.arange(n_var), N), + "Infer parameters via maximum likelihood estimation based on the CSP model" + ): + n = n.A.flatten() if issparse(n) else n.flatten() + cell_capture_rate = cell_total / np.median(cell_total) + + def loss_func(parameters): + # Loss function of cell specific Poisson model + parameter_alpha_div_gamma, parameter_gamma = parameters + mu = parameter_alpha_div_gamma * (1 - np.exp(-parameter_gamma * time)) * cell_capture_rate + loss = -np.sum(n * np.log(mu) - mu) + return loss + + # Initialize and add boundary conditions + alpha_div_gamma_init = np.mean(n) / np.mean(cell_capture_rate * (1 - np.exp(-gamma_init[i] * time))) + b1 = (0, 10 * alpha_div_gamma_init) + b2 = (0, 10 * gamma_init[i]) + bnds = (b1, b2) + parameters_init = np.array([alpha_div_gamma_init, gamma_init[i]]) + + # Solve + res = minimize(loss_func, parameters_init, method='SLSQP', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='Nelder-Mead', tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='COBYLA', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + parameters = res.x + loss = res.fun + success = res.success + alphadivgamma[i], gamma[i] = parameters + + # Calculate deviance R2 as goodness of fit + + def null_loss_func(parameters_null): + # Loss function of null model + parameters_a0 = parameters_null + mu = parameters_a0 * cell_capture_rate + loss0 = -np.sum(n * np.log(mu) - mu) + return loss0 + + def saturated_loss_func(): + # Loss function of saturated model + loss_saturated = -np.sum(n[n > 0] * np.log(n[n > 0]) - n[n > 0]) + return loss_saturated + + a0 = np.mean(n) / np.mean(cell_capture_rate) + loss0 = null_loss_func(a0) + + loss_saturated = saturated_loss_func() + null_devanice = 2 * (loss0 - loss_saturated) + devanice = 2 * (loss - loss_saturated) + gamma_r2_raw[i] = 1 - (devanice / (n_obs - 2)) / (null_devanice / (n_obs - 1)) + + # Top 40% genes were selected by goodness of fit + gamma_r2 = gamma_r2_raw.copy() + number_selected_genes = int(n_var * 0.4) + gamma_r2[gamma < 0.01] = 0 + sort_index = np.argsort(-gamma_r2) + gamma_r2[sort_index[:number_selected_genes]] = 1 + gamma_r2[sort_index[number_selected_genes + 1:]] = 0 + + return gamma, gamma_r2, gamma_r2_raw, alphadivgamma*gamma + + +def mle_cell_specific_zero_inflated_poisson( + N: Union[np.ndarray, csr_matrix], + time: np.ndarray, + gamma_init: np.ndarray, + cell_total: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Infer parameters based on cell specific zero-inflated Poisson distributions using maximum likelihood estimation + + Args: + N: The number of new mRNA counts for each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + gamma_init: The initial value of gamma. shape: (n_var,). + cell_total: The total counts of reads for each cell. shape: (n_obs,). + + Returns: + gamma: The estimated total mRNA degradation rate gamma. shape: (n_var,). + prob_off: The estimated probability of gene expression being in the off state $p_{off}$. shape: (n_var,). + gamma_r2: The R2 of gamma. shape: (n_var,). + gamma_r2_raw: The R2 of gamma without correction. shape: (n_var,). + alpha: The estimated gene specific transcription rate alpha. shape: (n_var,). + """ + n_var = N.shape[0] + n_obs = N.shape[1] + gamma = np.zeros(n_var) + gamma_r2 = np.zeros(n_var) + gamma_r2_raw = np.zeros(n_var) + prob_off = np.zeros(n_var) + alphadivgamma = np.zeros(n_var) + + for i, n in tqdm( + zip(np.arange(n_var), N), + "Infer parameters via maximum likelihood estimation based on the CSZIP model" + ): + n = n.A.flatten() if issparse(n) else n.flatten() + cell_capture_rate = cell_total / np.median(cell_total) + + def loss_func(parameters): + # Loss function of cell specific zero-inflated Poisson model + parameter_alpha_div_gamma, parameter_gamma, parameter_prob_off = parameters + mu = parameter_alpha_div_gamma * (1 - np.exp(-parameter_gamma * time)) * cell_capture_rate + n_eq_0_index = n < 0.001 + n_over_0_index = n > 0.001 + loss_eq0 = -np.sum(np.log(parameter_prob_off + (1 - parameter_prob_off) * np.exp(-mu[n_eq_0_index]))) + loss_over0 = -np.sum(np.log(1 - parameter_prob_off) + (-mu[n_over_0_index]) + n[n_over_0_index] * np.log( + mu[n_over_0_index])) + loss = loss_eq0 + loss_over0 + return loss + + # Initialize and add boundary conditions + mean_n = np.mean(n) + s2_n = np.mean(np.power(n, 2)) + temp = np.mean(cell_capture_rate * (1 - np.exp(-gamma_init[i] * time))) + prob_off_init = 1 - mean_n * mean_n * np.mean( + np.power(cell_capture_rate * (1 - np.exp(-gamma_init[i] * time)), 2)) / ( + temp * temp * (s2_n - mean_n)) # Use moment estimation as the initial value of prob_off + alphadivgamma_init = mean_n / ((1 - prob_off_init) * temp) + b1 = (0, 10 * alphadivgamma_init) + b2 = (0, 10 * gamma_init[i]) + b3 = (0, (np.sum(n < 0.001) / np.sum(n > -1))) + bnds = (b1, b2, b3) + parameters_init = np.array([alphadivgamma_init, gamma_init[i], prob_off_init]) + + # Slove + res = minimize(loss_func, parameters_init, method='SLSQP', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='Nelder-Mead', tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='COBYLA', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + parameters = res.x + alphadivgamma[i], gamma[i], prob_off[i] = parameters + loss = res.fun + success = res.success + + # Calculate deviance R2 as goodness of fit + + def null_Loss_func(parameters_null): + # Loss function of null model + parameters_null_lambda, parameters_null_prob_off = parameters_null + mu = parameters_null_lambda * cell_capture_rate + n_eq_0_index = n < 0.0001 + n_over_0_index = n > 0.0001 + null_loss_eq0 = -np.sum( + np.log(parameters_null_prob_off + (1 - parameters_null_prob_off) * np.exp(-mu[n_eq_0_index]))) + null_loss_over0 = -np.sum( + np.log(1 - parameters_null_prob_off) + (-mu[n_over_0_index]) + n[n_over_0_index] * np.log( + mu[n_over_0_index])) + null_loss = null_loss_eq0 + null_loss_over0 + return null_loss + + mean_cell_capture_rate = np.mean(cell_capture_rate) + prob_off_init_null = 1 - mean_n * mean_n * np.mean(np.power(cell_capture_rate, 2)) / ( + mean_cell_capture_rate * mean_cell_capture_rate * (s2_n - mean_n)) + lambda_init_null = mean_n / ((1 - prob_off_init_null) * mean_cell_capture_rate) + b1_null = (0, 10 * lambda_init_null) + b2_null = (0, (np.sum(n < 0.001) / np.sum(n > -1))) + bnds_null = (b1_null, b2_null) + parameters_init_null = np.array([lambda_init_null, prob_off_init_null]) + res_null = minimize(null_Loss_func, parameters_init_null, method='SLSQP', bounds=bnds_null, tol=1e-2, + options={'maxiter': 1000}) + loss0 = res_null.fun + + def saturated_loss_func(): + loss_saturated = -np.sum(n[n > 0] * np.log(n[n > 0]) - n[n > 0]) + return loss_saturated + + loss_saturated = saturated_loss_func() + null_devanice = 2 * (loss0 - loss_saturated) + devanice = 2 * (loss - loss_saturated) + + gamma_r2_raw[i] = 1 - (devanice / (n_obs - 2)) / (null_devanice / (n_obs - 1)) + + # Top 40% genes were selected by goodness of fit + gamma_r2 = gamma_r2_raw.copy() + number_selected_genes = int(n_var * 0.4) + gamma_r2[gamma < 0.01] = 0 + sort_index = np.argsort(-gamma_r2) + gamma_r2[sort_index[:number_selected_genes]] = 1 + gamma_r2[sort_index[number_selected_genes + 1:]] = 0 + + return gamma, prob_off, gamma_r2, gamma_r2_raw, gamma*alphadivgamma + + +def mle_independent_cell_specific_poisson( + UL: Union[np.ndarray, csr_matrix], + SL: Union[np.ndarray, csr_matrix], + time: np.ndarray, + gamma_init: np.ndarray, + beta_init: np.ndarray, + cell_total: np.ndarray, + Total_smoothed: Union[np.ndarray, csr_matrix], + S_smoothed: Union[np.ndarray, csr_matrix], +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Infer parameters based on independent cell specific Poisson distributions using maximum likelihood estimation + + Args: + UL: The number of unspliced labeled mRNA counts for each gene in each cell. shape: (n_var, n_obs). + SL: The number of spliced labeled mRNA counts for each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + gamma_init: The initial value of gamma. shape: (n_var,). + beta_init: The initial value of beta. shape: (n_var,). + cell_total: The total counts of reads for each cell. shape: (n_obs,). + Total_smoothed: The number of total mRNA expression after normalization and smoothing for each gene in each cell. shape: (n_var, n_obs). + S_smoothed: The number of spliced mRNA expression after normalization and smoothing for each gene in each cell. shape: (n_var, n_obs). + + Returns: + gamma_s: The estimated spliced mRNA degradation rate gamma_s. shape: (n_var,). + gamma_r2: The R2 of gamma. shape: (n_var,). + beta: The estimated gene specific splicing rate beta. shape: (n_var,). + gamma_t: The estimated total mRNA degradation rate gamma_t. shape: (n_var,). + gamma_r2_raw: The R2 of gamma without correction. shape: (n_var,). + alpha: The estimated gene specific transcription rate alpha. shape: (n_var,). + """ + n_var = UL.shape[0] + n_obs = UL.shape[1] + gamma_s = np.zeros(n_var) + gamma_r2 = np.zeros(n_var) + gamma_r2_raw = np.zeros(n_var) + beta = np.zeros(n_var) + alpha = np.zeros(n_var) + gamma_t = np.zeros(n_var) + + for i, ul, sl, r, s in tqdm( + zip(np.arange(n_var), UL, SL, Total_smoothed, S_smoothed), + "Estimate gamma via maximum likelihood estimation based on the ICSP model " + ): + sl = sl.A.flatten() if issparse(sl) else sl.flatten() + ul = ul.A.flatten() if issparse(ul) else ul.flatten() + r = r.A.flatten() if issparse(r) else r.flatten() + s = s.A.flatten() if issparse(s) else s.flatten() + + cell_capture_rate = cell_total / np.median(cell_total) + + def loss_func(parameters): + # Loss function of independent cell specific Poisson model + parameter_alpha, parameter_beta, parameter_gamma_s = parameters + mu_u = parameter_alpha / parameter_beta * (1 - np.exp(-parameter_beta * time)) * cell_capture_rate + mu_s = (parameter_alpha / parameter_gamma_s * (1 - np.exp(-parameter_gamma_s * time)) + parameter_alpha / + (parameter_gamma_s - parameter_beta) * (np.exp(-parameter_gamma_s * time) - np.exp( + -parameter_beta * time))) * cell_capture_rate + loss_u = -np.sum(ul * np.log(mu_u) - mu_u) + loss_s = -np.sum(sl * np.log(mu_s) - mu_s) + loss = loss_u + loss_s + return loss + + # The initial values of gamma_s, beta and alpha are obtained from the initial values of gamma_t. + gamma_s_init = gamma_init[i] * np.sum(r * s) / np.sum(np.power(s, 2)) + beta_init_new = beta_init[i] * gamma_s_init / gamma_init[i] + alpha_init = np.mean(ul + sl) / np.mean(cell_capture_rate * ( + (1 - np.exp(-beta_init_new * time)) / beta_init_new + (1 - np.exp(-gamma_s_init * time)) / gamma_s_init + + (np.exp(-gamma_s_init * time) - np.exp(-beta_init_new * time)) / (gamma_s_init - beta_init_new))) + + # Initialize and add boundary conditions + b1 = (0, 10 * alpha_init) + b2 = (0, 10 * beta_init_new) + b3 = (0, 10 * gamma_s_init) + bnds = (b1, b2, b3) + parameters_init = np.array([alpha_init, beta_init_new, gamma_s_init]) + + # Solve + res = minimize(loss_func, parameters_init, method='SLSQP', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='Nelder-Mead', tol=1e-2, options={'maxiter': 1000}) + # res = minimize(loss_func, parameters_init, method='COBYLA', bounds=bnds, tol=1e-2, options={'maxiter': 1000}) + parameters = res.x + loss = res.fun + success = res.success + alpha[i], beta[i], gamma_s[i] = parameters + + # Calculate deviance R2 as goodness of fit + + def null_loss_func(parameters_null): + # Loss function of null model + parameters_a0, parameters_b0 = parameters_null + mu_u = parameters_a0 * cell_capture_rate + mu_s = parameters_b0 * cell_capture_rate + loss0_u = -np.sum(ul * np.log(mu_u) - mu_u) + loss0_s = -np.sum(sl * np.log(mu_s) - mu_s) + loss0 = loss0_u + loss0_s + return loss0 + + b0 = np.mean(ul) / np.mean(cell_capture_rate) + c0 = np.mean(sl) / np.mean(cell_capture_rate) + loss0 = null_loss_func((b0, c0)) + + def saturated_loss_func(): + # Loss function of saturated model + loss_saturated_u = -np.sum(ul[ul > 0] * np.log(ul[ul > 0]) - ul[ul > 0]) + loss_saturated_s = -np.sum(sl[sl > 0] * np.log(sl[sl > 0]) - sl[sl > 0]) + loss_saturated = loss_saturated_u + loss_saturated_s + return loss_saturated + + loss_saturated = saturated_loss_func() + null_devanice = 2 * (loss0 - loss_saturated) + devanice = 2 * (loss - loss_saturated) + gamma_r2_raw[i] = 1 - (devanice / (2 * n_obs - 3)) / (null_devanice / (2 * n_obs - 2)) # + 0.82 + + gamma_t[i] = gamma_s[i] * np.sum(np.power(s, 2)) / np.sum(r * s) + + # Top 40% genes were selected by goodness of fit + gamma_r2 = gamma_r2_raw.copy() + number_selected_genes = int(n_var * 0.4) + gamma_r2[gamma_s < 0.01] = 0 + sort_index = np.argsort(-gamma_r2) + gamma_r2[sort_index[:number_selected_genes]] = 1 + gamma_r2[sort_index[number_selected_genes + 1:]] = 0 + + return gamma_s, gamma_r2, beta, gamma_t, gamma_r2_raw, alpha + + +def cell_specific_alpha_beta( + UL_smoothed_CSP: Union[np.ndarray, csr_matrix], + SL_smoothed_CSP: Union[np.ndarray, csr_matrix], + time: np.ndarray, + gamma_init: np.ndarray, + beta_init: np.ndarray, +) -> Tuple[csr_matrix, csr_matrix]: + """Infer cell specific transcription rate and splicing rate based on ICSP model + + Args: + UL_smoothed_CSP: The number of unspliced labeled mRNA expression after smoothing based on CSP type model for + each gene in each cell. shape: (n_var, n_obs). + SL_smoothed_CSP: The number of spliced labeled mRNA expression after smoothing based on CSP type model for + each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + gamma_init: The gene wise initial value of gamma. shape: (n_var,). + beta_init: The gene wise initial value of beta. shape: (n_var,). + + Returns: alpha_cs, beta_cs + alpha_cs: The transcription rate for each gene in each cell. shape: (n_var, n_obs). + beta_cs: The splicing rate for each gene in each cell. shape: (n_var, n_obs). + """ + beta_cs = np.zeros_like(UL_smoothed_CSP.A) if issparse(UL_smoothed_CSP) else np.zeros_like(UL_smoothed_CSP) + + n_var = UL_smoothed_CSP.shape[0] + n_obs = UL_smoothed_CSP.shape[1] + + for i, ul, sl, gamma_i, beta_i in tqdm( + zip(np.arange(n_var), UL_smoothed_CSP, SL_smoothed_CSP, gamma_init, beta_init), + "Estimate cell specific alpha and beta" + ): + sl = sl.A.flatten() if issparse(sl) else sl.flatten() + ul = ul.A.flatten() if issparse(ul) else ul.flatten() + + for j in range(n_obs): + sl_j = sl[j] + ul_j = ul[j] + sl_div_ul_j = sl_j / ul_j + time_j = time[j] + + def solve_beta_func(beta_j): + # Equation for solving cell specific beta + return sl_div_ul_j - (1 - np.exp(-gamma_i * time_j)) / gamma_i * beta_j / (1 - np.exp(-beta_j * time_j)) \ + - beta_j / (gamma_i - beta_j) * (np.exp(-gamma_i * time_j) - np.exp(-beta_j * time_j)) / \ + (1 - np.exp(-beta_j * time_j)) + + beta_j_solve = root(solve_beta_func, beta_i) + # beta_j_solve = fsolve(solve_beta_func, beta_i) + + beta_cs[i, j] = beta_j_solve.x + + k = 1 - np.exp(-beta_cs * (np.tile(time, (n_var, 1)))) + beta_cs = csr_matrix(beta_cs) + alpha_cs = beta_cs.multiply(UL_smoothed_CSP).multiply(1 / k) + return alpha_cs, beta_cs + + +def visualize_CSP_loss_landscape( + adata: AnnData, + gene_name_list: list, + figsize: tuple = (3, 3), + dpi: int = 75, + save_name: Optional[str] = None, +): + """Draw the landscape of CSP model-based loss function for the given genes. + + Args: + adata: class:`~anndata.AnnData` + an Annodata object + gene_name_list: A list of gene names that are going to be visualized. + figsize: The width and height of each panel in the figure. + dpi: The dot per inch of the figure. + save_name: The save path for visualization results. save_name = None means that only show but not save the + results. + + Returns: + ------- + A matplotlib plot that shows the landscape of CSP model-based loss function for the given genes. + """ + + def _traverse_CSP(n, time, gamma_init, cell_total): + """Traverse the CSP loss function to draw the landscape""" + n = n.A.flatten() if issparse(n) else n.flatten() + cell_capture_rate = cell_total / np.median(cell_total) + + def loss_func(parameters): + # Loss function of cell specific Poisson model + parameter_alpha_div_gamma, parameter_gamma = parameters + mu = parameter_alpha_div_gamma * (1 - np.exp(-parameter_gamma * time)) * cell_capture_rate + loss = -np.sum(n * np.log(mu) - mu - gammaln(n + 1)) + return loss + + def dldalpha_eq0(gamma): + # Analytic solution to the equation that the derivative of the loss with respect to alpha is equal to 0 + alpha_div_gamma_dldalpha_eq0 = np.mean(n) / np.mean(cell_capture_rate * (1 - np.exp(-gamma * time))) + return alpha_div_gamma_dldalpha_eq0 + + def alpha_constant(gamma): + # When gamma is sufficiently small, alpha is approximated as a constant. + alpha_div_gamma_constant = np.mean(n) / np.mean(cell_capture_rate * (gamma * time)) + return alpha_div_gamma_constant + + # Determine the scope of the traversal + alpha_div_gamma_init = np.mean(n / (1 - np.exp(-gamma_init * time))) + gamma_range = gamma_init * np.logspace(-2, 1, base=5, num=200) + alpha_div_gamma_range = alpha_div_gamma_init * np.logspace(-2, 1, base=5, num=200) + + # Iterate over the value of the loss function in the given range + loss_all = np.zeros((len(gamma_range), len(alpha_div_gamma_range))) + for s in range(len(gamma_range)): + for t in range(len(alpha_div_gamma_range)): + gamma_temp = gamma_range[s] + alpha_div_gamma_temp = alpha_div_gamma_range[t] + loss_all[s, t] = loss_func((alpha_div_gamma_temp, gamma_temp)) + + # Create grid data for drawing + X, Y = np.meshgrid(gamma_range, alpha_div_gamma_range) + Z = np.transpose(loss_all) + + # Calculate the loss value where dl/dalpha is equal to 0 and alpha is equal to a constant + alpha_div_gamma_dldalpha_eq0_range = np.zeros_like(gamma_range) + alpha_div_gamma_constant_range = np.zeros_like(gamma_range) + loss_dldalpha_eq0_range = np.zeros_like(gamma_range) + loss_constant_range = np.zeros_like(gamma_range) + for s in range(len(gamma_range)): + alpha_div_gamma_dldalpha_eq0_range[s] = dldalpha_eq0(gamma_range[s]) + alpha_div_gamma_constant_range[s] = alpha_constant(gamma_range[s]) + loss_dldalpha_eq0_range[s] = loss_func((alpha_div_gamma_dldalpha_eq0_range[s], gamma_range[s])) + loss_constant_range[s] = loss_func((alpha_div_gamma_constant_range[s], gamma_range[s])) + + return X, Y, Z, gamma_range, alpha_div_gamma_dldalpha_eq0_range, \ + alpha_div_gamma_constant_range, loss_dldalpha_eq0_range, loss_constant_range + + def _plot_landscape(X, Y, Z, gamma, alpha_div_gamma_dldalpha_eq0, alpha_div_gamma_constant, + loss_dldalpha_eq0, loss_constant, figsize, dpi, gene_name, save_name): + """Function to draw the landscape, dl/d$\alpha$ and $\alpha_cons$.""" + + # Adjust the range of the parameter to make the results clearer + index1 = np.where(np.logical_and(gamma > np.min(X), gamma < np.max(X))) + index2_dldgeq0 = np.where( + np.logical_and(alpha_div_gamma_dldalpha_eq0 > np.min(Y), alpha_div_gamma_dldalpha_eq0 < np.max(Y))) + index_dldgeq0 = np.intersect1d(index1, index2_dldgeq0) + index2_constant = np.where( + np.logical_and(alpha_div_gamma_constant > np.min(Y), alpha_div_gamma_constant < np.max(Y))) + index_constant = np.intersect1d(index1, index2_constant) + + # Create figure + fig = plt.figure(figsize=figsize, dpi=dpi) + ax = fig.add_subplot(111, projection='3d') + plt.tick_params(pad=-2) + + # Create plot + surf = ax.plot_surface(X, Y, Z, cmap='rainbow', rstride=1, cstride=1, alpha=0.75) + ax.plot(gamma[index_dldgeq0], alpha_div_gamma_dldalpha_eq0[index_dldgeq0], loss_dldalpha_eq0[index_dldgeq0], + color='black', + linewidth=1, label='$\\frac{\partial \ell}{\partial \\alpha}(\\alpha, \gamma_{t})=0$') + ax.plot(gamma[index_constant], alpha_div_gamma_constant[index_constant], loss_constant[index_constant], + color='red', + linewidth=1, label='$\\alpha=\\alpha_{cons}$') + plt.legend() + + cax = fig.add_axes([0.005, 0.15, 0.025, 0.75]) # left down right up + fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5, cax=cax) + + # Add labels + ax.set_xlabel('$\gamma_{t}$', labelpad=-7) + ax.set_ylabel('$\\alpha/\gamma_{t}$', labelpad=-7) + ax.set_zlabel('$-\ell(\\alpha,\gamma_{t})$', labelpad=-7) + ax.set_zlim(np.min(Z), np.max(Z)) + ax.set_title(f'Loss function landscape of for gene {gene_name}') + ax.zaxis.get_major_formatter().set_powerlimits((0, 1)) + + # ax.view_init(azim=-50) + fig.tight_layout() + plt.grid(False) + if save_name: + plt.savefig(save_name) + plt.show() + + sub_adata = adata[:, gene_name_list] + cell_total = sub_adata.obs['initial_cell_size'].astype("float").values + time = sub_adata.obs['time'] + N = sub_adata.layers['new'].T + gamma_init = sub_adata.var['gamma'] + n_var = len(gene_name_list) + for i, n, gene, gamma_init_i in tqdm( + zip(np.arange(n_var), N, gene_name_list, gamma_init), + 'Visualize the landscape of the CSP model loss function' + ): + X, Y, Z, gamma, alpha_div_gamma_dldalpha_eq0, alpha_div_gamma_constant, loss_dldalpha_eq0, loss_constant = \ + _traverse_CSP(n, time, gamma_init_i, cell_total) + _plot_landscape(X, Y, Z, gamma, alpha_div_gamma_dldalpha_eq0, alpha_div_gamma_constant, loss_dldalpha_eq0, + loss_constant, figsize, dpi, gene, save_name) + + +def robustness_measure_CSP( + adata: AnnData, + gene_name_list: list, +) -> np.ndarray: + """Calculate the robustness measure based on CSP model inference of the given genes + + Args: + adata: class:`~anndata.AnnData` + an Annodata object + gene_name_list: A list of gene names that are going to be calculated robustness measure based on CSP model. + + Returns: + robustness_measure: The robustness measure based on CSP model inference of the given genes. + shape: (len(gene_name_list),). + """ + sub_adata = adata[:, gene_name_list] + cell_total = sub_adata.obs['initial_cell_size'].astype("float").values + time = sub_adata.obs['time'] + N = sub_adata.layers['new'].T + robustness_measure = calculate_robustness_measure_CSP(N, time, cell_total) + return robustness_measure + + +def calculate_robustness_measure_CSP( + N: Union[np.ndarray, csr_matrix], + time: np.ndarray, + cell_total: np.ndarray, +) -> np.ndarray: + """Calculate the robustness measure based on CSP model inference + + Args: + N: The number of new mRNA counts for each gene in each cell. shape: (n_var, n_obs). + time: The time point of each cell. shape: (n_obs,). + cell_total: The total counts of reads for each cell. shape: (n_obs,). + + Returns: + robustness_measure: The robustness measure based on CSP model inference for each gene. shape: (n_var,). + """ + n_var = N.shape[0] + robustness_measure = np.zeros(n_var) + for i, n in tqdm( + zip(np.arange(n_var), N), + "Calculate the robustness measure" + ): + n = n.A.flatten() if issparse(n) else n.flatten() + cell_capture_rate = cell_total / np.median(cell_total) + + def partial_loss_partial_gamma(parameters): + # Partial derivative of loss with respect to gamma. + parameter_gamma = parameters + optimal_alphadivgamma = np.mean(n) / np.mean(cell_capture_rate * (1 - np.exp(-parameter_gamma * time))) + pLoss_pgamma = np.sum(-n * time * np.exp(-parameter_gamma * time) / (1 - np.exp( + -parameter_gamma * time)) + cell_capture_rate * optimal_alphadivgamma * time * np.exp( + -parameter_gamma * time)) + return pLoss_pgamma + + def loss_func(parameters): + # Loss function of cell specific Poisson model + parameter_alpha_div_gamma, parameter_gamma = parameters + mu = parameter_alpha_div_gamma * (1 - np.exp(-parameter_gamma * time)) * cell_capture_rate + loss = -np.sum(n * np.log(mu) - mu - gammaln(n + 1)) + return loss + + gamma_range = np.arange(0.01, 1.51, 0.01) + loss = np.zeros_like(gamma_range) + p_loss_p_gamma = np.zeros_like(gamma_range) + for s in range(len(gamma_range)): + gamma_temp = gamma_range[s] + alpha_div_gamma_temp = np.mean(n) / np.mean(cell_capture_rate * (1 - np.exp(-gamma_temp * time))) + p_loss_p_gamma[s] = partial_loss_partial_gamma(gamma_temp) + loss[s] = loss_func((gamma_temp, alpha_div_gamma_temp)) + + # robust_measure[i] = np.mean(np.abs(p_loss_p_gamma)) + robustness_measure[i] = np.sum(np.abs(loss[1:] - loss[0:-1])) + + return robustness_measure diff --git a/dynamo/tools/__init__.py b/dynamo/tools/__init__.py index 13ddb1ae5..50ae5e816 100755 --- a/dynamo/tools/__init__.py +++ b/dynamo/tools/__init__.py @@ -39,7 +39,7 @@ # dimension reduction related from .dimension_reduction import reduceDimension # , run_umap -from .dynamics import dynamics +from .dynamics import dynamics, dynamics_wrapper # state graph related from .graph_calculus import GraphVectorField diff --git a/dynamo/tools/dynamics.py b/dynamo/tools/dynamics.py index 0471c2d53..62c8e381d 100755 --- a/dynamo/tools/dynamics.py +++ b/dynamo/tools/dynamics.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd from anndata import AnnData +from numpy import ndarray from scipy.sparse import SparseEfficiencyWarning, csr_matrix, issparse from tqdm import tqdm @@ -29,6 +30,7 @@ from ..estimation.tsc.estimation_kinetic import * from ..estimation.tsc.twostep import fit_slope_stochastic, lin_reg_gamma_synthesis from ..estimation.tsc.utils_kinetic import * +from ..estimation.tsc import storm from .moments import ( moments, prepare_data_deterministic, @@ -38,6 +40,7 @@ prepare_data_no_splicing, ) from .utils import ( + get_auto_assump_mRNA, get_data_for_kin_params_estimation, get_U_S_for_velocity_estimation, get_valid_bools, @@ -53,30 +56,9 @@ warnings.simplefilter("ignore", SparseEfficiencyWarning) -# incorporate the model selection code soon -def dynamics( - adata: AnnData, - filter_gene_mode: Literal["final", "basic", "no"] = "final", - use_smoothed: bool = True, - assumption_mRNA: Literal["ss", "kinetic", "auto"] = "auto", - assumption_protein: Literal["ss"] = "ss", - model: Literal["auto", "deterministic", "stochastic"] = "auto", - est_method: Literal["ols", "rlm", "ransac", "gmm", "negbin", "auto", "twostep", "direct"] = "auto", - NTR_vel: bool = False, - group: Optional[str] = None, - protein_names: Optional[List[str]] = None, - concat_data: bool = False, - log_unnormalized: bool = True, - one_shot_method: Literal["combined", "sci-fate", "sci_fate"] = "combined", - fraction_for_deg: bool = False, - re_smooth: bool = False, - sanity_check: bool = False, - del_2nd_moments: Optional[bool] = None, - cores: int = 1, - tkey: str = None, - **est_kwargs, -) -> AnnData: - """Inclusive model of expression dynamics considers splicing, metabolic labeling and protein translation. +class BaseDynamics: + """The base class for the inclusive model of expression dynamics considers splicing, metabolic labeling and protein + translation. The function supports learning high-dimensional velocity vector samples for droplet based (10x, inDrop, drop-seq, etc), scSLAM-seq, NASC-seq sci-fate, scNT-seq, scEU-seq, cite-seq or REAP-seq datasets. @@ -285,202 +267,1653 @@ def dynamics( use_smoothed: Whether to use smoothed data (or first moment, done via local average of neighbor cells) NTR_vel: Whether to estimate NTR velocity log_unnormalized: Whether to log transform unnormalized data. - """ + """ + + def __init__(self, dynamics_kwargs: Dict): + self.adata = dynamics_kwargs["adata"] + self.filter_gene_mode = dynamics_kwargs["filter_gene_mode"] + self.use_smoothed = dynamics_kwargs["use_smoothed"] + self.assumption_mRNA = dynamics_kwargs["assumption_mRNA"] + self.assumption_protein = dynamics_kwargs["assumption_protein"] + self.model = dynamics_kwargs["model"] + self.model_was_auto = dynamics_kwargs["model_was_auto"] + self.experiment_type = dynamics_kwargs["experiment_type"] + self.has_splicing = dynamics_kwargs["has_splicing"] + self.has_labeling = dynamics_kwargs["has_labeling"] + self.splicing_labeling = dynamics_kwargs["splicing_labeling"] + self.has_protein = dynamics_kwargs["has_protein"] + self.est_method = dynamics_kwargs["est_method"] + self.NTR_vel = dynamics_kwargs["NTR_vel"] + self.group = dynamics_kwargs["group"] + self.protein_names = dynamics_kwargs["protein_names"] + self.concat_data = dynamics_kwargs["concat_data"] + self.log_unnormalized = dynamics_kwargs["log_unnormalized"] + self.one_shot_method = dynamics_kwargs["one_shot_method"] + self.fraction_for_deg = dynamics_kwargs["fraction_for_deg"] + self.re_smooth = dynamics_kwargs["re_smooth"] + self.sanity_check = dynamics_kwargs["sanity_check"] + self.del_2nd_moments = DynamoAdataConfig.use_default_var_if_none( + dynamics_kwargs["del_2nd_moments"], DynamoAdataConfig.DYNAMICS_DEL_2ND_MOMENTS_KEY + ) + self.cores = dynamics_kwargs["cores"] + if dynamics_kwargs["tkey"] is not None: + if dynamics_kwargs["adata"].obs[dynamics_kwargs["tkey"]].max() > 60: + main_warning( + "Looks like you are using minutes as the time unit. For the purpose of numeric stability, " + "we recommend using hour as the time unit." + ) + self.tkey = self.adata.uns["pp"]["tkey"] if dynamics_kwargs["tkey"] is None else dynamics_kwargs["tkey"] + self.est_kwargs = dynamics_kwargs["est_kwargs"] + + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + """Default method to estimate the velocity parameters.""" + self.est = ss_estimation(**kwargs) + if self.model.lower() == "deterministic": + self.est.fit_conventional_deterministic(**fit_kwargs) + elif self.model.lower() == "stochastic": + self.est.fit_conventional_stochastic(**fit_kwargs) + else: + raise NotImplementedError("Method not implemented.") - del_2nd_moments = DynamoAdataConfig.use_default_var_if_none( - del_2nd_moments, DynamoAdataConfig.DYNAMICS_DEL_2ND_MOMENTS_KEY - ) - if "pp" not in adata.uns_keys(): - raise ValueError(f"\nPlease run `dyn.pp.receipe_monocle(adata)` before running this function!") - if tkey is None: - tkey = adata.uns["pp"]["tkey"] - (experiment_type, has_splicing, has_labeling, splicing_labeling, has_protein,) = ( - adata.uns["pp"]["experiment_type"], - adata.uns["pp"]["has_splicing"], - adata.uns["pp"]["has_labeling"], - adata.uns["pp"]["splicing_labeling"], - adata.uns["pp"]["has_protein"], - ) + def estimate_params_ss(self, subset_adata: AnnData, **est_params_args): + """Estimate velocity parameters with steady state mRNA assumption.""" + if self.est_method.lower() == "auto": + self.est_method = "gmm" if self.model.lower() == "stochastic" else "ols" - X_data, X_fit_data = None, None - filter_list, filter_gene_mode_list = ( - [ - "use_for_pca", - "pass_basic_filter", - "no", - ], - ["final", "basic", "no"], - ) - filter_checker = [i in adata.var.columns for i in filter_list[:2]] - filter_checker.append(True) - filter_id = filter_gene_mode_list.index(filter_gene_mode) - which_filter = np.where(filter_checker[filter_id:])[0][0] + filter_id + if self.experiment_type.lower() == "one-shot": + self.beta = subset_adata.var.beta if "beta" in subset_adata.var.keys() else None + self.gamma = subset_adata.var.gamma if "gamma" in subset_adata.var.keys() else None + ss_estimation_kwargs = {"beta": self.beta, "gamma": self.gamma} + else: + ss_estimation_kwargs = {} - filter_gene_mode = filter_gene_mode_list[which_filter] + if self.one_shot_method == "storm-csp": + _, valid_bools, _ = self._filter() + self.NewCounts = self.adata[:, valid_bools].layers['new'].T + self.TotalCounts = self.adata[:, valid_bools].layers['total'].T + self.NewSmoothCSP = self.adata[:, valid_bools].layers['M_CSP_n'].T + else: + self.NewCounts = None + self.TotalCounts = None + self.NewSmoothCSP = None + + self.estimate_params_utils( + fit_kwargs=self.est_kwargs, + U=self.U.copy() if self.U is not None else None, + Ul=self.Ul.copy() if self.Ul is not None else None, + S=self.S.copy() if self.S is not None else None, + Sl=self.Sl.copy() if self.Sl is not None else None, + P=self.P.copy() if self.P is not None else None, + US=self.US.copy() if self.US is not None else None, + S2=self.S2.copy() if self.S2 is not None else None, + NewCounts=self.NewCounts.copy() if self.NewCounts is not None else None, + TotalCounts=self.TotalCounts.copy() if self.TotalCounts is not None else None, + NewSmoothCSP=self.NewSmoothCSP.copy() if self.NewSmoothCSP is not None else None, + conn=subset_adata.obsp["moments_con"], + t=self.t, + ind_for_proteins=self.ind_for_proteins, + model=self.model, + est_method=self.est_method, + experiment_type=self.experiment_type, + assumption_mRNA=self.assumption_mRNA, + assumption_protein=self.assumption_protein, + concat_data=self.concat_data, + cores=self.cores, + **ss_estimation_kwargs, + ) - valid_bools = get_valid_bools(adata, filter_gene_mode) - gene_num = sum(valid_bools) - if gene_num == 0: - raise Exception(f"no genes pass filter. Try resetting `filter_gene_mode = 'no'` to use all genes.") + self.alpha, self.beta, self.gamma, self.eta, self.delta = self.est.parameters.values() + + def estimate_params_kin(self, cur_grp_i: int, cur_grp: str, subset_adata: AnnData, **est_params_args): + """Estimate velocity parameters with kinetic mRNA assumption. Will be overriden in the subclass.""" + return_ntr = True if self.fraction_for_deg and self.experiment_type.lower() == "deg" else False + + if self.model_was_auto and self.experiment_type.lower() == "kin": + self.model = "mixture" + if self.est_method == "auto": + self.est_method = "direct" + data_type = "smoothed" if self.use_smoothed else "sfs" + + (params, half_life, self.cost, self.logLL, param_ranges, cur_X_data, cur_X_fit_data,) = self.estimate_params_utils( + fit_kwargs=self.est_kwargs, + subset_adata=subset_adata, + tkey=self.tkey, + model=self.model, + est_method=self.est_method, + experiment_type=self.experiment_type, + has_splicing=self.has_splicing, + splicing_labeling=self.splicing_labeling, + has_switch=True, + param_rngs={}, + data_type=data_type, + return_ntr=return_ntr, + **self.est_kwargs, + ) - if model.lower() == "auto": - model = "stochastic" - model_was_auto = True - else: - model_was_auto = False + if type(params) == dict: + self.alpha = params.pop("alpha") + self.beta = params.pop("beta") if "beta" in params else None + params = pd.DataFrame(params) + else: + self.alpha = params.loc[:, "alpha"].values if "alpha" in params.columns else None + self.beta = params.loc[:, "beta"].values if "beta" in params.columns else None - if tkey is not None: - if adata.obs[tkey].max() > 60: - main_warning( - "Looks like you are using minutes as the time unit. For the purpose of numeric stability, " - "we recommend using hour as the time unit." + len_t, len_g = len(np.unique(self.t)), len(self._group) + if cur_grp == self._group[0]: + if len_g != 1: + # X_data, X_fit_data = np.zeros((len_g, adata.n_vars, len_t)), np.zeros((len_g, adata.n_vars,len_t)) + self.X_data, self.X_fit_data = [None] * len_g, [None] * len_g + + if len(self._group) == 1: + self.X_data, self.X_fit_data = cur_X_data, cur_X_fit_data + else: + # X_data[cur_grp_i, :, :], X_fit_data[cur_grp_i, :, :] = cur_X_data, cur_X_fit_data + self.X_data[cur_grp_i], self.X_fit_data[cur_grp_i] = ( + cur_X_data, + cur_X_fit_data, ) - if model.lower() == "stochastic" or use_smoothed or re_smooth: - M_layers = [i for i in adata.layers.keys() if i.startswith("M_")] + # self.a, self.b, self.alpha_a, self.alpha_i, self.beta, self.gamma = ( + # params.loc[:, "a"].values if "a" in params.columns else None, + # params.loc[:, "b"].values if "b" in params.columns else None, + # params.loc[:, "alpha_a"].values if "alpha_a" in params.columns else None, + # params.loc[:, "alpha_i"].values if "alpha_i" in params.columns else None, + # params.loc[:, "beta"].values if "beta" in params.columns else None, + # params.loc[:, "gamma"].values if "gamma" in params.columns else None, + # ) + self.a, self.b, self.alpha_a, self.alpha_i, self.gamma = ( + params.loc[:, "a"].values if "a" in params.columns else None, + params.loc[:, "b"].values if "b" in params.columns else None, + params.loc[:, "alpha_a"].values if "alpha_a" in params.columns else None, + params.loc[:, "alpha_i"].values if "alpha_i" in params.columns else None, + params.loc[:, "gamma"].values if "gamma" in params.columns else None, + ) + if self.alpha is None: + self.alpha = fbar(self.a, self.b, self.alpha_a, 0) if self.alpha_i is None else fbar(self.a, self.b, + self.alpha_a, + self.alpha_i) + all_kinetic_params = [ + "a", + "b", + "alpha_a", + "alpha_i", + "alpha", + "beta", + "gamma", + ] + + self.kin_extra_params = params.loc[:, params.columns.difference(all_kinetic_params)] + + def estimate_parameters(self, cur_grp_i: int, cur_grp: str, subset_adata: AnnData, **est_params_args): + """Wrapper to call corresponding parameters estimation functions according to assumptions. Override this in the + subclass if the class doesn't use ss_estimation or kinetic_model to estimate.""" + if self.assumption_mRNA.lower() == "ss" or (self.experiment_type.lower() in ["one-shot", "mix_std_stm"]): + self.estimate_params_ss(subset_adata=subset_adata, **est_params_args) + elif self.assumption_mRNA.lower() == "kinetic": + self.estimate_params_kin(cur_grp_i=cur_grp_i, cur_grp=cur_grp, subset_adata=subset_adata, **est_params_args) + else: + main_warning("Not implemented yet.") - if len(M_layers) < 2 or re_smooth: + def set_velocity( + self, + vel_U: Union[ndarray, csr_matrix], + vel_S: Union[ndarray, csr_matrix], + vel_N: Union[ndarray, csr_matrix], + vel_T: Union[ndarray, csr_matrix], + vel_P: Union[ndarray, csr_matrix], + cur_grp: int, + cur_cells_bools: ndarray, + valid_bools_: ndarray, + kin_param_pre: str, + **set_velo_args, + ): + """Save the calculated parameters and velocity to anndata. Override this in the subclass if the class has a + different assumption.""" + if self.assumption_mRNA.lower() == "ss" or (self.experiment_type.lower() in ["one-shot", "mix_std_stm"]): + self.adata = set_velocity( + self.adata, + vel_U, + vel_S, + vel_N, + vel_T, + vel_P, + self._group, + cur_grp, + cur_cells_bools, + valid_bools_, + self.ind_for_proteins, + ) + + self.adata = set_param_ss( + self.adata, + self.est, + self.alpha, + self.beta, + self.gamma, + self.eta, + self.delta, + self.experiment_type, + self._group, + cur_grp, + kin_param_pre, + valid_bools_, + self.ind_for_proteins, + ) + elif self.assumption_mRNA.lower() == "kinetic": + self.adata = set_velocity( + self.adata, + vel_U, + vel_S, + vel_N, + vel_T, + vel_P, + self._group, + cur_grp, + cur_cells_bools, + valid_bools_, + self.ind_for_proteins, + ) + + self.adata = set_param_kinetic( + self.adata, + self.alpha, + self.a, + self.b, + self.alpha_a, + self.alpha_i, + self.beta, + self.gamma, + self.cost, + self.logLL, + kin_param_pre, + self.kin_extra_params, + self._group, + cur_grp, + cur_cells_bools, + valid_bools_, + ) + else: + main_warning("Not implemented yet.") + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """The core function to calculate the RNA velocity. Every subclass needs to implement this function. + + Args: + vel: the Velocity object to calculate the velocity. + U: the matrix representing unspliced layer. + S: the matrix representing spliced layer. + N: the matrix representing new layer in metabolic labeling. + T: the matrix representing total layer in metabolic labeling. + + Returns: + The velocity matrix for unspliced, spliced, new and total layers. + """ + raise NotImplementedError("This method has not been implemented.") + + def calculate_vel_P( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + """Calculate the protein velocity.""" + return vel.vel_p(T, self.P) if self.NTR_vel else vel.vel_p(S, self.P) + + def calculate_velocity(self, subset_adata: AnnData) -> Tuple: + """Read the U, S, N, T matrix, create the Velocity class and call the velocity calculation function.""" + U, S = get_U_S_for_velocity_estimation( + subset_adata, + self.use_smoothed, + self.has_splicing, + self.has_labeling, + self.log_unnormalized, + False, + ) + N, T = get_U_S_for_velocity_estimation( + subset_adata, + self.use_smoothed, + self.has_splicing, + self.has_labeling, + self.log_unnormalized, + True, + ) + if self.assumption_mRNA.lower() == "ss" or (self.experiment_type.lower() in ["one-shot", "mix_std_stm"]): + vel = Velocity(estimation=self.est) + elif self.assumption_mRNA.lower() == "kinetic": + params = {"alpha": self.alpha, "beta": self.beta, "gamma": self.gamma, "t": self.t} + vel = Velocity(**params) + else: + main_warning("Not implemented yet.") + + vel_U, vel_S, vel_N, vel_T = self.calculate_vels(vel=vel, U=U, S=S, N=N, T=T) + vel_P = self.calculate_vel_P(vel=vel, U=U, S=S, N=N, T=T) + + return vel_U, vel_S, vel_N, vel_T, vel_P + + def _filter(self) -> Tuple: + """Get filter bools based on existing filter in AnnData.""" + filter_list, filter_gene_mode_list = ( + [ + "use_for_pca", + "pass_basic_filter", + "no", + ], + ["final", "basic", "no"], + ) + filter_checker = [i in self.adata.var.columns for i in filter_list[:2]] + filter_checker.append(True) + filter_id = filter_gene_mode_list.index(self.filter_gene_mode) + which_filter = np.where(filter_checker[filter_id:])[0][0] + filter_id + + filter_gene_mode = filter_gene_mode_list[which_filter] + + valid_bools = get_valid_bools(self.adata, filter_gene_mode) + gene_num = sum(valid_bools) + if gene_num == 0: + raise Exception(f"no genes pass filter. Try resetting `filter_gene_mode = 'no'` to use all genes.") + return filter_gene_mode, valid_bools, gene_num + + def _smooth(self, valid_bools: ndarray): + """Smooth the data by moments when necessary.""" + M_layers = [i for i in self.adata.layers.keys() if i.startswith("M_")] + + if len(M_layers) < 2 or self.re_smooth: main_info("removing existing M layers:%s..." % (str(list(M_layers))), indent_level=2) for i in M_layers: - del adata.layers[i] + del self.adata.layers[i] main_info("making adata smooth...", indent_level=2) - if group is not None and group in adata.obs.columns: - moments(adata, genes=valid_bools, group=group) + if self.group is not None and self.group in self.adata.obs.columns: + moments(self.adata, genes=valid_bools, group=self.group) else: - moments(adata, genes=valid_bools, group=tkey) - elif tkey is not None: + moments(self.adata, genes=valid_bools, group=self.tkey) + elif self.tkey is not None: main_warning( - f"You used tkey {tkey} (or group {group}), but you have calculated local smoothing (1st moment) " + f"You used tkey {self.tkey} (or group {self.group}), but you have calculated local smoothing (1st moment) " f"for your data before. Please ensure you used the desired tkey or group when the smoothing was " f"performed. Try setting re_smooth = True if not sure." ) - valid_adata = adata[:, valid_bools].copy() - if group is not None and group in adata.obs.columns: - _group = adata.obs[group].unique() - if any(adata.obs[group].value_counts() < 50): - main_warning( - f"Note that some groups have less than 50 cells, this may lead to the velocities for some " - f"cells are all NaN values and cause issues for all downstream analysis. Please try to " - f"coarse-grain cell groupings. Cell number for each group are {adata.obs[group].value_counts()}" - ) - - else: - _group = ["_all_cells"] - - for cur_grp_i, cur_grp in enumerate(_group): - if cur_grp == "_all_cells": - kin_param_pre = "" - cur_cells_bools = np.ones(valid_adata.shape[0], dtype=bool) - subset_adata = valid_adata[cur_cells_bools] - else: - kin_param_pre = str(group) + "_" + str(cur_grp) + "_" - cur_cells_bools = (valid_adata.obs[group] == cur_grp).values - subset_adata = valid_adata[cur_cells_bools] - - if model.lower() == "stochastic" or use_smoothed: - moments(subset_adata) - ( - U, - Ul, - S, - Sl, - P, - US, - U2, - S2, - t, - normalized, - ind_for_proteins, - assump_mRNA, - ) = get_data_for_kin_params_estimation( - subset_adata, - has_splicing, - has_labeling, - model, - use_smoothed, - tkey, - protein_names, - log_unnormalized, - NTR_vel, + def _sanity_check( + self, + valid_bools: ndarray, + valid_bools_: ndarray, + gene_num: int, + subset_adata: AnnData, + kin_param_pre: str, + ) -> Tuple: + """Perform sanity check by checking the slope for kinetic or degradation metabolic labeling experiments.""" + indices_valid_bools = np.where(valid_bools)[0] + self.t, L = ( + self.t.flatten(), + (0 if self.Ul is None else self.Ul) + (0 if self.Sl is None else self.Sl), ) + t_uniq = np.unique(self.t) - valid_bools_ = valid_bools.copy() - if sanity_check and experiment_type.lower() in ["kin", "deg"]: - indices_valid_bools = np.where(valid_bools)[0] - t, L = ( - t.flatten(), - (0 if Ul is None else Ul) + (0 if Sl is None else Sl), + valid_gene_checker = np.zeros(gene_num, dtype=bool) + for L_iter, cur_L in tqdm( + enumerate(L), + desc=f"sanity check of {self.experiment_type} experiment data:", + ): + cur_L = cur_L.A.flatten() if issparse(cur_L) else cur_L.flatten() + y = strat_mom(cur_L, self.t, np.nanmean) + slope, _ = fit_linreg(t_uniq, y, intercept=True, r2=False) + valid_gene_checker[L_iter] = ( + True + if (slope > 0 and self.experiment_type == "kin") or (slope < 0 and self.experiment_type == "deg") + else False ) - t_uniq = np.unique(t) + valid_bools_[indices_valid_bools[~valid_gene_checker]] = False + main_warning(f"filtering {gene_num - valid_gene_checker.sum()} genes after sanity check.") - valid_gene_checker = np.zeros(gene_num, dtype=bool) - for L_iter, cur_L in tqdm( - enumerate(L), - desc=f"sanity check of {experiment_type} experiment data:", - ): - cur_L = cur_L.A.flatten() if issparse(cur_L) else cur_L.flatten() - y = strat_mom(cur_L, t, np.nanmean) - slope, _ = fit_linreg(t_uniq, y, intercept=True, r2=False) - valid_gene_checker[L_iter] = ( - True - if (slope > 0 and experiment_type == "kin") or (slope < 0 and experiment_type == "deg") - else False - ) - valid_bools_[indices_valid_bools[~valid_gene_checker]] = False - main_warning(f"filtering {gene_num - valid_gene_checker.sum()} genes after sanity check.") + if len(valid_bools_) < 5: + raise Exception( + f"After sanity check, you have less than 5 valid genes. Something is wrong about your " + f"metabolic labeling experiment!" + ) - if len(valid_bools_) < 5: - raise Exception( - f"After sanity check, you have less than 5 valid genes. Something is wrong about your " - f"metabolic labeling experiment!" + self.U, self.Ul, self.S, self.Sl = ( + (None if self.U is None else self.U[valid_gene_checker, :]), + (None if self.Ul is None else self.Ul[valid_gene_checker, :]), + (None if self.S is None else self.S[valid_gene_checker, :]), + (None if self.Sl is None else self.Sl[valid_gene_checker, :]), + ) + subset_adata = subset_adata[:, valid_gene_checker] + self.adata.var[kin_param_pre + "sanity_check"] = valid_bools_ + return subset_adata, valid_bools_ + + def estimate(self): + """Main function to estimate the RNA dynamics. + + The function initially conducts filtering, smoothing, and sanity checks to ensure data quality. Subsequently, it + calls the corresponding functions to estimate parameters and compute velocity. Lastly, it updates the AnnData + object and save all results. + """ + self.X_data, self.X_fit_data = None, None + filter_gene_mode, valid_bools, gene_num = self._filter() + + if self.model.lower() == "stochastic" or self.use_smoothed or self.re_smooth: + self._smooth(valid_bools=valid_bools) + + valid_adata = self.adata[:, valid_bools].copy() + if self.group is not None and self.group in self.adata.obs.columns: + self._group = self.adata.obs[self.group].unique() + if any(self.adata.obs[self.group].value_counts() < 50): + main_warning( + f"Note that some groups have less than 50 cells, this may lead to the velocities for some " + f"cells are all NaN values and cause issues for all downstream analysis. Please try to " + f"coarse-grain cell groupings. Cell number for each group are {self.adata.obs[self.group].value_counts()}" ) - U, Ul, S, Sl = ( - (None if U is None else U[valid_gene_checker, :]), - (None if Ul is None else Ul[valid_gene_checker, :]), - (None if S is None else S[valid_gene_checker, :]), - (None if Sl is None else Sl[valid_gene_checker, :]), - ) - subset_adata = subset_adata[:, valid_gene_checker] - adata.var[kin_param_pre + "sanity_check"] = valid_bools_ + else: + self._group = ["_all_cells"] - if assumption_mRNA.lower() == "auto": - assumption_mRNA = assump_mRNA - if experiment_type.lower() == "conventional": - assumption_mRNA = "ss" - elif experiment_type.lower() in ["mix_pulse_chase", "deg", "kin"]: - assumption_mRNA = "kinetic" + for cur_grp_i, cur_grp in enumerate(self._group): + if cur_grp == "_all_cells": + kin_param_pre = "" + cur_cells_bools = np.ones(valid_adata.shape[0], dtype=bool) + subset_adata = valid_adata[cur_cells_bools] + else: + kin_param_pre = str(self.group) + "_" + str(cur_grp) + "_" + cur_cells_bools = (valid_adata.obs[self.group] == cur_grp).values + subset_adata = valid_adata[cur_cells_bools] - if model.lower() == "stochastic" and experiment_type.lower() not in [ - "conventional", - "kinetics", - "degradation", - "kin", - "deg", - "one-shot", - ]: - """ - # temporially convert to deterministic model as moment model for mix_std_stm - and other types of labeling experiment is ongoing.""" + if self.model.lower() == "stochastic" or self.use_smoothed: + moments(subset_adata) + ( + self.U, + self.Ul, + self.S, + self.Sl, + self.P, + self.US, + self.U2, + self.S2, + self.t, + self.normalized, + self.ind_for_proteins, + assump_mRNA, + ) = get_data_for_kin_params_estimation( + subset_adata, + self.has_splicing, + self.has_labeling, + self.model, + self.use_smoothed, + self.tkey, + self.protein_names, + self.log_unnormalized, + self.NTR_vel, + ) - model = "deterministic" + valid_bools_ = valid_bools.copy() + if self.sanity_check and self.experiment_type.lower() in ["kin", "deg"]: + subset_adata, valid_bools_ = self.sanity_check( + valid_bools, valid_bools_, gene_num, subset_adata, kin_param_pre) - if model_was_auto and experiment_type.lower() in [ - "kinetic", - "kin", - "degradation", - "deg", - ]: - model = "deterministic" + self.estimate_parameters(cur_grp_i=cur_grp_i, cur_grp=cur_grp, subset_adata=subset_adata) + vel_U, vel_S, vel_N, vel_T, vel_P = self.calculate_velocity(subset_adata=subset_adata) + self.set_velocity(vel_U, vel_S, vel_N, vel_T, vel_P, cur_grp, cur_cells_bools, valid_bools_, kin_param_pre) - if assumption_mRNA.lower() == "ss" or (experiment_type.lower() in ["one-shot", "mix_std_stm"]): - if est_method.lower() == "auto": - est_method = "gmm" if model.lower() == "stochastic" else "ols" + if self.group is not None and self.group in self.adata.obs[self.group]: + uns_key = self.group + "_dynamics" + else: + uns_key = "dynamics" - if experiment_type.lower() == "one-shot": - try: + if self.sanity_check and self.experiment_type in ["kin", "deg"]: + sanity_check_cols = self.adata.var.columns.str.endswith("sanity_check") + self.adata.var["use_for_dynamics"] = self.adata.var.loc[:, sanity_check_cols].sum(1).astype(bool) + else: + self.adata.var["use_for_dynamics"] = False + self.adata.var.loc[valid_bools, "use_for_dynamics"] = True + + self.adata.uns[uns_key] = { + "filter_gene_mode": filter_gene_mode, + "t": self.t, + "group": self.group, + "X_data": self.X_data, + "X_fit_data": self.X_fit_data, + "asspt_mRNA": self.assumption_mRNA, + "experiment_type": self.experiment_type, + "normalized": self.normalized, + "model": self.model, + "est_method": self.est_method, + "has_splicing": self.has_splicing, + "has_labeling": self.has_labeling, + "splicing_labeling": self.splicing_labeling, + "has_protein": self.has_protein, + "use_smoothed": self.use_smoothed, + "NTR_vel": self.NTR_vel, + "log_unnormalized": self.log_unnormalized, + "fraction_for_deg": self.fraction_for_deg, + } + + if self.del_2nd_moments: + remove_2nd_moments(self.adata) + + return self.adata + + +class SplicedDynamics(BaseDynamics): + """Dynamics models for RNA data only contain spliced RNA. This includes the conventional, generalized moments method + (GMM) and negative binomial (NB) distribution method.""" + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Implement the velocity calculation function for splicing data. Calculate unspliced and spliced velocity.""" + vel_U = vel.vel_u(U) + vel_S = vel.vel_s(U, S) + vel_N = np.nan + vel_T = np.nan + return vel_U, vel_S, vel_N, vel_T + + +class LabeledDynamics(BaseDynamics): + """Dynamics model for metabolic labeling data.""" + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + """Calculate unspliced velocity. All subclass should implement this method.""" + raise NotImplementedError("This method has not been implemented.") + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + """Calculate spliced velocity. All subclass should implement this method.""" + raise NotImplementedError("This method has not been implemented.") + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + """Calculate new velocity. All subclass should implement this method.""" + raise NotImplementedError("This method has not been implemented.") + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + """Calculate total velocity. All subclass should implement this method.""" + raise NotImplementedError("This method has not been implemented.") + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Implement the velocity calculation function for metabolic labeling data. Unsplcied and spliced velocity will + be nan for data without splicing information.""" + if self.has_splicing: + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + else: + vel_U, vel_S = np.nan, np.nan + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +class OneShotDynamics(LabeledDynamics): + """Dynamics model for the one shot experiment, where there is only one labeling time point.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + self.est = ss_estimation(**kwargs) + if self.experiment_type.lower() in ["one-shot", "one_shot"]: + if self.one_shot_method == "storm-csp": + self.est.fit_oneshot(one_shot_method=self.one_shot_method, perc_right=50, **fit_kwargs) + else: + self.est.fit_oneshot(one_shot_method=self.one_shot_method, **fit_kwargs) + + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(U) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(N) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(N, T - N) if self.has_splicing else vel.vel_u(T) + + +class SSKineticsDynamics(LabeledDynamics): + """Two-step dynamics model for the Kinetic experiment with steady state assumption, which relies on two consecutive + linear regressions to estimate the degradation rate.""" + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return N.multiply(csr_matrix(self.gamma_ / self.Kc)) - csr_matrix(self.beta).multiply(U) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return (N - csr_matrix(self.Kc).multiply(N)).multiply(csr_matrix(self.gamma_ / self.Kc)) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return (N - csr_matrix(self.Kc).multiply(T)).multiply(csr_matrix(self.gamma_ / self.Kc)) + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Override the velocity calculation function to calculate extra parameters slope and actual gamma.""" + self.Kc = np.clip(self.gamma[:, None], 0, 1 - 1e-3) # S - U slope + self.gamma_ = -(np.log(1 - self.Kc) / self.t[None, :]) # actual gamma + if self.has_splicing: + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + else: + vel_U, vel_S = np.nan, np.nan + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +class KineticsDynamics(LabeledDynamics): + """Dynamic models for the kinetic experiment with kinetic assumption. This includes a kinetic two-step method and + the direct method.""" + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(U) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(N) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(T) + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Override the velocity calculation function to reset beta or alpha.""" + if self.has_splicing: + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + vel.parameters["beta"] = self.gamma + else: + vel_U, vel_S = np.nan, np.nan + alpha_ = one_shot_alpha_matrix(N, self.gamma, self.t) + vel.parameters["alpha"] = alpha_ + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +class TwoStepKineticsDynamics(KineticsDynamics): + """Dynamic models for the kinetic experiment with two-step method.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + kin_estimation = KineticEstimation(**kwargs) + return kin_estimation.fit_twostep_kinetics(**fit_kwargs) + + +class KineticsStormDynamics(LabeledDynamics): + """Stochastic transient dynamics for the kinetic experiment with kinetic assumption. This includes three stochastic + models. In Model 1, only transcription and mRNA degradation were considered. In Model 2, we considered + transcription, splicing, and spliced mRNA degradation. And in Model 3, we considered the switching of gene + expression states, transcription in the active state, and mRNA degradation.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + kin_estimation = KineticEstimation(**kwargs) + return kin_estimation.fit_storm(**fit_kwargs) + + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(U) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + if self.est_method == 'storm-icsp': + return vel.vel_u(self.Sl) + else: + return vel.vel_u(N) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + if self.est_method == 'storm-icsp': + return vel.vel_u(S) + else: + return vel.vel_u(T) + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Override the velocity calculation function to reset beta or alpha.""" + if self.has_splicing: + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + vel.parameters["beta"] = self.gamma + else: + vel_U, vel_S = np.nan, np.nan + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +class DirectKineticsDynamics(KineticsDynamics): + """Dynamic models for the kinetic experiment with direct method.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + kin_estimation = KineticEstimation(**kwargs) + return kin_estimation.fit_direct_kinetics(**fit_kwargs) + + +class DegradationDynamics(LabeledDynamics): + """Dynamics model for the degradation experiment. In degradation experiment, samples are chased after an extended + 4sU (or other nucleotide analog) labeling period and the wash-out to observe the decay of the abundance of the + (labeled) unspliced and spliced RNA decay over time.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + kin_estimation = KineticEstimation(**kwargs) + return kin_estimation.fit_degradation(**fit_kwargs) + + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return np.nan + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return np.nan + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return np.nan + + +class MixStdStmDynamics(LabeledDynamics): + """Dynamics model for the mixed steady state and stimulation labeling (mix_std_stm) experiment.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + self.est = ss_estimation(**kwargs) + self.est.fit_mix_std_stm(**fit_kwargs) + + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return self.alpha1 - csr_matrix(self.beta[:, None]).multiply(U) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return self.alpha1 - csr_matrix(self.gamma[:, None]).multiply(self.u_new) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return self.alpha1 - csr_matrix(self.gamma[:, None]).multiply(T) + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Override the velocity calculation function to calculate extra parameters u_new and alpha1.""" + if self.has_splicing: + u0, self.u_new, self.alpha1 = solve_alpha_2p_mat( + t0=np.max(self.t) - self.t, + t1=self.t, + alpha0=self.alpha[0], + beta=self.beta, + u1=N, + ) + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + else: + u0, self.u_new, self.alpha1 = solve_alpha_2p_mat( + t0=np.max(self.t) - self.t, + t1=self.t, + alpha0=self.alpha[0], + beta=self.gamma, + u1=N, + ) + vel_U, vel_S = np.nan, np.nan + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +class MixKineticsDynamics(LabeledDynamics): + """Dynamics model for two mix experiment type: mix_kin_deg and mix_pulse_chase.""" + def estimate_params_utils(self, fit_kwargs=None, **kwargs): + kin_estimation = KineticEstimation(**kwargs) + return kin_estimation.fit_mix_kinetics(**fit_kwargs) + + def calculate_vel_U( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(U, repeat=True) + + def calculate_vel_S( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_s(U, S) + + def calculate_vel_N( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(N, repeat=True) + + def calculate_vel_T( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Union[ndarray, csr_matrix]: + return vel.vel_u(T) if not self.has_splicing and self.NTR_vel else vel.vel_u(T, repeat=True) + + def calculate_vels( + self, + vel: Velocity, + U: Union[ndarray, csr_matrix], + S: Union[ndarray, csr_matrix], + N: Union[ndarray, csr_matrix], + T: Union[ndarray, csr_matrix], + ) -> Tuple: + """Override the velocity calculation function to reset beta when the data contains splicing information.""" + if self.has_splicing: + vel_U = self.calculate_vel_U(vel=vel, U=U, S=S, N=N, T=T) + vel_S = self.calculate_vel_S(vel=vel, U=U, S=S, N=N, T=T) + vel.parameters["beta"] = self.gamma + else: + vel_U, vel_S = np.nan, np.nan + vel_N = self.calculate_vel_N(vel=vel, U=U, S=S, N=N, T=T) + vel_T = self.calculate_vel_T(vel=vel, U=U, S=S, N=N, T=T) + return vel_U, vel_S, vel_N, vel_T + + +# TODO: rename this later +def dynamics_wrapper( + adata: AnnData, + filter_gene_mode: Literal["final", "basic", "no"] = "final", + use_smoothed: bool = True, + assumption_mRNA: Literal["ss", "kinetic", "auto"] = "auto", + assumption_protein: Literal["ss"] = "ss", + model: Literal["auto", "deterministic", "stochastic"] = "auto", + est_method: Literal["ols", "rlm", "ransac", "gmm", "negbin", "auto", "twostep", "direct"] = "auto", + NTR_vel: bool = False, + group: Optional[str] = None, + protein_names: Optional[List[str]] = None, + concat_data: bool = False, + log_unnormalized: bool = True, + one_shot_method: Literal["combined", "sci-fate", "sci_fate"] = "combined", + fraction_for_deg: bool = False, + re_smooth: bool = False, + sanity_check: bool = False, + del_2nd_moments: Optional[bool] = None, + cores: int = 1, + tkey: str = None, + **est_kwargs, +) -> AnnData: + """Predict the model and assumption if they are set as auto. Run corresponding Dynamics methods according to the + experiment type. More information can be found in the class BaseDynamics.""" + if "pp" not in adata.uns_keys(): + raise ValueError(f"\nPlease run `dyn.pp.receipe_monocle(adata)` before running this function!") + if model.lower() == "auto": + model = "stochastic" + model_was_auto = True + else: + model = model + model_was_auto = False + + (experiment_type, has_splicing, has_labeling, splicing_labeling, has_protein,) = ( + adata.uns["pp"]["experiment_type"], + adata.uns["pp"]["has_splicing"], + adata.uns["pp"]["has_labeling"], + adata.uns["pp"]["splicing_labeling"], + adata.uns["pp"]["has_protein"], + ) + + (NTR_vel, assump_mRNA) = get_auto_assump_mRNA( + subset_adata=adata, + has_splicing=has_splicing, + has_labeling=has_labeling, + use_moments=use_smoothed, + tkey=tkey, + NTR_vel=NTR_vel, + ) + if assumption_mRNA.lower() == "auto": + assumption_mRNA = assump_mRNA + if experiment_type.lower() == "conventional": + assumption_mRNA = "ss" + elif experiment_type.lower() in ["mix_pulse_chase", "deg", "kin"]: + assumption_mRNA = "kinetic" + + if model.lower() == "stochastic" and experiment_type.lower() not in [ + "conventional", + "kinetics", + "degradation", + "kin", + "deg", + "one-shot", + ]: + """ + # temporially convert to deterministic model as moment model for mix_std_stm + and other types of labeling experiment is ongoing.""" + + model = "deterministic" + + if model_was_auto and experiment_type.lower() in [ + "kinetic", + "kin", + "degradation", + "deg", + ]: + model = "deterministic" + + dynamics_kwargs = { + "adata": adata, + "filter_gene_mode": filter_gene_mode, + "use_smoothed": use_smoothed, + "assumption_mRNA": assumption_mRNA, + "assumption_protein": assumption_protein, + "model": model, + "model_was_auto": model_was_auto, + "experiment_type": experiment_type, + "has_splicing": has_splicing, + "has_labeling": has_labeling, + "splicing_labeling": splicing_labeling, + "has_protein": has_protein, + "est_method": est_method, + "NTR_vel": NTR_vel, + "group": group, + "protein_names": protein_names, + "concat_data": concat_data, + "log_unnormalized": log_unnormalized, + "one_shot_method": one_shot_method, + "fraction_for_deg": fraction_for_deg, + "re_smooth": re_smooth, + "sanity_check": sanity_check, + "del_2nd_moments": del_2nd_moments, + "cores": cores, + "tkey": tkey, + "est_kwargs": est_kwargs, + } + + if experiment_type == "conventional": + estimator = SplicedDynamics(dynamics_kwargs) + elif experiment_type in ["one-shot", "one_shot"]: + estimator = OneShotDynamics(dynamics_kwargs) + elif experiment_type == "kin": + if assumption_mRNA == "ss": + estimator = SSKineticsDynamics(dynamics_kwargs) + elif assumption_mRNA == "kinetic": + if est_method == 'twostep': + estimator = TwoStepKineticsDynamics(dynamics_kwargs) + elif est_method == "direct": + estimator = DirectKineticsDynamics(dynamics_kwargs) + elif "storm" in est_method: + estimator = KineticsStormDynamics(dynamics_kwargs) + else: + raise NotImplementedError("This method has not been implemented.") + elif experiment_type == "deg": + estimator = DegradationDynamics(dynamics_kwargs) + elif experiment_type == "mix_std_stm": + estimator = MixStdStmDynamics(dynamics_kwargs) + elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: + estimator = MixKineticsDynamics(dynamics_kwargs) + else: + raise NotImplementedError("This method has not been implemented.") + adata = estimator.estimate() + return adata + + +# incorporate the model selection code soon +def dynamics( + adata: AnnData, + filter_gene_mode: Literal["final", "basic", "no"] = "final", + use_smoothed: bool = True, + assumption_mRNA: Literal["ss", "kinetic", "auto"] = "auto", + assumption_protein: Literal["ss"] = "ss", + model: Literal["auto", "deterministic", "stochastic"] = "auto", + est_method: Literal["ols", "rlm", "ransac", "gmm", "negbin", "auto", "twostep", "direct"] = "auto", + NTR_vel: bool = False, + group: Optional[str] = None, + protein_names: Optional[List[str]] = None, + concat_data: bool = False, + log_unnormalized: bool = True, + one_shot_method: Literal["combined", "sci-fate", "sci_fate"] = "combined", + fraction_for_deg: bool = False, + re_smooth: bool = False, + sanity_check: bool = False, + del_2nd_moments: Optional[bool] = None, + cores: int = 1, + tkey: str = None, + **est_kwargs, +) -> AnnData: + """Inclusive model of expression dynamics considers splicing, metabolic labeling and protein translation. + + The function supports learning high-dimensional velocity vector samples for droplet based (10x, inDrop, drop-seq, + etc), scSLAM-seq, NASC-seq sci-fate, scNT-seq, scEU-seq, cite-seq or REAP-seq datasets. + + Args: + adata: an AnnData object. + filter_gene_mode: The string for indicating which mode of gene filter will be used. Defaults to "final". + use_smoothed: whether to use the smoothed data when estimating kinetic parameters and calculating velocity for + each gene. When you have time-series data (`tkey` is not None), we recommend to smooth data among cells from + each time point. Defaults to True. + assumption_mRNA: Parameter estimation assumption for mRNA. Available options are: + (1) 'ss': pseudo steady state; + (2) 'kinetic' or None: degradation and kinetic data without steady state assumption. + (3) 'auto': dynamo will choose a reasonable assumption of the system under study automatically. + If no labelling data exists, assumption_mRNA will automatically set to be 'ss'. For one-shot experiment, + assumption_mRNA is set to be None. However we will use steady state assumption to estimate parameters alpha + and gamma either by a deterministic linear regression or the first order decay approach in line of the + sci-fate paper; + Defaults to "auto". + assumption_protein: Parameter estimation assumption for protein. Available options are: + (1) 'ss': pseudo steady state; + Defaults to "ss". + model: String indicates which estimation model will be used. + Available options are: + (1) 'deterministic': The method based on `deterministic` ordinary differential equations; + (2) 'stochastic' or `moment`: The new method from us that is based on `stochastic` master equations; + Note that `kinetic` model doesn't need to assumes the `experiment_type` is not `conventional`. As other + labeling experiments, if you specify the `tkey`, dynamo can also apply `kinetic` model on `conventional` + scRNA-seq datasets. A "model_selection" model will be supported soon in which alpha, beta and gamma will be + modeled as a function of time. + Defaults to "auto". + est_method: This parameter should be used in conjunction with `model` parameter. + Available options when the `model` is 'ss' include: + (1) 'ols': The canonical method or Ordinary Least Squares regression from the seminar RNA velocity paper + based on deterministic ordinary differential equations; + (2) 'rlm': The robust linear models from statsmodels. Robust Regression provides an alternative to OLS + regression by lowering the restrictions on assumptions and dampens the effect of outliers in order + to fit majority of the data. + (3) 'ransac': RANSAC (RANdom SAmple Consensus) algorithm for robust linear regression. RANSAC is an + iterative algorithm for the robust estimation of parameters from a subset of inliers from the + complete dataset. RANSAC implementation is based on RANSACRegressor function from sklearn package. + Note that if `rlm` or `ransac` failed, it will roll back to the `ols` method. In addition, `ols`, + `rlm` and `ransac` can be only used in conjunction with the `deterministic` model. + (4) 'gmm': The new generalized methods of moments from us that is based on master equations, similar to + the "moment" model in the excellent scVelo package; + (5) 'negbin': The new method from us that models steady state RNA expression as a negative binomial + distribution, also built upon on master equations. + (6) 'auto': dynamo will choose the suitable estimation method based on the `assumption_mRNA`, + `experiment_type` and `model` parameter. + Note that all those methods require using extreme data points (except negbin, which use all data points) for + estimation. Extreme data points are defined as the data from cells whose expression of unspliced / spliced + or new / total RNA, etc. are in the top or bottom, 5%, for example. `linear_regression` only considers the + mean of RNA species (based on the `deterministic` ordinary different equations) while moment based methods + (`gmm`, `negbin`) considers both first moment (mean) and second moment (uncentered variance) of RNA species + (based on the `stochastic` master equations). + The above method are all (generalized) linear regression based method. In order to return estimated + parameters (including RNA half-life), it additionally returns R-squared (either just for extreme data points + or all data points) as well as the log-likelihood of the fitting, which will be used for transition matrix + and velocity embedding. + Available options when the `assumption_mRNA` is 'kinetic' include: + (1) 'auto': dynamo will choose the suitable estimation method based on the `assumption_mRNA`, + `experiment_type` and `model` parameter. + (2) `twostep`: first for each time point, estimate K (1-e^{-rt}) using the total and new RNA data. Then + use regression via t-np.log(1-K) to get degradation rate gamma. When splicing and labeling data both + exist, replacing new/total with ul/u can be used to estimate beta. Suitable for velocity estimation. + (3) `direct` (default): method that directly uses the kinetic model to estimate rate parameters, + generally not good for velocity estimation. + Under `kinetic` model, choosing estimation is `experiment_type` dependent. For `kinetics` experiments, + dynamo supposes methods including RNA bursting or without RNA bursting. Dynamo also adaptively estimates + parameters, based on whether the data has splicing or without splicing. + Under `kinetic` assumption, the above method uses non-linear least square fitting. In order to return + estimated parameters (including RNA half-life), it additionally returns the log-likelihood of the + fitting, which will be used for transition matrix and velocity embedding. + All `est_method` uses least square to estimate optimal parameters with latin cubic sampler for initial + sampling. Defaults to "auto". + NTR_vel: whether to use NTR (new/total ratio) velocity for labeling datasets. Defaults to False. + group: the column key/name that identifies the grouping information (for example, clusters that correspond to + different cell types) of cells. This will be used to calculate 1/2 st moments and covariance for each cells + in each group. It will also enable estimating group-specific (i.e cell-type specific) kinetic parameters. + Defaults to None. + protein_names: a list of gene names corresponds to the rows of the measured proteins in the `X_protein` of the + `obsm` attribute. The names have to be included in the adata.var.index. Defaults to None. + concat_data: whether to concatenate data before estimation. If your data is a list of matrices for each time + point, this need to be set as True. Defaults to False. + log_unnormalized: whether to log transform the unnormalized data. Defaults to True. + one_shot_method: The method that will be used for estimating kinetic parameters for one-shot experiment data. + (1) the "sci-fate" method directly solves gamma with the first-order decay model; + (2) the "combined" model uses the linear regression under steady state to estimate relative gamma, and then + calculate absolute gamma (degradation rate), beta (splicing rate) and cell-wise alpha (transcription + rate). Defaults to "combined". + fraction_for_deg: whether to use the fraction of labeled RNA instead of the raw labeled RNA to estimate the + degradation parameter. Defaults to False. + re_smooth: whether to re-smooth the adata and also recalculate 1/2 moments or covariance. Defaults to False. + sanity_check: whether to perform sanity-check before estimating kinetic parameters and velocity vectors, + currently only applicable to kinetic or degradation metabolic labeling based scRNA-seq data. The basic idea + is that for kinetic (degradation) experiment, the total labelled RNA for each gene should increase + (decrease) over time. If they don't satisfy this criteria, those genes will be ignored during the + estimation. Defaults to False. + del_2nd_moments: whether to remove second moments or covariances. Default it is `False` so this avoids + recalculating 2nd moments or covariance but it may take a lot memory when your dataset is big. Set this to + `True` when your data is huge (like > 25, 000 cells or so) to reducing the memory footprint. Defaults to + None. + cores: number of cores to run the estimation. If cores is set to be > 1, multiprocessing will be used to + parallel the parameter estimation. Currently only applicable cases when assumption_mRNA is `ss` or cases + when experiment_type is either "one-shot" or "mix_std_stm". Defaults to 1. + tkey: the column key for the labeling time of cells in .obs. Used for labeling based scRNA-seq data. If `tkey` + is None, then `adata.uns["pp"]["tkey"]` will be checked and used if exists. Defaults to None. + **est_kwargs: Other arguments passed to the fit method (steady state models) or estimation methods (kinetic + models). + + Raises: + ValueError: preprocessing not performed. + Exception: No gene pass filter. + Exception: Too few valid genes. + + Returns: + An updated AnnData object with estimated kinetic parameters, inferred velocity and estimation related + information included. The estimated kinetic parameters are currently appended to .obs (should move to .obsm with + the key `dynamics` later). Depends on the estimation method, experiment type and whether you applied estimation + for each groups via `group`, the number of returned parameters can be variable. For conventional scRNA-seq + (including cite-seq or other types of protein/RNA coassays) and somethings metabolic labeling data, the + parameters will at mostly include: + alpha: Transcription rate + beta: Splicing rate + gamma: Spliced RNA degradation rate + eta: Translation rate (only applicable to RNA/protein coassay) + delta: Protein degradation rate (only applicable to RNA/protein coassay) + alpha_b: intercept of alpha fit + beta_b: intercept of beta fit + gamma_b: intercept of gamma fit + eta_b: intercept of eta fit (only applicable to RNA/protein coassay) + delta_b: intercept of delta fit (only applicable to RNA/protein coassay) + alpha_r2: r-squared for goodness of fit of alpha estimation + beta_r2: r-squared for goodness of fit of beta estimation + gamma_r2: r-squared for goodness of fit of gamma estimation + eta_r2: r-squared for goodness of fit of eta estimation (only applicable to RNA/protein coassay) + delta_r2: r-squared for goodness of fit of delta estimation (only applicable to RNA/protein coassay) + alpha_logLL: loglikelihood of alpha estimation (only applicable to stochastic model) + beta_loggLL: loglikelihood of beta estimation (only applicable to stochastic model) + gamma_logLL: loglikelihood of gamma estimation (only applicable to stochastic model) + eta_logLL: loglikelihood of eta estimation (only applicable to stochastic model and RNA/protein coassay) + delta_loggLL: loglikelihood of delta estimation (only applicable to stochastic model and RNA/protein + coassay) + uu0: estimated amount of unspliced unlabeled RNA at time 0 (only applicable to data with both splicing + and labeling) + ul0: estimated amount of unspliced labeled RNA at time 0 (only applicable to data with both splicing + and labeling) + su0: estimated amount of spliced unlabeled RNA at time 0 (only applicable to data with both splicing + and labeling) + sl0: estimated amount of spliced labeled RNA at time 0 (only applicable to data with both splicing and + labeling) + U0: estimated amount of unspliced RNA (uu + ul) at time 0 + S0: estimated amount of spliced (su + sl) RNA at time 0 + total0: estimated amount of spliced (U + S) RNA at time 0 + half_life: Spliced mRNA's half-life (log(2) / gamma) + + Note that all data points are used when estimating r2 although only extreme data points are used for + estimating r2. This is applicable to all estimation methods, either `linear_regression`, `gmm` or `negbin`. + By default we set the intercept to be 0. + + For metabolic labeling data, the kinetic parameters will at most include: + alpha: Transcription rate (effective - when RNA promoter switching considered) + beta: Splicing rate + gamma: Spliced RNA degradation rate + a: Switching rate from active promoter state to inactive promoter state + b: Switching rate from inactive promoter state to active promoter state + alpha_a: Transcription rate for active promoter + alpha_i: Transcription rate for inactive promoter + cost: cost of the kinetic parameters estimation + logLL: loglikelihood of kinetic parameters estimation + alpha_r2: r-squared for goodness of fit of alpha estimation + beta_r2: r-squared for goodness of fit of beta estimation + gamma_r2: r-squared for goodness of fit of gamma estimation + uu0: estimated amount of unspliced unlabeled RNA at time 0 (only applicable to data with both splicing + and labeling) + ul0: estimated amount of unspliced labeled RNA at time 0 (only applicable to data with both splicing + and labeling) + su0: estimated amount of spliced unlabeled RNA at time 0 (only applicable to data with both splicing + and labeling) + sl0: estimated amount of spliced labeled RNA at time 0 (only applicable to data with both splicing and + labeling) + u0: estimated amount of unspliced RNA (including uu, ul) at time 0 + s0: estimated amount of spliced (including su, sl) RNA at time 0 + total0: estimated amount of spliced (including U, S) RNA at time 0 + p_half_life: half-life for unspliced mRNA + half_life: half-life for spliced mRNA + + If sanity_check has performed, a column with key `sanity_check` will also included which indicates which + gene passes filter (`filter_gene_mode`) and sanity check. This is only applicable to kinetic and degradation + metabolic labeling experiments. + + In addition, the `dynamics` key of the .uns attribute corresponds to a dictionary that includes the + following keys: + t: An array like object that indicates the time point of each cell used during parameters estimation + (applicable only to kinetic models) + group: The group that you used to estimate parameters group-wise + X_data: The input that was used for estimating parameters (applicable only to kinetic models) + X_fit_data: The data that was fitted during parameters estimation (applicable only to kinetic models) + asspt_mRNA: Assumption of mRNA dynamics (steady state or kinetic) + experiment_type: Experiment type (either conventional or metabolic labeling based) + normalized: Whether to normalize data + model: Model used for the parameter estimation (either auto, deterministic or stochastic) + has_splicing: Does the adata has splicing? detected automatically + has_labeling: Does the adata has labelling? detected automatically + has_protein: Does the adata has protein information? detected automatically + use_smoothed: Whether to use smoothed data (or first moment, done via local average of neighbor cells) + NTR_vel: Whether to estimate NTR velocity + log_unnormalized: Whether to log transform unnormalized data. + """ + + del_2nd_moments = DynamoAdataConfig.use_default_var_if_none( + del_2nd_moments, DynamoAdataConfig.DYNAMICS_DEL_2ND_MOMENTS_KEY + ) + if "pp" not in adata.uns_keys(): + raise ValueError(f"\nPlease run `dyn.pp.receipe_monocle(adata)` before running this function!") + if tkey is None: + tkey = adata.uns["pp"]["tkey"] + (experiment_type, has_splicing, has_labeling, splicing_labeling, has_protein,) = ( + adata.uns["pp"]["experiment_type"], + adata.uns["pp"]["has_splicing"], + adata.uns["pp"]["has_labeling"], + adata.uns["pp"]["splicing_labeling"], + adata.uns["pp"]["has_protein"], + ) + + X_data, X_fit_data = None, None + filter_list, filter_gene_mode_list = ( + [ + "use_for_pca", + "pass_basic_filter", + "no", + ], + ["final", "basic", "no"], + ) + filter_checker = [i in adata.var.columns for i in filter_list[:2]] + filter_checker.append(True) + filter_id = filter_gene_mode_list.index(filter_gene_mode) + which_filter = np.where(filter_checker[filter_id:])[0][0] + filter_id + + filter_gene_mode = filter_gene_mode_list[which_filter] + + valid_bools = get_valid_bools(adata, filter_gene_mode) + gene_num = sum(valid_bools) + if gene_num == 0: + raise Exception(f"no genes pass filter. Try resetting `filter_gene_mode = 'no'` to use all genes.") + + if model.lower() == "auto": + model = "stochastic" + model_was_auto = True + else: + model_was_auto = False + + if tkey is not None: + if adata.obs[tkey].max() > 60: + main_warning( + "Looks like you are using minutes as the time unit. For the purpose of numeric stability, " + "we recommend using hour as the time unit." + ) + + if model.lower() == "stochastic" or use_smoothed or re_smooth: + M_layers = [i for i in adata.layers.keys() if i.startswith("M_")] + + if len(M_layers) < 2 or re_smooth: + main_info("removing existing M layers:%s..." % (str(list(M_layers))), indent_level=2) + for i in M_layers: + del adata.layers[i] + main_info("making adata smooth...", indent_level=2) + + if group is not None and group in adata.obs.columns: + moments(adata, genes=valid_bools, group=group) + else: + moments(adata, genes=valid_bools, group=tkey) + elif tkey is not None: + main_warning( + f"You used tkey {tkey} (or group {group}), but you have calculated local smoothing (1st moment) " + f"for your data before. Please ensure you used the desired tkey or group when the smoothing was " + f"performed. Try setting re_smooth = True if not sure." + ) + + valid_adata = adata[:, valid_bools].copy() + if group is not None and group in adata.obs.columns: + _group = adata.obs[group].unique() + if any(adata.obs[group].value_counts() < 50): + main_warning( + f"Note that some groups have less than 50 cells, this may lead to the velocities for some " + f"cells are all NaN values and cause issues for all downstream analysis. Please try to " + f"coarse-grain cell groupings. Cell number for each group are {adata.obs[group].value_counts()}" + ) + + else: + _group = ["_all_cells"] + + for cur_grp_i, cur_grp in enumerate(_group): + if cur_grp == "_all_cells": + kin_param_pre = "" + cur_cells_bools = np.ones(valid_adata.shape[0], dtype=bool) + subset_adata = valid_adata[cur_cells_bools] + else: + kin_param_pre = str(group) + "_" + str(cur_grp) + "_" + cur_cells_bools = (valid_adata.obs[group] == cur_grp).values + subset_adata = valid_adata[cur_cells_bools] + + if model.lower() == "stochastic" or use_smoothed: + moments(subset_adata) + ( + U, + Ul, + S, + Sl, + P, + US, + U2, + S2, + t, + normalized, + ind_for_proteins, + assump_mRNA, + ) = get_data_for_kin_params_estimation( + subset_adata, + has_splicing, + has_labeling, + model, + use_smoothed, + tkey, + protein_names, + log_unnormalized, + NTR_vel, + ) + + valid_bools_ = valid_bools.copy() + if sanity_check and experiment_type.lower() in ["kin", "deg"]: + indices_valid_bools = np.where(valid_bools)[0] + t, L = ( + t.flatten(), + (0 if Ul is None else Ul) + (0 if Sl is None else Sl), + ) + t_uniq = np.unique(t) + + valid_gene_checker = np.zeros(gene_num, dtype=bool) + for L_iter, cur_L in tqdm( + enumerate(L), + desc=f"sanity check of {experiment_type} experiment data:", + ): + cur_L = cur_L.A.flatten() if issparse(cur_L) else cur_L.flatten() + y = strat_mom(cur_L, t, np.nanmean) + slope, _ = fit_linreg(t_uniq, y, intercept=True, r2=False) + valid_gene_checker[L_iter] = ( + True + if (slope > 0 and experiment_type == "kin") or (slope < 0 and experiment_type == "deg") + else False + ) + valid_bools_[indices_valid_bools[~valid_gene_checker]] = False + main_warning(f"filtering {gene_num - valid_gene_checker.sum()} genes after sanity check.") + + if len(valid_bools_) < 5: + raise Exception( + f"After sanity check, you have less than 5 valid genes. Something is wrong about your " + f"metabolic labeling experiment!" + ) + + U, Ul, S, Sl = ( + (None if U is None else U[valid_gene_checker, :]), + (None if Ul is None else Ul[valid_gene_checker, :]), + (None if S is None else S[valid_gene_checker, :]), + (None if Sl is None else Sl[valid_gene_checker, :]), + ) + subset_adata = subset_adata[:, valid_gene_checker] + adata.var[kin_param_pre + "sanity_check"] = valid_bools_ + + if assumption_mRNA.lower() == "auto": + assumption_mRNA = assump_mRNA + if experiment_type.lower() == "conventional": + assumption_mRNA = "ss" + elif experiment_type.lower() in ["mix_pulse_chase", "deg", "kin"]: + assumption_mRNA = "kinetic" + + if model.lower() == "stochastic" and experiment_type.lower() not in [ + "conventional", + "kinetics", + "degradation", + "kin", + "deg", + "one-shot", + ]: + """ + # temporially convert to deterministic model as moment model for mix_std_stm + and other types of labeling experiment is ongoing.""" + + model = "deterministic" + + if model_was_auto and experiment_type.lower() in [ + "kinetic", + "kin", + "degradation", + "deg", + ]: + model = "deterministic" + + if assumption_mRNA.lower() == "ss" or (experiment_type.lower() in ["one-shot", "mix_std_stm"]): + if est_method.lower() == "auto": + est_method = "gmm" if model.lower() == "stochastic" else "ols" + + if experiment_type.lower() == "one-shot": + try: vel_params_df = get_vel_params(subset_adata) beta = vel_params_df.beta if "beta" in vel_params_df.columns else None gamma = vel_params_df.gamma if "gamma" in vel_params_df.columns else None @@ -488,502 +1921,1545 @@ def dynamics( beta, gamma = None, None ss_estimation_kwargs = {"beta": beta, "gamma": gamma} else: - ss_estimation_kwargs = {} + ss_estimation_kwargs = {} + + est = ss_estimation( + U=U.copy() if U is not None else None, + Ul=Ul.copy() if Ul is not None else None, + S=S.copy() if S is not None else None, + Sl=Sl.copy() if Sl is not None else None, + P=P.copy() if P is not None else None, + US=US.copy() if US is not None else None, + S2=S2.copy() if S2 is not None else None, + conn=subset_adata.obsp["moments_con"], + t=t, + ind_for_proteins=ind_for_proteins, + model=model, + est_method=est_method, + experiment_type=experiment_type, + assumption_mRNA=assumption_mRNA, + assumption_protein=assumption_protein, + concat_data=concat_data, + cores=cores, + **ss_estimation_kwargs, + ) # U: (unlabeled) unspliced; S: (unlabeled) spliced; U / Ul: old and labeled; U, Ul, S, Sl: uu/ul/su/sl + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + if experiment_type.lower() in ["one-shot", "one_shot"]: + est.fit(one_shot_method=one_shot_method, **est_kwargs) + else: + # experiment_type can be `kin` also and by default use + # conventional method to estimate k but correct for time + est.fit(**est_kwargs) + + alpha, beta, gamma, eta, delta = est.parameters.values() + + U, S = get_U_S_for_velocity_estimation( + subset_adata, + use_smoothed, + has_splicing, + has_labeling, + log_unnormalized, + NTR_vel, + ) + vel = Velocity(estimation=est) + + if experiment_type.lower() in [ + "one_shot", + "one-shot", + "kin", + "mix_std_stm", + ]: + U_, S_ = get_U_S_for_velocity_estimation( + subset_adata, + use_smoothed, + has_splicing, + has_labeling, + log_unnormalized, + not NTR_vel, + ) + + # also get vel_N and vel_T + if NTR_vel: + if has_splicing: + if experiment_type == "kin": + Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope + gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + + vel_U = U.multiply(csr_matrix(gamma_ / Kc)) - csr_matrix(beta).multiply(U_) # vel.vel_s(U_) + vel_S = vel.vel_s(U_, S_) + + vel_N = (U - csr_matrix(Kc).multiply(U)).multiply(csr_matrix(gamma_ / Kc)) # vel.vel_u(U) + # scale back to true velocity via multiplying "gamma_ / Kc". + vel_T = (U - csr_matrix(Kc).multiply(S)).multiply(csr_matrix(gamma_ / Kc)) + elif experiment_type == "mix_std_stm": + # steady state RNA: u0, stimulation RNA: u_new; + # cell-wise transcription rate under simulation: alpha1 + u0, u_new, alpha1 = solve_alpha_2p_mat( + t0=np.max(t) - t, + t1=t, + alpha0=alpha[0], + beta=beta, + u1=U, + ) + vel_U = alpha1 - csr_matrix(beta[:, None]).multiply(U_) + vel_S = vel.vel_s(U_, S_) + + vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) + vel_T = alpha1 - csr_matrix(beta[:, None]).multiply(S) + else: + vel_U = vel.vel_u(U_) + vel_S = vel.vel_s(U_, S_) + vel_N = vel.vel_u(U) + vel_T = vel.vel_s(U, S - U) # need to consider splicing + else: + if experiment_type == "kin": + vel_U = np.nan + vel_S = np.nan + + Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope + gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + vel_N = (U - csr_matrix(Kc).multiply(U)).multiply(csr_matrix(gamma_ / Kc)) # vel.vel_u(U) + # scale back to true velocity via multiplying "gamma_ / Kc". + vel_T = (U - csr_matrix(Kc).multiply(S)).multiply(csr_matrix(gamma_ / Kc)) + elif experiment_type == "mix_std_stm": + vel_U = np.nan + vel_S = np.nan + + # steady state RNA: u0, stimulation RNA: u_new; + # cell-wise transcription rate under simulation: alpha1 + u0, u_new, alpha1 = solve_alpha_2p_mat( + t0=np.max(t) - t, + t1=t, + alpha0=alpha[0], + beta=gamma, + u1=U, + ) + + vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) + vel_T = alpha1 - csr_matrix(gamma[:, None]).multiply(S) + else: + vel_U = np.nan + vel_S = np.nan + vel_N = vel.vel_u(U) + vel_T = vel.vel_u(S) # don't consider splicing + else: + if has_splicing: + if experiment_type == "kin": + Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope + gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + + vel_U = U_.multiply(csr_matrix(gamma_ / Kc) - csr_matrix(beta).multiply(U)) # vel.vel_u(U) + vel_S = vel.vel_s(U, S) + + vel_N = (U_ - csr_matrix(Kc).multiply(U_)).multiply( + csr_matrix(gamma_ / Kc) + ) # vel.vel_u(U_) + # scale back to true velocity via multiplying "gamma_ / Kc". + vel_T = (U_ - csr_matrix(Kc).multiply(S_)).multiply(csr_matrix(gamma_ / Kc)) + elif experiment_type == "mix_std_stm": + # steady state RNA: u0, stimulation RNA: u_new; + # cell-wise transcription rate under simulation: alpha1 + u0, u_new, alpha1 = solve_alpha_2p_mat( + t0=np.max(t) - t, + t1=t, + alpha0=alpha[0], + beta=beta, + u1=U_, + ) + + vel_U = alpha1 - csr_matrix(beta[:, None]).multiply(U) + vel_S = vel.vel_s(U, S) + + vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) + vel_T = alpha1 - csr_matrix(beta[:, None]).multiply(S_) + + else: + vel_U = vel.vel_u(U) + vel_S = vel.vel_s(U, S) + vel_N = vel.vel_u(U_) + vel_T = vel.vel_s(U_, S_ - U_) # need to consider splicing + else: + if experiment_type == "kin": + vel_U = np.nan + vel_S = np.nan + + Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope + gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + vel_N = (U_ - csr_matrix(Kc).multiply(U_)).multiply( + csr_matrix(gamma_ / Kc) + ) # vel.vel_u(U_) + # scale back to true velocity via multiplying "gamma_ / Kc". + vel_T = (U_ - csr_matrix(Kc).multiply(S_)).multiply(csr_matrix(gamma_ / Kc)) + elif experiment_type == "mix_std_stm": + vel_U = np.nan + vel_S = np.nan + + # steady state RNA: u0, stimulation RNA: u_new; + # cell-wise transcription rate under simulation: alpha1 + u0, u_new, alpha1 = solve_alpha_2p_mat( + t0=np.max(t) - t, + t1=t, + alpha0=alpha[0], + beta=gamma, + u1=U_, + ) + + vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) + vel_T = alpha1 - csr_matrix(gamma[:, None]).multiply(S_) + else: + vel_U = np.nan + vel_S = np.nan + vel_N = vel.vel_u(U_) + vel_T = vel.vel_u(S_) # don't consider splicing + else: + vel_U = vel.vel_u(U) + vel_S = vel.vel_s(U, S) + vel_N, vel_T = np.nan, np.nan + + vel_P = vel.vel_p(S, P) + + adata = set_velocity( + adata, + vel_U, + vel_S, + vel_N, + vel_T, + vel_P, + _group, + cur_grp, + cur_cells_bools, + valid_bools_, + ind_for_proteins, + ) + + adata = set_param_ss( + adata, + est, + alpha, + beta, + gamma, + eta, + delta, + experiment_type, + _group, + cur_grp, + kin_param_pre, + valid_bools_, + ind_for_proteins, + ) + + elif assumption_mRNA.lower() == "kinetic": + return_ntr = True if fraction_for_deg and experiment_type.lower() == "deg" else False + + if model_was_auto and experiment_type.lower() == "kin": + model = "mixture" + if est_method == "auto": + est_method = "direct" + data_type = "smoothed" if use_smoothed else "sfs" + + (params, half_life, cost, logLL, param_ranges, cur_X_data, cur_X_fit_data,) = kinetic_model( + subset_adata, + tkey, + model, + est_method, + experiment_type, + has_splicing, + splicing_labeling, + has_switch=True, + param_rngs={}, + data_type=data_type, + return_ntr=return_ntr, + **est_kwargs, + ) + + if type(params) == dict: + alpha = params.pop("alpha") + params = pd.DataFrame(params) + else: + alpha = params.loc[:, "alpha"].values if "alpha" in params.columns else None + + len_t, len_g = len(np.unique(t)), len(_group) + if cur_grp == _group[0]: + if len_g != 1: + # X_data, X_fit_data = np.zeros((len_g, adata.n_vars, len_t)), np.zeros((len_g, adata.n_vars,len_t)) + X_data, X_fit_data = [None] * len_g, [None] * len_g + + if len(_group) == 1: + X_data, X_fit_data = cur_X_data, cur_X_fit_data + else: + # X_data[cur_grp_i, :, :], X_fit_data[cur_grp_i, :, :] = cur_X_data, cur_X_fit_data + X_data[cur_grp_i], X_fit_data[cur_grp_i] = ( + cur_X_data, + cur_X_fit_data, + ) + + a, b, alpha_a, alpha_i, beta, gamma = ( + params.loc[:, "a"].values if "a" in params.columns else None, + params.loc[:, "b"].values if "b" in params.columns else None, + params.loc[:, "alpha_a"].values if "alpha_a" in params.columns else None, + params.loc[:, "alpha_i"].values if "alpha_i" in params.columns else None, + params.loc[:, "beta"].values if "beta" in params.columns else None, + params.loc[:, "gamma"].values if "gamma" in params.columns else None, + ) + if alpha is None: + alpha = fbar(a, b, alpha_a, 0) if alpha_i is None else fbar(a, b, alpha_a, alpha_i) + all_kinetic_params = [ + "a", + "b", + "alpha_a", + "alpha_i", + "alpha", + "beta", + "gamma", + ] - est = ss_estimation( - U=U.copy() if U is not None else None, - Ul=Ul.copy() if Ul is not None else None, - S=S.copy() if S is not None else None, - Sl=Sl.copy() if Sl is not None else None, - P=P.copy() if P is not None else None, - US=US.copy() if US is not None else None, - S2=S2.copy() if S2 is not None else None, - conn=subset_adata.obsp["moments_con"], - t=t, - ind_for_proteins=ind_for_proteins, - model=model, - est_method=est_method, - experiment_type=experiment_type, - assumption_mRNA=assumption_mRNA, - assumption_protein=assumption_protein, - concat_data=concat_data, - cores=cores, - **ss_estimation_kwargs, - ) # U: (unlabeled) unspliced; S: (unlabeled) spliced; U / Ul: old and labeled; U, Ul, S, Sl: uu/ul/su/sl + extra_params = params.loc[:, params.columns.difference(all_kinetic_params)] + # if alpha = None, set alpha to be U; N - gamma R + params = {"alpha": alpha, "beta": beta, "gamma": gamma, "t": t} + vel = Velocity(**params) + # Fix below: + U, S = get_U_S_for_velocity_estimation( + subset_adata, + use_smoothed, + has_splicing, + has_labeling, + log_unnormalized, + NTR_vel, + ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + U_, S_ = get_U_S_for_velocity_estimation( + subset_adata, + use_smoothed, + has_splicing, + has_labeling, + log_unnormalized, + not NTR_vel, + ) - if experiment_type.lower() in ["one-shot", "one_shot"]: - est.fit(one_shot_method=one_shot_method, **est_kwargs) + # also get vel_N and vel_T + if NTR_vel: + if has_splicing: + if experiment_type == "kin": + vel_U = vel.vel_u(U_) + vel_S = vel.vel_s(U_, S_) + vel.parameters["beta"] = gamma + vel_N = vel.vel_u(U) + vel_T = vel.vel_u(S) # no need to consider splicing + elif experiment_type == "deg": + if splicing_labeling: + vel_U = np.nan + vel_S = vel.vel_s(U_, S_) + vel_N = np.nan + vel_T = np.nan + else: + vel_U = np.nan + vel_S = vel.vel_s(U_, S_) + vel_N = np.nan + vel_T = np.nan + elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: + vel_U = vel.vel_u(U_, repeat=True) + vel_S = vel.vel_s(U_, S_) + vel.parameters["beta"] = gamma + vel_N = vel.vel_u(U, repeat=True) + vel_T = vel.vel_u(S, repeat=True) # no need to consider splicing else: - # experiment_type can be `kin` also and by default use - # conventional method to estimate k but correct for time - est.fit(**est_kwargs) + if experiment_type == "kin": + vel_U = np.nan + vel_S = np.nan + + # calculate cell-wise alpha, if est_method is twostep, this can be skipped + alpha_ = one_shot_alpha_matrix(U, gamma, t) + + vel.parameters["alpha"] = alpha_ + + vel_N = vel.vel_u(U) + vel_T = vel.vel_u(S) # don't consider splicing + elif experiment_type == "deg": + vel_U = np.nan + vel_S = np.nan + vel_N = np.nan + vel_T = np.nan + elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: + vel_U = np.nan + vel_S = np.nan + vel_N = vel.vel_u(U, repeat=True) + vel_T = vel.vel_u(S) # don't consider splicing + else: + if has_splicing: + if experiment_type == "kin": + vel_U = vel.vel_u(U) + vel_S = vel.vel_s(U, S) + vel.parameters["beta"] = gamma + vel_N = vel.vel_u(U_) + vel_T = vel.vel_u(S_) # no need to consider splicing + elif experiment_type == "deg": + if splicing_labeling: + vel_U = np.nan + vel_S = vel.vel_s(U, S) + vel_N = np.nan + vel_T = np.nan + else: + vel_U = np.nan + vel_S = vel.vel_s(U, S) + vel_N = np.nan + vel_T = np.nan + elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: + vel_U = vel.vel_u(U, repeat=True) + vel_S = vel.vel_s(U, S) + vel.parameters["beta"] = gamma + vel_N = vel.vel_u(U_, repeat=True) + vel_T = vel.vel_u(S_, repeat=True) # no need to consider splicing + else: + if experiment_type == "kin": + vel_U = np.nan + vel_S = np.nan + + # calculate cell-wise alpha, if est_method is twostep, this can be skipped + alpha_ = one_shot_alpha_matrix(U_, gamma, t) + + vel.parameters["alpha"] = alpha_ + + vel_N = vel.vel_u(U_) + vel_T = vel.vel_u(S_) # need to consider splicing + elif experiment_type == "deg": + vel_U = np.nan + vel_S = np.nan + vel_N = np.nan + vel_T = np.nan + elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: + vel_U = np.nan + vel_S = np.nan + vel_N = vel.vel_u(U_, repeat=True) + vel_T = vel.vel_u(S_, repeat=True) # don't consider splicing + + vel_P = vel.vel_p(S, P) + + adata = set_velocity( + adata, + vel_U, + vel_S, + vel_N, + vel_T, + vel_P, + _group, + cur_grp, + cur_cells_bools, + valid_bools_, + ind_for_proteins, + ) + + adata = set_param_kinetic( + adata, + alpha, + a, + b, + alpha_a, + alpha_i, + beta, + gamma, + cost, + logLL, + kin_param_pre, + extra_params, + _group, + cur_grp, + cur_cells_bools, + valid_bools_, + ) + # add protein related parameters in the moment model below: + elif model.lower() == "model_selection": + main_warning("Not implemented yet.") + + if group is not None and group in adata.obs[group]: + uns_key = group + "_dynamics" + else: + uns_key = "dynamics" + + if sanity_check and experiment_type in ["kin", "deg"]: + sanity_check_cols = adata.var.columns.str.endswith("sanity_check") + adata.var["use_for_dynamics"] = adata.var.loc[:, sanity_check_cols].sum(1).astype(bool) + else: + adata.var["use_for_dynamics"] = False + adata.var.loc[valid_bools, "use_for_dynamics"] = True + + adata.uns[uns_key] = { + "filter_gene_mode": filter_gene_mode, + "t": t, + "group": group, + "X_data": X_data, + "X_fit_data": X_fit_data, + "asspt_mRNA": assumption_mRNA, + "experiment_type": experiment_type, + "normalized": normalized, + "model": model, + "est_method": est_method, + "has_splicing": has_splicing, + "has_labeling": has_labeling, + "splicing_labeling": splicing_labeling, + "has_protein": has_protein, + "use_smoothed": use_smoothed, + "NTR_vel": NTR_vel, + "log_unnormalized": log_unnormalized, + "fraction_for_deg": fraction_for_deg, + } + + if del_2nd_moments: + remove_2nd_moments(adata) + + return adata + + +class KineticEstimation: + """The clss to estimate the parameters required for velocity estimation when the mRNA assumption is 'kinetic'.""" + def __init__( + self, + subset_adata: AnnData, + tkey: str, + model: Literal["auto", "deterministic", "stochastic"], + est_method: Literal["twostep", "direct", "storm-csp", "storm-cszip", "storm-icsp"], + experiment_type: str, + has_splicing: bool, + splicing_labeling: bool, + has_switch: bool, + param_rngs: Dict[str, List[int]], + data_type: Literal["smoothed", "sfs"] = "sfs", + return_ntr: bool = False, + **est_kwargs, + ): + """Constructor. + + Args: + subset_adata: an AnnData object with invalid genes trimmed. + tkey: the column key for the labeling time of cells in .obs. Used for labeling based scRNA-seq data. If `tkey` + is None, then `adata.uns["pp"]["tkey"]` will be checked and used if exists. + model: String indicates which estimation model will be used. + Available options are: + (1) 'deterministic': The method based on `deterministic` ordinary differential equations; + (2) 'stochastic' or `moment`: The new method from us that is based on `stochastic` master equations; + Note that `kinetic` model doesn't need to assume the `experiment_type` is not `conventional`. As other + labeling experiments, if you specify the `tkey`, dynamo can also apply `kinetic` model on `conventional` + scRNA-seq datasets. A "model_selection" model will be supported soon in which alpha, beta and gamma will be + modeled as a function of time. + est_method: Available options when the `assumption_mRNA` is 'kinetic' include: + (1) 'auto': dynamo will choose the suitable estimation method based on the `assumption_mRNA`, + `experiment_type` and `model` parameter. + (2) `twostep`: first for each time point, estimate K (1-e^{-rt}) using the total and new RNA data. Then + use regression via t-np.log(1-K) to get degradation rate gamma. When splicing and labeling data both + exist, replacing new/total with ul/u can be used to estimate beta. Suitable for velocity estimation. + (3) `direct` (default): method that directly uses the kinetic model to estimate rate parameters, + generally not good for velocity estimation. + Under `kinetic` model, choosing estimation is `experiment_type` dependent. For `kinetics` experiments, + dynamo supposes methods including RNA bursting or without RNA bursting. Dynamo also adaptively estimates + parameters, based on whether the data has splicing or without splicing. + Under `kinetic` assumption, the above method uses non-linear least square fitting. In order to return + estimated parameters (including RNA half-life), it additionally returns the log-likelihood of the + fitting, which will be used for transition matrix and velocity embedding. + All `est_method` uses least square to estimate optimal parameters with latin cubic sampler for initial + sampling. + experiment_type: the experiment type of the data. + has_splicing: whether the object containing unspliced and spliced data + splicing_labeling: hether the object containing both splicing and labelling data + has_switch: whether there should be switch for stochastic model. + param_rngs: the range set for each parameter. + data_type: the data type, could be "smoothed" or "sfs". Defaults to "sfs". + return_ntr: whether to deal with new/total ratio. Defaults to False. + est_kwargs: additional keyword arguments of fitting function. + + Returns: + A tuple (Estm_df, half_life, cost, logLL, _param_ranges, X_data, X_fit_data), where Estm_df contains the + parameters required for mRNA velocity calculation, half_life is for half-life of spliced mRNA, cost is for the + cost of kinetic parameters estimation, logLL is for loglikelihood of kinetic parameters estimation, + _param_ranges is for the intended range of parameter estimation, X_data is for the data used for parameter + estimation, and X_fit_data is for the data that get fitted during parameter estimation. + """ + self.subset_adata = subset_adata + self.tkey = tkey + self.model = model + self.est_method = est_method + self.experiment_type = experiment_type + self.has_splicing = has_splicing + self.splicing_labeling = splicing_labeling + self.has_switch = has_switch + self.param_rngs = param_rngs + self.data_type = data_type + self.return_ntr = return_ntr + self.est_kwargs = est_kwargs + self.time = subset_adata.obs[tkey].astype("float").values + + def fit_twostep_kinetics(self): + """Fit the input data to estimate parameters for kinetics experiment type with two-step method.""" + if self.has_splicing: + layers = ( + ["M_u", "M_s", "M_t", "M_n"] + if ("M_u" in self.subset_adata.layers.keys() and self.data_type == "smoothed") + else ["X_u", "X_s", "X_t", "X_n"] + ) + U, S, Total, New = ( + self.subset_adata.layers[layers[0]].T, + self.subset_adata.layers[layers[1]].T, + self.subset_adata.layers[layers[2]].T, + self.subset_adata.layers[layers[3]].T, + ) + US, S2 = ( + self.subset_adata.layers["M_us"].T, + self.subset_adata.layers["M_ss"].T, + ) + # gamma, gamma_r2 = lin_reg_gamma_synthesis(U, Ul, time, perc_right=100) + ( + gamma_k, + gamma_b, + gamma_all_r2, + gamma_all_logLL, + ) = fit_slope_stochastic(S, U, US, S2, perc_left=None, perc_right=100) + ( + gamma, + gamma_r2, + X_data, + mean_R2, + K_fit, + ) = lin_reg_gamma_synthesis(Total, New, self.time, perc_right=100) + + k = 1 - np.exp(-gamma[:, None] * self.time[None, :]) + beta = gamma / gamma_k # gamma_k = gamma / beta + + Estm_df = { + "alpha": csr_matrix(gamma[:, None]).multiply(New).multiply(1 / k), + "beta": beta, + "gamma_k": gamma_k, + "gamma_b": gamma_b, + "gamma_k_r2": gamma_all_r2, + "gamma_logLL": gamma_all_logLL, + "gamma": gamma, + "gamma_r2": gamma_r2, + "mean_R2": mean_R2, + } + half_life = np.log(2) / gamma + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + X_data, + K_fit, + ) + + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) + else: + layers = ( + ["M_t", "M_n"] + if ("M_t" in self.subset_adata.layers.keys() and self.data_type == "smoothed") + else ["X_t", "X_n"] + ) + Total, New = ( + self.subset_adata.layers[layers[0]].T, + self.subset_adata.layers[layers[1]].T, + ) + ( + gamma, + gamma_r2, + X_data, + mean_R2, + K_fit, + ) = lin_reg_gamma_synthesis(Total, New, self.time, perc_right=100) + + k = 1 - np.exp(-gamma[:, None] * self.time[None, :]) + Estm_df = { + "alpha": csr_matrix(gamma[:, None]).multiply(New).multiply(1 / k), + "gamma": gamma, + "gamma_k": gamma, # required for phase_potrait + "gamma_r2": gamma_r2, + "mean_R2": mean_R2, + } + half_life = np.log(2) / gamma + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + X_data, + K_fit, + ) + + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) + def fit_storm(self): + """Fit the input data to estimate parameters for kinetics experiment type with storm method.""" + if self.has_splicing: + # Initialization based on the steady-state assumption + layers_smoothed = ["M_u", "M_s", "M_t", "M_n"] + U_smoothed, S_smoothed, Total_smoothed, New_smoothed = ( + self.subset_adata.layers[layers_smoothed[0]].T, + self.subset_adata.layers[layers_smoothed[1]].T, + self.subset_adata.layers[layers_smoothed[2]].T, + self.subset_adata.layers[layers_smoothed[3]].T, + ) + + US_smoothed, S2_smoothed = ( + self.subset_adata.layers["M_us"].T, + self.subset_adata.layers["M_ss"].T, + ) + (gamma_k, _, _, _,) = fit_slope_stochastic(S_smoothed, U_smoothed, US_smoothed, S2_smoothed, + perc_left=None, perc_right=5) + (gamma_init, _, _, _, _) = lin_reg_gamma_synthesis(Total_smoothed, New_smoothed, self.time, perc_right=5) + beta_init = gamma_init / gamma_k # gamma_k = gamma / beta + + # Read raw counts + layers_raw = ["ul", "sl"] + UL_raw, SL_raw = ( + self.subset_adata.layers[layers_raw[0]].T, + self.subset_adata.layers[layers_raw[1]].T, + ) - alpha, beta, gamma, eta, delta = est.parameters.values() + # Read smoothed values based CSP type distribution for cell-specific parameter inference + UL_smoothed_CSP, SL_smoothed_CSP = ( + self.subset_adata.layers['M_CSP_ul'].T, + self.subset_adata.layers['M_CSP_sl'].T, + ) - U, S = get_U_S_for_velocity_estimation( - subset_adata, - use_smoothed, - has_splicing, - has_labeling, - log_unnormalized, - NTR_vel, + # Parameters inference based on maximum likelihood estimation + cell_total = self.subset_adata.obs['initial_cell_size'].astype("float").values + # Independent cell-specific Poisson + (gamma_s, gamma_r2, beta, gamma_t, gamma_r2_raw, alpha) = storm.mle_independent_cell_specific_poisson \ + (UL_raw, SL_raw, self.time, gamma_init, beta_init, cell_total, Total_smoothed, S_smoothed) + gamma_k = gamma_s / beta + gamma_b = np.zeros_like(gamma_k) + + # Cell specific parameters (fixed gamma_s) + alpha, beta = storm.cell_specific_alpha_beta(UL_smoothed_CSP, SL_smoothed_CSP, self.time, gamma_s, beta) + + # # Cell specific parameters(fixed gamma_t) + # k = 1 - np.exp(-gamma_t[:, None] * time[None, :]) + # alpha = csr_matrix(gamma_t[:, None]).multiply(UL_smoothed_CSP+SL_smoothed_CSP).multiply(1 / k) + + Estm_df = { + "alpha": alpha, + "beta": beta, + "gamma_k": gamma_k, + "gamma_b": gamma_b, + # "gamma_k_r2": gamma_all_r2, + # "gamma_logLL": gamma_all_logLL, + "gamma": gamma_s, + "gamma_r2": gamma_r2, + # "mean_R2": mean_R2, + "gamma_t": gamma_t, + "gamma_r2_raw": gamma_r2_raw, + } + half_life = np.log(2) / gamma_s + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + None, + None, ) - vel = Velocity(estimation=est) - if experiment_type.lower() in [ - "one_shot", - "one-shot", - "kin", - "mix_std_stm", - ]: - U_, S_ = get_U_S_for_velocity_estimation( - subset_adata, - use_smoothed, - has_splicing, - has_labeling, - log_unnormalized, - not NTR_vel, - ) + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) + else: + # Initialization based on the steady-state assumption + layers_smoothed = ["M_t", "M_n"] + Total_smoothed, New_smoothed = ( + self.subset_adata.layers[layers_smoothed[0]].T, + self.subset_adata.layers[layers_smoothed[1]].T, + ) + (gamma_init, _, _, _, _,) = lin_reg_gamma_synthesis(Total_smoothed, New_smoothed, self.time, + perc_right=5) + + # Read raw counts + layers_raw = ["total", "new"] + Total_raw, New_raw = ( + self.subset_adata.layers[layers_raw[0]].T, + self.subset_adata.layers[layers_raw[1]].T, + ) - # also get vel_N and vel_T - if NTR_vel: - if has_splicing: - if experiment_type == "kin": - Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope - gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + # Read smoothed values based CSP type distribution for cell-specific parameter inference + layers_smoothed_CSP = ["M_CSP_t", "M_CSP_n"] + Total_smoothed_CSP, New_smoothed_CSP = ( + self.subset_adata.layers[layers_smoothed_CSP[0]].T, + self.subset_adata.layers[layers_smoothed_CSP[1]].T, + ) - vel_U = U.multiply(csr_matrix(gamma_ / Kc)) - csr_matrix(beta).multiply(U_) # vel.vel_s(U_) - vel_S = vel.vel_s(U_, S_) + # Parameters inference based on maximum likelihood estimation + cell_total = self.subset_adata.obs['initial_cell_size'].astype("float").values - vel_N = (U - csr_matrix(Kc).multiply(U)).multiply(csr_matrix(gamma_ / Kc)) # vel.vel_u(U) - # scale back to true velocity via multiplying "gamma_ / Kc". - vel_T = (U - csr_matrix(Kc).multiply(S)).multiply(csr_matrix(gamma_ / Kc)) - elif experiment_type == "mix_std_stm": - # steady state RNA: u0, stimulation RNA: u_new; - # cell-wise transcription rate under simulation: alpha1 - u0, u_new, alpha1 = solve_alpha_2p_mat( - t0=np.max(t) - t, - t1=t, - alpha0=alpha[0], - beta=beta, - u1=U, - ) - vel_U = alpha1 - csr_matrix(beta[:, None]).multiply(U_) - vel_S = vel.vel_s(U_, S_) + if "storm-csp" == self.est_method: + gamma, gamma_r2, gamma_r2_raw, alpha = storm.mle_cell_specific_poisson(New_raw, self.time, + gamma_init, cell_total) + elif "storm-cszip" == self.est_method: + gamma, prob_off, gamma_r2, gamma_r2_raw, alpha = storm.mle_cell_specific_zero_inflated_poisson( + New_raw, self.time, gamma_init, cell_total) + alpha = alpha * (1 - prob_off) # gene-wise alpha + else: + raise NotImplementedError("This method has not been implemented.") + + k = 1 - np.exp(-gamma[:, None] * self.time[None, :]) + alpha = csr_matrix(gamma[:, None]).multiply(New_smoothed_CSP).multiply(1 / k) # gene-cell-wise alpha + + Estm_df = { + "alpha": alpha, + "gamma": gamma, + "gamma_k": gamma, # required for phase_potrait + "gamma_r2": gamma_r2, + "gamma_r2_raw": gamma_r2_raw, + # "mean_R2": mean_R2, + "prob_off": prob_off if "cszip" in self.est_method else None + } + half_life = np.log(2) / gamma + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + None, # X_data, + None, # K_fit, + ) - vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) - vel_T = alpha1 - csr_matrix(beta[:, None]).multiply(S) - else: - vel_U = vel.vel_u(U_) - vel_S = vel.vel_s(U_, S_) - vel_N = vel.vel_u(U) - vel_T = vel.vel_s(U, S - U) # need to consider splicing - else: - if experiment_type == "kin": - vel_U = np.nan - vel_S = np.nan + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) - Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope - gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma - vel_N = (U - csr_matrix(Kc).multiply(U)).multiply(csr_matrix(gamma_ / Kc)) # vel.vel_u(U) - # scale back to true velocity via multiplying "gamma_ / Kc". - vel_T = (U - csr_matrix(Kc).multiply(S)).multiply(csr_matrix(gamma_ / Kc)) - elif experiment_type == "mix_std_stm": - vel_U = np.nan - vel_S = np.nan + def fit_direct_kinetics(self): + """Fit the input data to estimate parameters for kinetics experiment type with direct method.""" + if self.has_splicing and self.splicing_labeling: + layers = ( + ["M_ul", "M_sl", "M_uu", "M_su"] + if ("M_ul" in self.subset_adata.layers.keys() and self.data_type == "smoothed") + else ["X_ul", "X_sl", "X_uu", "X_su"] + ) - # steady state RNA: u0, stimulation RNA: u_new; - # cell-wise transcription rate under simulation: alpha1 - u0, u_new, alpha1 = solve_alpha_2p_mat( - t0=np.max(t) - t, - t1=t, - alpha0=alpha[0], - beta=gamma, - u1=U, - ) + if self.model.lower() in ["deterministic", "stochastic"]: + layer_u = "M_ul" if ("M_ul" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_ul" + layer_s = "M_sl" if ("M_ul" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_sl" - vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) - vel_T = alpha1 - csr_matrix(gamma[:, None]).multiply(S) - else: - vel_U = np.nan - vel_S = np.nan - vel_N = vel.vel_u(U) - vel_T = vel.vel_u(S) # don't consider splicing - else: - if has_splicing: - if experiment_type == "kin": - Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope - gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma + X, X_raw = prepare_data_has_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_u=layer_u, + layer_s=layer_s, + total_layers=layers, + ) + elif self.model.startswith("mixture"): + X, _, X_raw = prepare_data_deterministic( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layers=layers, + total_layers=layers, + ) - vel_U = U_.multiply(csr_matrix(gamma_ / Kc) - csr_matrix(beta).multiply(U)) # vel.vel_u(U) - vel_S = vel.vel_s(U, S) + if self.model.lower() == "deterministic": + X = [X[i][[0, 1], :] for i in range(len(X))] + _param_ranges = { + "alpha": [0, 1000], + "beta": [0, 1000], + "gamma": [0, 1000], + } + x0 = {"u0": [0, 1000], "s0": [0, 1000]} + Est, _ = Estimation_DeterministicKin, Deterministic + elif self.model.lower() == "stochastic": + x0 = { + "u0": [0, 1000], + "s0": [0, 1000], + "uu0": [0, 1000], + "ss0": [0, 1000], + "us0": [0, 1000], + } - vel_N = (U_ - csr_matrix(Kc).multiply(U_)).multiply( - csr_matrix(gamma_ / Kc) - ) # vel.vel_u(U_) - # scale back to true velocity via multiplying "gamma_ / Kc". - vel_T = (U_ - csr_matrix(Kc).multiply(S_)).multiply(csr_matrix(gamma_ / Kc)) - elif experiment_type == "mix_std_stm": - # steady state RNA: u0, stimulation RNA: u_new; - # cell-wise transcription rate under simulation: alpha1 - u0, u_new, alpha1 = solve_alpha_2p_mat( - t0=np.max(t) - t, - t1=t, - alpha0=alpha[0], - beta=beta, - u1=U_, - ) + if self.has_switch: + _param_ranges = { + "a": [0, 1000], + "b": [0, 1000], + "alpha_a": [0, 1000], + "alpha_i": 0, + "beta": [0, 1000], + "gamma": [0, 1000], + } + Est, _ = Estimation_MomentKin, Moments + else: + _param_ranges = { + "alpha": [0, 1000], + "beta": [0, 1000], + "gamma": [0, 1000], + } - vel_U = alpha1 - csr_matrix(beta[:, None]).multiply(U) - vel_S = vel.vel_s(U, S) + Est, _ = ( + Estimation_MomentKinNoSwitch, + Moments_NoSwitching, + ) + elif self.model.lower() == "mixture": + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "beta": [0, 1000], + "gamma": [0, 1000], + } + x0 = { + "ul0": [0, 0], + "sl0": [0, 0], + "uu0": [0, 1000], + "su0": [0, 1000], + } - vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) - vel_T = alpha1 - csr_matrix(beta[:, None]).multiply(S_) + Est = Mixture_KinDeg_NoSwitching(Deterministic(), Deterministic()) + elif self.model.lower() == "mixture_deterministic_stochastic": + X, X_raw = prepare_data_mix_has_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_u=layers[2], + layer_s=layers[3], + layer_ul=layers[0], + layer_sl=layers[1], + total_layers=layers, + mix_model_indices=[0, 1, 5, 6, 7, 8, 9], + ) - else: - vel_U = vel.vel_u(U) - vel_S = vel.vel_s(U, S) - vel_N = vel.vel_u(U_) - vel_T = vel.vel_s(U_, S_ - U_) # need to consider splicing - else: - if experiment_type == "kin": - vel_U = np.nan - vel_S = np.nan + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "beta": [0, 1000], + "gamma": [0, 1000], + } + x0 = { + "ul0": [0, 0], + "sl0": [0, 0], + "u0": [0, 1000], + "s0": [0, 1000], + "uu0": [0, 1000], + "ss0": [0, 1000], + "us0": [0, 1000], + } + Est = Mixture_KinDeg_NoSwitching(Deterministic(), Moments_NoSwitching()) + elif self.model.lower() == "mixture_stochastic_stochastic": + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "beta": [0, 1000], + "gamma": [0, 1000], + } + X, X_raw = prepare_data_mix_has_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_u=layers[2], + layer_s=layers[3], + layer_ul=layers[0], + layer_sl=layers[1], + total_layers=layers, + mix_model_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + ) + x0 = { + "ul0": [0, 1000], + "sl0": [0, 1000], + "ul_ul0": [0, 1000], + "sl_sl0": [0, 1000], + "ul_sl0": [0, 1000], + "u0": [0, 1000], + "s0": [0, 1000], + "uu0": [0, 1000], + "ss0": [0, 1000], + "us0": [0, 1000], + } + Est = Mixture_KinDeg_NoSwitching(Moments_NoSwitching(), Moments_NoSwitching()) + else: + raise NotImplementedError( + f"model {self.model} with kinetic assumption is not implemented. " + f"current supported models for kinetics experiments include: stochastic, deterministic, mixture," + f"mixture_deterministic_stochastic or mixture_stochastic_stochastic" + ) + else: + total_layer = "M_t" if ("M_t" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_total" + + if self.model.lower() in ["deterministic", "stochastic"]: + layer = "M_n" if ("M_n" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_new" + X, X_raw = prepare_data_no_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer=layer, + total_layer=total_layer, + ) + elif self.model.lower().startswith("mixture"): + layers = ( + ["M_n", "M_t"] + if ("M_n" in self.subset_adata.layers.keys() and self.data_type == "smoothed") + else ["X_new", "X_total"] + ) - Kc = np.clip(gamma[:, None], 0, 1 - 1e-3) # S - U slope - gamma_ = -(np.log(1 - Kc) / t[None, :]) # actual gamma - vel_N = (U_ - csr_matrix(Kc).multiply(U_)).multiply( - csr_matrix(gamma_ / Kc) - ) # vel.vel_u(U_) - # scale back to true velocity via multiplying "gamma_ / Kc". - vel_T = (U_ - csr_matrix(Kc).multiply(S_)).multiply(csr_matrix(gamma_ / Kc)) - elif experiment_type == "mix_std_stm": - vel_U = np.nan - vel_S = np.nan + X, _, X_raw = prepare_data_deterministic( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layers=layers, + total_layers=total_layer, + ) - # steady state RNA: u0, stimulation RNA: u_new; - # cell-wise transcription rate under simulation: alpha1 - u0, u_new, alpha1 = solve_alpha_2p_mat( - t0=np.max(t) - t, - t1=t, - alpha0=alpha[0], - beta=gamma, - u1=U_, - ) + if self.model.lower() == "deterministic": + X = [X[i][0, :] for i in range(len(X))] + _param_ranges = { + "alpha": [0, 1000], + "gamma": [0, 1000], + } + x0 = {"u0": [0, 1000]} + Est, _ = ( + Estimation_DeterministicKinNosp, + Deterministic_NoSplicing, + ) + elif self.model.lower() == "stochastic": + x0 = { + "u0": [0, 1000], + "uu0": [0, 1000], + } + if self.has_switch: + _param_ranges = { + "a": [0, 1000], + "b": [0, 1000], + "alpha_a": [0, 1000], + "alpha_i": 0, + "gamma": [0, 1000], + } + Est, _ = Estimation_MomentKinNosp, Moments_Nosplicing + else: + _param_ranges = { + "alpha": [0, 1000], + "gamma": [0, 1000], + } + Est, _ = ( + Estimation_MomentKinNoSwitchNoSplicing, + Moments_NoSwitchingNoSplicing, + ) + elif self.model.lower() == "mixture": + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "gamma": [0, 1000], + } + x0 = {"u0": [0, 0], "o0": [0, 1000]} + Est = Mixture_KinDeg_NoSwitching(Deterministic_NoSplicing(), Deterministic_NoSplicing()) + elif self.model.lower() == "mixture_deterministic_stochastic": + X, X_raw = prepare_data_mix_no_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_n=layers[0], + layer_t=layers[1], + total_layer=total_layer, + mix_model_indices=[0, 2, 3], + ) - vel_N = alpha1 - csr_matrix(gamma[:, None]).multiply(u_new) - vel_T = alpha1 - csr_matrix(gamma[:, None]).multiply(S_) - else: - vel_U = np.nan - vel_S = np.nan - vel_N = vel.vel_u(U_) - vel_T = vel.vel_u(S_) # don't consider splicing + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "gamma": [0, 1000], + } + x0 = {"u0": [0, 1000], "o0": [0, 1000], "oo0": [0, 1000]} + Est = Mixture_KinDeg_NoSwitching( + Deterministic_NoSplicing(), + Moments_NoSwitchingNoSplicing(), + ) + elif self.model.lower() == "mixture_stochastic_stochastic": + X, X_raw = prepare_data_mix_no_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_n=layers[0], + layer_t=layers[1], + total_layer=total_layer, + mix_model_indices=[0, 1, 2, 3], + ) + + _param_ranges = { + "alpha": [0, 1000], + "alpha_2": [0, 0], + "gamma": [0, 1000], + } + x0 = { + "u0": [0, 1000], + "uu0": [0, 1000], + "o0": [0, 1000], + "oo0": [0, 1000], + } + Est = Mixture_KinDeg_NoSwitching( + Moments_NoSwitchingNoSplicing(), + Moments_NoSwitchingNoSplicing(), + ) else: - vel_U = vel.vel_u(U) - vel_S = vel.vel_s(U, S) - vel_N, vel_T = np.nan, np.nan + raise NotImplementedError( + f"model {self.model} with kinetic assumption is not implemented. " + f"current supported models for kinetics experiments include: stochastic, deterministic, " + f"mixture, mixture_deterministic_stochastic or mixture_stochastic_stochastic" + ) + _param_ranges = update_dict(_param_ranges, self.param_rngs) + x0_ = np.vstack([ran for ran in x0.values()]).T + + n_genes = self.subset_adata.n_vars + cost, logLL = np.zeros(n_genes), np.zeros(n_genes) + all_keys = list(_param_ranges.keys()) + list(x0.keys()) + all_keys = [cur_key for cur_key in all_keys if cur_key != "alpha_i"] + half_life, Estm = np.zeros(n_genes), [None] * n_genes + X_data, X_fit_data = [None] * n_genes, [None] * n_genes + if self.experiment_type: + popt = [None] * n_genes + + main_debug("model: %s, experiment_type: %s" % (self.model, self.experiment_type)) + for i_gene in tqdm(range(n_genes), desc="estimating kinetic-parameters using kinetic model"): + if self.model.lower().startswith("mixture"): + estm = Est + if self.model.lower() == "mixture": + cur_X_data = np.vstack([X[i_layer][i_gene] for i_layer in range(len(X))]) + if issparse(X_raw[0]): + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene].A for i_layer in range(len(X))]) + else: + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene] for i_layer in range(len(X))]) + else: + cur_X_data = X[i_gene] + cur_X_raw = X_raw[i_gene] - vel_P = vel.vel_p(S, P) + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) - adata = set_velocity( - adata, - vel_U, - vel_S, - vel_N, - vel_T, - vel_P, - _group, - cur_grp, - cur_cells_bools, - valid_bools_, - ind_for_proteins, - ) + _, cost[i_gene] = estm.auto_fit(np.unique(self.time), cur_X_data) + ( + model_1, + model_2, + kinetic_parameters, + mix_x0, + ) = estm.export_dictionary().values() + tmp = list(kinetic_parameters.values()) + tmp.extend(mix_x0) + Estm[i_gene] = tmp + else: + cur_X_data, cur_X_raw = X[i_gene], X_raw[i_gene] - adata = set_param_ss( - adata, - est, - alpha, - beta, - gamma, - eta, - delta, - experiment_type, - _group, - cur_grp, - kin_param_pre, - valid_bools_, - ind_for_proteins, - ) + if self.has_splicing: + alpha0 = guestimate_alpha(np.sum(cur_X_data, 0), np.unique(self.time)) + else: + alpha0 = ( + guestimate_alpha(cur_X_data, np.unique(self.time)) + if cur_X_data.ndim == 1 + else guestimate_alpha(cur_X_data[0], np.unique(self.time)) + ) - elif assumption_mRNA.lower() == "kinetic": - return_ntr = True if fraction_for_deg and experiment_type.lower() == "deg" else False + if self.model.lower() == "stochastic": + _param_ranges.update({"alpha_a": [0, alpha0 * 10]}) + elif self.model.lower() == "deterministic": + _param_ranges.update({"alpha": [0, alpha0 * 10]}) + param_ranges = [ran for ran in _param_ranges.values()] - if model_was_auto and experiment_type.lower() == "kin": - model = "mixture" - if est_method == "auto": - est_method = "direct" - data_type = "smoothed" if use_smoothed else "sfs" + estm = Est(*param_ranges, x0=x0_) if "x0" in inspect.getfullargspec(Est) else Est(*param_ranges) + _, cost[i_gene] = estm.fit_lsq(np.unique(self.time), cur_X_data, **self.est_kwargs) + if self.model.lower() == "deterministic": + Estm[i_gene] = estm.export_parameters() + else: + tmp = np.ma.array(estm.export_parameters(), mask=False) + tmp.mask[3] = True + Estm[i_gene] = tmp.compressed() - (params, half_life, cost, logLL, param_ranges, cur_X_data, cur_X_fit_data,) = kinetic_model( - subset_adata, - tkey, - model, - est_method, - experiment_type, - has_splicing, - splicing_labeling, - has_switch=True, - param_rngs={}, - data_type=data_type, - return_ntr=return_ntr, - **est_kwargs, - ) + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) - if type(params) == dict: - alpha = params.pop("alpha") - params = pd.DataFrame(params) + X_data[i_gene] = cur_X_data + if self.model.lower().startswith("mixture"): + X_fit_data[i_gene] = estm.simulator.x.T + X_fit_data[i_gene][estm.model1.n_species:] *= estm.scale else: - alpha = params.loc[:, "alpha"].values if "alpha" in params.columns else None + if hasattr(estm, "extract_data_from_simulator"): + X_fit_data[i_gene] = estm.extract_data_from_simulator() + else: + X_fit_data[i_gene] = estm.simulator.x.T - len_t, len_g = len(np.unique(t)), len(_group) - if cur_grp == _group[0]: - if len_g != 1: - # X_data, X_fit_data = np.zeros((len_g, adata.n_vars, len_t)), np.zeros((len_g, adata.n_vars,len_t)) - X_data, X_fit_data = [None] * len_g, [None] * len_g + half_life[i_gene] = np.log(2) / Estm[i_gene][-1] - if len(_group) == 1: - X_data, X_fit_data = cur_X_data, cur_X_fit_data + if self.model.lower().startswith("mixture"): + species = [0, 1, 2, 3] if self.has_splicing else [0, 1] + gof = GoodnessOfFit(estm.export_model(), params=estm.export_parameters()) + gof.prepare_data(self.time, cur_X_raw.T, species=species, normalize=True) else: - # X_data[cur_grp_i, :, :], X_fit_data[cur_grp_i, :, :] = cur_X_data, cur_X_fit_data - X_data[cur_grp_i], X_fit_data[cur_grp_i] = ( - cur_X_data, - cur_X_fit_data, + gof = GoodnessOfFit( + estm.export_model(), + params=estm.export_parameters(), + x0=estm.simulator.x0, ) + gof.prepare_data(self.time, cur_X_raw.T, normalize=True) - a, b, alpha_a, alpha_i, beta, gamma = ( - params.loc[:, "a"].values if "a" in params.columns else None, - params.loc[:, "b"].values if "b" in params.columns else None, - params.loc[:, "alpha_a"].values if "alpha_a" in params.columns else None, - params.loc[:, "alpha_i"].values if "alpha_i" in params.columns else None, - params.loc[:, "beta"].values if "beta" in params.columns else None, - params.loc[:, "gamma"].values if "gamma" in params.columns else None, - ) - if alpha is None: - alpha = fbar(a, b, alpha_a, 0) if alpha_i is None else fbar(a, b, alpha_a, alpha_i) - all_kinetic_params = [ - "a", - "b", - "alpha_a", - "alpha_i", - "alpha", - "beta", - "gamma", - ] + logLL[i_gene] = gof.calc_mean_squared_deviation() # .calc_gaussian_loglikelihood() - extra_params = params.loc[:, params.columns.difference(all_kinetic_params)] - # if alpha = None, set alpha to be U; N - gamma R - params = {"alpha": alpha, "beta": beta, "gamma": gamma, "t": t} - vel = Velocity(**params) - # Fix below: - U, S = get_U_S_for_velocity_estimation( - subset_adata, - use_smoothed, - has_splicing, - has_labeling, - log_unnormalized, - NTR_vel, + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + + return Estm_df, half_life, cost, logLL, _param_ranges, X_data, X_fit_data + + def fit_degradation(self): + """Fit the input data to estimate parameters for degradation experiment type.""" + if self.has_splicing and self.splicing_labeling: + layers = ( + ["M_ul", "M_sl", "M_uu", "M_su"] + if ("M_ul" in self.subset_adata.layers.keys() and self.data_type == "smoothed") + else ["X_ul", "X_sl", "X_uu", "X_su"] ) - U_, S_ = get_U_S_for_velocity_estimation( - subset_adata, - use_smoothed, - has_splicing, - has_labeling, - log_unnormalized, - not NTR_vel, + if self.model.lower() in ["deterministic", "stochastic"]: + layer_u = "M_ul" if ("M_ul" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_ul" + layer_s = "M_sl" if ("M_sl" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_sl" + + X, X_raw = prepare_data_has_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer_u=layer_u, + layer_s=layer_s, + total_layers=layers, + return_ntr=self.return_ntr, + ) + elif self.model.lower().startswith("mixture"): + X, _, X_raw = prepare_data_deterministic( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layers=layers, + total_layers=layers, + return_ntr=self.return_ntr, + ) + + if self.model.lower() == "deterministic": + X = [X[i][[0, 1], :] for i in range(len(X))] + _param_ranges = { + "beta": [0, 1000], + "gamma": [0, 1000], + } + x0 = { + "u0": [0, 1000], + "s0": [0, 1000], + } + Est, _ = Estimation_DeterministicDeg, Deterministic + elif self.model.lower() == "stochastic": + _param_ranges = { + "beta": [0, 1000], + "gamma": [0, 1000], + } + x0 = { + "u0": [0, 1000], + "s0": [0, 1000], + "uu0": [0, 1000], + "ss0": [0, 1000], + "us0": [0, 1000], + } + Est, _ = Estimation_MomentDeg, Moments_NoSwitching + else: + raise NotImplementedError( + f"model {self.model} with kinetic assumption is not implemented. " + f"current supported models for degradation experiment include: " + f"stochastic, deterministic." + ) + else: + total_layer = "M_t" if ("M_t" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_total" + + layer = "M_n" if ("M_n" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_new" + X, X_raw = prepare_data_no_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer=layer, + total_layer=total_layer, + return_ntr=self.return_ntr, ) - # also get vel_N and vel_T - if NTR_vel: - if has_splicing: - if experiment_type == "kin": - vel_U = vel.vel_u(U_) - vel_S = vel.vel_s(U_, S_) - vel.parameters["beta"] = gamma - vel_N = vel.vel_u(U) - vel_T = vel.vel_u(S) # no need to consider splicing - elif experiment_type == "deg": - if splicing_labeling: - vel_U = np.nan - vel_S = vel.vel_s(U_, S_) - vel_N = np.nan - vel_T = np.nan - else: - vel_U = np.nan - vel_S = vel.vel_s(U_, S_) - vel_N = np.nan - vel_T = np.nan - elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: - vel_U = vel.vel_u(U_, repeat=True) - vel_S = vel.vel_s(U_, S_) - vel.parameters["beta"] = gamma - vel_N = vel.vel_u(U, repeat=True) - vel_T = vel.vel_u(S, repeat=True) # no need to consider splicing + if self.model.lower() == "deterministic": + X = [X[i][0, :] for i in range(len(X))] + _param_ranges = { + "gamma": [0, 10], + } + x0 = {"u0": [0, 1000]} + Est, _ = ( + Estimation_DeterministicDegNosp, + Deterministic_NoSplicing, + ) + elif self.model.lower() == "stochastic": + _param_ranges = { + "gamma": [0, 10], + } + x0 = {"u0": [0, 1000], "uu0": [0, 1000]} + Est, _ = Estimation_MomentDegNosp, Moments_NoSwitchingNoSplicing + else: + raise NotImplementedError( + f"model {self.model} with kinetic assumption is not implemented. " + f"current supported models for degradation experiment include: " + f"stochastic, deterministic.") + _param_ranges = update_dict(_param_ranges, self.param_rngs) + x0_ = np.vstack([ran for ran in x0.values()]).T + + n_genes = self.subset_adata.n_vars + cost, logLL = np.zeros(n_genes), np.zeros(n_genes) + all_keys = list(_param_ranges.keys()) + list(x0.keys()) + all_keys = [cur_key for cur_key in all_keys if cur_key != "alpha_i"] + half_life, Estm = np.zeros(n_genes), [None] * n_genes + X_data, X_fit_data = [None] * n_genes, [None] * n_genes + if self.experiment_type: + popt = [None] * n_genes + + main_debug("model: %s, experiment_type: %s" % (self.model, self.experiment_type)) + for i_gene in tqdm(range(n_genes), desc="estimating kinetic-parameters using kinetic model"): + if self.model.lower().startswith("mixture"): + estm = Est + if self.model.lower() == "mixture": + cur_X_data = np.vstack([X[i_layer][i_gene] for i_layer in range(len(X))]) + if issparse(X_raw[0]): + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene].A for i_layer in range(len(X))]) + else: + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene] for i_layer in range(len(X))]) else: - if experiment_type == "kin": - vel_U = np.nan - vel_S = np.nan + cur_X_data = X[i_gene] + cur_X_raw = X_raw[i_gene] - # calculate cell-wise alpha, if est_method is twostep, this can be skipped - alpha_ = one_shot_alpha_matrix(U, gamma, t) + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) + + _, cost[i_gene] = estm.auto_fit(np.unique(self.time), cur_X_data) + ( + model_1, + model_2, + kinetic_parameters, + mix_x0, + ) = estm.export_dictionary().values() + tmp = list(kinetic_parameters.values()) + tmp.extend(mix_x0) + Estm[i_gene] = tmp + else: + estm = Est() + cur_X_data, cur_X_raw = X[i_gene], X_raw[i_gene] - vel.parameters["alpha"] = alpha_ + _, cost[i_gene] = estm.auto_fit(np.unique(self.time), cur_X_data) + Estm[i_gene] = estm.export_parameters()[1:] - vel_N = vel.vel_u(U) - vel_T = vel.vel_u(S) # don't consider splicing - elif experiment_type == "deg": - vel_U = np.nan - vel_S = np.nan - vel_N = np.nan - vel_T = np.nan - elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: - vel_U = np.nan - vel_S = np.nan - vel_N = vel.vel_u(U, repeat=True) - vel_T = vel.vel_u(S) # don't consider splicing + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) + # model_1, kinetic_parameters, mix_x0 = estm.export_dictionary().values() + # tmp = list(kinetic_parameters.values()) + # tmp.extend(mix_x0) + # Estm[i_gene] = tmp + + X_data[i_gene] = cur_X_data + if self.model.lower().startswith("mixture"): + X_fit_data[i_gene] = estm.simulator.x.T + X_fit_data[i_gene][estm.model1.n_species:] *= estm.scale else: - if has_splicing: - if experiment_type == "kin": - vel_U = vel.vel_u(U) - vel_S = vel.vel_s(U, S) - vel.parameters["beta"] = gamma - vel_N = vel.vel_u(U_) - vel_T = vel.vel_u(S_) # no need to consider splicing - elif experiment_type == "deg": - if splicing_labeling: - vel_U = np.nan - vel_S = vel.vel_s(U, S) - vel_N = np.nan - vel_T = np.nan - else: - vel_U = np.nan - vel_S = vel.vel_s(U, S) - vel_N = np.nan - vel_T = np.nan - elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: - vel_U = vel.vel_u(U, repeat=True) - vel_S = vel.vel_s(U, S) - vel.parameters["beta"] = gamma - vel_N = vel.vel_u(U_, repeat=True) - vel_T = vel.vel_u(S_, repeat=True) # no need to consider splicing + if hasattr(estm, "extract_data_from_simulator"): + X_fit_data[i_gene] = estm.extract_data_from_simulator() else: - if experiment_type == "kin": - vel_U = np.nan - vel_S = np.nan - - # calculate cell-wise alpha, if est_method is twostep, this can be skipped - alpha_ = one_shot_alpha_matrix(U_, gamma, t) + X_fit_data[i_gene] = estm.simulator.x.T - vel.parameters["alpha"] = alpha_ + half_life[i_gene] = estm.calc_half_life("gamma") - vel_N = vel.vel_u(U_) - vel_T = vel.vel_u(S_) # need to consider splicing - elif experiment_type == "deg": - vel_U = np.nan - vel_S = np.nan - vel_N = np.nan - vel_T = np.nan - elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: - vel_U = np.nan - vel_S = np.nan - vel_N = vel.vel_u(U_, repeat=True) - vel_T = vel.vel_u(S_, repeat=True) # don't consider splicing + if self.model.lower().startswith("mixture"): + species = [0, 1, 2, 3] if self.has_splicing else [0, 1] + gof = GoodnessOfFit(estm.export_model(), params=estm.export_parameters()) + gof.prepare_data(self.time, cur_X_raw.T, species=species, normalize=True) + else: + gof = GoodnessOfFit( + estm.export_model(), + params=estm.export_parameters(), + x0=estm.simulator.x0, + ) + gof.prepare_data(self.time, cur_X_raw.T, normalize=True) - vel_P = vel.vel_p(S, P) + logLL[i_gene] = gof.calc_mean_squared_deviation() # .calc_gaussian_loglikelihood() - adata = set_velocity( - adata, - vel_U, - vel_S, - vel_N, - vel_T, - vel_P, - _group, - cur_grp, - cur_cells_bools, - valid_bools_, - ind_for_proteins, + if self.est_method == "twostep" and self.has_splicing: + layers = ["M_u", "M_s"] if ("M_u" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else ["X_u", + "X_s"] + U, S = ( + self.subset_adata.layers[layers[0]].T, + self.subset_adata.layers[layers[1]].T, + ) + US, S2 = self.subset_adata.layers["M_us"].T, self.subset_adata.layers["M_ss"].T + # beta, beta_r2 = lin_reg_gamma_synthesis(U, Ul, time, perc_right=100) + gamma_k, gamma_b, gamma_all_r2, gamma_all_logLL = fit_slope_stochastic( + S, U, US, S2, perc_left=None, perc_right=5 ) - adata = set_param_kinetic( - adata, - alpha, - a, - b, - alpha_a, - alpha_i, - beta, - gamma, - cost, - logLL, - kin_param_pre, - extra_params, - _group, - cur_grp, - cur_cells_bools, - valid_bools_, + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + Estm_df["gamma_k"] = gamma_k # gamma_k = gamma / beta + Estm_df["beta"] = Estm_df["gamma"] / gamma_k # gamma_k = gamma / beta + Estm_df["gamma_r2"] = gamma_all_r2 + else: + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + + return Estm_df, half_life, cost, logLL, _param_ranges, X_data, X_fit_data + + def fit_mix_kinetics(self): + """Fit the input data to estimate parameters for mix_kinetics_degradation experiment type.""" + total_layer = "M_t" if ("M_t" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_total" + + if self.model.lower() in ["deterministic"]: + layer = "M_n" if ("M_n" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else "X_new" + X, X_raw = prepare_data_no_splicing( + self.subset_adata, + self.subset_adata.var.index, + self.time, + layer=layer, + total_layer=total_layer, ) - # add protein related parameters in the moment model below: - elif model.lower() == "model_selection": - main_warning("Not implemented yet.") + if self.model.lower() == "deterministic": + X = [X[i][0, :] for i in range(len(X))] + _param_ranges = { + "alpha": [0, 1000], + "gamma": [0, 1000], + } + x0 = {"u0": [0, 1000]} + Est = Estimation_KineticChase + else: + raise NotImplementedError( + f"only `deterministic` model implemented for mix_pulse_chase/mix_kin_deg experiment!" + ) + _param_ranges = update_dict(_param_ranges, self.param_rngs) + x0_ = np.vstack([ran for ran in x0.values()]).T + + n_genes = self.subset_adata.n_vars + cost, logLL = np.zeros(n_genes), np.zeros(n_genes) + all_keys = list(_param_ranges.keys()) + list(x0.keys()) + all_keys = [cur_key for cur_key in all_keys if cur_key != "alpha_i"] + half_life, Estm = np.zeros(n_genes), [None] * n_genes + X_data, X_fit_data = [None] * n_genes, [None] * n_genes + if self.experiment_type: + popt = [None] * n_genes + + main_debug("model: %s, experiment_type: %s" % (self.model, self.experiment_type)) + for i_gene in tqdm(range(n_genes), desc="estimating kinetic-parameters using kinetic model"): + if self.model.lower().startswith("mixture"): + estm = Est + if self.model.lower() == "mixture": + cur_X_data = np.vstack([X[i_layer][i_gene] for i_layer in range(len(X))]) + if issparse(X_raw[0]): + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene].A for i_layer in range(len(X))]) + else: + cur_X_raw = np.hstack([X_raw[i_layer][:, i_gene] for i_layer in range(len(X))]) + else: + cur_X_data = X[i_gene] + cur_X_raw = X_raw[i_gene] - if group is not None and group in adata.obs[group]: - uns_key = group + "_dynamics" - else: - uns_key = "dynamics" + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) - if sanity_check and experiment_type in ["kin", "deg"]: - sanity_check_cols = adata.var.columns.str.endswith("sanity_check") - adata.var["use_for_dynamics"] = adata.var.loc[:, sanity_check_cols].sum(1).astype(bool) - else: - adata.var["use_for_dynamics"] = False - adata.var.loc[valid_bools, "use_for_dynamics"] = True + _, cost[i_gene] = estm.auto_fit(np.unique(self.time), cur_X_data) + ( + model_1, + model_2, + kinetic_parameters, + mix_x0, + ) = estm.export_dictionary().values() + tmp = list(kinetic_parameters.values()) + tmp.extend(mix_x0) + Estm[i_gene] = tmp + else: + estm = Est() + cur_X_data, cur_X_raw = X[i_gene], X_raw[i_gene] - adata.uns[uns_key] = { - "filter_gene_mode": filter_gene_mode, - "t": t, - "group": group, - "X_data": X_data, - "X_fit_data": X_fit_data, - "asspt_mRNA": assumption_mRNA, - "experiment_type": experiment_type, - "normalized": normalized, - "model": model, - "est_method": est_method, - "has_splicing": has_splicing, - "has_labeling": has_labeling, - "splicing_labeling": splicing_labeling, - "has_protein": has_protein, - "use_smoothed": use_smoothed, - "NTR_vel": NTR_vel, - "log_unnormalized": log_unnormalized, - "fraction_for_deg": fraction_for_deg, - } + popt[i_gene], cost[i_gene] = estm.auto_fit(np.unique(self.time), cur_X_data) + Estm[i_gene] = estm.export_parameters() - if del_2nd_moments: - remove_2nd_moments(adata) + if issparse(cur_X_raw[0, 0]): + cur_X_raw = np.hstack((cur_X_raw[0, 0].A, cur_X_raw[1, 0].A)) + # model_1, kinetic_parameters, mix_x0 = estm.export_dictionary().values() + # tmp = list(kinetic_parameters.values()) + # tmp.extend(mix_x0) + # Estm[i_gene] = tmp - return adata + X_data[i_gene] = cur_X_data + if self.model.lower().startswith("mixture"): + X_fit_data[i_gene] = estm.simulator.x.T + X_fit_data[i_gene][estm.model1.n_species:] *= estm.scale + else: + # kinetic chase simulation + kinetic_chase = estm.simulator.x.T + # hidden x + tt, h = estm.simulator.calc_init_conc() + + X_fit_data[i_gene] = [kinetic_chase, [tt, h]] + + half_life[i_gene] = estm.calc_half_life("gamma") + + if self.model.lower().startswith("mixture"): + species = [0, 1, 2, 3] if self.has_splicing else [0, 1] + gof = GoodnessOfFit(estm.export_model(), params=estm.export_parameters()) + gof.prepare_data(self.time, cur_X_raw.T, species=species, normalize=True) + else: + gof = GoodnessOfFit( + estm.export_model(), + params=estm.export_parameters(), + x0=estm.simulator.x0, + ) + gof.prepare_data(self.time, cur_X_raw.T, normalize=True) + + logLL[i_gene] = gof.calc_mean_squared_deviation() # .calc_gaussian_loglikelihood() + + if self.est_method == "twostep": + if self.has_splicing: + layers = ( + ["M_u", "M_s"] if ("M_u" in self.subset_adata.layers.keys() and self.data_type == "smoothed") else ["X_u", + "X_s"] + ) + U, S = ( + self.subset_adata.layers[layers[0]].T, + self.subset_adata.layers[layers[1]].T, + ) + US, S2 = ( + self.subset_adata.layers["M_us"].T, + self.subset_adata.layers["M_ss"].T, + ) + # beta, beta_r2 = lin_reg_gamma_synthesis(U, Ul, time, perc_right=100) + ( + gamma_k, + gamma_b, + gamma_all_r2, + gamma_all_logLL, + ) = fit_slope_stochastic(S, U, US, S2, perc_left=None, perc_right=5) + + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + Estm_df["gamma_k"] = gamma_k # gamma_k = gamma / beta + Estm_df["beta"] = Estm_df["gamma"] / gamma_k # gamma_k = gamma / beta + Estm_df["gamma_r2"] = gamma_all_r2 + else: + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + Estm_df["gamma_k"] = Estm_df["gamma"] # fix a bug in pl.dynamics + else: + Estm_df = pd.DataFrame(np.vstack(Estm), columns=[*all_keys[: len(Estm[0])]]) + + return Estm_df, half_life, cost, logLL, _param_ranges, X_data, X_fit_data def kinetic_model( subset_adata: AnnData, tkey: str, model: Literal["auto", "deterministic", "stochastic"], - est_method: Literal["twostep", "direct"], + est_method: Literal["twostep", "direct", "storm-csp", "storm-cszip", "storm-icsp"], experiment_type: str, has_splicing: bool, splicing_labeling: bool, @@ -1162,6 +3638,152 @@ def kinetic_model( K_fit, ) + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) + elif "storm" in est_method: + if has_splicing: + # Initialization based on the steady-state assumption + layers_smoothed = ["M_u", "M_s", "M_t", "M_n"] + U_smoothed, S_smoothed, Total_smoothed, New_smoothed = ( + subset_adata.layers[layers_smoothed[0]].T, + subset_adata.layers[layers_smoothed[1]].T, + subset_adata.layers[layers_smoothed[2]].T, + subset_adata.layers[layers_smoothed[3]].T, + ) + + US_smoothed, S2_smoothed = ( + subset_adata.layers["M_us"].T, + subset_adata.layers["M_ss"].T, + ) + (gamma_k, _, _, _,) = fit_slope_stochastic(S_smoothed, U_smoothed, US_smoothed, S2_smoothed, + perc_left=None, perc_right=5) + (gamma_init, _, _, _, _) = lin_reg_gamma_synthesis(Total_smoothed, New_smoothed, time, perc_right=5) + beta_init = gamma_init / gamma_k # gamma_k = gamma / beta + + # Read raw counts + layers_raw = ["ul", "sl"] + UL_raw, SL_raw = ( + subset_adata.layers[layers_raw[0]].T, + subset_adata.layers[layers_raw[1]].T, + ) + + # Read smoothed values based CSP type distribution for cell-specific parameter inference + UL_smoothed_CSP, SL_smoothed_CSP = ( + subset_adata.layers['M_CSP_ul'].T, + subset_adata.layers['M_CSP_sl'].T, + ) + + # Parameters inference based on maximum likelihood estimation + cell_total = subset_adata.obs['initial_cell_size'].astype("float").values + # Independent cell-specific Poisson + (gamma_s, gamma_r2, beta, gamma_t, gamma_r2_raw, alpha) = storm.mle_independent_cell_specific_poisson \ + (UL_raw, SL_raw, time, gamma_init, beta_init, cell_total, Total_smoothed, S_smoothed) + gamma_k = gamma_s / beta + gamma_b = np.zeros_like(gamma_k) + + # Cell specific parameters (fixed gamma_s) + alpha, beta = storm.cell_specific_alpha_beta(UL_smoothed_CSP, SL_smoothed_CSP, time, gamma_s, beta) + + # # Cell specific parameters(fixed gamma_t) + # k = 1 - np.exp(-gamma_t[:, None] * time[None, :]) + # alpha = csr_matrix(gamma_t[:, None]).multiply(UL_smoothed_CSP+SL_smoothed_CSP).multiply(1 / k) + + Estm_df = { + "alpha": alpha, + "beta": beta, + "gamma_k": gamma_k, + "gamma_b": gamma_b, + # "gamma_k_r2": gamma_all_r2, + # "gamma_logLL": gamma_all_logLL, + "gamma": gamma_s, + "gamma_r2": gamma_r2, + # "mean_R2": mean_R2, + "gamma_t": gamma_t, + "gamma_r2_raw": gamma_r2_raw, + } + half_life = np.log(2) / gamma_s + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + None, + None, + ) + + return ( + Estm_df, + half_life, + cost, + logLL, + _param_ranges, + X_data, + X_fit_data, + ) + else: + # Initialization based on the steady-state assumption + layers_smoothed = ["M_t", "M_n"] + Total_smoothed, New_smoothed = ( + subset_adata.layers[layers_smoothed[0]].T, + subset_adata.layers[layers_smoothed[1]].T, + ) + (gamma_init, _, _, _, _,) = lin_reg_gamma_synthesis(Total_smoothed, New_smoothed, time, + perc_right=5) + + # Read raw counts + layers_raw = ["total", "new"] + Total_raw, New_raw = ( + subset_adata.layers[layers_raw[0]].T, + subset_adata.layers[layers_raw[1]].T, + ) + + # Read smoothed values based CSP type distribution for cell-specific parameter inference + layers_smoothed_CSP = ["M_CSP_t", "M_CSP_n"] + Total_smoothed_CSP, New_smoothed_CSP = ( + subset_adata.layers[layers_smoothed_CSP[0]].T, + subset_adata.layers[layers_smoothed_CSP[1]].T, + ) + + # Parameters inference based on maximum likelihood estimation + cell_total = subset_adata.obs['initial_cell_size'].astype("float").values + + if "storm-csp" == est_method: + gamma, gamma_r2, gamma_r2_raw, alpha = storm.mle_cell_specific_poisson(New_raw, time, + gamma_init, cell_total) + elif "storm-cszip" == est_method: + gamma, prob_off, gamma_r2, gamma_r2_raw, alpha = storm.mle_cell_specific_zero_inflated_poisson( + New_raw, time, gamma_init, cell_total) + alpha = alpha * (1 - prob_off) # gene-wise alpha + else: + raise NotImplementedError("This method has not been implemented.") + + k = 1 - np.exp(-gamma[:, None] * time[None, :]) + alpha = csr_matrix(gamma[:, None]).multiply(New_smoothed_CSP).multiply(1 / k) # gene-cell-wise alpha + + Estm_df = { + "alpha": alpha, + "gamma": gamma, + "gamma_k": gamma, # required for phase_potrait + "gamma_r2": gamma_r2, + "gamma_r2_raw": gamma_r2_raw, + # "mean_R2": mean_R2, + "prob_off": prob_off if "cszip" in est_method else None + } + half_life = np.log(2) / gamma + cost, logLL, _param_ranges, X_data, X_fit_data = ( + None, + None, + None, + None, # X_data, + None, # K_fit, + ) + return ( Estm_df, half_life, @@ -1663,7 +4285,7 @@ def kinetic_model( X_data[i_gene] = cur_X_data if model.lower().startswith("mixture"): X_fit_data[i_gene] = estm.simulator.x.T - X_fit_data[i_gene][estm.model1.n_species :] *= estm.scale + X_fit_data[i_gene][estm.model1.n_species:] *= estm.scale elif experiment_type in ["mix_kin_deg", "mix_pulse_chase"]: # kinetic chase simulation kinetic_chase = estm.simulator.x.T diff --git a/dynamo/tools/moments.py b/dynamo/tools/moments.py index 1cfb2ec54..d7d77cfad 100755 --- a/dynamo/tools/moments.py +++ b/dynamo/tools/moments.py @@ -173,6 +173,16 @@ def moments( ) layers = DynamoAdataKeyManager.get_available_layer_keys(adata, layers, False, False) + + # for CSP-type method + layers_raw = [ + layer + for layer in layers + if not (layer.startswith("X")) and not (layer.startswith("M")) and ( + not layer.endswith("matrix") and not layer.endswith("ambiguous")) + ] + layers_raw.sort(reverse=True) # ensure we get M_CSP_us, M_CSP_tn, etc (instead of M_CSP_su or M_CSP_nt). + layers = [ layer for layer in layers @@ -224,6 +234,30 @@ def moments( layer_x, layer_y, conn, normalize_W=normalize, mX=None, mY=None ) + # for CSP-type method + size_factor = adata.obs['Size_Factor'].astype("float").values + mapper_CSP = { + "new": "M_CSP_n", + "old": "M_CSP_o", + "total": "M_CSP_t", + "uu": "M_CSP_uu", + "ul": "M_CSP_ul", + "su": "M_CSP_su", + "sl": "M_CSP_sl", + "unspliced": "M_CSP_u", + "spliced": "M_CSP_s", + } + + # for CSP-type method + for i, layer in enumerate(layers_raw): + layer_x = adata.layers[layer].copy() + layer_x = inverse_norm(adata, layer_x) + + if mapper_CSP[layer] not in adata.layers.keys(): + local_size_factor = conn.dot(size_factor) + local_raw_counts = conn.dot(layer_x) + adata.layers[mapper_CSP[layer]] = csr_matrix(local_raw_counts/local_size_factor.reshape(-1,1)) + if "X_protein" in adata.obsm.keys(): # may need to update with mnn or just use knn from protein layer itself. adata.obsm[mapper["X_protein"]] = conn.dot(adata.obsm["X_protein"]) adata.obsp["moments_con"] = conn diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index a89e2e499..2c40dcdd3 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -1128,6 +1128,72 @@ def log_unnormalized_data( return raw +def get_auto_assump_mRNA( + subset_adata, + has_splicing, + has_labeling, + use_moments, + tkey, + NTR_vel, +): + if not NTR_vel: + if has_labeling and not has_splicing: + main_warning( + "Your adata only has labeling data, but `NTR_vel` is set to be " + "`False`. Dynamo will reset it to `True` to enable this analysis." + ) + NTR_vel = True + + normalized, assumption_mRNA = ( + False, + None, + ) + mapper = get_mapper() + + # labeling plus splicing + if np.all(([i in subset_adata.layers.keys() for i in ["X_ul", "X_sl", "X_su"]])) or np.all( + ([mapper[i] in subset_adata.layers.keys() for i in ["X_ul", "X_sl", "X_su"]]) + ): # only uu, ul, su, sl provided + normalized, assumption_mRNA = ( + True, + "ss" if NTR_vel else "kinetic", + ) + + elif np.all(([i in subset_adata.layers.keys() for i in ["uu", "ul", "sl", "su"]])): + normalized, assumption_mRNA = ( + False, + "ss" if NTR_vel else "kinetic", + ) + # labeling without splicing + if not has_splicing and ( + ("X_new" in subset_adata.layers.keys() and not use_moments) + or (mapper["X_new"] in subset_adata.layers.keys() and use_moments) + ): # run new / total ratio (NTR) + normalized, assumption_mRNA = ( + True, + "ss" if NTR_vel else "kinetic", + ) + elif not has_splicing and "new" in subset_adata.layers.keys(): + assumption_mRNA = "ss" if NTR_vel else "kinetic" + # splicing data + if not has_labeling and ( + ("X_unspliced" in subset_adata.layers.keys() and not use_moments) + or (mapper["X_unspliced"] in subset_adata.layers.keys() and use_moments) + ): + normalized, assumption_mRNA = ( + True, + "kinetic" if tkey in subset_adata.obs.columns else "ss", + ) + elif not has_labeling and "unspliced" in subset_adata.layers.keys(): + assumption_mRNA = "kinetic" if tkey in subset_adata.obs.columns else "ss" + + if has_labeling: + if assumption_mRNA is None: + assumption_mRNA = "ss" if NTR_vel else "kinetic" + + return NTR_vel, assumption_mRNA + + def get_data_for_kin_params_estimation( subset_adata, has_splicing, @@ -1598,11 +1664,22 @@ def set_param_kinetic( adata.layers["cell_wise_alpha"][cur_cells_ind, valid_ind_] = alpha else: params_df.loc[valid_ind, kin_param_pre + "alpha"] = alpha + + # to support cell-wise beta + if isarray(beta) and beta.ndim > 1: + params_df.loc[valid_ind, kin_param_pre + "beta"] = beta.mean(1) + if cur_grp == _group[0]: + adata.layers["cell_wise_beta"] = sp.csr_matrix((adata.shape), dtype=np.float64) + beta = beta.T.tocsr() if sp.issparse(beta) else sp.csr_matrix(beta, dtype=np.float64).T + adata.layers["cell_wise_beta"][cur_cells_ind, valid_ind_] = beta + else: + params_df.loc[valid_ind, kin_param_pre + "beta"] = beta + params_df.loc[valid_ind, kin_param_pre + "a"] = a params_df.loc[valid_ind, kin_param_pre + "b"] = b params_df.loc[valid_ind, kin_param_pre + "alpha_a"] = alpha_a params_df.loc[valid_ind, kin_param_pre + "alpha_i"] = alpha_i - params_df.loc[valid_ind, kin_param_pre + "beta"] = beta + # params_df.loc[valid_ind, kin_param_pre + "beta"] = beta params_df.loc[valid_ind, kin_param_pre + "gamma"] = gamma params_df.loc[valid_ind, kin_param_pre + "half_life"] = np.log(2) / gamma params_df.loc[valid_ind, kin_param_pre + "cost"] = cost @@ -2271,6 +2348,9 @@ def set_transition_genes( # min_r2 = 0.5 if min_r2 is None else min_r2 # else: min_r2 = 0.9 if min_r2 is None else min_r2 + elif "storm" in adata.uns["dynamics"]["est_method"] and adata.uns["dynamics"]["experiment_type"] == "kin": + # for storm method + min_r2 = 0.9 if min_r2 is None else min_r2 elif adata.uns["dynamics"]["experiment_type"] in [ "mix_kin_deg", "mix_pulse_chase",