import logging
import re
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, Literal, Self, TypedDict
import attrs
import cv2
import numpy as np
import polars as pl
from neuralib.io import csv_header
from neuralib.typing import PathLike
from neuralib.util.verbose import fprint
from polars.testing import assert_frame_equal
from tifffile import tifffile
from tqdm import tqdm
__all__ = ['SequenceLabeller']
Logger = logging.getLogger(__name__)
# ================ #
# KeyBoard Mapping #
# ================ #
class KeyMapping(TypedDict, total=False):
"""For controlling the keyboard in different OS"""
escape: int
backspace: int
# space: int # labeller specific printable for notes
enter: int
left_square_bracket: int # [
right_square_bracket: int # ]
left: int
right: int
up: int
down: int
plus: int # +
minus: int # -
equal: int # =
#
COMMON_KEYMAPPING: KeyMapping = {
'escape': 27,
# 'space': 32,
'enter': 13,
'left_square_bracket': 91,
'right_square_bracket': 93,
# ord
'plus': ord('+'),
'minus': ord('-'),
'equal': ord('=')
}
WIN_KEYMAPPING: KeyMapping = {
**COMMON_KEYMAPPING,
'backspace': 8,
'left': 2424832,
'right': 2555904,
'up': 2490368,
'down': 2621440
}
MAC_KEYMAPPING: KeyMapping = {
**COMMON_KEYMAPPING,
'backspace': 127,
'left': 2,
'right': 3,
'up': 0,
'down': 1
}
LINUX_KEYMAPPING: KeyMapping = {
**COMMON_KEYMAPPING,
'backspace': 8,
'left': 81,
'right': 83,
'up': 82,
'down': 84
}
def get_keymapping() -> KeyMapping:
p = sys.platform
if p in ('linux', 'linux2'):
return LINUX_KEYMAPPING
elif p == 'darwin':
return MAC_KEYMAPPING
elif p == 'win32':
return WIN_KEYMAPPING
else:
raise RuntimeError(f'unsupported platform: {p}')
@attrs.define
class FrameInfo:
filename: str
"""name of the an image/frame"""
image: np.ndarray
"""`Array[uint, [H, W]|[H, W, 3])`"""
notes: str | None = attrs.field(default=None)
"""notes for the image"""
@property
def itype(self) -> Literal['gray', 'rgb']:
"""image color type"""
if self.image.ndim == 2:
return 'gray'
elif self.image.ndim == 3:
return 'rgb'
else:
raise TypeError('')
@property
def text_color(self) -> int | tuple[int, int, int]:
if self.itype == 'gray':
return 2 ** 16 - 1
elif self.itype == 'rgb':
return 0, 0, 255
else:
raise RuntimeError('')
@property
def height(self) -> int:
return self.image.shape[0]
@property
def width(self) -> int:
return self.image.shape[1]
class CloseSaveInterrupt(KeyboardInterrupt):
"""write & quiet triggered"""
def __init__(self, mode: Literal[':wq', ':q', ':q!']):
self.mode = mode
[docs]
class SequenceLabeller:
window_title: ClassVar[str] = 'SeqLabeller'
[docs]
def __init__(self, seqs_info: list[FrameInfo],
output: PathLike | None = None):
self.seqs_info = seqs_info
self.output: Path | None = Path(output) if output is not None else None # for notes
self.message_queue: list[str] = []
self.buffer = '' # input buffer
self._frame_index = 0
def __len__(self) -> int:
return len(self.seqs_info)
[docs]
@classmethod
def load_sequences(cls, seqs: np.ndarray | list[np.ndarray],
filenames: list[str] | None = None,
output: PathLike | None = None) -> Self:
"""
:param seqs:
:param filenames:
:param output:
:return:
"""
if isinstance(seqs, np.ndarray):
seqs = list(seqs)
n_frames = len(seqs)
if filenames is None:
filenames = list(np.arange(n_frames).astype(str))
seqs_info = [FrameInfo(filenames[i], seqs[i], None) for i in range(n_frames)]
return cls(seqs_info, output)
[docs]
@classmethod
def load_from_dir(cls, directory: PathLike,
file_suffix: str = '.tif',
sort_func: Callable[[Path], Any] | None = None,
single_frame_per_file: bool = True,
output: PathLike | None = None) -> Self:
"""
:param directory: directory contain image sequences
:param file_suffix: sequence file suffix
:param sort_func: sorted function with signature `(filename:Path) -> Comparable`
:param single_frame_per_file:
:param output:
:return:
"""
directory = Path(directory)
if not directory.is_dir():
raise NotADirectoryError(f'{directory}')
files = sorted(list(directory.glob(f'*{file_suffix}')), key=sort_func)
if len(files) == 0:
raise FileNotFoundError('')
else:
fprint(f'LOAD image sequence: {len(files)} files', vtype='io')
seqs = []
for f in tqdm(files, unit='file', ncols=80):
if file_suffix == '.pdf':
from neuralib.imglib.io import read_pdf
img = read_pdf(f, dpi=200)
seqs.append(img)
else:
if single_frame_per_file:
img = cv2.imread(str(f))
if img is None:
raise FileNotFoundError(f'Failed to read image: {f}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
seqs.append(img)
elif not single_frame_per_file and file_suffix in ('.tif', '.tiff'):
seqs.append(tifffile.imread(str(f)))
else:
raise NotImplementedError('')
filenames = [it.stem for it in files]
seqs_info = [FrameInfo(filenames[i], seqs[i], None) for i in range(len(seqs))]
return cls(seqs_info, output)
@property
def n_frames(self) -> int:
"""aka. number of images"""
return len(self.seqs_info)
@property
def current_frame_index(self) -> int:
return self._frame_index
@current_frame_index.setter
def current_frame_index(self, value: int):
n = len(self.seqs_info)
self._frame_index = (value + n) % n
self.message_queue = []
info = self.seqs_info[self._frame_index]
if info.filename is not None:
self.enqueue_message(f'{info.filename}')
if (note := self.read_note()) is not None:
self.enqueue_message(note)
@property
def text_color(self) -> float | tuple[int, int, int]:
return self.seqs_info[self.current_frame_index].text_color
# ===== #
# Notes #
# ===== #
[docs]
def save_note(self):
"""save image-related notes to file"""
from datetime import datetime
if self.output is None:
raise RuntimeError('specify output first for writing notes!')
fields = ['filename', 'notes', 'datetime']
t = datetime.now().replace(second=0, microsecond=0).strftime("%Y-%m-%d %H:%M") # TODO to frame dependent
with csv_header(self.output, fields, quotes_header='notes') as csv:
for seq in self.seqs_info:
csv(str(seq.filename), seq.notes, t)
_prev_note: pl.DataFrame | None = None # for checking changes
[docs]
def load_note(self):
"""read image-related notes from file"""
from neuralib.util.verbose import printdf
if self.output is None:
raise RuntimeError('specify output first for loading notes!')
df = self._prev_note = pl.read_csv(self.output, schema_overrides={'filename': pl.Utf8})
printdf(df)
for i, info in enumerate(self.seqs_info):
note = df.filter(pl.col('filename') == info.filename)['notes'].item()
self.seqs_info[i].notes = note
[docs]
def write_note(self, note: str, *, append_mode: bool = False):
if append_mode:
prev = self.seqs_info[self.current_frame_index].notes
self.seqs_info[self.current_frame_index].notes = note if prev is None else prev + ';' + note
else:
self.seqs_info[self.current_frame_index].notes = note
self.current_frame_index = self.current_frame_index # trigger enqueue_message
if self.output is None:
self.enqueue_message('specify output first for writing notes!')
[docs]
def read_note(self) -> str | None:
return self.seqs_info[self._frame_index].notes
[docs]
def clear_note(self):
self.seqs_info[self.current_frame_index].notes = None
[docs]
def check_note_changes(self) -> bool:
"""True if any changes in notes"""
if self._prev_note is None:
if any([seq.notes is not None for seq in self.seqs_info]):
return True
else:
prev_note = self._prev_note.select('filename', 'notes')
cur_note = pl.DataFrame([
[str(seq.filename), seq.notes]
for seq in self.seqs_info
], schema=['filename', 'notes'], orient='row')
try:
assert_frame_equal(prev_note, cur_note)
except AssertionError:
return True
return False
# ============= #
# Key & Command #
# ============= #
[docs]
def goto_begin(self):
self.current_frame_index = 0
[docs]
def goto_end(self):
self.current_frame_index = self.n_frames - 1
[docs]
def go_to(self, i: int):
if i > len(self) or i < 0:
self.enqueue_message(f'invalid sequence index: {i}')
return
self.current_frame_index = i
[docs]
def handle_keycode(self, k: int):
mapping = get_keymapping()
ret = self._handle_keymapping(mapping, k)
if ret is not None: # printable
self.buffer += chr(k)
def _handle_keymapping(self, mapping: KeyMapping, value: int) -> int | None:
"""
Handling the keyboard mapping
:param mapping:
:param value:
:return: int value if cannot find key in keymapping, otherwise return None
"""
try:
ret = next(key for key, key_value in mapping.items() if key_value == value)
except StopIteration:
return value
else:
if ret == 'left':
self.current_frame_index -= 1
elif ret == 'right':
self.current_frame_index += 1
elif ret == 'left_square_bracket':
self.current_frame_index += 10
elif ret == 'right_square_bracket':
self.current_frame_index -= 10
elif ret == 'backspace':
if len(self.buffer) > 0:
self.buffer = self.buffer[:-1]
elif ret == 'enter': # handle command in buffer
command = self._proc_image_command = self.buffer
self.buffer = ''
try:
self.handle_command(command)
except KeyboardInterrupt:
raise
except BaseException as e:
self.enqueue_message(f'command "{command}" {type(e).__name__}: {e}')
elif ret == 'escape':
self.buffer = ''
[docs]
def handle_command(self, command: str):
Logger.debug(f'command: {command}')
if command == ':h':
self.enqueue_message(':h : print this document')
self.enqueue_message(':q : quit (unable if not save the changes)')
self.enqueue_message(':q! : quit (without save)')
self.enqueue_message(':wq : save notes and quit')
self.enqueue_message(':c : clear current note')
self.enqueue_message(':i : print current file index')
self.enqueue_message(':N : goto N-th image')
self.enqueue_message('+message : append note')
self.enqueue_message('message : (replace) note')
elif command == ':c':
self.enqueue_message(f'clear notes: {self.seqs_info[self.current_frame_index].notes}')
self.clear_note()
elif command == ':i':
self.enqueue_message(f'current file index: {self.current_frame_index}')
elif re.match(r'^:(\d)', command):
match = re.search(r'^:(\d)', command)
if match is None:
raise RuntimeError(f'unknown command : "{command}"')
self.go_to(int(match.group(1)))
elif command.startswith('+'):
self.write_note(command[1:], append_mode=True)
elif not command.startswith(':'):
self.write_note(command)
elif command in (':wq', ':q', ':q!'):
raise CloseSaveInterrupt(command)
else:
raise RuntimeError(f'unknown command : "{command}"')
# ============ #
# Msg / Buffer #
# ============ #
[docs]
def enqueue_message(self, text: str):
self.message_queue.append(text)
def _show_queued_message(self, image: np.ndarray):
"""drawing enqueued message"""
y = 70
s = 30
i = 0
while i < len(self.message_queue):
m = self.message_queue[i]
cv2.putText(image, m, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 1, self.text_color, 2, cv2.LINE_AA)
i += 1
y += s
def _show_buffer(self, image):
"""drawing input buffer content"""
cv2.putText(image, self.buffer, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, self.text_color, 2, cv2.LINE_AA)
# ========= #
# Main Loop #
# ========= #
[docs]
def main(self):
"""main loop for the GUI"""
cv2.namedWindow(self.window_title, cv2.WINDOW_GUI_NORMAL)
if self.output is not None:
if self.output.exists():
self.load_note()
try:
while True:
try:
while True:
self._loop()
except CloseSaveInterrupt as e:
if e.mode == ':wq':
if self.output is not None:
self.save_note()
fprint(f'SAVE csv -> {str(self.output)}!', vtype='io')
break
elif e.mode == ':q!':
break
elif e.mode == ':q':
if self.check_note_changes():
self.enqueue_message('please save the note using :wq, or force quit using :q!')
continue
else:
break
finally:
cv2.destroyWindow(self.window_title)
def _loop(self):
#
try:
info = self.seqs_info[self.current_frame_index]
except IndexError:
pass
else:
image = info.image.copy()
if len(self.buffer):
self._show_buffer(image)
self._show_queued_message(image)
cv2.imshow(self.window_title, image)
#
if sys.platform in ('darwin', 'linux', 'linux2'):
k = cv2.waitKey(1)
elif sys.platform == 'win32':
k = cv2.waitKeyEx(1)
else:
raise RuntimeError('')
if k >= 0:
self.handle_keycode(k)
# ============= #
# Main Argparse #
# ============= #
def main():
import argparse
ap = argparse.ArgumentParser(description='run the sequences labeller')
ap.add_argument('-D', '--dir', type=Path, required=True, help='path with image sequences', dest='directory')
ap.add_argument('--suffix', choices=('.pdf', '.tif', '.tiff'), default='.pdf', help='image sequence suffix')
ap.add_argument('-O', '--output', type=Path, default=None, help='csv output for note')
opt = ap.parse_args()
labeller = SequenceLabeller.load_from_dir(opt.directory,
file_suffix=opt.suffix,
output=opt.output)
labeller.main()
if __name__ == '__main__':
main()