Source code for hofmann.model.colour

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING

import numpy as np

from hofmann.model.composition import Composition

if TYPE_CHECKING:
    from hofmann.model.atom_data import AtomData
    from hofmann.model.atom_style import AtomStyle

#: A colour specification accepted throughout hofmann.
#:
#: Can be any of:
#:
#: - A CSS colour name or hex string (e.g. ``"red"``, ``"#ff0000"``).
#: - A single float for grey (``0.0`` = black, ``1.0`` = white).
#: - An RGB tuple or list with values in ``[0, 1]``
#:   (e.g. ``(1.0, 0.0, 0.0)``).
#:
#: See :func:`normalise_colour` for conversion to a normalised RGB tuple.
Colour = str | float | tuple[float, float, float] | list[float]

#: A colourmap specification for atom-data colouring.
#:
#: Can be any of:
#:
#: - A matplotlib colourmap name (e.g. ``"viridis"``).
#: - A callable mapping a float in ``[0, 1]`` to an RGB or RGBA sequence.
#: - A matplotlib :class:`~matplotlib.colors.Colormap` object (which is
#:   callable and returns RGBA).
#:
#: Callables returning RGBA are automatically truncated to RGB by
#: :func:`_resolve_cmap`.
CmapSpec = str | Callable[[float], Sequence[float]]


[docs] def normalise_colour(colour: Colour) -> tuple[float, float, float]: """Convert a colour specification to a normalised (r, g, b) tuple. Accepts CSS colour names (e.g. ``"red"``), hex strings (e.g. ``"#FF0000"``), grey floats (e.g. ``0.7``), or RGB tuples (e.g. ``(1.0, 0.3, 0.3)``). Args: colour: The colour to normalise. Returns: A tuple of three floats in [0, 1]. Raises: ValueError: If the colour cannot be interpreted. """ if isinstance(colour, (int, float)) and not isinstance(colour, bool): f = float(colour) if not 0.0 <= f <= 1.0: raise ValueError(f"Grey value must be in [0, 1], got {f}") return (f, f, f) if isinstance(colour, (tuple, list)): if len(colour) != 3: raise ValueError( f"RGB sequence must have 3 elements, got {len(colour)}" ) r, g, b = (float(c) for c in colour) for name, val in [("r", r), ("g", g), ("b", b)]: if not 0.0 <= val <= 1.0: raise ValueError( f"RGB component {name} must be in [0, 1], got {val}" ) return (r, g, b) if isinstance(colour, str): from matplotlib.colors import to_rgb try: return to_rgb(colour) except ValueError: raise ValueError(f"Unrecognised colour name: {colour!r}") raise ValueError(f"Cannot interpret colour: {colour!r}")
def _species_colours( species: tuple[str | Composition, ...], atom_styles: dict[str, AtomStyle], ) -> list[tuple[float, float, float]]: """Return per-atom colours from species styles (the default path). For pure-string sites, returns the species' :class:`AtomStyle.colour`. For :class:`Composition` sites, returns the *dominant species*' colour (highest occupancy, alphabetical tiebreak). Sites with no matching style fall back to grey ``(0.5, 0.5, 0.5)``. Colours are normalised once per resolved label and cached, so the cost is proportional to the number of distinct dominant labels rather than the number of atoms. """ cache: dict[str, tuple[float, float, float]] = {} colours: list[tuple[float, float, float]] = [] for site in species: if isinstance(site, Composition): label = site.dominant_species else: label = site if label not in cache: style = atom_styles.get(label) if style is not None: cache[label] = normalise_colour(style.colour) else: cache[label] = (0.5, 0.5, 0.5) colours.append(cache[label]) return colours def _resolve_cmap( cmap: CmapSpec, ) -> Callable[[float], tuple[float, float, float]]: """Turn a colourmap specification into a callable float -> RGB. Accepts a colourmap name (string), a callable mapping a float in ``[0, 1]`` to a colour tuple, or a matplotlib ``Colormap`` object. The returned wrapper always produces 3-tuple ``(r, g, b)`` even if the underlying callable returns RGBA. Raises: TypeError: If *cmap* is not a string and not callable. """ if isinstance(cmap, str): import matplotlib fn: Callable[..., Sequence[float]] = matplotlib.colormaps[cmap] elif callable(cmap): fn = cmap else: raise TypeError(f"Unsupported cmap type: {type(cmap)}") def _wrap(val: float) -> tuple[float, float, float]: result = fn(val) return (result[0], result[1], result[2]) return _wrap def _resolve_single_layer( atom_data: dict[str, np.ndarray], key: str, fallback: list[tuple[float, float, float]], cmap: CmapSpec, colour_range: tuple[float, float] | None, scene_atom_data: AtomData | None = None, ) -> tuple[list[tuple[float, float, float]], np.ndarray]: """Resolve colours for a single colour_by key. Args: scene_atom_data: The scene's :class:`AtomData` container. When provided, derived global metadata is used for 2-D data so that colouring is consistent across animation frames: :attr:`AtomData.ranges` for numeric data and :attr:`AtomData.labels` for categorical data. Returns: A tuple of ``(colours, missing_mask)`` where *colours* is a per-atom list of ``(r, g, b)`` tuples and *missing_mask* is a boolean array that is ``True`` for atoms with missing data (which received their species fallback colour). """ values = atom_data[key] cmap_fn = _resolve_cmap(cmap) if values.dtype.kind in ("U", "O"): labels = None if scene_atom_data is not None: labels = scene_atom_data.labels[key] return _resolve_categorical(values, fallback, cmap_fn, labels) if colour_range is None and scene_atom_data is not None: colour_range = scene_atom_data.ranges[key] return _resolve_numerical(values, fallback, cmap_fn, colour_range) def _resolve_atom_colours( species: tuple[str | Composition, ...], atom_styles: dict[str, AtomStyle], atom_data: dict[str, np.ndarray], colour_by: str | list[str] | None = None, cmap: CmapSpec | list[CmapSpec] = "viridis", colour_range: tuple[float, float] | None | list[tuple[float, float] | None] = None, scene_atom_data: AtomData | None = None, ) -> list[tuple[float, float, float]]: """Resolve per-atom RGB colours, optionally using a colourmap. When *colour_by* is ``None`` (the default) the usual species-based colours from *atom_styles* are returned. When it is a single string, the named array from *atom_data* is mapped through *cmap*. When *colour_by* is a **list** of keys, each layer is tried in order and the first non-missing value (non-NaN for numerical, non-empty for categorical) determines the atom's colour. This allows different colouring rules for different atom subsets:: scene.set_atom_data("metal_type", by_index={0: "Fe", 2: "Co"}) scene.set_atom_data("o_coord", by_index={1: 4, 3: 6}) scene.render_mpl( colour_by=["metal_type", "o_coord"], cmap=["Set1", "Blues"], ) Args: species: Per-atom species labels. atom_styles: Species-to-style mapping. atom_data: Per-atom metadata arrays from the scene. colour_by: Key (or list of keys) into *atom_data* to colour by, or ``None`` for species-based colouring. When a list, layers are tried in priority order. cmap: A matplotlib colourmap name (e.g. ``"viridis"``), a matplotlib ``Colormap`` object, or a callable mapping a float in ``[0, 1]`` to an ``(r, g, b)`` tuple. When *colour_by* is a list, *cmap* may also be a list of the same length (one per layer). A single value is broadcast to all layers. colour_range: Explicit ``(vmin, vmax)`` for normalising numerical data. ``None`` auto-ranges from the data. Ignored for categorical data. When *colour_by* is a list, may also be a list of the same length. scene_atom_data: The scene's :class:`AtomData` container. When provided and a key's *colour_range* is ``None``, the derived global range from :attr:`AtomData.ranges` is used for 2-D numeric data so that colourmap scaling is consistent across frames. Returns: List of ``(r, g, b)`` tuples, one per atom. Raises: KeyError: If *colour_by* (or any key in the list) is not found in *atom_data*. ValueError: If *colour_by* is a list and *cmap* or *colour_range* is also a list of a different length, or if *colour_by* is a single string and *cmap* or *colour_range* is a list. """ if colour_by is None: return _species_colours(species, atom_styles) fallback = _species_colours(species, atom_styles) # --- Single key (common case) --- if isinstance(colour_by, str): if isinstance(cmap, list): raise ValueError( "cmap must not be a list when colour_by is a single string" ) if isinstance(colour_range, list): raise ValueError( "colour_range must not be a list when colour_by is a " "single string" ) colours, _mask = _resolve_single_layer( atom_data, colour_by, fallback, cmap, colour_range, scene_atom_data=scene_atom_data, ) return colours # --- List of keys (priority merge) --- n_layers = len(colour_by) # Broadcast cmap / colour_range to lists. if not isinstance(cmap, list): cmaps = [cmap] * n_layers else: cmaps = cmap if not isinstance(colour_range, list): ranges: list[tuple[float, float] | None] = [colour_range] * n_layers else: ranges = colour_range if len(cmaps) != n_layers: raise ValueError( f"colour_by has {n_layers} keys but cmap has " f"{len(cmaps)} entries" ) if len(ranges) != n_layers: raise ValueError( f"colour_by has {n_layers} keys but colour_range has " f"{len(ranges)} entries" ) # Resolve each layer independently. layers = [ _resolve_single_layer( atom_data, key, fallback, cm, cr, scene_atom_data=scene_atom_data, ) for key, cm, cr in zip(colour_by, cmaps, ranges) ] # Merge: first layer with non-missing data wins. n_atoms = len(species) result: list[tuple[float, float, float]] = list(fallback) for i in range(n_atoms): for layer_colours, layer_mask in layers: if not layer_mask[i]: result[i] = layer_colours[i] break return result def _resolve_numerical( values: np.ndarray, fallback: list[tuple[float, float, float]], cmap_fn: Callable[[float], tuple[float, float, float]], colour_range: tuple[float, float] | None, ) -> tuple[list[tuple[float, float, float]], np.ndarray]: """Map numerical values through a colourmap. Integer arrays are automatically coerced to float so that NaN sentinels (used for missing data) are representable. Returns: A tuple of ``(colours, missing_mask)`` where *missing_mask* is a boolean array that is ``True`` for atoms whose values are NaN. """ if values.dtype.kind not in ("b", "i", "u", "f"): raise ValueError( f"_resolve_numerical requires a numeric dtype " f"(bool, integer, or float), got {values.dtype}" ) values = values.astype(float, copy=False) mask = np.isnan(values) if colour_range is not None: vmin, vmax = colour_range else: valid = values[~mask] if len(valid) == 0: return list(fallback), mask vmin, vmax = float(np.min(valid)), float(np.max(valid)) if vmin == vmax: normalised = np.where(mask, np.nan, 0.5) else: normalised = (values - vmin) / (vmax - vmin) normalised = np.clip(normalised, 0.0, 1.0) colours: list[tuple[float, float, float]] = [] for i, val in enumerate(normalised): if mask[i]: colours.append(fallback[i]) else: colours.append(cmap_fn(float(val))) return colours, mask def _is_categorical_missing(v: object) -> bool: """Return True if *v* should be treated as a missing categorical value. Missing values are ``None``, empty strings, and float ``NaN`` (including numpy floating scalars such as ``np.float64('nan')``). """ if v is None: return True if isinstance(v, str) and v == "": return True if isinstance(v, (float, np.floating)) and np.isnan(v): return True return False def _resolve_categorical( values: np.ndarray, fallback: list[tuple[float, float, float]], cmap_fn: Callable[[float], tuple[float, float, float]], global_labels: tuple[str, ...] | None = None, ) -> tuple[list[tuple[float, float, float]], np.ndarray]: """Map categorical labels through a colourmap. Values of ``None``, ``NaN``, and empty strings are treated as missing and receive their species fallback colour. Args: global_labels: When provided, these labels define the colourmap positions (consistent across animation frames). Typically looked up via :attr:`AtomData.labels`, which holds unique labels across all frames. Returns: A tuple of ``(colours, missing_mask)`` where *missing_mask* is a boolean array that is ``True`` for atoms whose values are missing (``None``, ``NaN``, or empty string). """ # Build missing mask. missing = np.empty(len(values), dtype=bool) for i, v in enumerate(values): missing[i] = _is_categorical_missing(v) # Use global labels when available, otherwise discover from slice. seen: dict[str, int] if global_labels is not None: seen = {label: idx for idx, label in enumerate(global_labels)} else: seen = {} for i, v in enumerate(values): if not missing[i]: s = str(v) if s not in seen: seen[s] = len(seen) n_labels = len(seen) if n_labels == 0: return list(fallback), missing # Space labels evenly across [0, 1]. if n_labels == 1: positions = {label: 0.5 for label in seen} else: positions = { label: idx / (n_labels - 1) for label, idx in seen.items() } colours: list[tuple[float, float, float]] = [] for i, v in enumerate(values): if missing[i]: colours.append(fallback[i]) else: colours.append(cmap_fn(positions[str(v)])) return colours, missing