Source code for xchrom.tr._utils

## train model functions
import numpy as np
import h5py
import tensorflow as tf
import scipy.sparse as sparse
import time
from tensorflow.keras.callbacks import Callback
from ._scbasset_utils import StochasticReverseComplement, StochasticShift, conv_block, conv_tower, dense_block, GELU
from ..tl._utils import calc_nsls_score

[docs]class Generator: """ Generate input data for XChrom model training and create TensorFlow datasets. This class combines data generation and dataset creation functionality, compatible with tf.data.Dataset.from_generator. Parameters ---------- seq_path : str The path to the sequence HDF5 file, generated by make_h5_sparse function adata : anndata.AnnData The anndata object containing: - adata.X: The raw count matrix of data - adata.obs['b_zscore']: The sequencing depth vector - adata.obsm[cell_embed]: The initial cell embedding matrix cell_input_key : str, default 'zscore32_perpc' The key name of the initial cell embedding in adata.obsm peakid : array-like, optional The array of peak indices to extract adata.X data. If None, uses sorted(range(adata.shape[1])) m : scipy.sparse matrix, optional The peak-cell matrix. Can be loaded with sparse.load_npz('m.npz') or adata.X.toarray().T. If None, uses adata.X.toarray().T batch_size : int, default 128 The batch size for the dataset Attributes ---------- n_cells : int The number of cells n_peaks : int The number of peaks Examples -------- >>> gen1 = Generator(seq_path='sequence.h5', adata=ad_train, m=m_train, ... cell_input_key='zscore32_perpc', batch_size=256) >>> train_ds = gen1.create_dataset(shuffle=True) >>> gen2 = Generator(seq_path='sequence.h5', adata=ad_val, m=m_val, ... cell_input_key='zscore32_perpc', batch_size=256) >>> val_ds = gen2.create_dataset(shuffle=False) >>> model.fit(train_ds, validation_data=val_ds, ... batch_size=128,epochs=1000, callbacks=callbacks_list) """
[docs] def __init__(self, seq_path, adata, cell_input_key, peakid=None, m=None, batch_size=128): self.seq_path = seq_path self.batch_size = batch_size # set peak index to extract adata.X data if peakid is None: self.peakid = sorted(range(adata.shape[1])) else: self.peakid = np.array(peakid) # set peak-cell matrix if m is None: print("Converting adata.X to dense array. For large datasets, consider pre-computing and saving as sparse matrix.") self.m = sparse.csr_matrix(adata.X.T) else: self.m = m # extract cell embedding and other data try: self.cellembed = adata.obsm[cell_input_key] except KeyError: raise KeyError(f"Cell embedding '{cell_input_key}' not found in adata.obsm. " f"Available keys: {list(adata.obsm.keys())}") try: self.b = np.array(adata.obs['b_zscore'].values) except KeyError: raise KeyError("'b_zscore' not found in adata.obs. " "Please ensure the data is properly preprocessed.") self.n_cells = adata.shape[0] self.n_peaks = len(self.peakid) # open HDF5 file and get data try: self.file = h5py.File(seq_path, 'r') self.data = self.file['X'][self.peakid] except Exception as e: raise IOError(f"Error reading sequence file {seq_path}: {str(e)}")
def __call__(self): """Generator function, used by tf.data.Dataset.from_generator""" for i in range(self.data.shape[0]): # get sequence data and convert to one-hot encoding x1 = self.data[i] seq_len = x1.shape[0] ones = np.ones(seq_len) rows = np.arange(seq_len) x_seq = sparse.coo_matrix( (ones, (rows, x1)), shape=(seq_len, 4), dtype='int8' ).toarray() # get labels y_indices = self.m.indices[self.m.indptr[i]:self.m.indptr[i+1]] y_labels = np.zeros(self.n_cells, dtype='int8') y_labels[y_indices] = 1 yield (x_seq, self.cellembed, self.b), y_labels def create_dataset(self, shuffle=False): """ Create a TensorFlow dataset from the generator. Parameters ---------- shuffle : bool, default False Whether to shuffle the dataset Returns ------- tf.data.Dataset The configured TensorFlow dataset for training or prediction """ seq_len = self.data.shape[1] cell_vec = self.cellembed.shape[1] dataset = tf.data.Dataset.from_generator( self, output_signature=( (tf.TensorSpec(shape=(seq_len, 4), dtype=tf.float32), tf.TensorSpec(shape=(self.n_cells, cell_vec), dtype=tf.float32), tf.TensorSpec(shape=(self.n_cells,), dtype=tf.float32)), tf.TensorSpec(shape=(self.n_cells,), dtype=tf.int8) ) ) if shuffle: dataset = dataset.shuffle(2000, reshuffle_each_iteration=True) return dataset.batch(self.batch_size).prefetch(tf.data.AUTOTUNE) def get_dataset_info(self): """ Get information about the dataset. Returns ------- dict Dictionary containing dataset information """ return { 'n_cells': self.n_cells, 'n_peaks': self.n_peaks, 'seq_length': self.data.shape[1], 'cell_embedding_dim': self.cellembed.shape[1], 'batch_size': self.batch_size, } def __enter__(self): """Context manager entry""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit - close HDF5 file""" if hasattr(self, 'file'): self.file.close() def close(self): """Explicitly close the HDF5 file""" if hasattr(self, 'file'): self.file.close() def __del__(self): """Destructor to ensure file is closed""" self.close()
def add_noise(inputs,noise_factor=0.01): return inputs + noise_factor * tf.random.normal(shape=tf.shape(inputs))
[docs]def XChrom_model( n_cells:int, cell_vec:int = 32, seq_len:int = 1344, show_summary:bool = True, noise_factor:float = 0.01 ): """ XChrom model construction Parameters ---------- n_cells: int number of cells cell_vec: int dimension of cell embeddings, default is 32 with rna.obsm['X_pca'], it is consistent with the final dimension of the peak embeddings. if you want to use other cell embeddings, you can set it to the dimension of the cell/peak embeddings. seq_len: int length of the sequence, default is 1344 show_summary: bool whether to show the summary of the model, default is True noise_factor: float noise factor for the cell embeddings, default is 0.01 """ sequence = tf.keras.Input(shape=(seq_len, 4), name="sequence") cell_input = tf.keras.Input(shape=(n_cells, cell_vec), name='cell_embed') seq_depth = tf.keras.Input(shape=(n_cells), name='sequencing_depth') # sequence part structure (x1, reverse_bool,) = StochasticReverseComplement()(sequence) x1 = StochasticShift(3)(x1) x1 = conv_block(x1, filters=288, kernel_size=17, pool_size=3) x1 = conv_tower(x1, filters_init=288, filters_mult=1.122, repeat=6, kernel_size=5, pool_size=2) x1 = conv_block(x1, filters=256, kernel_size=1) x1 = dense_block(x1, flatten=True, units=cell_vec, dropout=0.2) x1 = GELU()(x1) x1 = tf.squeeze(x1, axis=1) x1 = tf.expand_dims(x1, axis=-1) # (bsz, 32, 1) # Add noise to cell_embeddings noisy_cell_input = tf.keras.layers.Lambda(lambda x: tf.keras.backend.in_train_phase( add_noise(x, noise_factor=noise_factor), x))(cell_input) x2 = tf.keras.layers.LayerNormalization()(noisy_cell_input) x2 = tf.keras.layers.Dense(64, activation='relu')(x2) x2 = tf.keras.layers.Dense(cell_vec, activation='linear', name='final_cellembed')(x2) # Depth bias branch: use Reshape to force a distinct tensor copy and avoid # GPU memory aliasing with intermediate tensors from the large cell branch b = tf.keras.layers.Reshape((n_cells, 1))(seq_depth) b = tf.keras.layers.Dense(1, activation='linear')(b) b = tf.keras.layers.Reshape((n_cells,))(b) multiplied = tf.matmul(x2, x1) multiplied = tf.squeeze(multiplied, axis=-1) # (bsz, num_cells) logits = multiplied + b # Guard against rare GPU NaN (can occur when large intermediate tensors # share memory with b under TF2.6's BFC allocator) logits = tf.where(tf.math.is_nan(logits), tf.zeros_like(logits), logits) safe_logits = tf.clip_by_value(logits, -15.0, 15.0) output = tf.keras.activations.sigmoid(safe_logits) model = tf.keras.Model(inputs=(sequence, cell_input, seq_depth), outputs=output) if show_summary: model.summary() return model
class Callback_TrackScore(Callback): """ Callback to track clustering metrics (neighbor score and label score) during XChrom model training. This callback computes cell embeddings from the trained model at each epoch and evaluates how well the learned embeddings preserve neighborhood structure and cell type labels compared to reference scRNA-seq data. Parameters ---------- ad_rna : anndata.AnnData scRNA-seq data with raw cell represenation in .obsm['X_pca'] (or custom key). Used to compute scRNA cell neighborhoods, and provide raw cell embedding for prediction. ad_atac : anndata.AnnData scATAC-seq raw or predicted data that has been assigned initial cell embeddings, by default from scRNA-seq in .obsm['zscore32_perpc'] and must contain cell types, or clustering results from paired scRNA-seq data model : tf.keras.Model XChrom model being trained, has 'final_cellembed' layer by default. use_rep_rna : str, optional The key name of the scRNA cells dimension reduction, default is 'X_pca', to compute and generate scRNA neighbors list. use_rep_atac : str, optional The key name to store predicted cell embeddings in ad_atac.obsm, default 'x2pred'. Usually don't need to change this. label : str, optional The key name of the cell type labels, default is 'celltype', or 'leiden' from scRNA-seq data. Should be assighed to ad_atac.obs[label] cell_input_key : str, optional The key name of the cell input data in ad_atac.obsm, default 'zscore32_perpc'. Usually don't need to change this.If you want to use the raw cell embedding, you can change this to 'X_pca'. Notes ----- Default values are optimized for standard XChrom workflow: - 'X_pca': Raw cell embedding from scRNA-seq data, saved in ad_rna.obsm['X_pca'] - 'zscore32_perpc': Z-scored 32 PCs from scRNA-seq data, saved in ad_atac.obsm['zscore32_perpc'] - 'x2pred': Standard key for predicted cell embeddings, saved in ad_atac.obsm['x2pred'] - 'celltype': True cell types,or leiden clustering based on scRNA-seq data, saved in ad_atac.obs['celltype'] Custom variable names can be used by simply changing the parameter values. The callback will automatically use your custom keys to access the data. Examples -------- Basic usage with default parameters: >>> callback = Callback_TrackScore( ... ad_rna=rna_data, ... ad_atac=atac_data, ... model=xchrom_model, ... use_rep_rna='X_pca', ## saved in ad_rna.obsm['X_pca'] from scRNA-seq data ... use_rep_atac='x2pred', ## saved in ad_atac.obsm['x2pred'] from model prediction per epoch ... label='celltype', ## saved in ad_atac.obs['celltype'] ... cell_input_key='zscore32_perpc' ## saved in ad_atac.obsm['zscore32_perpc'] ... ) >>> model.fit(dataset, epochs=10, callbacks=[callback]) Troubleshooting --------------- Common issues and solutions: 1. KeyError about missing keys: - Check your data keys: ad_atac.obsm.keys() and ad_atac.obs.keys() - Use custom parameter names to match your data 2. Shape mismatch errors: - Ensure cell_input_key data has shape (n_cells, cell_vec) - Verify ad_rna and ad_atac have the same number of cells 3. Model layer not found: - Ensure XChrom model has 'final_cellembed' layer, which is the output layer of cell embedding - Check model architecture with model.summary() Data Requirements ----------------- ad_rna.obsm must contain: - Raw cell embedding (default: 'X_pca') for generating initial cell embeddings and computing RNA neighbors from scRNA-seq data ad_atac.obsm must contain: - Initial cell embeddings (default: 'zscore32_perpc') for model prediction from scRNA-seq data ad_atac.obs must contain: - Cell type labels (default: 'celltype') for label consistency evaluation,which can be from true cell type labels or paired scRNA-seq clustering results Output ------ During training, prints: - Neighbor Score: Fraction of shared neighbors between scRNA and predicted embeddings - Label Score: Fraction of neighbors with same cell type label Saves to ad_atac.obsm[use_rep_atac]: - Predicted cell embeddings from current epoch """ def __init__(self, ad_rna, ad_atac, model,print_scores=False,use_rep_rna = 'X_pca', use_rep_atac='x2pred', label='pc32_leiden', cell_input_key='zscore32_perpc'): super().__init__() self.ad_rna = ad_rna self.print_scores = print_scores self.use_rep_rna = use_rep_rna self.ad_atac = ad_atac self.use_rep_atac = use_rep_atac self.label = label self.cell_input_key = cell_input_key # Validate input data keys if self.cell_input_key not in ad_atac.obsm: raise KeyError(f"Cell input key '{self.cell_input_key}' not found in ad_atac.obsm. " f"Available keys: {list(ad_atac.obsm.keys())}") if self.label not in ad_atac.obs: raise KeyError(f"Label key '{self.label}' not found in ad_atac.obs. " f"Available keys: {list(ad_atac.obs.keys())}") self.cell_input = tf.convert_to_tensor( np.expand_dims(ad_atac.obsm[self.cell_input_key], axis=0), dtype=tf.float32 ) self.model_ = model def on_epoch_end(self, epoch, logs=None): total_start_time = time.time() # Recreate the cellembedding_model at the end of each epoch to use the updated parameters cellembedding_model = tf.keras.Model( inputs=self.model_.input[1], outputs=self.model_.get_layer('final_cellembed').output ) # extract cell embeddings embeddings = cellembedding_model.predict(self.cell_input, batch_size=1, steps=1) embeddings = np.array(tf.squeeze(embeddings)) # check the shape of embeddings, then save to ad_atac if embeddings.shape[0] == self.ad_atac.n_obs: self.ad_atac.obsm[self.use_rep_atac] = embeddings else: raise ValueError(f"Predicted embeddings shape {embeddings.shape} does not match number of cells {self.ad_atac.n_obs}") # calculate the cluster metrics(ns,ls score) neighbor_score, label_score = calc_nsls_score(self.ad_rna, self.ad_atac, use_rep_atac=self.use_rep_atac,use_rep_rna=self.use_rep_rna,label=self.label) # save the cluster metrics to logs if logs is not None: logs['neighbor_score'] = neighbor_score logs['label_score'] = label_score total_time = time.time() - total_start_time if self.print_scores: print(f"Neighbors Score: {neighbor_score:.4f}, Labels Score: {label_score:.4f},", f"\tUsing time in epoch {epoch + 1}: {total_time:.4f}s") class Callback_SaveModel(Callback): """ Callback to save the model weights at the end of each epoch. Parameters ---------- filepath: Union[str, Path] The path to save the model weights. save_freq: int The frequency to save the model weights. Examples -------- >>> callback = Callback_SaveModel(filepath='model_weights_{epoch:02d}.h5', save_freq=10) >>> model.fit(dataset, epochs=10, callbacks=[callback]) """ def __init__(self, filepath, save_freq): super().__init__() self.filepath = filepath self.save_freq = save_freq def on_epoch_end(self, epoch, logs=None): if (epoch + 1) % self.save_freq == 0: file_path = self.filepath.format(epoch=epoch + 1) self.model.save_weights(file_path) print(f'Saved model to {file_path} at epoch {epoch + 1}')