Skip to content

Image Processing Module

Image processing functions for qi2lab 3D MERFISH.

This module includes various utilities for image processing, such as downsampling, padding, and chunked GPU-based deconvolution.

History:
  • 2024/12: Refactored repo structure.
  • 2024/07: Added numba-accelerated downsampling, padding helper functions, and chunked GPU deconvolution.

Functions:

Name Description
downsample_axis

Numba accelerated downsampling for 3D images along a specified axis.

downsample_image_anisotropic

Numba accelerated anisotropic downsampling

estimate_shading

Estimate shading using stack of images and BaSiCPy.

no_op

Function to monkey patch print to suppress output.

replace_hot_pixels

Replace hot pixels with median values surrounding them.

downsample_axis(image, level=2, axis=0)

Numba accelerated downsampling for 3D images along a specified axis.

Parameters:

Name Type Description Default
image ArrayLike

3D image to be downsampled.

required
level int

Amount of downsampling.

2
axis int

Axis along which to downsample (0, 1, or 2).

0

Returns:

Name Type Description
downsampled_image ArrayLike

3D downsampled image.

Source code in src/merfish3danalysis/utils/imageprocessing.py
@njit(parallel=True)
def downsample_axis(
    image: ArrayLike, 
    level: int = 2, 
    axis: int = 0
) -> ArrayLike:
    """Numba accelerated downsampling for 3D images along a specified axis.

    Parameters
    ----------
    image: ArrayLike
        3D image to be downsampled.
    level: int
        Amount of downsampling.
    axis: int
        Axis along which to downsample (0, 1, or 2).

    Returns
    -------
    downsampled_image: ArrayLike
        3D downsampled image.

    """
    if axis == 0:
        new_length = image.shape[0] // level + (1 if image.shape[0] % level != 0 else 0)
        downsampled_image = np.zeros(
            (new_length, image.shape[1], image.shape[2]), dtype=image.dtype
        )

        for y in prange(image.shape[1]):
            for x in range(image.shape[2]):
                for z in range(new_length):
                    sum_value = 0.0
                    count = 0
                    for j in range(level):
                        original_index = z * level + j
                        if original_index < image.shape[0]:
                            sum_value += image[original_index, y, x]
                            count += 1
                    if count > 0:
                        downsampled_image[z, y, x] = sum_value / count

    elif axis == 1:
        new_length = image.shape[1] // level + (1 if image.shape[1] % level != 0 else 0)
        downsampled_image = np.zeros(
            (image.shape[0], new_length, image.shape[2]), dtype=image.dtype
        )

        for z in prange(image.shape[0]):
            for x in range(image.shape[2]):
                for y in range(new_length):
                    sum_value = 0.0
                    count = 0
                    for j in range(level):
                        original_index = y * level + j
                        if original_index < image.shape[1]:
                            sum_value += image[z, original_index, x]
                            count += 1
                    if count > 0:
                        downsampled_image[z, y, x] = sum_value / count

    elif axis == 2:
        new_length = image.shape[2] // level + (1 if image.shape[2] % level != 0 else 0)
        downsampled_image = np.zeros(
            (image.shape[0], image.shape[1], new_length), dtype=image.dtype
        )

        for z in prange(image.shape[0]):
            for y in range(image.shape[1]):
                for x in range(new_length):
                    sum_value = 0.0
                    count = 0
                    for j in range(level):
                        original_index = x * level + j
                        if original_index < image.shape[2]:
                            sum_value += image[z, y, original_index]
                            count += 1
                    if count > 0:
                        downsampled_image[z, y, x] = sum_value / count

    return downsampled_image

downsample_image_anisotropic(image, level=(2, 6, 6))

Numba accelerated anisotropic downsampling

Parameters:

Name Type Description Default
image ArrayLike

3D image to be downsampled

required
level tuple[int, int, int]

anisotropic downsampling level

(2, 6, 6)

Returns:

Name Type Description
downsampled_image ArrayLike

downsampled 3D image

Source code in src/merfish3danalysis/utils/imageprocessing.py
def downsample_image_anisotropic(image: ArrayLike, level: tuple[int,int,int] = (2,6,6)) -> ArrayLike:
    """Numba accelerated anisotropic downsampling

    Parameters
    ----------
    image: ArrayLike
        3D image to be downsampled
    level: tuple[int,int,int], default=(2,6,6)
        anisotropic downsampling level

    Returns
    -------
    downsampled_image: ArrayLike
        downsampled 3D image
    """

    downsampled_image = downsample_axis(
        downsample_axis(downsample_axis(image, level[0], 0), level[1], 1), level[2], 2
    )

    return downsampled_image

estimate_shading(images)

Estimate shading using stack of images and BaSiCPy.

Parameters:

Name Type Description Default
images list[ArrayLike]

4D image stack [p,z,y,x]

required

Returns:

Name Type Description
shading_image ArrayLike

estimated shading image

Source code in src/merfish3danalysis/utils/imageprocessing.py
def estimate_shading(
    images: list[ArrayLike]
) -> ArrayLike:
    """Estimate shading using stack of images and BaSiCPy.

    Parameters
    ----------
    images: ArrayLike
        4D image stack [p,z,y,x]

    Returns
    -------
    shading_image: ArrayLike
        estimated shading image
    """

    # GPU

    import cupy as cp  # type: ignore
    from cupyx.scipy import ndimage  # type: ignore
    from basicpy import BaSiC # type: ignore

    maxz_images = []
    for image in images:
        maxz_images.append(cp.squeeze(cp.max(image.result(),axis=0)))    

    maxz_images = cp.asnumpy(maxz_images).astype(np.uint16)
    gc.collect()
    cp.cuda.Stream.null.synchronize()
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

    original_print = builtins.print
    builtins.print = no_op
    basic = BaSiC(get_darkfield=False)
    basic.autotune(maxz_images[:])
    basic.fit(maxz_images[:])
    builtins.print = original_print
    shading_correction = basic.flatfield.astype(np.float32) / np.max(basic.flatfield.astype(np.float32),axis=(0,1))

    del basic
    gc.collect()

    cp.cuda.Stream.null.synchronize()
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

    return shading_correction

no_op(*args, **kwargs)

Function to monkey patch print to suppress output.

Parameters:

Name Type Description Default
args

positional arguments

()
kwargs

keyword arguments

{}
Source code in src/merfish3danalysis/utils/imageprocessing.py
def no_op(*args, **kwargs):
    """Function to monkey patch print to suppress output.

    Parameters
    ----------
    args: Any
        positional arguments
    kwargs: Any
        keyword arguments
    """

    pass

replace_hot_pixels(noise_map, data, threshold=375.0)

Replace hot pixels with median values surrounding them.

Parameters:

Name Type Description Default
noise_map ArrayLike

darkfield image collected at long exposure time to get hot pixels

required
data ArrayLike

ND data [broadcast_dim,z,y,x]

required

Returns:

Name Type Description
data ArrayLike

hotpixel corrected data

Source code in src/merfish3danalysis/utils/imageprocessing.py
def replace_hot_pixels(
    noise_map: ArrayLike, 
    data: ArrayLike, 
    threshold: float = 375.0
) -> ArrayLike:
    """Replace hot pixels with median values surrounding them.

    Parameters
    ----------
    noise_map: ArrayLike
        darkfield image collected at long exposure time to get hot pixels
    data: ArrayLike
        ND data [broadcast_dim,z,y,x]

    Returns
    -------
    data: ArrayLike
        hotpixel corrected data
    """

    # GPU
    import cupy as cp  # type: ignore
    from cupyx.scipy import ndimage  # type: ignore

    data = cp.asarray(data, dtype=cp.float32)
    noise_map = cp.asarray(noise_map, dtype=cp.float32)

    # threshold darkfield_image to generate bad pixel matrix
    hot_pixels = cp.squeeze(cp.asarray(noise_map))
    hot_pixels[hot_pixels <= threshold] = 0
    hot_pixels[hot_pixels > threshold] = 1
    hot_pixels = hot_pixels.astype(cp.float32)
    inverted_hot_pixels = cp.ones_like(hot_pixels) - hot_pixels.copy()

    data = cp.asarray(data, dtype=cp.float32)
    for z_idx in range(data.shape[0]):
        median = ndimage.median_filter(data[z_idx, :, :], size=3)
        data[z_idx, :] = inverted_hot_pixels * data[z_idx, :] + hot_pixels * median

    data[data < 0] = 0

    data = cp.asnumpy(data).astype(np.uint16)
    gc.collect()
    cp.cuda.Stream.null.synchronize()
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

    return data