# pyright: reportAttributeAccessIssue=false, reportCallIssue=false, reportArgumentType=false, reportIncompatibleMethodOverride=false, reportOptionalMemberAccess=false, reportOptionalSubscript=false, reportOptionalOperand=false, reportPossiblyUnboundVariable=false, reportAssignmentType=false
import colorsys
import traceback
from datetime import datetime
from typing import Literal
import cv2
import matplotlib.pyplot as plt
import numpy as np
from argclz import AbstractParser
from PyQt6.QtCore import QPointF, Qt, QTimer
from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QFont, QImage, QMouseEvent, QPainter, QPen, QPixmap, QWheelEvent
from PyQt6.QtWidgets import (
QApplication,
QCheckBox,
QComboBox,
QFileDialog,
QGraphicsPixmapItem,
QGraphicsScene,
QGraphicsView,
QGridLayout,
QGroupBox,
QHBoxLayout,
QLabel,
QListWidget,
QMainWindow,
QPushButton,
QSizePolicy,
QSlider,
QSpinBox,
QSplitter,
QTextEdit,
QVBoxLayout,
QWidget,
)
from skimage.io import imread
from .ccf import DorsalCCF
__all__ = ['RegistrationApp', 'RegistrationOptions']
# TODO might load reference from retinotopic and align with cur widefield?
# TODO cur widefield apply translation or rotation together with projective2d, and save as 2d
def _normalize_to_uint8(arr: np.ndarray) -> np.ndarray:
if arr.dtype == np.uint8:
return np.ascontiguousarray(arr)
arr_min = arr.min()
arr_range = np.ptp(arr)
if arr_range == 0:
return np.zeros(arr.shape, dtype=np.uint8)
arr8 = 255 * (arr - arr_min) / arr_range
return np.ascontiguousarray(arr8.astype(np.uint8))
def _label_image_to_rgb(arr: np.ndarray, label_colors: dict[int, tuple[int, int, int]] | None = None) -> np.ndarray:
rgb_img = np.zeros((*arr.shape, 3), dtype=np.uint8)
for label in np.unique(arr):
if label == 0:
continue
label_int = int(label)
if label_colors is not None and label_int in label_colors:
color = label_colors[label_int]
else:
color = tuple(int(c * 255) for c in plt.cm.tab20(label_int % 20)[:3])
rgb_img[arr == label] = color
return np.ascontiguousarray(rgb_img)
def _distinct_label_colors(labels: list[int]) -> dict[int, tuple[int, int, int]]:
colors = {}
for index, label in enumerate(sorted(labels)):
hue = (index * 0.618033988749895) % 1.0
red, green, blue = colorsys.hsv_to_rgb(hue, 0.70, 0.95)
colors[label] = (int(red * 255), int(green * 255), int(blue * 255))
return colors
def np_to_qpixmap(
arr: np.ndarray,
colorize_labels: bool = False,
label_colors: dict[int, tuple[int, int, int]] | None = None
) -> QPixmap:
try:
match arr.ndim:
case 2:
if colorize_labels:
rgb_img = _label_image_to_rgb(arr, label_colors)
else:
# Grayscale
arr8 = _normalize_to_uint8(arr)
h, w = arr8.shape
image = QImage(arr8.data, w, h, w, QImage.Format.Format_Grayscale8)
return QPixmap.fromImage(image)
case 3:
if arr.shape[2] == 3: # RGB image
rgb_img = _normalize_to_uint8(arr)
else:
# video stack
frame = arr[0] # Take first frame
arr8 = _normalize_to_uint8(frame)
h, w = arr8.shape
image = QImage(arr8.data, w, h, w, QImage.Format.Format_Grayscale8)
return QPixmap.fromImage(image)
case _:
raise ValueError(f"Unsupported array shape: {arr.shape}")
h, w = rgb_img.shape[:2]
image = QImage(rgb_img.data, w, h, 3 * w, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
except Exception as e:
print("Error in np_to_qpixmap:", e)
return QPixmap()
class ZoomPanGraphicsView(QGraphicsView):
def __init__(self, scene, drop_callback=None):
super().__init__(scene)
self.setRenderHint(QPainter.RenderHint.Antialiasing)
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setMinimumSize(300, 300)
self.setAcceptDrops(True)
self.drop_callback = drop_callback
def wheelEvent(self, event: QWheelEvent):
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
self.scale(zoom_factor, zoom_factor)
def dragEnterEvent(self, event: QDragEnterEvent):
if event.mimeData().hasUrls():
event.acceptProposedAction()
def dropEvent(self, event: QDropEvent):
if event.mimeData().hasUrls():
for url in event.mimeData().urls():
file_path = url.toLocalFile()
if file_path.lower().endswith(
('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp', '.npy', '.mp4', '.avi', '.mov')):
if self.drop_callback:
self.drop_callback(file_path)
def drawForeground(self, painter, rect):
super().drawForeground(painter, rect)
# get the scene coordinates of the viewport corners
scene_rect = self.mapToScene(self.viewport().rect()).boundingRect()
left, top, right, bottom = scene_rect.left(), scene_rect.top(), scene_rect.right(), scene_rect.bottom()
pen = QPen(Qt.GlobalColor.white)
pen.setWidth(1)
painter.setPen(pen)
painter.setFont(QFont("Sans", 8))
# draw X axis line
painter.drawLine(QPointF(left, bottom), QPointF(right, bottom))
# draw Y axis line
painter.drawLine(QPointF(left, top), QPointF(left, bottom))
# draw X ticks + labels every 50 pixels
start_x = np.floor(left / 50) * 50
for x in np.arange(start_x, right, 50):
tick_pos = QPointF(x, bottom)
painter.drawLine(tick_pos, tick_pos + QPointF(0, -5))
painter.drawText(tick_pos + QPointF(2, -7), f"{int(x)}")
# draw Y ticks + labels every 50 pixels
start_y = np.floor(top / 50) * 50
for y in np.arange(start_y, bottom, 50):
tick_pos = QPointF(left, y)
painter.drawLine(tick_pos, tick_pos + QPointF(5, 0))
painter.drawText(tick_pos + QPointF(7, 3), f"{int(y)}")
class ImageScene(QGraphicsScene):
def __init__(
self,
name,
click_callback,
parent=None,
colorize_labels=False,
label_colors=None,
show_boundaries=False
):
super().__init__(parent)
self.name = name
self.pix_item = None
self.click_callback = click_callback
self.colorize_labels = colorize_labels
self.label_colors = label_colors
self.show_boundaries = show_boundaries
self.enabled = False
self.points = []
def set_image(self, array):
try:
display_array = self._display_array(array)
self.clear()
pixmap = np_to_qpixmap(
display_array,
colorize_labels=self.colorize_labels,
label_colors=self.label_colors
)
self.pix_item = QGraphicsPixmapItem(pixmap)
self.addItem(self.pix_item)
for pt in self.points:
self.addEllipse(pt[0] - 2, pt[1] - 2, 4, 4, brush=Qt.GlobalColor.red)
except Exception as e:
print(f"Error setting image in {self.name}:", e)
def _display_array(self, array):
if self.show_boundaries and self.colorize_labels:
from skimage.segmentation import find_boundaries
rgb_img = _label_image_to_rgb(array, self.label_colors)
boundary = find_boundaries(array, mode='outer')
rgb_img[boundary] = (255, 255, 255)
return rgb_img
return array
def clear_scene(self):
self.clear()
self.pix_item = None
self.points.clear()
def draw_point(self, point):
self.points.append(point)
self.addEllipse(point[0] - 2, point[1] - 2, 5, 5, brush=Qt.GlobalColor.red)
def mousePressEvent(self, event: QMouseEvent):
if self.enabled and event.button() == Qt.MouseButton.LeftButton:
pos = event.scenePos()
self.draw_point([pos.x(), pos.y()])
self.click_callback(self.name, [pos.x(), pos.y()])
[docs]
class RegistrationApp(QMainWindow):
[docs]
def __init__(self):
super().__init__()
self.setWindowTitle("Dorsal Map Registration Tool")
# --- State ---
self.point_pairs = []
self.expecting = 'Wfield'
# --- Scenes & Views ---
self.wf_scene = ImageScene("Wfield", self.on_point_clicked)
self.dorsal_scene = ImageScene(
"Dorsal",
self.on_point_clicked,
colorize_labels=True,
show_boundaries=True
)
self.wf_view = ZoomPanGraphicsView(self.wf_scene)
self.dorsal_view = ZoomPanGraphicsView(self.dorsal_scene)
# --- Status Labels ---
self.hist_shape_label = QLabel("WF size: N/A")
self.dorsal_shape_label = QLabel("Dorsal size: N/A")
# --- Controls ---
# File I/O
self.load_hist_btn = QPushButton("Load Widefield Image (Image/Movie)")
self.load_dorsal_btn = QPushButton("Load Dorsal Map (.npy)")
# Video controls
self.play_btn = QPushButton("▶ Play")
self.pause_btn = QPushButton("⏸ Pause")
# Transform controls
self.transform_btn = QPushButton("Apply Transform")
self.save_tf_btn = QPushButton("Save Transform")
self.save_img_btn = QPushButton("Save Image")
self.transform_type_box = QComboBox()
self.transform_type_box.addItems(["projective", "similarity"])
self.ransac_threshold_slider = QSlider(Qt.Orientation.Horizontal)
self.ransac_threshold_slider.setRange(10, 100)
self.ransac_threshold_slider.setValue(50)
self.ransac_threshold_label = QLabel("RANSAC Threshold: 50 px")
self.ransac_threshold_label.hide()
self.ransac_threshold_slider.hide()
# Annotation controls
self.clear_btn = QPushButton("Clear All Points")
self.undo_btn = QPushButton("Undo Last Pair")
# Region selection
self.update_dorsal_btn = QPushButton("Update Dorsal View")
self.region_list = QListWidget()
self.region_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
self.region_list.setMinimumHeight(260)
self.region_list.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding)
# Hemisphere selection
self.left_hemisphere_checkbox = QCheckBox("Left Hemisphere")
self.right_hemisphere_checkbox = QCheckBox("Right Hemisphere")
# Reshape and View
self.resize_btn = QPushButton("Resize Both Images")
self.resize_width = QSpinBox()
self.resize_width.setRange(64, 2048)
self.resize_width.setValue(300)
self.resize_height = QSpinBox()
self.resize_height.setRange(64, 2048)
self.resize_height.setValue(300)
self.zoom_to_fit_btn = QPushButton("Zoom to Fit Views")
# Log console
self.log_box = QTextEdit()
self.log_box.setReadOnly(True)
self.log_box.setStyleSheet("background-color: black; color: lime; font-family: monospace;")
# === Group Boxes ===
# File I/O Group
io_group = QGroupBox("File I/O")
io_layout = QVBoxLayout(io_group)
io_layout.addWidget(self.load_hist_btn)
io_layout.addWidget(self.load_dorsal_btn)
# Video Controls Group
video_group = QGroupBox("Video Controls")
video_layout = QHBoxLayout(video_group)
video_layout.addWidget(self.play_btn)
video_layout.addWidget(self.pause_btn)
# Transform Controls Group
transform_group = QGroupBox("Transform Controls")
transform_layout = QVBoxLayout(transform_group)
transform_layout.addWidget(QLabel("Transform type:"))
transform_layout.addWidget(self.transform_type_box)
transform_layout.addWidget(self.transform_btn)
transform_layout.addWidget(self.save_tf_btn)
transform_layout.addWidget(self.save_img_btn)
transform_layout.addWidget(self.ransac_threshold_label)
transform_layout.addWidget(self.ransac_threshold_slider)
# Annotation Group
points_group = QGroupBox("Point Annotation")
points_layout = QHBoxLayout(points_group)
points_layout.addWidget(self.clear_btn)
points_layout.addWidget(self.undo_btn)
# Region Selection Group
region_group = QGroupBox("Region Selection")
region_layout = QVBoxLayout(region_group)
region_layout.addWidget(QLabel("Select region(s):"))
region_layout.addWidget(self.region_list)
region_layout.addWidget(self.update_dorsal_btn)
# Hemisphere selection sub-layout
hemisphere_layout = QHBoxLayout()
hemisphere_layout.addWidget(self.left_hemisphere_checkbox)
hemisphere_layout.addWidget(self.right_hemisphere_checkbox)
region_layout.addLayout(hemisphere_layout)
region_group.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding)
# Resize group
resize_group = QGroupBox("Resize Widefield and Map")
resize_layout = QHBoxLayout(resize_group)
resize_layout.addWidget(QLabel("Width:"))
resize_layout.addWidget(self.resize_width)
resize_layout.addWidget(QLabel("Height:"))
resize_layout.addWidget(self.resize_height)
resize_layout.addWidget(self.resize_btn)
view_group = QGroupBox("View Tools")
view_layout = QVBoxLayout(view_group)
view_layout.addWidget(self.zoom_to_fit_btn)
view_layout.addStretch(1)
# --- Assemble Top Controls ---
controls_layout = QGridLayout()
controls_layout.setContentsMargins(0, 0, 0, 0)
controls_layout.setHorizontalSpacing(12)
controls_layout.setVerticalSpacing(10)
controls_layout.addWidget(io_group, 0, 0)
controls_layout.addWidget(video_group, 0, 1)
controls_layout.addWidget(region_group, 0, 2, 4, 1)
controls_layout.addWidget(transform_group, 1, 0, 1, 2)
controls_layout.addWidget(points_group, 2, 0)
controls_layout.addWidget(view_group, 2, 1)
controls_layout.addWidget(resize_group, 3, 0, 1, 2)
controls_layout.setColumnStretch(0, 1)
controls_layout.setColumnStretch(1, 1)
controls_layout.setColumnStretch(2, 2)
controls_layout.setRowStretch(3, 1)
controls_widget = QWidget()
controls_widget.setLayout(controls_layout)
controls_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
# --- Main area: status bar, views, log ---
# Status bar
status_layout = QHBoxLayout()
status_layout.addWidget(self.hist_shape_label)
status_layout.addStretch(1)
status_layout.addWidget(self.dorsal_shape_label)
status_widget = QWidget()
status_widget.setLayout(status_layout)
# Views splitter
view_splitter = QSplitter(Qt.Orientation.Horizontal)
view_splitter.addWidget(self.wf_view)
view_splitter.addWidget(self.dorsal_view)
# Vertical splitter stacking status, views, and log
main_splitter = QSplitter(Qt.Orientation.Vertical)
main_splitter.addWidget(status_widget)
main_splitter.addWidget(view_splitter)
main_splitter.addWidget(self.log_box)
main_splitter.setStretchFactor(0, 0)
main_splitter.setStretchFactor(1, 5)
main_splitter.setStretchFactor(2, 0)
# Final layout: controls on top, content below
main_layout = QVBoxLayout()
main_layout.setContentsMargins(8, 8, 8, 8)
main_layout.setSpacing(8)
main_layout.addWidget(controls_widget)
main_layout.addWidget(main_splitter, 1)
main_layout.setStretch(0, 0)
main_layout.setStretch(1, 1)
container = QWidget()
container.setLayout(main_layout)
self.setCentralWidget(container)
# --- Signal-Slot Connections ---
for btn, fn in [
(self.load_hist_btn, self.load_widefield),
(self.load_dorsal_btn, self.load_dorsal_map),
(self.play_btn, self.play_video),
(self.pause_btn, self.pause_video),
(self.transform_btn, self.apply_transform),
(self.save_tf_btn, self.save_transform),
(self.save_img_btn, self.save_dorsal_image),
(self.clear_btn, self.clear_all_points),
(self.undo_btn, self.undo_last_pair),
(self.zoom_to_fit_btn, self.zoom_to_fit_views),
(self.update_dorsal_btn, self.update_dorsal_from_region),
(self.resize_btn, self.resize_both_images)
]:
btn.clicked.connect(self.safe_call(fn))
# Checkbox connections for hemisphere selection
self.left_hemisphere_checkbox.stateChanged.connect(self.safe_call(self.on_hemisphere_changed))
self.right_hemisphere_checkbox.stateChanged.connect(self.safe_call(self.on_hemisphere_changed))
self.ransac_threshold_slider.valueChanged.connect(self.update_ransac_threshold)
# --- Finish initialization ---
self.video_path = None
self.wf_img = None
self.dorsal_map = None
self.current_transform = None
# video
self.video_cap = None
self.current_frame_idx = None
self.tiff_frames = None
self.video_timer = QTimer()
self.video_timer.timeout.connect(self.safe_call(self.next_frame))
# Live transform overlay attributes
self.transform_stats = None
self.live_fig = None
self.live_ax = None
self.live_overlay = None
self.live_frame_text = None
self.original_next_tiff_frame = None
self.original_next_frame = None
# Method references (will be reassigned during live overlay)
self._next_tiff_frame = self._next_tiff_frame
self._next_frame = self._next_frame
self.ccf = DorsalCCF.from_json()
self.dorsal_scene.label_colors = _distinct_label_colors([
label.label
for label in self.ccf.region_labels
])
self.region_list.addItem("[All Regions]")
for region in self.ccf.region_list:
self.region_list.addItem(region)
[docs]
def log(self, message, level: Literal['info', 'warn', 'error', 'success', 'debug'] = "info"):
timestamp = datetime.now().strftime("%H:%M:%S")
colors = {
"info": "lime",
"warn": "yellow",
"error": "red",
"success": "cyan",
"debug": "gray"
}
color = colors.get(level, "lime")
level_tag = level.upper()
formatted_msg = f'<span style="color: {color};">[{timestamp}] [{level_tag}] {message}</span>'
self.log_box.append(formatted_msg)
[docs]
def safe_call(self, func):
def wrapper():
try:
func()
except Exception as e:
self.log(f"Error: {e}\n{traceback.format_exc()}", "error")
return wrapper
[docs]
def zoom_to_fit_views(self):
if self.wf_scene.pix_item:
self.wf_view.fitInView(self.wf_scene.pix_item, Qt.AspectRatioMode.KeepAspectRatio)
if self.dorsal_scene.pix_item:
self.dorsal_view.fitInView(self.dorsal_scene.pix_item, Qt.AspectRatioMode.KeepAspectRatio)
self.log("Zoomed to fit both views.")
# ============== #
# Video Features #
# ============== #
[docs]
def next_frame(self):
if self.video_cap is None:
return
if self.video_cap == "tiff":
# Safety check: ensure _next_tiff_frame is callable
if callable(self._next_tiff_frame):
self._next_tiff_frame()
else:
# Fallback to the original method
RegistrationApp._next_tiff_frame(self)
else:
# Safety check: ensure _next_frame is callable
if callable(self._next_frame):
self._next_frame()
else:
# Fallback to the original method
RegistrationApp._next_frame(self)
def _next_tiff_frame(self):
if not hasattr(self, 'tiff_frames') or not hasattr(self, 'current_frame_idx'):
self.video_timer.stop()
self.log("TIFF video data not available.", "error")
return
frame = self.tiff_frames[self.current_frame_idx]
self.current_frame_idx = (self.current_frame_idx + 1) % self.tiff_frames.shape[0]
# resize the frame to match current settings?
if hasattr(self, 'wf_img') and self.wf_img is not None:
# Use the current resized dimensions
target_h, target_w = self.wf_img.shape[:2]
if frame.shape[:2] != (target_h, target_w):
frame = cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
self.wf_img = frame
self.wf_scene.set_image(self.wf_img)
self.hist_shape_label.setText(
f"TIFF video frame {self.current_frame_idx}/{self.tiff_frames.shape[0]}: {frame.shape[1]} x {frame.shape[0]}"
)
def _next_frame(self):
ret, frame = self.video_cap.read()
if not ret:
self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = self.video_cap.read()
if not ret:
self.video_timer.stop()
self.log("Video playback error.", "error")
return
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Check if we should resize the frame to match current settings
if hasattr(self, 'wf_img') and self.wf_img is not None:
# Use the current resized dimensions
target_h, target_w = self.wf_img.shape[:2]
if frame.shape[:2] != (target_h, target_w):
frame = cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
self.wf_img = frame
self.wf_scene.set_image(self.wf_img)
self.hist_shape_label.setText(f"Histology video frame size: {frame.shape[1]} x {frame.shape[0]}")
[docs]
def play_video(self):
if self.video_cap is not None:
self.video_timer.start(30)
self.log("Playing video.")
[docs]
def pause_video(self):
self.video_timer.stop()
self.log("Paused video.")
# ============================== #
# Load image/video OR dorsal map #
# ============================== #
[docs]
def load_widefield(self):
path, _ = QFileDialog.getOpenFileName(self, "Select Widefield Image or Movie")
if not path:
return
if path.lower().endswith(('.avi', '.mp4', '.mov', '.mkv')):
self.video_path = path
self.video_cap = cv2.VideoCapture(path)
self.next_frame()
self.log(f"Loaded movie: {path}", "success")
else:
loaded_data = imread(path)
# Check if this is a multi-frame TIFF (3D array with many frames or 4D array)
if loaded_data.ndim == 4 or (loaded_data.ndim == 3 and loaded_data.shape[0] > 10):
self.video_path = path
self.tiff_frames = loaded_data
self.current_frame_idx = 0
self.video_cap = "tiff" # flag for TIFF video
self.wf_img = loaded_data[0] # show first frame
self.wf_scene.set_image(self.wf_img)
h, w = self.wf_img.shape[:2]
self.hist_shape_label.setText(f"TIFF video size: {w} x {h} ({loaded_data.shape[0]} frames)")
self.log(f"Loaded TIFF video: {path} ({loaded_data.shape[0]} frames)", "success")
else:
self.wf_img = loaded_data
self.wf_scene.set_image(self.wf_img)
self.video_cap = None
h, w = self.wf_img.shape[:2]
self.hist_shape_label.setText(f"widefield image size: {w} x {h}")
self.log(f"Loaded image: {path}", "success")
# Set max resize range to WF image size
self.resize_width.setMaximum(w)
self.resize_height.setMaximum(h)
self.point_pairs.clear()
self.expecting = 'Wfield'
self.wf_scene.enabled = True
self.dorsal_scene.enabled = False
[docs]
def load_dorsal_map(self):
path, _ = QFileDialog.getOpenFileName(self, "Select Dorsal Map (.npy)")
if not path:
return
arr = np.load(path)
# if histology is already loaded, resize dorsal to match its pixel shape
if self.wf_img is not None:
h, w = self.wf_img.shape[:2]
arr = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST)
self.log(f'resize dorsal map to {w, h}', "debug")
self.dorsal_map = arr
self.dorsal_scene.set_image(self.dorsal_map)
self.point_pairs.clear()
self.expecting = 'Wfield'
self.wf_scene.enabled = True
self.dorsal_scene.enabled = False
h, w = self.dorsal_map.shape[:2]
self.dorsal_shape_label.setText(f"Dorsal map size: {w} x {h}")
self.log(f"Loaded dorsal map: {path}", "success")
[docs]
def update_dorsal_from_region(self, crop: bool = False):
if self.wf_img is None:
self.log("Load the wfield image first!", level='warn')
return
# Start with region selection
selected = [it.text() for it in self.region_list.selectedItems()
if it.text() != "[All Regions]"]
if selected:
ccf = self.ccf.select_region(selected)
region_label = ", ".join(selected)
else:
ccf = self.ccf
region_label = "[All Regions]"
# Apply hemisphere selection
ccf = self._apply_hemisphere_selection(ccf)
# Build the display label
left_checked = self.left_hemisphere_checkbox.isChecked()
right_checked = self.right_hemisphere_checkbox.isChecked()
if left_checked and not right_checked:
hemisphere_label = " (Left Hemisphere)"
elif right_checked and not left_checked:
hemisphere_label = " (Right Hemisphere)"
else:
hemisphere_label = ""
full_label = region_label + hemisphere_label
# get the array and crop to nonzero if needed
arr = ccf.to_numpy()
if crop:
mask = arr > 0
if mask.any():
ys, xs = np.where(mask)
arr = arr[ys.min():ys.max() + 1, xs.min():xs.max() + 1]
# now resize to match histology (if loaded)
if self.wf_img is not None:
h, w = self.wf_img.shape[:2]
arr = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST)
self.log(f"Resized dorsal map to histology dims: {w}×{h}", "debug")
# update the view and label
self.dorsal_map = arr
self.dorsal_scene.set_image(arr)
self.dorsal_shape_label.setText(f"Dorsal size: {w} × {h}")
self.log(f"Dorsal view updated for: {full_label}", "success")
[docs]
def resize_both_images(self):
target_w = self.resize_width.value()
target_h = self.resize_height.value()
size = (target_w, target_h)
self.log(f"Resizing both images to {size}")
# Resize histology
if self.wf_img is not None:
hist = cv2.resize(self.wf_img, size, interpolation=cv2.INTER_LINEAR)
self.wf_img = hist
self.wf_scene.set_image(hist)
self.hist_shape_label.setText(f"Wfield size: {hist.shape[1]} × {hist.shape[0]}")
else:
self.log("Wfield image not loaded.", "warn")
# Resize dorsal
if self.dorsal_map is not None:
dorsal = cv2.resize(self.dorsal_map, size, interpolation=cv2.INTER_NEAREST)
self.dorsal_map = dorsal
self.dorsal_scene.set_image(dorsal)
self.dorsal_shape_label.setText(f"Dorsal size: {dorsal.shape[1]} × {dorsal.shape[0]}")
else:
self.log("Dorsal map not loaded.", "warn")
self.log("Resize complete.", "success")
# ==================== #
# Hemisphere Selection #
# ==================== #
[docs]
def on_hemisphere_changed(self):
"""Handle hemisphere checkbox changes and update dorsal view"""
# Automatically trigger dorsal view update when hemisphere selection changes
self.update_dorsal_from_region()
def _apply_hemisphere_selection(self, ccf):
"""Apply hemisphere selection to a CCF object based on checkbox states"""
left_checked = self.left_hemisphere_checkbox.isChecked()
right_checked = self.right_hemisphere_checkbox.isChecked()
# Handle hemisphere selection logic
if left_checked and right_checked:
# Both checked: return full CCF (no hemisphere filtering)
return ccf
elif left_checked:
# Only left checked: select left hemisphere
return ccf.select_hemisphere('left')
elif right_checked:
# Only right checked: select right hemisphere
return ccf.select_hemisphere('right')
else:
# Neither checked: return full CCF (no hemisphere filtering)
return ccf
[docs]
def convert_to_boundary(self):
"""Convert current dorsal map to boundary using skimage boundary detection"""
if self.dorsal_map is None:
self.log("No dorsal map loaded. Please load or generate a dorsal map first.", "warn")
return
try:
from skimage.segmentation import find_boundaries
boundary = find_boundaries(self.dorsal_map, mode='outer')
boundary_arr = np.where(boundary, self.dorsal_map, 0).astype(self.dorsal_map.dtype)
self.dorsal_map = boundary_arr
self.dorsal_scene.set_image(self.dorsal_map)
h, w = self.dorsal_map.shape[:2]
self.dorsal_shape_label.setText(f"Boundary map size: {w} × {h}")
self.log("Converted dorsal map to boundary", "success")
except Exception as e:
self.log(f"Error converting to boundary: {e}", "error")
# ================== #
# Point / Annotation #
# ================== #
[docs]
def on_point_clicked(self, image_name, point):
if image_name != self.expecting:
return
if self.expecting == 'Wfield':
self.point_pairs.append({'wfield': point, 'dorsal': None})
self.expecting = 'Dorsal'
self.wf_scene.enabled = False
self.dorsal_scene.enabled = True
self.log(f"WField point: {point}", "debug")
else:
self.point_pairs[-1]['dorsal'] = point
self.expecting = 'Wfield'
self.wf_scene.enabled = True
self.dorsal_scene.enabled = False
self.log(f"Dorsal point: {point}", "debug")
[docs]
def update_ransac_threshold(self):
thr = self.ransac_threshold_slider.value()
self.ransac_threshold_label.setText(f"RANSAC Threshold: {thr} px")
[docs]
def clear_all_points(self):
self.point_pairs.clear()
self.wf_scene.clear_scene()
self.dorsal_scene.clear_scene()
if self.wf_img is not None:
self.wf_scene.set_image(self.wf_img)
if self.dorsal_map is not None:
self.dorsal_scene.set_image(self.dorsal_map)
self.expecting = 'Wfield'
self.wf_scene.enabled = True
self.dorsal_scene.enabled = False
self.log("Cleared all points and scenes.", "success")
[docs]
def undo_last_pair(self):
if not self.point_pairs:
return
last_pair = self.point_pairs[-1]
if last_pair['dorsal'] is None:
self.expecting = 'Wfield'
else:
self.expecting = 'Wfield'
self.point_pairs.pop()
self.wf_scene.clear_scene()
self.dorsal_scene.clear_scene()
if self.wf_img is not None:
self.wf_scene.set_image(self.wf_img)
if self.dorsal_map is not None:
self.dorsal_scene.set_image(self.dorsal_map)
for pair in self.point_pairs:
if pair['wfield'] is not None:
self.wf_scene.draw_point(pair['wfield'])
if pair['dorsal'] is not None:
self.dorsal_scene.draw_point(pair['dorsal'])
self.wf_scene.enabled = (self.expecting == 'Wfield')
self.dorsal_scene.enabled = (self.expecting == 'Dorsal')
self.log("Undid last point pair.", "success")
# ================== #
# Transform / Result #
# ================== #
[docs]
def save_dorsal_image(self):
if self.dorsal_map is None:
self.log("No dorsal map to save.", "warn")
return
out_path, _ = QFileDialog.getSaveFileName(self, "Save Dorsal Map", filter="NumPy array (*.npy)")
if out_path:
np.save(out_path, self.dorsal_map)
self.log(f"Saved dorsal image: {out_path}", "success")
# ========================= #
# Live Transform Overlay #
# ========================= #
def _next_tiff_frame_with_overlay(self):
"""TIFF frame advance with live transform overlay"""
# Safety check: if original method is None, fall back to built-in method
if self.original_next_tiff_frame is not None:
self.original_next_tiff_frame()
else:
# Fall back to calling the original method directly
RegistrationApp._next_tiff_frame(self)
self.update_live_overlay()
def _next_frame_with_overlay(self):
"""Regular video frame advance with live transform overlay"""
# Safety check: if original method is None, fall back to built-in method
if self.original_next_frame is not None:
self.original_next_frame()
else:
# Fall back to calling the original method directly
RegistrationApp._next_frame(self)
self.update_live_overlay()
[docs]
def update_live_overlay(self):
"""Update the live matplotlib overlay with current transformed frame"""
try:
if not hasattr(self, 'live_fig') or not hasattr(self, 'current_transform'):
return
# check if the plot window is still open
if not plt.fignum_exists(self.live_fig.number):
# restore original methods
self.stop_live_transform_plot()
return
# Get current frame and transform it
if self.wf_img is not None and self.current_transform is not None:
# Use the exact same frame that's displayed in the main GUI
current_frame = self.wf_img.copy()
# Convert frame to uint8 for consistent display
frame_min, frame_max = current_frame.min(), current_frame.max()
if current_frame.dtype != np.uint8:
# Apply the same normalization as np_to_qpixmap function
if np.issubdtype(current_frame.dtype, np.integer):
if current_frame.dtype == np.uint16:
# Common case: uint16 data
if frame_max > 255:
img_uint8 = (current_frame.astype(np.float32) / 65535.0 * 255).astype(np.uint8)
else:
img_uint8 = current_frame.astype(np.uint8)
else:
# Other integer types - use min/max normalization
if frame_max > frame_min:
norm = 255 * (current_frame.astype(np.float32) - frame_min) / (frame_max - frame_min)
img_uint8 = norm.astype(np.uint8)
else:
img_uint8 = np.full_like(current_frame, 128, dtype=np.uint8)
else:
# Float types
if frame_max <= 1.0 and frame_min >= 0.0:
img_uint8 = (current_frame * 255).astype(np.uint8)
else:
if frame_max > frame_min:
norm = 255 * (current_frame - frame_min) / (frame_max - frame_min)
img_uint8 = norm.astype(np.uint8)
else:
img_uint8 = np.full_like(current_frame, 128, dtype=np.uint8)
else:
img_uint8 = current_frame
# Apply transform with same settings as initial transform
h, w = self.dorsal_map.shape[:2]
warped = cv2.warpPerspective(
img_uint8, self.current_transform, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0
)
# Convert back to float for consistent display with dorsal map
warped_float = warped.astype(np.float32)
# Update the overlay image with consistent intensity scaling
self.live_overlay.set_data(warped_float)
# Use fixed intensity range for consistency
self.live_overlay.set_clim(vmin=0, vmax=255)
# Update frame counter
if hasattr(self, 'current_frame_idx') and self.current_frame_idx is not None:
if hasattr(self, 'tiff_frames'):
frame_text = f"Frame: {self.current_frame_idx}/{self.tiff_frames.shape[0]}"
else:
frame_text = f"Frame: {self.current_frame_idx}"
self.live_frame_text.set_text(frame_text)
# Refresh the plot
self.live_fig.canvas.draw_idle()
self.live_fig.canvas.flush_events()
except Exception as e:
self.log(f"Error updating live overlay: {e}", "error")
print(f"Live overlay error details: {e}") # Additional debug info
[docs]
class RegistrationOptions(AbstractParser):
DESCRIPTION = 'Register widefield images to a dorsal cortex map'
[docs]
def run(self):
app = QApplication.instance()
owns_app = app is None
if app is None:
app = QApplication([])
window = RegistrationApp()
window.show()
if owns_app:
app.exec()