Source code for hofmann.rendering.interactive

"""Interactive matplotlib viewer with mouse and keyboard controls."""

from __future__ import annotations

import time
from typing import Any

import matplotlib.pyplot as plt
import numpy as np

from hofmann.model import (
    CmapSpec,
    Colour,
    RenderStyle,
    StructureScene,
    ViewState,
    normalise_colour,
)
from hofmann.rendering.painter import _draw_scene
from hofmann.rendering.precompute import _precompute_scene
from hofmann.rendering.projection import _scene_extent
from hofmann.rendering.static import _resolve_style


# ---------------------------------------------------------------------------
# Rotation helpers
# ---------------------------------------------------------------------------

def _rotation_x(angle: float) -> np.ndarray:
    """Rotation matrix about the X axis by *angle* radians."""
    c, s = np.cos(angle), np.sin(angle)
    return np.array([
        [1.0, 0.0, 0.0],
        [0.0,   c,  -s],
        [0.0,   s,   c],
    ])


def _rotation_y(angle: float) -> np.ndarray:
    """Rotation matrix about the Y axis by *angle* radians."""
    c, s = np.cos(angle), np.sin(angle)
    return np.array([
        [ c,  0.0,  s],
        [0.0, 1.0, 0.0],
        [-s,  0.0,  c],
    ])


def _rotation_z(angle: float) -> np.ndarray:
    """Rotation matrix about the Z axis by *angle* radians."""
    c, s = np.cos(angle), np.sin(angle)
    return np.array([
        [  c,  -s, 0.0],
        [  s,   c, 0.0],
        [0.0, 0.0, 1.0],
    ])


# ---------------------------------------------------------------------------
# Interactive renderer — constants
# ---------------------------------------------------------------------------

_KEY_ROTATION_STEP = 0.05  # radians (~3 degrees) per key press
_KEY_ZOOM_FACTOR = 1.1  # multiplicative zoom per key press / scroll step
_KEY_PAN_FRACTION = 0.05  # fraction of scene extent per key press
_PERSPECTIVE_STEP = 0.1  # perspective increment per key press
_DISTANCE_FACTOR = 1.05  # viewing distance multiplier per key press

_HELP_TEXT = """\
Arrows     Rotate          Shift+Arrows  Pan
,  .       Roll            +  =  -       Zoom
p  P       Perspective     d  D          Distance
b          Bonds           o             Outlines
e          Polyhedra       u             Unit cell
a          Axes            r             Reset view
[  ]       Prev/next frame {  }          First/last
f          Frame indicator g             Go to frame
s          Set step size
h          Toggle help     Scroll        Zoom
Drag       Rotate"""


def _apply_key_action(
    key: str,
    view: "ViewState",
    style: "RenderStyle",
    state: dict,
    *,
    n_frames: int,
    base_extent: float,
    initial_view: dict,
    has_lattice: bool = False,
) -> str:
    """Apply a keyboard action, mutating *view*, *style*, and *state*.

    Returns a string indicating the required redraw kind:

    - ``"view"`` — view-only change (rotation, zoom, pan, etc.).
    - ``"full"`` — style or frame change needing recomputation.
    - ``"none"`` — unrecognised key, no redraw needed.
    """
    # -- Number input mode --
    input_mode = state.get("input_mode")
    if input_mode is not None:
        if key in "0123456789":
            state["input_buffer"] = state.get("input_buffer", "") + key
            return "view"
        elif key == "enter":
            buf = state.get("input_buffer", "")
            state["input_mode"] = None
            state["input_buffer"] = ""
            if input_mode == "goto":
                if buf:
                    target = int(buf)
                    state["frame_index"] = max(0, min(target, n_frames - 1))
                    return "full"
            elif input_mode == "step":
                if buf:
                    val = int(buf)
                    if val >= 1:
                        state["frame_step"] = min(val, n_frames - 1)
            return "view"
        elif key == "escape":
            state["input_mode"] = None
            state["input_buffer"] = ""
            return "view"
        else:
            # Swallow all other keys during input mode.
            return "view"

    # -- Rotation --
    if key == "left":
        view.rotation = _rotation_y(-_KEY_ROTATION_STEP) @ view.rotation
    elif key == "right":
        view.rotation = _rotation_y(_KEY_ROTATION_STEP) @ view.rotation
    elif key == "up":
        view.rotation = _rotation_x(-_KEY_ROTATION_STEP) @ view.rotation
    elif key == "down":
        view.rotation = _rotation_x(_KEY_ROTATION_STEP) @ view.rotation
    elif key == ",":
        view.rotation = _rotation_z(_KEY_ROTATION_STEP) @ view.rotation
    elif key == ".":
        view.rotation = _rotation_z(-_KEY_ROTATION_STEP) @ view.rotation

    # -- Zoom --
    elif key in ("+", "="):
        view.zoom = min(100.0, view.zoom * _KEY_ZOOM_FACTOR)
    elif key == "-":
        view.zoom = max(0.01, view.zoom / _KEY_ZOOM_FACTOR)

    # -- Pan (shift + arrows) --
    # Moving the centre in screen-right shifts the *camera* right,
    # so the scene appears to move left.  Negate so that the arrow
    # direction matches the apparent scene movement.
    elif key == "shift+left":
        step = _KEY_PAN_FRACTION * base_extent / view.zoom
        view.centre = view.centre + step * view.rotation[0]
    elif key == "shift+right":
        step = _KEY_PAN_FRACTION * base_extent / view.zoom
        view.centre = view.centre - step * view.rotation[0]
    elif key == "shift+down":
        step = _KEY_PAN_FRACTION * base_extent / view.zoom
        view.centre = view.centre + step * view.rotation[1]
    elif key == "shift+up":
        step = _KEY_PAN_FRACTION * base_extent / view.zoom
        view.centre = view.centre - step * view.rotation[1]

    # -- Perspective / distance --
    elif key == "p":
        view.perspective = min(1.0, view.perspective + _PERSPECTIVE_STEP)
    elif key == "P":
        view.perspective = max(0.0, view.perspective - _PERSPECTIVE_STEP)
    elif key == "d":
        view.view_distance *= _DISTANCE_FACTOR
    elif key == "D":
        view.view_distance = max(0.1, view.view_distance / _DISTANCE_FACTOR)

    # -- Style toggles (no recomputation needed) --
    elif key == "b":
        style.show_bonds = not style.show_bonds
    elif key == "o":
        style.show_outlines = not style.show_outlines
    elif key == "e":
        style.show_polyhedra = not style.show_polyhedra
    elif key == "u":
        # Resolve None (auto-detect) to the effective value before toggling.
        effective = style.show_cell if style.show_cell is not None else has_lattice
        style.show_cell = not effective
    elif key == "a":
        effective = style.show_axes if style.show_axes is not None else has_lattice
        style.show_axes = not effective

    # -- Frame navigation --
    elif key == "]" and n_frames > 1:
        step = state.get("frame_step", 1)
        state["frame_index"] = (state["frame_index"] + step) % n_frames
        return "full"
    elif key == "[" and n_frames > 1:
        step = state.get("frame_step", 1)
        state["frame_index"] = (state["frame_index"] - step) % n_frames
        return "full"
    elif key == "}" and n_frames > 1:
        state["frame_index"] = n_frames - 1
        return "full"
    elif key == "{" and n_frames > 1:
        state["frame_index"] = 0
        return "full"

    # -- Input mode entry --
    elif key == "g" and n_frames > 1:
        state["input_mode"] = "goto"
        state["input_buffer"] = ""
        return "view"
    elif key == "s" and n_frames > 1:
        state["input_mode"] = "step"
        state["input_buffer"] = ""
        return "view"
    elif key in ("g", "s") and n_frames <= 1:
        return "none"

    # -- Frame indicator --
    elif key == "f" and n_frames > 1:
        state["indicator_visible"] = not state.get("indicator_visible", False)
        return "view"
    elif key == "f" and n_frames <= 1:
        return "none"

    # -- Reset --
    elif key == "r":
        view.rotation = initial_view["rotation"].copy()
        view.zoom = initial_view["zoom"]
        view.centre = initial_view["centre"].copy()
        view.perspective = initial_view["perspective"]
        view.view_distance = initial_view["view_distance"]

    # -- Help overlay --
    elif key == "h":
        state["help_visible"] = not state["help_visible"]

    else:
        return "none"

    return "view"


[docs] def render_mpl_interactive( scene: StructureScene, *, style: RenderStyle | None = None, frame_index: int = 0, figsize: tuple[float, float] = (5.0, 5.0), dpi: int = 150, background: Colour = "white", colour_by: str | list[str] | None = None, cmap: CmapSpec | list[CmapSpec] = "viridis", colour_range: tuple[float, float] | None | list[tuple[float, float] | None] = None, **style_kwargs: object, ) -> tuple["ViewState", "RenderStyle"]: """Interactive matplotlib viewer with mouse and keyboard controls. Opens a matplotlib window where the user can manipulate the view with the mouse and keyboard: **Mouse:** - **Left-drag** to rotate the structure. - **Scroll** to zoom in/out. **Keyboard:** - **Arrow keys** rotate around the horizontal/vertical axes. - **,** / **.** roll in the screen plane. - **+** / **=** / **-** zoom in/out. - **Shift+Arrow** keys pan the view. - **p** / **P** increase/decrease perspective strength. - **d** / **D** increase/decrease viewing distance. - **b** toggle bonds, **o** toggle outlines, **e** toggle polyhedra, **u** toggle unit cell, **a** toggle axes widget. - **[** / **]** step to the previous/next frame; **{** / **}** jump to the first/last frame. - **f** toggle frame indicator, **g** go to a specific frame, **s** set frame step size. - **r** reset the view to its initial state. - **h** toggle a help overlay listing all keybindings. When the window is closed the updated :class:`ViewState` and :class:`RenderStyle` are returned, allowing the user to re-use both for static rendering:: view, style = scene.render_mpl_interactive() scene.view = view scene.render_mpl("output.svg", style=style) Args: scene: The StructureScene to render. style: A :class:`RenderStyle` controlling visual appearance. Any :class:`RenderStyle` field name may also be passed as a keyword argument to override individual fields. frame_index: Which frame to render initially. figsize: Figure size in inches ``(width, height)``. dpi: Resolution. background: Background colour. colour_by: Key (or list of keys) into ``scene.atom_data`` to colour atoms by. cmap: Matplotlib colourmap name, object, or callable. When *colour_by* is a list, may also be a list of the same length. colour_range: Explicit ``(vmin, vmax)`` for numerical data. When *colour_by* is a list, may also be a list of the same length. **style_kwargs: Any :class:`RenderStyle` field name as a keyword argument. Unknown names raise :class:`TypeError`. Returns: A ``(ViewState, RenderStyle)`` tuple reflecting any view and style changes applied during the interactive session. """ resolved = _resolve_style(style, **style_kwargs) n_frames = len(scene.frames) if not 0 <= frame_index < n_frames: raise ValueError( f"frame_index {frame_index} out of range for scene " f"with {n_frames} frame(s)" ) bg_rgb = normalise_colour(background) # Use lower-fidelity polygon counts for interactive responsiveness. # Save the static values so we can restore them before returning. static_circle_segments = resolved.circle_segments static_arc_segments = resolved.arc_segments resolved.circle_segments = resolved.interactive_circle_segments resolved.arc_segments = resolved.interactive_arc_segments # Work on a copy so we don't mutate the original scene's view. view = ViewState( rotation=scene.view.rotation.copy(), zoom=scene.view.zoom, centre=scene.view.centre.copy(), perspective=scene.view.perspective, view_distance=scene.view.view_distance, slab_origin=( scene.view.slab_origin.copy() if scene.view.slab_origin is not None else None ), slab_near=scene.view.slab_near, slab_far=scene.view.slab_far, ) # Fixed viewport extent — rotation-invariant so the scene doesn't # appear to shift or rescale while dragging. base_extent = _scene_extent(scene, view, frame_index, resolved.atom_scale) # Pre-compute bonds, colours, adjacency for the current frame — # these don't change during interactive rotation / zoom but are # recomputed on frame navigation. colour_kwargs: dict[str, Any] = dict( colour_by=colour_by, cmap=cmap, colour_range=colour_range, ) pre = _precompute_scene(scene, frame_index, resolved, **colour_kwargs) draw_kwargs: dict[str, Any] = dict( frame_index=frame_index, bg_rgb=bg_rgb, precomputed=pre, ) fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) fig.set_facecolor(bg_rgb) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) _draw_scene(ax, scene, view, resolved, viewport_extent=base_extent, **draw_kwargs) # ---- Interaction state ---- state: dict = { "drag_active": False, "drag_last_xy": None, "last_draw_t": 0.0, "frame_index": frame_index, "help_visible": False, "precomputed": pre, "base_extent": base_extent, "frame_step": 1, "input_mode": None, "input_buffer": "", "indicator_visible": False, } # Snapshot for the reset key. initial_view = { "rotation": view.rotation.copy(), "zoom": view.zoom, "centre": view.centre.copy(), "perspective": view.perspective, "view_distance": view.view_distance, } _DRAG_SENSITIVITY = 0.01 # radians per pixel _MIN_INTERVAL = 0.03 # seconds between redraws (~30 fps cap) # ---- Redraw helpers ---- def _redraw() -> None: """Repaint the scene using the fixed viewport extent.""" _draw_scene( ax, scene, view, resolved, viewport_extent=state["base_extent"], **draw_kwargs, ) if state["help_visible"]: _add_help_overlay() if state.get("indicator_visible") or state.get("input_mode"): _add_frame_indicator() fig.canvas.draw_idle() state["last_draw_t"] = time.monotonic() def _throttled_redraw() -> None: """Redraw only if enough time has elapsed since the last draw.""" if time.monotonic() - state["last_draw_t"] >= _MIN_INTERVAL: _redraw() def _full_redraw() -> None: """Recompute bonds/colours and repaint (for frame or style changes).""" state["precomputed"] = _precompute_scene( scene, state["frame_index"], resolved, **colour_kwargs, ) draw_kwargs["precomputed"] = state["precomputed"] draw_kwargs["frame_index"] = state["frame_index"] state["base_extent"] = _scene_extent( scene, view, state["frame_index"], resolved.atom_scale, ) _redraw() def _add_help_overlay() -> None: """Add the keybinding help text to the axes.""" ax.text( 0.02, 0.98, _HELP_TEXT, transform=ax.transAxes, fontsize=7, fontfamily="monospace", verticalalignment="top", bbox=dict( boxstyle="round,pad=0.5", facecolor="white", alpha=0.85, edgecolor="grey", ), zorder=1000, ) def _add_frame_indicator() -> None: """Add the frame indicator or input prompt at bottom-centre.""" idx = state["frame_index"] total = n_frames step = state.get("frame_step", 1) input_mode = state.get("input_mode") if input_mode == "goto": text = f"Go to: {state.get('input_buffer', '')}_" elif input_mode == "step": text = f"Step: {state.get('input_buffer', '')}_" elif step > 1: text = f"Frame {idx} / {total} (step {step})" else: text = f"Frame {idx} / {total}" ax.text( 0.5, 0.02, text, transform=ax.transAxes, fontsize=7, fontfamily="monospace", ha="center", va="bottom", color="0.4", zorder=1000, ) # ---- Mouse handlers ---- def on_press(event): if event.inaxes != ax or event.button != 1: return state["drag_active"] = True state["drag_last_xy"] = (event.x, event.y) def on_motion(event): if not state["drag_active"] or state["drag_last_xy"] is None: return x0, y0 = state["drag_last_xy"] dx = event.x - x0 dy = event.y - y0 state["drag_last_xy"] = (event.x, event.y) # Incremental rotation in screen space: horizontal drag rotates # around the screen Y axis, vertical drag around screen X. # Applying to the *current* rotation gives intuitive "grab and # drag the object" behaviour regardless of accumulated rotation. view.rotation = ( _rotation_y(dx * _DRAG_SENSITIVITY) @ _rotation_x(-dy * _DRAG_SENSITIVITY) @ view.rotation ) _throttled_redraw() def on_release(event): if state["drag_active"]: state["drag_active"] = False # Final redraw to ensure the last position is rendered. _redraw() def on_scroll(event): if event.inaxes != ax or event.step is None: return factor = _KEY_ZOOM_FACTOR ** event.step view.zoom = max(0.01, min(100.0, view.zoom * factor)) _redraw() # ---- Keyboard handler ---- def on_key_press(event): if event.key is None: return kind = _apply_key_action( event.key, view, resolved, state, n_frames=len(scene.frames), base_extent=state["base_extent"], initial_view=initial_view, has_lattice=scene.lattice is not None, ) if kind == "full": _full_redraw() elif kind == "view": _throttled_redraw() # ---- Connect events ---- fig.canvas.mpl_connect("button_press_event", on_press) fig.canvas.mpl_connect("motion_notify_event", on_motion) fig.canvas.mpl_connect("button_release_event", on_release) fig.canvas.mpl_connect("scroll_event", on_scroll) fig.canvas.mpl_connect("key_press_event", on_key_press) # Disconnect matplotlib's default key handler to avoid conflicts # (e.g. 'p' for pan tool, 'o' for zoom-to-rect). manager = fig.canvas.manager if manager is not None: handler_id = getattr(manager, "key_press_handler_id", None) if handler_id is not None: fig.canvas.mpl_disconnect(handler_id) try: plt.show() finally: # Restore static-quality segment counts so the returned style # is ready for publication rendering. resolved.circle_segments = static_circle_segments resolved.arc_segments = static_arc_segments plt.close(fig) return view, resolved