"""
Module: valid
-------------
JAX PyTree factory functions for configuration management
in the cryoblob preprocessing pipeline. This module provides
type-safe validation using JAX's functional approach with
jax.lax.cond for preprocessing parameters, file paths,
and blob detection configurations.
Functions
---------
- `make_mrc_image`:
Factory function to create an MRC_Image instance.
- `make_preprocessing_config`:
Factory function for preprocessing configuration PyTree
- `make_blob_detection_config`:
Factory function for blob detection configuration PyTree
- `make_file_processing_config`:
Factory function for file processing configuration PyTree
- `make_mrc_metadata`:
Factory function for MRC metadata PyTree
- `make_ridge_detection_config`:
Factory function for ridge detection configuration PyTree
- `make_watershed_config`:
Factory function for watershed configuration PyTree
- `make_enhanced_blob_detection_config`:
Factory function for enhanced blob detection configuration PyTree
- `make_hessian_blob_config`:
Factory function for Hessian blob detection configuration PyTree
- `make_adaptive_filter_config`:
Factory function for adaptive filter configuration PyTree
"""
from beartype.typing import Optional, Union
from beartype import beartype
from jaxtyping import Array, Float, Num, jaxtyped
import jax.numpy as jnp
from jax import lax
from cryoblob.types import (
scalar_float,
scalar_int,
PreprocessingConfig,
BlobDetectionConfig,
FileProcessingConfig,
MRCMetadata,
MRC_Image,
AdaptiveFilterConfig,
RidgeDetectionConfig,
WatershedConfig,
HessianBlobConfig,
EnhancedBlobDetectionConfig,
)
@jaxtyped(typechecker=beartype)
def make_mrc_image(
image_data: Union[Num[Array, "H W"], Num[Array, "D H W"]],
voxel_size: Float[Array, "3"],
origin: Float[Array, "3"],
data_min: scalar_float,
data_max: scalar_float,
data_mean: scalar_float,
mode: scalar_int,
) -> MRC_Image:
"""
Description
-----------
Factory function to create an MRC_Image instance.
Parameters
----------
- `image_data` (Num[Array, "H W"] | Num[Array, "D H W"]):
The image data array from the MRC file. Can be 2D or 3D.
- `voxel_size` (Float[Array, "3"]):
Voxel size in the order (Z, Y, X).
- `origin` (Float[Array, "3"]):
Origin coordinates from the MRC file header (Z, Y, X).
- `data_min` (scalar_float):
Minimum value of image data (as stored in header).
- `data_max` (scalar_float):
Maximum value of image data (as stored in header).
- `data_mean` (scalar_float):
Mean value of image data (as stored in header).
- `mode` (scalar_int):
Data type mode from MRC header (e.g., 0: int8, 2: float32).
Returns
-------
- `MRC_Image`:
An instance of the MRC_Image PyTree structure.
"""
return MRC_Image(
image_data=image_data,
voxel_size=voxel_size,
origin=origin,
data_min=data_min,
data_max=data_max,
data_mean=data_mean,
mode=mode,
)
[docs]
def make_preprocessing_config(
exponential: bool = True,
logarizer: bool = False,
gblur: int = 2,
background: int = 0,
apply_filter: int = 0
) -> PreprocessingConfig:
"""
Factory function to create a PreprocessingConfig PyTree with validation.
Parameters
----------
exponential : bool
Apply exponential function to enhance contrast (default: True)
logarizer : bool
Apply logarithmic transformation (default: False)
gblur : int
Gaussian blur sigma, 0 means no blur (default: 2, range: 0-50)
background : int
Background subtraction sigma, 0 means no subtraction (default: 0, range: 0-100)
apply_filter : int
Wiener filter kernel size, 0 means no filter (default: 0, range: 0-20)
Returns
-------
PreprocessingConfig
Validated preprocessing configuration PyTree
Raises
------
ValueError
If validation fails
"""
# Convert to JAX arrays
exponential_arr = jnp.asarray(exponential, dtype=jnp.bool_)
logarizer_arr = jnp.asarray(logarizer, dtype=jnp.bool_)
gblur_arr = jnp.asarray(gblur, dtype=jnp.int32)
background_arr = jnp.asarray(background, dtype=jnp.int32)
apply_filter_arr = jnp.asarray(apply_filter, dtype=jnp.int32)
# Validation using lax.cond
# Check ranges
gblur_validated = lax.cond(
(gblur_arr < 0) | (gblur_arr > 50),
lambda x: lax.stop_gradient(jnp.asarray(2, dtype=jnp.int32)), # Default value
lambda x: x,
gblur_arr
)
background_validated = lax.cond(
(background_arr < 0) | (background_arr > 100),
lambda x: lax.stop_gradient(jnp.asarray(0, dtype=jnp.int32)), # Default value
lambda x: x,
background_arr
)
apply_filter_validated = lax.cond(
(apply_filter_arr < 0) | (apply_filter_arr > 20),
lambda x: lax.stop_gradient(jnp.asarray(0, dtype=jnp.int32)), # Default value
lambda x: x,
apply_filter_arr
)
# Check conflicting options
logarizer_validated = lax.cond(
exponential_arr & logarizer_arr,
lambda x: lax.stop_gradient(jnp.asarray(False, dtype=jnp.bool_)), # Disable logarizer if both are True
lambda x: x,
logarizer_arr
)
return PreprocessingConfig(
exponential=exponential_arr,
logarizer=logarizer_validated,
gblur=gblur_validated,
background=background_validated,
apply_filter=apply_filter_validated
)
[docs]
def make_blob_detection_config(
min_sigma: float = 1.0,
max_sigma: float = 50.0,
num_sigma: int = 10,
threshold: float = 0.01,
exclude_border: int = 0
) -> BlobDetectionConfig:
"""
Factory function to create a BlobDetectionConfig PyTree with validation.
Parameters
----------
min_sigma : float
Minimum sigma for Laplacian of Gaussian (default: 1.0)
max_sigma : float
Maximum sigma for Laplacian of Gaussian (default: 50.0)
num_sigma : int
Number of sigma values to test (default: 10)
threshold : float
Detection threshold (default: 0.01)
exclude_border : int
Pixels to exclude from border (default: 0)
Returns
-------
BlobDetectionConfig
Validated blob detection configuration PyTree
"""
# Convert to JAX arrays
min_sigma_arr = jnp.asarray(min_sigma, dtype=jnp.float32)
max_sigma_arr = jnp.asarray(max_sigma, dtype=jnp.float32)
num_sigma_arr = jnp.asarray(num_sigma, dtype=jnp.int32)
threshold_arr = jnp.asarray(threshold, dtype=jnp.float32)
exclude_border_arr = jnp.asarray(exclude_border, dtype=jnp.int32)
# Validation
min_sigma_validated = lax.cond(
min_sigma_arr <= 0,
lambda x: jnp.asarray(1.0, dtype=jnp.float32),
lambda x: x,
min_sigma_arr
)
max_sigma_validated = lax.cond(
max_sigma_arr <= 0,
lambda x: jnp.asarray(50.0, dtype=jnp.float32),
lambda x: x,
max_sigma_arr
)
max_sigma_validated = lax.cond(
max_sigma_validated < min_sigma_validated,
lambda x: min_sigma_validated,
lambda x: x,
max_sigma_validated
)
num_sigma_validated = lax.cond(
num_sigma_arr <= 0,
lambda x: jnp.asarray(10, dtype=jnp.int32),
lambda x: x,
num_sigma_arr
)
exclude_border_validated = lax.cond(
exclude_border_arr < 0,
lambda x: jnp.asarray(0, dtype=jnp.int32),
lambda x: x,
exclude_border_arr
)
return BlobDetectionConfig(
min_sigma=min_sigma_validated,
max_sigma=max_sigma_validated,
num_sigma=num_sigma_validated,
threshold=threshold_arr,
exclude_border=exclude_border_validated
)
[docs]
def make_file_processing_config(
batch_size: int = 4,
memory_limit_gb: float = 8.0
) -> FileProcessingConfig:
"""
Factory function to create a FileProcessingConfig PyTree with validation.
Parameters
----------
batch_size : int
Number of files to process in parallel (default: 4)
memory_limit_gb : float
Memory limit in GB (default: 8.0)
Returns
-------
FileProcessingConfig
Validated file processing configuration PyTree
"""
# Convert to JAX arrays
batch_size_arr = jnp.asarray(batch_size, dtype=jnp.int32)
memory_limit_gb_arr = jnp.asarray(memory_limit_gb, dtype=jnp.float32)
# Validation
batch_size_validated = lax.cond(
batch_size_arr <= 0,
lambda x: jnp.asarray(4, dtype=jnp.int32),
lambda x: x,
batch_size_arr
)
memory_limit_gb_validated = lax.cond(
memory_limit_gb_arr <= 0,
lambda x: jnp.asarray(8.0, dtype=jnp.float32),
lambda x: x,
memory_limit_gb_arr
)
return FileProcessingConfig(
batch_size=batch_size_validated,
memory_limit_gb=memory_limit_gb_validated
)
[docs]
def make_adaptive_filter_config(
kernel_size: int = 5,
noise_estimate: float = 0.01,
iterations: int = 10
) -> AdaptiveFilterConfig:
"""
Factory function to create an AdaptiveFilterConfig PyTree with validation.
Parameters
----------
kernel_size : int
Size of the filter kernel (default: 5)
noise_estimate : float
Initial noise estimate (default: 0.01)
iterations : int
Number of adaptation iterations (default: 10)
Returns
-------
AdaptiveFilterConfig
Validated adaptive filter configuration PyTree
"""
# Convert to JAX arrays
kernel_size_arr = jnp.asarray(kernel_size, dtype=jnp.int32)
noise_estimate_arr = jnp.asarray(noise_estimate, dtype=jnp.float32)
iterations_arr = jnp.asarray(iterations, dtype=jnp.int32)
# Ensure kernel_size is odd and positive
kernel_size_validated = lax.cond(
kernel_size_arr <= 0,
lambda x: jnp.asarray(5, dtype=jnp.int32),
lambda x: x,
kernel_size_arr
)
kernel_size_validated = lax.cond(
kernel_size_validated % 2 == 0,
lambda x: x + 1,
lambda x: x,
kernel_size_validated
)
noise_estimate_validated = lax.cond(
noise_estimate_arr <= 0,
lambda x: jnp.asarray(0.01, dtype=jnp.float32),
lambda x: x,
noise_estimate_arr
)
iterations_validated = lax.cond(
iterations_arr <= 0,
lambda x: jnp.asarray(10, dtype=jnp.int32),
lambda x: x,
iterations_arr
)
return AdaptiveFilterConfig(
kernel_size=kernel_size_validated,
noise_estimate=noise_estimate_validated,
iterations=iterations_validated
)
[docs]
def make_ridge_detection_config(
min_scale: float = 1.0,
max_scale: float = 10.0,
scale_step: float = 0.5,
threshold: float = 0.1
) -> RidgeDetectionConfig:
"""
Factory function to create a RidgeDetectionConfig PyTree with validation.
Parameters
----------
min_scale : float
Minimum scale for ridge detection (default: 1.0)
max_scale : float
Maximum scale for ridge detection (default: 10.0)
scale_step : float
Step size for scale space (default: 0.5)
threshold : float
Detection threshold (default: 0.1)
Returns
-------
RidgeDetectionConfig
Validated ridge detection configuration PyTree
"""
# Convert to JAX arrays
min_scale_arr = jnp.asarray(min_scale, dtype=jnp.float32)
max_scale_arr = jnp.asarray(max_scale, dtype=jnp.float32)
scale_step_arr = jnp.asarray(scale_step, dtype=jnp.float32)
threshold_arr = jnp.asarray(threshold, dtype=jnp.float32)
# Validation
min_scale_validated = lax.cond(
min_scale_arr <= 0,
lambda x: jnp.asarray(1.0, dtype=jnp.float32),
lambda x: x,
min_scale_arr
)
max_scale_validated = lax.cond(
max_scale_arr <= 0,
lambda x: jnp.asarray(10.0, dtype=jnp.float32),
lambda x: x,
max_scale_arr
)
max_scale_validated = lax.cond(
max_scale_validated < min_scale_validated,
lambda x: min_scale_validated,
lambda x: x,
max_scale_validated
)
scale_step_validated = lax.cond(
scale_step_arr <= 0,
lambda x: jnp.asarray(0.5, dtype=jnp.float32),
lambda x: x,
scale_step_arr
)
return RidgeDetectionConfig(
min_scale=min_scale_validated,
max_scale=max_scale_validated,
scale_step=scale_step_validated,
threshold=threshold_arr
)
[docs]
def make_watershed_config(
min_distance: int = 10,
threshold_abs: Optional[float] = None,
compactness: float = 0.0
) -> WatershedConfig:
"""
Factory function to create a WatershedConfig PyTree with validation.
Parameters
----------
min_distance : int
Minimum distance between markers (default: 10)
threshold_abs : Optional[float]
Absolute threshold for markers (default: None)
compactness : float
Compactness parameter for watershed (default: 0.0)
Returns
-------
WatershedConfig
Validated watershed configuration PyTree
"""
# Convert to JAX arrays
min_distance_arr = jnp.asarray(min_distance, dtype=jnp.int32)
threshold_abs_val = jnp.asarray(
-1.0 if threshold_abs is None else threshold_abs,
dtype=jnp.float32
)
compactness_arr = jnp.asarray(compactness, dtype=jnp.float32)
# Validation
min_distance_validated = lax.cond(
min_distance_arr <= 0,
lambda x: jnp.asarray(10, dtype=jnp.int32),
lambda x: x,
min_distance_arr
)
compactness_validated = lax.cond(
compactness_arr < 0,
lambda x: jnp.asarray(0.0, dtype=jnp.float32),
lambda x: x,
compactness_arr
)
return WatershedConfig(
min_distance=min_distance_validated,
threshold_abs=threshold_abs_val,
compactness=compactness_validated
)
[docs]
def make_hessian_blob_config(
min_sigma: float = 1.0,
max_sigma: float = 30.0,
num_sigma: int = 10,
threshold: float = 0.01
) -> HessianBlobConfig:
"""
Factory function to create a HessianBlobConfig PyTree with validation.
Parameters
----------
min_sigma : float
Minimum sigma for scale space (default: 1.0)
max_sigma : float
Maximum sigma for scale space (default: 30.0)
num_sigma : int
Number of scales to test (default: 10)
threshold : float
Detection threshold (default: 0.01)
Returns
-------
HessianBlobConfig
Validated Hessian blob detection configuration PyTree
"""
# Convert to JAX arrays
min_sigma_arr = jnp.asarray(min_sigma, dtype=jnp.float32)
max_sigma_arr = jnp.asarray(max_sigma, dtype=jnp.float32)
num_sigma_arr = jnp.asarray(num_sigma, dtype=jnp.int32)
threshold_arr = jnp.asarray(threshold, dtype=jnp.float32)
# Validation
min_sigma_validated = lax.cond(
min_sigma_arr <= 0,
lambda x: jnp.asarray(1.0, dtype=jnp.float32),
lambda x: x,
min_sigma_arr
)
max_sigma_validated = lax.cond(
max_sigma_arr <= 0,
lambda x: jnp.asarray(30.0, dtype=jnp.float32),
lambda x: x,
max_sigma_arr
)
max_sigma_validated = lax.cond(
max_sigma_validated < min_sigma_validated,
lambda x: min_sigma_validated,
lambda x: x,
max_sigma_validated
)
num_sigma_validated = lax.cond(
num_sigma_arr <= 0,
lambda x: jnp.asarray(10, dtype=jnp.int32),
lambda x: x,
num_sigma_arr
)
return HessianBlobConfig(
min_sigma=min_sigma_validated,
max_sigma=max_sigma_validated,
num_sigma=num_sigma_validated,
threshold=threshold_arr
)
[docs]
def make_enhanced_blob_detection_config(
min_blob_size: float = 5.0,
max_blob_size: float = 50.0,
detection_threshold: float = 0.05,
use_ridge_detection: bool = True,
use_watershed: bool = True
) -> EnhancedBlobDetectionConfig:
"""
Factory function to create an EnhancedBlobDetectionConfig PyTree with validation.
Parameters
----------
min_blob_size : float
Minimum expected blob size (default: 5.0)
max_blob_size : float
Maximum expected blob size (default: 50.0)
detection_threshold : float
Overall detection threshold (default: 0.05)
use_ridge_detection : bool
Enable ridge detection for elongated objects (default: True)
use_watershed : bool
Enable watershed for overlapping blobs (default: True)
Returns
-------
EnhancedBlobDetectionConfig
Validated enhanced blob detection configuration PyTree
"""
# Convert to JAX arrays
min_blob_size_arr = jnp.asarray(min_blob_size, dtype=jnp.float32)
max_blob_size_arr = jnp.asarray(max_blob_size, dtype=jnp.float32)
detection_threshold_arr = jnp.asarray(detection_threshold, dtype=jnp.float32)
use_ridge_detection_arr = jnp.asarray(use_ridge_detection, dtype=jnp.bool_)
use_watershed_arr = jnp.asarray(use_watershed, dtype=jnp.bool_)
# Validation
min_blob_size_validated = lax.cond(
min_blob_size_arr <= 0,
lambda x: jnp.asarray(5.0, dtype=jnp.float32),
lambda x: x,
min_blob_size_arr
)
max_blob_size_validated = lax.cond(
max_blob_size_arr <= 0,
lambda x: jnp.asarray(50.0, dtype=jnp.float32),
lambda x: x,
max_blob_size_arr
)
max_blob_size_validated = lax.cond(
max_blob_size_validated < min_blob_size_validated,
lambda x: min_blob_size_validated,
lambda x: x,
max_blob_size_validated
)
detection_threshold_validated = lax.cond(
detection_threshold_arr <= 0,
lambda x: jnp.asarray(0.05, dtype=jnp.float32),
lambda x: x,
detection_threshold_arr
)
return EnhancedBlobDetectionConfig(
min_blob_size=min_blob_size_validated,
max_blob_size=max_blob_size_validated,
detection_threshold=detection_threshold_validated,
use_ridge_detection=use_ridge_detection_arr,
use_watershed=use_watershed_arr
)