Source code for cryoblob.files

"""
Module: files
---------------------------

Contains the codes for interfacing with data files.
One goal here is to separate the Python code from
the JAX code. Thus most of the necessary outward
facing code, which is necessarily in Python, is here.

Functions
---------
- `file_params`:
    Get the parameters for the file organization.
- `load_mrc`:
    Reads an MRC-format cryo-EM file, extracting image data and metadata.
- `process_single_file`:
    Process a single file for blob detection with memory optimization.
- `process_batch_of_files`:
    Process a batch of files in parallel with memory optimization.
- `folder_blobs`:
    Process a folder of images for blob detection with memory optimization.
- `estimate_batch_size`:
    Estimate optimal batch size for processing MRC files based on available memory.
- `estimate_memory_usage`:
    Estimate memory usage in GB for processing a single MRC file.
- `get_optimal_batch_size`:
    Get optimal batch size by sampling multiple files from the list.
"""

import glob
import json
import os
from importlib.resources import files

import jax
import jax.numpy as jnp
import mrcfile
import numpy as np
import pandas as pd
from beartype import beartype
from beartype.typing import Dict, List, Literal, Optional, Tuple, Union
from jax import device_get, device_put, vmap
from jaxtyping import Array, Float, jaxtyped
from tqdm.auto import tqdm

from cryoblob.blobs import blob_list_log, preprocessing
from cryoblob.types import MRC_Image, scalar_float, scalar_int, PreprocessingConfig
from cryoblob.valid import make_preprocessing_config, make_mrc_image

jax.config.update("jax_enable_x64", True)


@jaxtyped(typechecker=beartype)
def file_params() -> Tuple[str, dict]:
    """
    Description
    -----------
    Run this at the beginning to generate the dict
    This gives both the absolute and relative paths
    on how the files are organized.

    Returns
    -------
    - `main_directory` (str):
        the main directory where the package is located.
    - `folder_structure` (dict):
        where the files and data are stored, as read
        from the organization.json file.
    """
    pkg_directory: str = os.path.dirname(__file__)
    listring: List = pkg_directory.split("/")[1:-2]
    listring.append("")
    listring.insert(0, "")
    main_directory: str = "/".join(listring)
    folder_structure: dict = json.load(
        open(files("cryoblob.params").joinpath("organization.json"))
    )
    return (main_directory, folder_structure)


[docs] def load_mrc(filepath: str) -> MRC_Image: """ Description ----------- Reads an MRC-format cryo-EM file from the specified path, extracting image data and relevant metadata. All numeric data are converted into JAX arrays and wrapped into a structured `MRC_Image` PyTree, compatible with JAX's functional programming paradigm. Parameters ---------- - `filepath` (str): Path to the MRC file to be loaded. Returns ------- `MRC_Image` (A PyTree containing): - `image_data`: Image array (2D or 3D). - `voxel_size`: Array containing voxel dimensions in Å (Z, Y, X). - `origin`: Array indicating the origin coordinates from the header (Z, Y, X). - `data_min`: Minimum pixel value. - `data_max`: Maximum pixel value. - `data_mean`: Mean pixel value. - `mode`: Integer code representing data type (e.g., 0=int8, 1=int16, 2=float32). Examples -------- >>> mrc_image = load_mrc("example.mrc") >>> print(mrc_image.voxel_size) Array([1.2, 1.2, 1.2], dtype=float32) Notes ----- - This function uses the `mrcfile` library for parsing MRC files. - The resulting PyTree structure (`MRC_Image`) is explicitly designed for use in JAX-based image processing pipelines. """ with mrcfile.open(filepath, permissive=True) as mrc: data = jnp.array(mrc.data) voxel_size = jnp.array( [ float(mrc.voxel_size.z), float(mrc.voxel_size.y), float(mrc.voxel_size.x), ] ) origin = jnp.array( [ float(mrc.header.origin.z), float(mrc.header.origin.y), float(mrc.header.origin.x), ] ) data_min = jnp.array(mrc.header.dmin) data_max = jnp.array(mrc.header.dmax) data_mean = jnp.array(mrc.header.dmean) mode = jnp.array(mrc.header.mode) MRC_data: MRC_Image = make_mrc_image( image_data=data, voxel_size=voxel_size, origin=origin, data_min=data_min, data_max=data_max, data_mean=data_mean, mode=mode, ) return MRC_data
@jaxtyped(typechecker=beartype) def process_single_file( file_path: str, preprocessing_config: PreprocessingConfig, blob_downscale: scalar_float, stream_mode: Optional[bool] = True, ) -> Tuple[Float[Array, "n 3"], str]: """ Description ----------- Process a single MRC file for blob detection with memory optimization and validated preprocessing configuration. Parameters ---------- - `file_path` (str): Path to the MRC image file to process - `preprocessing_config` (PreprocessingConfig): Validated preprocessing configuration containing all processing parameters - `blob_downscale` (scalar_float): Downscaling factor applied during blob detection to reduce computational load - `stream_mode` (bool, optional): Whether to use memory-mapped file access for large files to reduce memory usage. Default is True Returns ------- - `scaled_blobs` (Float[Array, "n 3"]): Array of detected blob coordinates and sizes where each row contains [Y_position_nm, X_position_nm, Size_nm] - `file_path` (str): Original file path for tracking processed files Raises ------ Exception: Returns empty array and original file path if processing fails, with error message printed to console Notes ----- The function uses streaming mode for large files to reduce memory usage and immediately releases file handles after reading. All intermediate arrays are explicitly deleted to manage GPU memory efficiently. """ try: if stream_mode: with mrcfile.mmap(file_path, mode="r") as data_mrc: x_calib: scalar_float = jnp.asarray(data_mrc.voxel_size.x) y_calib: scalar_float = jnp.asarray(data_mrc.voxel_size.y) im_data: Float[Array, "h w"] = jnp.asarray( data_mrc.data, dtype=jnp.float64 ) else: with mrcfile.open(file_path) as data_mrc: x_calib: scalar_float = jnp.asarray(data_mrc.voxel_size.x) y_calib: scalar_float = jnp.asarray(data_mrc.voxel_size.y) im_data: Float[Array, "h w"] = jnp.asarray( data_mrc.data, dtype=jnp.float64 ) im_data = device_put(im_data) mrc_image: MRC_Image = make_mrc_image( image_data=im_data, voxel_size=jnp.array([1.0, y_calib, x_calib]), origin=jnp.zeros(3), data_min=jnp.min(im_data), data_max=jnp.max(im_data), data_mean=jnp.mean(im_data), mode=2, ) preprocessed_imdata: Float[Array, "h w"] = preprocessing( image_orig=mrc_image.image_data, return_params=False, exponential=preprocessing_config.exponential, logarizer=preprocessing_config.logarizer, gblur=preprocessing_config.gblur, background=preprocessing_config.background, apply_filter=preprocessing_config.apply_filter, ) preprocessed_mrc: MRC_Image = make_mrc_image( image_data=preprocessed_imdata, voxel_size=mrc_image.voxel_size, origin=mrc_image.origin, data_min=jnp.min(preprocessed_imdata), data_max=jnp.max(preprocessed_imdata), data_mean=jnp.mean(preprocessed_imdata), mode=mrc_image.mode, ) del im_data del mrc_image blob_list: Float[Array, "n 3"] = blob_list_log( preprocessed_mrc, downscale=blob_downscale ) del preprocessed_imdata del preprocessed_mrc scaled_blobs: Float[Array, "n 3"] = jnp.concatenate( [ (blob_list[:, 0] * y_calib)[:, None], (blob_list[:, 1] * x_calib)[:, None], (blob_list[:, 2] * jnp.sqrt(y_calib**2 + x_calib**2))[:, None], ], axis=1, ) return scaled_blobs, file_path except Exception as e: print(f"Error processing {file_path}: {str(e)}") empty_array: Float[Array, "0 3"] = jnp.array([]).reshape(0, 3) return empty_array, file_path @jaxtyped(typechecker=beartype) def process_batch_of_files( file_batch: List[str], preprocessing_config: PreprocessingConfig, blob_downscale: scalar_float, ) -> List[Tuple[Float[Array, "n 3"], str]]: """ Process a batch of files in parallel with memory optimization. Parameters ---------- - `file_batch` (List[str]): List of file paths to process - `preprocessing_kwargs` (Dict): Preprocessing parameters - `blob_downscale` (float): Downscaling factor Returns ------- - `results` (List[Tuple[Array, str]]): List of (blobs, file_path) tuples """ batch_process_fn = vmap( lambda x: process_single_file(x, preprocessing_config, blob_downscale) ) return batch_process_fn(jnp.array(file_batch)) @jaxtyped(typechecker=beartype) def folder_blobs( folder_location: str, file_type: Optional[Literal[" mrc"]] = "mrc", blob_downscale: Optional[scalar_float] = 7.0, target_memory_gb: Optional[scalar_float] = 4.0, stream_large_files: Optional[bool] = True, **kwargs, ) -> pd.DataFrame: """ Description ----------- Process a folder of MRC images for blob detection with memory optimization and validated preprocessing configuration. Automatically manages batch processing and memory usage to prevent GPU memory overflow. Parameters ---------- - `folder_location` (str): Path to folder containing MRC images to process - `file_type` (Literal["mrc"], optional): File extension to search for in the folder. Default is "mrc" - `blob_downscale` (scalar_float, optional): Downscaling factor applied during blob detection. Default is 7.0 - `target_memory_gb` (scalar_float, optional): Target GPU memory usage in GB for batch size optimization. Default is 4.0 - `stream_large_files` (bool, optional): Whether to use memory-mapped file access for large files. Default is True - `**kwargs`: Additional preprocessing parameters passed to PreprocessingConfig. Valid options: exponential, logarizer, gblur, background, apply_filter Returns ------- - `blob_dataframe` (pd.DataFrame): DataFrame containing detected blob information with columns: ['File Location', 'Center Y (nm)', 'Center X (nm)', 'Size (nm)'] Raises ------ ValueError: If preprocessing parameters are invalid according to PreprocessingConfig validation Notes ----- Memory Management: - Uses batch processing to control memory usage - Automatically adjusts batch size based on available memory - Clears device memory between batches - Streams large files if needed - Efficiently handles intermediate results The function processes files in batches to prevent memory overflow and provides a progress bar to track processing status. Empty folders return an empty DataFrame with the expected column structure. """ default_kwargs: Dict[str, Union[bool, int]] = { "exponential": False, "logarizer": False, "gblur": 0, "background": 0, "apply_filter": 0, } combined_kwargs: Dict[str, Union[bool, int]] = {**default_kwargs, **kwargs} preprocessing_config: PreprocessingConfig = make_preprocessing_config( **combined_kwargs ) file_pattern: str = folder_location + "*." + file_type file_list: List[str] = glob.glob(file_pattern) if not file_list: empty_columns: List[str] = [ "File Location", "Center Y (nm)", "Center X (nm)", "Size (nm)", ] return pd.DataFrame(columns=empty_columns) batch_size: scalar_int = estimate_batch_size(file_list[0], target_memory_gb) all_blobs: List[np.ndarray] = [] all_files: List[str] = [] with tqdm(total=len(file_list), desc="Processing files") as pbar: for i in range(0, len(file_list), batch_size): batch_files: List[str] = file_list[i : i + batch_size] batch_results: List[Tuple[Float[Array, "n 3"], str]] = ( process_batch_of_files( batch_files, preprocessing_config, blob_downscale ) ) for blobs, file_path in batch_results: if len(blobs) > 0: cpu_blobs: np.ndarray = device_get(blobs) all_blobs.append(cpu_blobs) file_count: int = len(cpu_blobs) all_files.extend([file_path] * file_count) pbar.update(len(batch_files)) if all_blobs: combined_blobs: np.ndarray = np.concatenate(all_blobs, axis=0) blob_data: np.ndarray = np.column_stack((all_files, combined_blobs)) column_names: List[str] = [ "File Location", "Center Y (nm)", "Center X (nm)", "Size (nm)", ] blob_dataframe: pd.DataFrame = pd.DataFrame( data=blob_data, columns=column_names ) else: column_names: List[str] = [ "File Location", "Center Y (nm)", "Center X (nm)", "Size (nm)", ] blob_dataframe: pd.DataFrame = pd.DataFrame(columns=column_names) return blob_dataframe @jaxtyped(typechecker=beartype) def estimate_batch_size( sample_file_path: str, target_memory_gb: Optional[scalar_float] = 4.0, safety_factor: Optional[scalar_float] = 0.7, processing_overhead: Optional[scalar_float] = 3.0, ) -> scalar_int: """ Description ----------- Estimate optimal batch size for processing MRC files based on available memory and file characteristics. This function analyzes a sample file to estimate memory requirements and calculates the maximum number of files that can be processed simultaneously without exceeding memory limits. Parameters ---------- - `sample_file_path` (str): Path to a representative MRC file for size estimation - `target_memory_gb` (scalar_float, optional): Target GPU memory usage in GB. Default is 4.0 - `safety_factor` (scalar_float, optional): Safety factor to prevent memory overflow (0.0-1.0). Default is 0.7 (use 70% of target memory) - `processing_overhead` (scalar_float, optional): Memory overhead multiplier for processing operations. Default is 3.0 (processing uses 3x the raw data size) Returns ------- - `batch_size` (scalar_int): Recommended batch size for processing Notes ----- The estimation considers: - Raw file size in memory (dtype conversion) - Preprocessing operations (filtering, transformations) - Blob detection memory requirements - JAX compilation overhead - Intermediate array storage Memory estimation formula: ``` per_file_memory = file_size * processing_overhead available_memory = target_memory_gb * safety_factor * 1e9 batch_size = max(1, available_memory // per_file_memory) ``` Examples -------- >>> batch_size = estimate_batch_size("sample.mrc", target_memory_gb=8.0) >>> print(f"Recommended batch size: {batch_size}") """ try: with mrcfile.open(sample_file_path, mode="r", permissive=True) as mrc: data_shape: tuple = mrc.data.shape array_elements: scalar_int = int(jnp.prod(jnp.array(data_shape))) jax_memory_bytes: scalar_float = float(array_elements * 8) per_file_memory: scalar_float = jax_memory_bytes * processing_overhead target_memory_bytes: scalar_float = target_memory_gb * 1e9 available_memory: scalar_float = target_memory_bytes * safety_factor estimated_batch_size: scalar_float = available_memory / per_file_memory batch_size: scalar_int = max(1, int(jnp.floor(estimated_batch_size))) min_batch_size: scalar_int = 1 max_batch_size: scalar_int = 50 final_batch_size: scalar_int = max( min_batch_size, min(batch_size, max_batch_size) ) return final_batch_size except Exception as e: print(f"Warning: Could not estimate batch size for {sample_file_path}: {e}") print("Using conservative batch size of 2") return 2 @jaxtyped(typechecker=beartype) def estimate_memory_usage( file_path: str, include_preprocessing: Optional[bool] = True, include_blob_detection: Optional[bool] = True, ) -> scalar_float: """ Description ----------- Estimate memory usage in GB for processing a single MRC file. Parameters ---------- - `file_path` (str): Path to MRC file - `include_preprocessing` (bool, optional): Include memory for preprocessing operations. Default is True - `include_blob_detection` (bool, optional): Include memory for blob detection. Default is True Returns ------- - `memory_gb` (scalar_float): Estimated memory usage in GB """ try: with mrcfile.open(file_path, mode="r", permissive=True) as mrc: data_shape: tuple = mrc.data.shape array_elements: scalar_int = int(jnp.prod(jnp.array(data_shape))) base_memory: scalar_float = float(array_elements * 8) total_memory: scalar_float = base_memory if include_preprocessing: preprocessing_memory: scalar_float = base_memory * 2.0 total_memory += preprocessing_memory if include_blob_detection: typical_scales: scalar_int = 10 downscale_factor: scalar_float = 4.0 downscaled_elements: scalar_float = array_elements / ( downscale_factor**2 ) scale_space_memory: scalar_float = ( downscaled_elements * typical_scales * 8 ) total_memory += scale_space_memory memory_gb: scalar_float = total_memory / 1e9 return memory_gb except Exception as e: print(f"Warning: Could not estimate memory for {file_path}: {e}") return 1.0 @jaxtyped(typechecker=beartype) def get_optimal_batch_size( file_list: list[str], target_memory_gb: Optional[scalar_float] = 4.0, sample_fraction: Optional[scalar_float] = 0.1, ) -> scalar_int: """ Description ----------- Get optimal batch size by sampling multiple files from the list. Parameters ---------- - `file_list` (list[str]): List of file paths to process - `target_memory_gb` (scalar_float, optional): Target memory usage in GB. Default is 4.0 - `sample_fraction` (scalar_float, optional): Fraction of files to sample for estimation. Default is 0.1 Returns ------- - `batch_size` (scalar_int): Optimal batch size """ if not file_list: return 1 num_samples: scalar_int = max(1, int(len(file_list) * sample_fraction)) sample_indices: list = jnp.linspace( 0, len(file_list) - 1, num_samples, dtype=int ).tolist() batch_sizes: list = [] for idx in sample_indices: try: batch_size: scalar_int = estimate_batch_size( file_list[idx], target_memory_gb=target_memory_gb ) batch_sizes.append(batch_size) except Exception: continue if not batch_sizes: return 2 optimal_batch_size: scalar_int = int(min(batch_sizes)) return max(1, optimal_batch_size)