Source code for xchrom.tl._tf_activity

import os
from pathlib import Path
from typing import Union
import anndata
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
from scipy import stats
from Bio import SeqIO
from .._utils import setup_seed
from ..tr._utils import XChrom_model, Generator
from ._utils import fasta_to_h5, generate_h5_files, generate_tf_activity_data
import scipy.sparse

[docs]def calc_tf_activity( motif_dir: Union[str, Path], background_fasta: Union[str, Path], model_path: Union[str, Path], ad_rna: anndata.AnnData, output_file: Union[str, Path] = Path('/tf_activity.h5ad'), cell_embed_raw: str = 'X_pca_harmony', regenerate_motif_h5: bool = False, regenerate_bg_h5: bool = False, seq_len: int = 1344, seed: int = 20, **model_kwargs ) -> anndata.AnnData: """ Calculate motif insertion scores. Parameters ---------- motif_dir: Union[str, Path] The path to the directory containing motif insertion fasta files background_fasta: Union[str, Path] The path to the background sequence fasta file model_path: Union[str, Path] The path to the trained XChrom model weights file ad_rna: anndata.AnnData scRNA-seq data, must contain raw cell embedding in ad_rna.obsm output_file: Union[str, Path], default='./tf_activity.h5ad' Output file path for TF activity results cell_embed_raw: str, default='X_pca_harmony' The key name of raw cell embedding from ad_rna.obsm regenerate_motif_h5: bool, default False Whether to regenerate motif insertion sequence h5 files, if False, will use existing h5 files regenerate_bg_h5: bool, default False Whether to regenerate background sequence h5 files, if False, will use existing h5 file seq_len: int, default 1344 Sequence length for background sequence and motif insertion sequence seed: int, default 20 Random seed **model_kwargs: dict Additional parameters passed to XChrom_model Returns ------- anndata.AnnData TF activity results, X is the activity matrix of cells × motifs save to './tf_activity.h5ad' by default Examples -------- >>> import xchrom as xc >>> tf_act = xc.tl.calc_tf_activity( ... motif_dir='./motif_fasta/', ... background_fasta='./shuffled_peaks.fasta', ... model_path='./best_model.h5', ... ad_rna=m1d1_rna, ... cell_embed_raw='X_pca_harmony', ... regenerate_bg_h5=True, ... regenerate_motif_h5=True, ... seed=20 ... ) """ setup_seed(seed=seed) print("=== Start calculating TF Activity ===") # 1. Prepare raw cell embedding print("1. Prepare raw cell embedding...") if cell_embed_raw not in ad_rna.obsm: raise ValueError(f"The raw cell embedding '{cell_embed_raw}' is not found in ad_rna.obsm") cell_embed = stats.zscore(np.array(ad_rna.obsm[cell_embed_raw]), axis=0) # 2. Generate motif insertion sequence h5 files print("2. Generate motif insertion sequence h5 files...") if regenerate_motif_h5: h5_files, motif_names = generate_h5_files(motif_dir) else: h5_files = [] motif_names = [] files = sorted([f for f in os.listdir(motif_dir) if f.endswith('.h5')]) if not files: print("No h5 files found, generating...") h5_files, motif_names = generate_h5_files(motif_dir) else: for h5 in files: h5_path = os.path.join(motif_dir, h5) h5_files.append(h5_path) motif_names.append(os.path.splitext(h5)[0]) print(f"Found {len(h5_files)} motif files") # 3. Generate background sequence h5 file print("3. Generate background sequence h5 file...") bg_h5_path = os.path.splitext(background_fasta)[0] + ".h5" if regenerate_bg_h5 or not os.path.exists(bg_h5_path): print("Generating background sequence h5 file...") fasta_to_h5(background_fasta, seq_len=seq_len) num_seqs = len(list(SeqIO.parse(background_fasta, "fasta"))) # 4. Create virtual data for dataset creation print("4. Prepare data for XChrom model input...") obs_info = ad_rna.obs # Create virtual data for dataset generation zero_matrix = np.zeros((len(obs_info), num_seqs)) data_virtual = anndata.AnnData(X=zero_matrix, obs=obs_info) data_virtual.obsm['zscore32_perpc'] = cell_embed data_virtual.obs['b_zscore'] = np.full(data_virtual.shape[0], 1) # Pre-compute sparse matrix to avoid repeated warnings print("Converting adata.X to dense array. For large datasets, consider pre-computing and saving as sparse matrix.") m_sparse = scipy.sparse.csr_matrix(data_virtual.X.T) # 5. Load trained model print("5. Load trained model...") n_cells = data_virtual.shape[0] model = XChrom_model(n_cells=n_cells, show_summary=False, **model_kwargs) model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"Trained model file not found: {model_path}") model.load_weights(model_path) # Create prediction model (remove the last layer) prediction_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output) # 6. Calculate background sequence prediction print("6. Calculate background sequence prediction...") gen = Generator( seq_path=bg_h5_path, adata=data_virtual, cell_input_key='zscore32_perpc', peakid=None, m=m_sparse, batch_size=128 ) bg_ds = gen.create_dataset() bg_pred = prediction_model.predict(bg_ds) bg_pred_mean = bg_pred.mean(axis=0) # 7. Calculate motif insertion scores print("7. Calculate motif insertion scores...") h5_files, motif_names = zip(*sorted(zip(h5_files, motif_names))) n_motifs = len(h5_files) tf_scores_raw = np.zeros((n_cells, n_motifs)) print(f"Processing {n_motifs} motif files...") for i, motif_path in enumerate(tqdm(h5_files)): try: # Create tf.data.Dataset for each motif file gen = Generator( seq_path=motif_path, adata=data_virtual, cell_input_key='zscore32_perpc', peakid=None, m=m_sparse, batch_size=128) motif_ds = gen.create_dataset() # Predict motif insertion effect motif_pred = prediction_model.predict(motif_ds) motif_mean = motif_pred.mean(axis=0) # Calculate the difference between motif insertion and background delta = motif_mean - bg_pred_mean tf_scores_raw[:, i] = delta except Exception as e: print(f"Error processing {motif_path}: {str(e)}") tf_scores_raw[:, i] = 0 # 8. Create result AnnData object print("8. Save results...") tf_act = anndata.AnnData( X=tf_scores_raw, # cells × motifs obs=ad_rna.obs.copy(), var=pd.DataFrame(index=list(motif_names)) ) tf_act.var["motif_name"] = tf_act.var.index # Save results output_file = Path(output_file) output_file.parent.mkdir(parents=True, exist_ok=True) tf_act.write_h5ad(output_file) print(f"=== Done! Results saved to: {output_file} ===") print(f"TF activity matrix shape: {tf_act.shape} (cells × motifs)") return tf_act