Source code for xchrom.tr.train

#!/usr/bin/env python
"""
XChrom model training script

Run as a standalone script:
CUDA_VISIBLE_DEVICES="0" python /mnt/netshare2/miaoyuanyuan/miniconda3/envs/XChrom/lib/python3.8/site-packages/xchrom/tr/train.py \
    --input_folder  ./data/1_within_sample/train_data/  \
    --cell_embedding_ad  ./data/1_within_sample/m_brain_paired_rna.h5ad \
    --cellembed_raw  X_pca  \
    --bottleneck  32 \
    --out_path  ./data/1_within_sample/train_out/ \
    --epochs  1000\
    --save_freq  1000 \
    --verbose  0 # silent mode, no progress bar

Import as a module:
    >>> import xchrom as xc
    history = xc.tr.train_XChrom(
        input_folder='./data/1_within_sample/train_data/',
        cell_embedding_ad='./data/1_within_sample/m_brain_paired_rna.h5ad',
        cellembed_raw='X_pca',
        out_path='./data/1_within_sample/train_out/',
        trackscore = True,
        celltype = 'pc32_leiden',
        epochs = 1000,
        save_freq = 1000,
        verbose = 0  # silent mode, no progress bar
        )
"""

import anndata
import h5py
import tensorflow as tf
import numpy as np
import scipy.sparse as sparse
import pickle
import configargparse
import random
from pathlib import Path
from typing import Union, Dict, Any, Literal
from scipy import stats
import os
try:
    from .._utils import setup_seed
    from ._utils import Generator, XChrom_model, Callback_TrackScore, Callback_SaveModel
except ImportError:
    import sys
    sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
    from xchrom._utils import setup_seed
    from xchrom.tr._utils import Generator, XChrom_model, Callback_TrackScore, Callback_SaveModel

# os.environ["CUDA_VISIBLE_DEVICES"]="3,2,1,0"  

[docs]def train_XChrom( input_folder: Union[str, Path], cell_embedding_ad: Union[str, Path], out_path: Union[str, Path] = './train_out', bottleneck: int = 32, batch_size: int = 128, lr: float = 0.01, epochs: int = 1000, save_freq: int = 1000, trackscore: bool = False, celltype: str = 'celltype', seed: int = 20, train_split: float = 0.9, cellembed_raw: str = 'X_pca', verbose: Literal[0, 1, 2] = 1, print_scores: bool = False, **kwargs ) -> Dict[str, Any]: """ Train XChrom model Parameters ---------- input_folder: Union[str, Path] Preprocessed data folder, should contain: trainval_seqs.h5, splits.h5, ad_trainval.h5ad, m_trainval.npz cell_embedding_ad: Union[str, Path] scRNA-seq data file path containing raw cell embedding out_path: Union[str, Path], default 'train_out' Output path bottleneck: int, default 32 Bottleneck layer size,should be the same as the dimension of raw cell embedding batch_size: int, default 128 Batch size lr: float, default 0.01 Learning rate epochs: int, default 1000 Number of training epochs save_freq: int, default 1000 Model saving frequency trackscore: bool, default False Whether to compute score metrics every epoch celltype: str, default 'cell_type' Cell type label column name (used when trackscore=True) seed: int, default 20 Random seed train_split: float, default 0.9 Training set/validation set ratio cellembed_raw: str, default 'X_pca' Raw cell embedding key in cell embedding adata verbose: int, default 1 Training verbosity mode. 0=silent, 1=progress bar, 2=one line per epoch print_scores: bool, default False Whether to print ns,ls scores every epoch when trackscore=True **kwargs: dict Additional parameters Returns ------- Dict[str, Any] Dictionary containing training history and model information Examples -------- >>> import xchrom as xc >>> history = xc.tr.train_XChrom( input_folder='./data/1_within_sample/train_data/', cell_embedding_ad='./data/1_within_sample/m_brain_paired_rna.h5ad', cellembed_raw='X_pca', out_path='./data/1_within_sample/train_out/', trackscore = True, celltype = 'pc32_leiden', epochs = 1000, save_freq = 1000, verbose = 0, # silent mode, no progress bar print_scores = False # whether to print ns,ls scores every epoch when trackscore=True ) """ setup_seed(seed) # Convert paths to Path objects input_folder = Path(input_folder) cell_embedding_ad = Path(cell_embedding_ad) out_path = Path(out_path) # Verify input files required_files = [ 'trainval_seqs.h5', 'splits.h5', 'ad_trainval.h5ad', 'm_trainval.npz' ] for file in required_files: if not (input_folder / file).exists(): raise FileNotFoundError(f"Required file not found: {input_folder / file}") if not cell_embedding_ad.exists(): raise FileNotFoundError(f"Cell embedding file not found: {cell_embedding_ad}") # Create output directory os.makedirs(out_path, exist_ok=True) os.makedirs(out_path / 'epoch_model', exist_ok=True) print("=== Start training XChrom model ===") print(f"Input folder: {os.path.abspath(input_folder)}") print(f"Cell embedding file: {os.path.abspath(cell_embedding_ad)}") print(f"Raw cell embedding key: {cellembed_raw}") print(f"Output path: {os.path.abspath(out_path)}") print(f"Model parameters: bottleneck={bottleneck}, batch_size={batch_size}, lr={lr}") # 1. Load raw cell embedding and make z-score normalization print("1. Load raw cell embedding and make z-score normalization...") rna_ad = anndata.read_h5ad(cell_embedding_ad) if cellembed_raw not in rna_ad.obsm: raise ValueError(f"Embedding key '{cellembed_raw}' not found in RNA data") print(f"Raw cell embedding saved to: {os.path.abspath(cell_embedding_ad)}.obsm['{cellembed_raw}']") zscore32_perpc = stats.zscore(np.array(rna_ad.obsm[cellembed_raw]), axis=0) rna_ad.obsm['zscore32_perpc'] = zscore32_perpc rna_ad.write_h5ad(cell_embedding_ad) print(f"Initial cell embedding saved to: {os.path.abspath(cell_embedding_ad)}.obsm['zscore32_perpc']") print(f"Initial cell embedding shape: {rna_ad.obsm['zscore32_perpc'].shape}") # 2. Load training data print("2. Load training data...") with h5py.File(input_folder / 'splits.h5', 'r') as hf: trainval_cellid = hf['trainval_cell'][:] rna_trainval = rna_ad[trainval_cellid, :] trainval_seq = str(input_folder / 'trainval_seqs.h5') ad_trainval = anndata.read_h5ad(input_folder / 'ad_trainval.h5ad') m_trainval = sparse.load_npz(input_folder / 'm_trainval.npz').tocsr() # Verify data consistency if not ad_trainval.obs.index.equals(rna_trainval.obs.index): raise ValueError("scATAC and scRNA data cell indices do not match") ad_trainval.obsm['zscore32_perpc'] = rna_trainval.obsm['zscore32_perpc'] # 3. Prepare training/validation split print("3. Prepare train/val data split...") peak_ids = list(range(m_trainval.shape[0])) random.shuffle(peak_ids) train_size = int(train_split * len(peak_ids)) train_id = sorted(peak_ids[:train_size]) val_id = sorted(peak_ids[train_size:]) print(f"Training peak number: {len(train_id)}, Validation peak number: {len(val_id)}") # Prepare data subset ad_train = ad_trainval[:, train_id] ad_val = ad_trainval[:, val_id] m_train = m_trainval[train_id, :] m_val = m_trainval[val_id, :] train_cell = ad_train.shape[0] # 4. Create TensorFlow dataset print("4. Create TensorFlow dataset...") gen1 = Generator( adata=ad_train, seq_path=trainval_seq, cell_input_key='zscore32_perpc', peakid=train_id, m=m_train, batch_size=batch_size) train_ds = gen1.create_dataset(shuffle=True) gen2 = Generator( adata=ad_val, seq_path=trainval_seq, cell_input_key='zscore32_perpc', peakid=val_id, m=m_val, batch_size=batch_size) val_ds = gen2.create_dataset(shuffle=False) # 5. Build and compile model print("5. Build and compile model...") model = XChrom_model(n_cells=train_cell, cell_vec=bottleneck, **kwargs) lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=lr, decay_steps=10000, decay_rate=0.9 ) model.compile( optimizer=tf.keras.optimizers.Adam( learning_rate=lr_schedule, beta_1=0.95, beta_2=0.9995, clipnorm=1.0 ), loss='binary_crossentropy', metrics=[ 'binary_accuracy', tf.keras.metrics.AUC(name='auc', curve='ROC', multi_label=False), tf.keras.metrics.AUC(name='pr', curve='PR', multi_label=False) ] ) # 6. Set training callbacks print("6. Set training callbacks...") best_model_path = out_path / 'E1000best_model.h5' save_freq = int(save_freq) callbacks_list = [ tf.keras.callbacks.ModelCheckpoint( str(best_model_path), save_best_only=True, save_weights_only=True, monitor='auc', mode='max',restore_best_weights=True, ), tf.keras.callbacks.EarlyStopping( monitor='auc', min_delta=1e-6, mode='max', patience=50, verbose=1 ), Callback_SaveModel( str(out_path / 'epoch_model' / 'epoch{epoch:04d}_model.h5'), save_freq ) ] if trackscore: if celltype not in ad_trainval.obs.columns: if celltype not in rna_trainval.obs.columns: raise ValueError(f"Cell type column '{celltype}' not found in either RNA or AD data, which is required when trackscore=True") else: ad_trainval.obs[celltype] = rna_trainval.obs[celltype] callbacks_list.append( Callback_TrackScore(rna_trainval, ad_trainval, model, print_scores=print_scores, use_rep_rna=cellembed_raw, label=celltype, cell_input_key='zscore32_perpc') ) # 7. Start training print("7. Start training...") print(f"Model will be saved to: {best_model_path}") history = model.fit( train_ds, batch_size=batch_size, epochs=epochs, validation_data=val_ds, callbacks=callbacks_list, verbose=verbose ) # 8. Save results print("8. Save training results...") history_path = out_path / 'history.pickle' with open(history_path, 'wb') as f: pickle.dump(history.history, f) print(f"=== Training completed! ===") print(f"Best model: {os.path.abspath(best_model_path)}") print(f"Training history: {os.path.abspath(history_path)}") return { 'history': history.history, 'model': model, 'best_model_path': os.path.abspath(best_model_path), 'train_cells_number': train_cell, 'train_peaks_number': len(train_id), 'val_peaks_number': len(val_id) }
def make_parser(): """Create command line argument parser""" parser = configargparse.ArgParser( description="Train XChrom model - can be run as a standalone script or imported as a module" ) parser.add_argument('--input_folder', type=str, required=True, help='Preprocessed data folder, should contain: trainval_seqs.h5, splits.h5, ad_trainval.h5ad, m_trainval.npz') parser.add_argument('--cell_embedding_ad', type=str, required=True, help='scRNA-seq data file path containing raw cell embedding') parser.add_argument('--out_path', type=str, default='train_out', help='Output path, default to ./train_out/') parser.add_argument('--bottleneck', type=int, default=32, help='Bottleneck layer size, default to 32') parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training, default to 128') parser.add_argument('--lr', type=float, default=0.01, help='Learning rate, default to 0.01') parser.add_argument('--epochs', type=int, default=1000, help='Number of training epochs, default to 1000') parser.add_argument('--save_freq', type=int, default=1000, help='Model saving frequency, default to 1000 epochs, just save the best model') parser.add_argument('--trackscore', action='store_true', help='Whether to compute ns,ls score metrics every epoch, default to False') parser.add_argument('--celltype', type=str, default='cell_type', help='Cell type label column name, required when trackscore=True') parser.add_argument('--seed', type=int, default=20, help='Random seed, default to 20') parser.add_argument('--train_split', type=float, default=0.9, help='Training set/validation set ratio, default to 0.9') parser.add_argument('--cellembed_raw', type=str, default='X_pca', help='Raw cell embedding key in RNA data, default to X_pca') parser.add_argument('--verbose', type=int, default=1, choices=[0, 1, 2], help='Training verbosity mode. 0=silent, 1=progress bar, 2=one line per epoch, default to 1') return parser def main(): """Command line entry function""" parser = make_parser() args = parser.parse_args() # call training function try: result = train_XChrom( input_folder=args.input_folder, cell_embedding_ad=args.cell_embedding_ad, out_path=args.out_path, bottleneck=args.bottleneck, batch_size=args.batch_size, lr=args.lr, epochs=args.epochs, save_freq=args.save_freq, trackscore=args.trackscore, celltype=args.celltype, seed=args.seed, train_split=args.train_split, cellembed_raw=args.cellembed_raw, verbose=args.verbose ) print(f"\n=== Training statistics ===") print(f"Training cell number: {result['train_cells_number']}") print(f"Training peak number: {result['train_peaks_number']}") print(f"Validation peak number: {result['val_peaks_number']}") except Exception as e: print(f"Error during training: {str(e)}") raise if __name__ == "__main__": main()