"""Data-point overlay rendering for geolatent.
:class:`DataOverlay` is responsible for all trace types that are layered *on
top* of decision surfaces: per-class scatter clouds, class-centroid markers,
and optional confidence-ellipsoid surfaces.
Design rationale: each rendering method returns a list of independent Plotly
traces rather than mutating an existing figure. This keeps the overlay layer
stateless and trivially composable with scene objects produced by other modules.
"""
from __future__ import annotations
from typing import Dict, List, Optional
import numpy as np
import plotly.graph_objects as go
from ..config.themes import VisualizationConfig
from ..core.geometry import GeometryUtils
[docs]
class DataOverlay:
"""Generates data-point and structural overlays for 3-D scenes.
Parameters
----------
config : VisualizationConfig
Master configuration; colour palette, marker sizes, and opacity values
are all sourced from here.
Examples
--------
>>> overlay = DataOverlay(DARK_SCIENTIFIC)
>>> traces = overlay.render_scatter(X_3d, y)
>>> centroid = overlay.render_centroids(X_3d, y)
"""
def __init__(self, config: VisualizationConfig) -> None:
self.config = config
self._geometry = GeometryUtils(config)
# Scatter cloud
[docs]
def render_scatter(
self,
X_proj: np.ndarray,
y: np.ndarray,
*,
class_names: Optional[Dict] = None,
point_size_override: Optional[int] = None,
opacity_override: Optional[float] = None,
) -> List[go.Scatter3d]:
"""Render a per-class scatter trace for every unique class label.
One ``go.Scatter3d`` trace is produced per class, allowing Plotly's
interactive legend to toggle individual classes on and off.
Parameters
----------
X_proj : np.ndarray of shape (n_samples, 3)
3-D projected coordinates.
y : np.ndarray of shape (n_samples,)
Class label vector.
class_names : dict, optional
Mapping ``{label: display_string}``.
point_size_override : int, optional
Override ``config.render.marker_size``.
opacity_override : float, optional
Override ``config.render.scatter_opacity``.
Returns
-------
traces : list of go.Scatter3d
One trace per unique class, in sorted label order.
"""
colors = self.config.colors.class_colors
size = point_size_override or self.config.render.marker_size
opacity = opacity_override or self.config.render.scatter_opacity
unique_classes = np.unique(y)
traces: List[go.Scatter3d] = []
for idx, cls in enumerate(unique_classes):
mask = y == cls
pts = X_proj[mask]
color = colors[idx % len(colors)]
label = self._class_label(cls, class_names)
# Build hover text
hover_text = [
f"<b>{label}</b><br>"
f"x: {pt[0]:.4f}<br>y: {pt[1]:.4f}<br>z: {pt[2]:.4f}"
for pt in pts
]
traces.append(
go.Scatter3d(
x=pts[:, 0],
y=pts[:, 1],
z=pts[:, 2],
mode="markers",
marker=dict(
size=size,
color=color,
opacity=opacity,
line=dict(
color=self.config.colors.marker_line_color,
width=0.5,
),
symbol="circle",
),
name=label,
legendgroup=str(cls),
showlegend=True,
text=hover_text,
hovertemplate="%{text}<extra></extra>",
)
)
return traces
# Class centroids
[docs]
def render_centroids(
self,
X_proj: np.ndarray,
y: np.ndarray,
*,
class_names: Optional[Dict] = None,
) -> go.Scatter3d:
"""Render diamond-shaped class-centroid markers.
Delegates to :meth:`~geolatent.core.geometry.GeometryUtils.compute_class_centroids`.
Parameters
----------
X_proj : np.ndarray of shape (n_samples, 3)
y : np.ndarray of shape (n_samples,)
class_names : dict, optional
Returns
-------
trace : go.Scatter3d
"""
return self._geometry.compute_class_centroids(
X_proj, y, class_names=class_names
)
# Confidence ellipsoids
[docs]
def render_ellipsoids(
self,
X_proj: np.ndarray,
y: np.ndarray,
*,
confidence: float = 0.90,
class_names: Optional[Dict] = None,
) -> List[go.Surface]:
"""Render parametric confidence-ellipsoid surfaces for each class.
Delegates to :meth:`~geolatent.core.geometry.GeometryUtils.compute_class_ellipsoids`.
Parameters
----------
X_proj : np.ndarray of shape (n_samples, 3)
y : np.ndarray of shape (n_samples,)
confidence : float
Confidence level for the Mahalanobis-distance ellipsoid (default 0.90).
class_names : dict, optional
Returns
-------
traces : list of go.Surface
"""
return self._geometry.compute_class_ellipsoids(
X_proj, y, confidence=confidence, class_names=class_names
)
# Convex hulls
[docs]
def render_convex_hulls(
self,
X_proj: np.ndarray,
y: np.ndarray,
*,
class_names: Optional[Dict] = None,
) -> List[go.Mesh3d]:
"""Render transparent convex-hull surfaces around each class cluster.
Delegates to
:meth:`~geolatent.core.geometry.GeometryUtils.compute_convex_hull_traces`.
Parameters
----------
X_proj : np.ndarray of shape (n_samples, 3)
y : np.ndarray of shape (n_samples,)
class_names : dict, optional
Returns
-------
traces : list of go.Mesh3d
"""
return self._geometry.compute_convex_hull_traces(
X_proj, y, class_names=class_names
)
# Trajectory (optimisation path / attention walk)
[docs]
def render_trajectory(
self,
waypoints: np.ndarray,
*,
name: str = "Trajectory",
color: Optional[str] = None,
line_width: int = 4,
show_waypoints: bool = True,
) -> List[go.BaseTraceType]:
"""Render an ordered sequence of waypoints as a 3-D polyline.
Useful for visualising gradient-descent trajectories, attention paths,
or other sequential processes in the projected embedding space.
Parameters
----------
waypoints : np.ndarray of shape (n_steps, 3)
Ordered sequence of 3-D positions.
name : str
Legend label for the trajectory.
color : str, optional
Line colour (hex). Defaults to the first class colour.
line_width : int
Width of the line in pixels.
show_waypoints : bool
Whether to overlay individual step markers.
Returns
-------
traces : list of Plotly traces
"""
waypoints = np.asarray(waypoints, dtype=np.float64)
if waypoints.ndim != 2 or waypoints.shape[1] != 3:
raise ValueError(
"waypoints must be a 2-D array of shape (n_steps, 3)."
)
if color is None:
color = self.config.colors.class_colors[0]
traces: List[go.BaseTraceType] = [
go.Scatter3d(
x=waypoints[:, 0],
y=waypoints[:, 1],
z=waypoints[:, 2],
mode="lines",
line=dict(color=color, width=line_width),
name=name,
showlegend=True,
hoverinfo="skip",
)
]
if show_waypoints:
n = len(waypoints)
step_colors = [
f"rgba({int(0x58)},{int(0xa6)},{int(0xff)},{i / max(n - 1, 1):.2f})"
for i in range(n)
]
traces.append(
go.Scatter3d(
x=waypoints[:, 0],
y=waypoints[:, 1],
z=waypoints[:, 2],
mode="markers",
marker=dict(
size=4,
color=list(range(n)),
colorscale="Blues",
showscale=False,
),
name=f"{name} steps",
showlegend=False,
hovertemplate=(
"Step %{customdata}<br>"
"x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}<extra></extra>"
),
customdata=list(range(n)),
)
)
return traces
# Private helpers
@staticmethod
def _class_label(cls: object, class_names: Optional[Dict]) -> str:
if class_names and cls in class_names:
return str(class_names[cls])
return f"Class {cls}"