## train model functions
import numpy as np
import h5py
import tensorflow as tf
import scipy.sparse as sparse
import time
from tensorflow.keras.callbacks import Callback
from ._scbasset_utils import StochasticReverseComplement, StochasticShift, conv_block, conv_tower, dense_block, GELU
from ..tl._utils import calc_nsls_score
[docs]class Generator:
"""
Generate input data for XChrom model training and create TensorFlow datasets.
This class combines data generation and dataset creation functionality,
compatible with tf.data.Dataset.from_generator.
Parameters
----------
seq_path : str
The path to the sequence HDF5 file, generated by make_h5_sparse function
adata : anndata.AnnData
The anndata object containing:
- adata.X: The raw count matrix of data
- adata.obs['b_zscore']: The sequencing depth vector
- adata.obsm[cell_embed]: The initial cell embedding matrix
cell_input_key : str, default 'zscore32_perpc'
The key name of the initial cell embedding in adata.obsm
peakid : array-like, optional
The array of peak indices to extract adata.X data.
If None, uses sorted(range(adata.shape[1]))
m : scipy.sparse matrix, optional
The peak-cell matrix. Can be loaded with sparse.load_npz('m.npz')
or adata.X.toarray().T. If None, uses adata.X.toarray().T
batch_size : int, default 128
The batch size for the dataset
Attributes
----------
n_cells : int
The number of cells
n_peaks : int
The number of peaks
Examples
--------
>>> gen1 = Generator(seq_path='sequence.h5', adata=ad_train, m=m_train,
... cell_input_key='zscore32_perpc', batch_size=256)
>>> train_ds = gen1.create_dataset(shuffle=True)
>>> gen2 = Generator(seq_path='sequence.h5', adata=ad_val, m=m_val,
... cell_input_key='zscore32_perpc', batch_size=256)
>>> val_ds = gen2.create_dataset(shuffle=False)
>>> model.fit(train_ds, validation_data=val_ds,
... batch_size=128,epochs=1000, callbacks=callbacks_list)
"""
[docs] def __init__(self, seq_path, adata, cell_input_key,
peakid=None, m=None, batch_size=128):
self.seq_path = seq_path
self.batch_size = batch_size
# set peak index to extract adata.X data
if peakid is None:
self.peakid = sorted(range(adata.shape[1]))
else:
self.peakid = np.array(peakid)
# set peak-cell matrix
if m is None:
print("Converting adata.X to dense array. For large datasets, consider pre-computing and saving as sparse matrix.")
self.m = sparse.csr_matrix(adata.X.T)
else:
self.m = m
# extract cell embedding and other data
try:
self.cellembed = adata.obsm[cell_input_key]
except KeyError:
raise KeyError(f"Cell embedding '{cell_input_key}' not found in adata.obsm. "
f"Available keys: {list(adata.obsm.keys())}")
try:
self.b = np.array(adata.obs['b_zscore'].values)
except KeyError:
raise KeyError("'b_zscore' not found in adata.obs. "
"Please ensure the data is properly preprocessed.")
self.n_cells = adata.shape[0]
self.n_peaks = len(self.peakid)
# open HDF5 file and get data
try:
self.file = h5py.File(seq_path, 'r')
self.data = self.file['X'][self.peakid]
except Exception as e:
raise IOError(f"Error reading sequence file {seq_path}: {str(e)}")
def __call__(self):
"""Generator function, used by tf.data.Dataset.from_generator"""
for i in range(self.data.shape[0]):
# get sequence data and convert to one-hot encoding
x1 = self.data[i]
seq_len = x1.shape[0]
ones = np.ones(seq_len)
rows = np.arange(seq_len)
x_seq = sparse.coo_matrix(
(ones, (rows, x1)),
shape=(seq_len, 4),
dtype='int8'
).toarray()
# get labels
y_indices = self.m.indices[self.m.indptr[i]:self.m.indptr[i+1]]
y_labels = np.zeros(self.n_cells, dtype='int8')
y_labels[y_indices] = 1
yield (x_seq, self.cellembed, self.b), y_labels
def create_dataset(self, shuffle=False):
"""
Create a TensorFlow dataset from the generator.
Parameters
----------
shuffle : bool, default False
Whether to shuffle the dataset
Returns
-------
tf.data.Dataset
The configured TensorFlow dataset for training or prediction
"""
seq_len = self.data.shape[1]
cell_vec = self.cellembed.shape[1]
dataset = tf.data.Dataset.from_generator(
self,
output_signature=(
(tf.TensorSpec(shape=(seq_len, 4), dtype=tf.float32),
tf.TensorSpec(shape=(self.n_cells, cell_vec), dtype=tf.float32),
tf.TensorSpec(shape=(self.n_cells,), dtype=tf.float32)),
tf.TensorSpec(shape=(self.n_cells,), dtype=tf.int8)
)
)
if shuffle:
dataset = dataset.shuffle(2000, reshuffle_each_iteration=True)
return dataset.batch(self.batch_size).prefetch(tf.data.AUTOTUNE)
def get_dataset_info(self):
"""
Get information about the dataset.
Returns
-------
dict
Dictionary containing dataset information
"""
return {
'n_cells': self.n_cells,
'n_peaks': self.n_peaks,
'seq_length': self.data.shape[1],
'cell_embedding_dim': self.cellembed.shape[1],
'batch_size': self.batch_size,
}
def __enter__(self):
"""Context manager entry"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - close HDF5 file"""
if hasattr(self, 'file'):
self.file.close()
def close(self):
"""Explicitly close the HDF5 file"""
if hasattr(self, 'file'):
self.file.close()
def __del__(self):
"""Destructor to ensure file is closed"""
self.close()
def add_noise(inputs,noise_factor=0.01):
return inputs + noise_factor * tf.random.normal(shape=tf.shape(inputs))
[docs]def XChrom_model(
n_cells:int,
cell_vec:int = 32,
seq_len:int = 1344,
show_summary:bool = True,
noise_factor:float = 0.01
):
"""
XChrom model construction
Parameters
----------
n_cells: int
number of cells
cell_vec: int
dimension of cell embeddings, default is 32 with rna.obsm['X_pca'], it is consistent with the final dimension of the peak embeddings.
if you want to use other cell embeddings, you can set it to the dimension of the cell/peak embeddings.
seq_len: int
length of the sequence, default is 1344
show_summary: bool
whether to show the summary of the model, default is True
noise_factor: float
noise factor for the cell embeddings, default is 0.01
"""
sequence = tf.keras.Input(shape=(seq_len, 4), name="sequence")
cell_input = tf.keras.Input(shape=(n_cells, cell_vec), name='cell_embed')
seq_depth = tf.keras.Input(shape=(n_cells), name='sequencing_depth')
# sequence part structure
(x1, reverse_bool,) = StochasticReverseComplement()(sequence)
x1 = StochasticShift(3)(x1)
x1 = conv_block(x1, filters=288, kernel_size=17, pool_size=3)
x1 = conv_tower(x1, filters_init=288, filters_mult=1.122, repeat=6, kernel_size=5, pool_size=2)
x1 = conv_block(x1, filters=256, kernel_size=1)
x1 = dense_block(x1, flatten=True, units=cell_vec, dropout=0.2)
x1 = GELU()(x1)
x1 = tf.squeeze(x1, axis=1)
x1 = tf.expand_dims(x1, axis=-1) # (bsz, 32, 1)
# Add noise to cell_embeddings
noisy_cell_input = tf.keras.layers.Lambda(lambda x: tf.keras.backend.in_train_phase(
add_noise(x, noise_factor=noise_factor), x))(cell_input)
x2 = tf.keras.layers.LayerNormalization()(noisy_cell_input)
x2 = tf.keras.layers.Dense(64, activation='relu')(x2)
x2 = tf.keras.layers.Dense(cell_vec, activation='linear', name='final_cellembed')(x2)
# Depth bias branch: use Reshape to force a distinct tensor copy and avoid
# GPU memory aliasing with intermediate tensors from the large cell branch
b = tf.keras.layers.Reshape((n_cells, 1))(seq_depth)
b = tf.keras.layers.Dense(1, activation='linear')(b)
b = tf.keras.layers.Reshape((n_cells,))(b)
multiplied = tf.matmul(x2, x1)
multiplied = tf.squeeze(multiplied, axis=-1) # (bsz, num_cells)
logits = multiplied + b
# Guard against rare GPU NaN (can occur when large intermediate tensors
# share memory with b under TF2.6's BFC allocator)
logits = tf.where(tf.math.is_nan(logits), tf.zeros_like(logits), logits)
safe_logits = tf.clip_by_value(logits, -15.0, 15.0)
output = tf.keras.activations.sigmoid(safe_logits)
model = tf.keras.Model(inputs=(sequence, cell_input, seq_depth), outputs=output)
if show_summary:
model.summary()
return model
class Callback_TrackScore(Callback):
"""
Callback to track clustering metrics (neighbor score and label score) during XChrom model training.
This callback computes cell embeddings from the trained model at each epoch and evaluates
how well the learned embeddings preserve neighborhood structure and cell type labels
compared to reference scRNA-seq data.
Parameters
----------
ad_rna : anndata.AnnData
scRNA-seq data with raw cell represenation in .obsm['X_pca'] (or custom key).
Used to compute scRNA cell neighborhoods, and provide raw cell embedding for prediction.
ad_atac : anndata.AnnData
scATAC-seq raw or predicted data that has been assigned initial cell embeddings, by default from scRNA-seq in .obsm['zscore32_perpc']
and must contain cell types, or clustering results from paired scRNA-seq data
model : tf.keras.Model
XChrom model being trained, has 'final_cellembed' layer by default.
use_rep_rna : str, optional
The key name of the scRNA cells dimension reduction, default is 'X_pca', to compute and generate scRNA neighbors list.
use_rep_atac : str, optional
The key name to store predicted cell embeddings in ad_atac.obsm, default 'x2pred'.
Usually don't need to change this.
label : str, optional
The key name of the cell type labels, default is 'celltype', or 'leiden' from scRNA-seq data.
Should be assighed to ad_atac.obs[label]
cell_input_key : str, optional
The key name of the cell input data in ad_atac.obsm, default 'zscore32_perpc'.
Usually don't need to change this.If you want to use the raw cell embedding, you can change this to 'X_pca'.
Notes
-----
Default values are optimized for standard XChrom workflow:
- 'X_pca': Raw cell embedding from scRNA-seq data, saved in ad_rna.obsm['X_pca']
- 'zscore32_perpc': Z-scored 32 PCs from scRNA-seq data, saved in ad_atac.obsm['zscore32_perpc']
- 'x2pred': Standard key for predicted cell embeddings, saved in ad_atac.obsm['x2pred']
- 'celltype': True cell types,or leiden clustering based on scRNA-seq data, saved in ad_atac.obs['celltype']
Custom variable names can be used by simply changing the parameter values.
The callback will automatically use your custom keys to access the data.
Examples
--------
Basic usage with default parameters:
>>> callback = Callback_TrackScore(
... ad_rna=rna_data,
... ad_atac=atac_data,
... model=xchrom_model,
... use_rep_rna='X_pca', ## saved in ad_rna.obsm['X_pca'] from scRNA-seq data
... use_rep_atac='x2pred', ## saved in ad_atac.obsm['x2pred'] from model prediction per epoch
... label='celltype', ## saved in ad_atac.obs['celltype']
... cell_input_key='zscore32_perpc' ## saved in ad_atac.obsm['zscore32_perpc']
... )
>>> model.fit(dataset, epochs=10, callbacks=[callback])
Troubleshooting
---------------
Common issues and solutions:
1. KeyError about missing keys:
- Check your data keys: ad_atac.obsm.keys() and ad_atac.obs.keys()
- Use custom parameter names to match your data
2. Shape mismatch errors:
- Ensure cell_input_key data has shape (n_cells, cell_vec)
- Verify ad_rna and ad_atac have the same number of cells
3. Model layer not found:
- Ensure XChrom model has 'final_cellembed' layer, which is the output layer of cell embedding
- Check model architecture with model.summary()
Data Requirements
-----------------
ad_rna.obsm must contain:
- Raw cell embedding (default: 'X_pca') for generating initial cell embeddings and computing RNA neighbors from scRNA-seq data
ad_atac.obsm must contain:
- Initial cell embeddings (default: 'zscore32_perpc') for model prediction from scRNA-seq data
ad_atac.obs must contain:
- Cell type labels (default: 'celltype') for label consistency evaluation,which can be from true cell type labels or paired scRNA-seq clustering results
Output
------
During training, prints:
- Neighbor Score: Fraction of shared neighbors between scRNA and predicted embeddings
- Label Score: Fraction of neighbors with same cell type label
Saves to ad_atac.obsm[use_rep_atac]:
- Predicted cell embeddings from current epoch
"""
def __init__(self, ad_rna, ad_atac, model,print_scores=False,use_rep_rna = 'X_pca', use_rep_atac='x2pred', label='pc32_leiden', cell_input_key='zscore32_perpc'):
super().__init__()
self.ad_rna = ad_rna
self.print_scores = print_scores
self.use_rep_rna = use_rep_rna
self.ad_atac = ad_atac
self.use_rep_atac = use_rep_atac
self.label = label
self.cell_input_key = cell_input_key
# Validate input data keys
if self.cell_input_key not in ad_atac.obsm:
raise KeyError(f"Cell input key '{self.cell_input_key}' not found in ad_atac.obsm. "
f"Available keys: {list(ad_atac.obsm.keys())}")
if self.label not in ad_atac.obs:
raise KeyError(f"Label key '{self.label}' not found in ad_atac.obs. "
f"Available keys: {list(ad_atac.obs.keys())}")
self.cell_input = tf.convert_to_tensor(
np.expand_dims(ad_atac.obsm[self.cell_input_key], axis=0),
dtype=tf.float32
)
self.model_ = model
def on_epoch_end(self, epoch, logs=None):
total_start_time = time.time()
# Recreate the cellembedding_model at the end of each epoch to use the updated parameters
cellembedding_model = tf.keras.Model(
inputs=self.model_.input[1],
outputs=self.model_.get_layer('final_cellembed').output
)
# extract cell embeddings
embeddings = cellembedding_model.predict(self.cell_input, batch_size=1, steps=1)
embeddings = np.array(tf.squeeze(embeddings))
# check the shape of embeddings, then save to ad_atac
if embeddings.shape[0] == self.ad_atac.n_obs:
self.ad_atac.obsm[self.use_rep_atac] = embeddings
else:
raise ValueError(f"Predicted embeddings shape {embeddings.shape} does not match number of cells {self.ad_atac.n_obs}")
# calculate the cluster metrics(ns,ls score)
neighbor_score, label_score = calc_nsls_score(self.ad_rna, self.ad_atac, use_rep_atac=self.use_rep_atac,use_rep_rna=self.use_rep_rna,label=self.label)
# save the cluster metrics to logs
if logs is not None:
logs['neighbor_score'] = neighbor_score
logs['label_score'] = label_score
total_time = time.time() - total_start_time
if self.print_scores:
print(f"Neighbors Score: {neighbor_score:.4f}, Labels Score: {label_score:.4f},",
f"\tUsing time in epoch {epoch + 1}: {total_time:.4f}s")
class Callback_SaveModel(Callback):
"""
Callback to save the model weights at the end of each epoch.
Parameters
----------
filepath: Union[str, Path]
The path to save the model weights.
save_freq: int
The frequency to save the model weights.
Examples
--------
>>> callback = Callback_SaveModel(filepath='model_weights_{epoch:02d}.h5', save_freq=10)
>>> model.fit(dataset, epochs=10, callbacks=[callback])
"""
def __init__(self, filepath, save_freq):
super().__init__()
self.filepath = filepath
self.save_freq = save_freq
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.save_freq == 0:
file_path = self.filepath.format(epoch=epoch + 1)
self.model.save_weights(file_path)
print(f'Saved model to {file_path} at epoch {epoch + 1}')