Source code for neuralib.stardist.run_2d

from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

import numpy as np
from argclz import argument, as_argument
from neuralib.util.verbose import fprint

from .base import AbstractSegmentationOptions
from .core import STARDIST_MODEL, StarDistSegmentation, read_stardist

if TYPE_CHECKING:
    from stardist.models import StarDist2D  # pyright: ignore[reportMissingImports]

__all__ = ['StarDist2DOptions']


[docs] class StarDist2DOptions(AbstractSegmentationOptions): DESCRIPTION = 'Run the Stardist model for segmentation' model: STARDIST_MODEL = cast(Any, cast(Any, as_argument(AbstractSegmentationOptions.model)).with_options( # pyright: ignore[reportIncompatibleVariableOverride] default='2D_versatile_fluo', help='stardist pretrained model' )) prob_thresh: float | None = argument( '--prob', default=None, help='Consider only object candidates from pixels with predicted object probability above this threshold. ' 'Seealso: stardist.models.base._predict_instances_generator: prob_thresh' )
[docs] def run(self): if self.napari_view: self.launch_napari() else: self.eval()
[docs] def seg_output(self, filepath: Path) -> Path: return filepath.with_name(filepath.stem + '_seg').with_suffix('.npz')
[docs] def eval(self, **kwargs) -> None: from stardist.models import StarDist2D # pyright: ignore[reportMissingImports] model = StarDist2D.from_pretrained(self.model) if self.file_mode: self._eval(self.file, self.process_image(), model) elif self.batch_mode: for file, image in self.foreach_process_image(): self._eval(file, image, model) else: raise RuntimeError('run fail')
def _eval(self, filepath: Path, image: np.ndarray, model: 'StarDist2D', **kwargs): out_seg = self.seg_output(filepath) if out_seg.exists() and not self.invalid_existed_result: fprint(f'cached {filepath} because {out_seg} exists, use --invalid to invalid cache', vtype='IO') return labels, detail = model.predict_instances(image, prob_thresh=self.prob_thresh, **kwargs) labels = labels.astype(np.float64) labels[labels == 0] = np.nan res = StarDistSegmentation(labels, detail['coord'], detail['prob'], str(filepath), self.model) # mask probability if self.prob_thresh is not None: res.mask_probability(self.prob_thresh) # save output res.to_npz(out_seg) if self.save_ij_roi: res.to_roi(self.ij_roi_output(filepath)) # noinspection PyTypeChecker
[docs] def launch_napari(self, with_widget: bool = False, **kwargs): """ Launch napari viewer for stardist results :param with_widget: If True, launch also with the starDist widget (required package ``stardist-napari``) """ import napari file = self.seg_output(self.file) if not file.exists() or self.invalid_existed_result: self.eval() res = read_stardist(file) viewer = napari.Viewer() viewer.add_image(self.process_image(), name='image') viewer.add_image(res.labels, name='labels', colormap='twilight_shifted', opacity=0.5) viewer.add_points(res.points, face_color='red') if with_widget: viewer.window.add_plugin_dock_widget("stardist-napari", "StarDist") napari.run()
if __name__ == '__main__': StarDist2DOptions().main()