Source code for plotpy.items.quiver

# -*- coding: utf-8 -*-
#
# Licensed under the terms of the BSD 3-Clause
# (see plotpy/LICENSE for details)

"""
Quiver plot item
================

This module provides the :py:class:`QuiverItem` class for displaying 2D vector
fields (quiver plots), similar to Matplotlib's :py:func:`matplotlib.pyplot.quiver`.

.. autoclass:: QuiverItem
   :members:
"""

from __future__ import annotations

import math
import sys
from typing import TYPE_CHECKING

import numpy as np
from guidata.utils.misc import assert_interfaces_valid
from qtpy import QtCore as QC
from qtpy import QtGui as QG
from qwt import QwtPlotItem

from plotpy.interfaces import IBasePlotItem, IDecoratorItemType

if TYPE_CHECKING:
    import qwt.scale_map
    from qtpy.QtCore import QPointF, QRectF
    from qtpy.QtGui import QPainter

    from plotpy.interfaces import IItemType
    from plotpy.styles.base import ItemParameters


[docs] class QuiverItem(QwtPlotItem): """Quiver (2D vector field) plot item Displays arrows at grid positions (X, Y) with direction and magnitude defined by (U, V) components, similar to Matplotlib's ``quiver``. Args: x: 1D or 2D array of arrow X positions y: 1D or 2D array of arrow Y positions u: 1D or 2D array of arrow X components v: 1D or 2D array of arrow Y components color: Arrow color (default: black) arrow_scale: Scale factor for arrow length in pixels (default: 30.0). Larger values produce longer arrows. arrow_head_size: Size of arrow heads in pixels (default: 6.0) headwidth: Arrow head width as a multiple of head size (default: 0.7) """ __implements__ = (IBasePlotItem,) _readonly = True _private = False _can_select = True _can_resize = False _can_rotate = False _can_move = False _icon_name = "quiver.png" def __init__( self, x: np.ndarray, y: np.ndarray, u: np.ndarray, v: np.ndarray, color: str | QG.QColor = "black", arrow_scale: float = 30.0, arrow_head_size: float = 6.0, headwidth: float = 0.7, ) -> None: super().__init__() self.setItemAttribute(QwtPlotItem.AutoScale, True) self.selected = False self.immutable = True self._set_data(x, y, u, v) self._color = QG.QColor(color) self._arrow_scale = arrow_scale self._arrow_head_size = arrow_head_size self._headwidth = headwidth # ---- Data handling ---------------------------------------------------------- def _set_data( self, x: np.ndarray, y: np.ndarray, u: np.ndarray, v: np.ndarray, ) -> None: """Set and validate vector field data. Args: x: Arrow X positions (1D or 2D) y: Arrow Y positions (1D or 2D) u: Arrow X direction components (1D or 2D) v: Arrow Y direction components (1D or 2D) """ x = np.asarray(x, dtype=np.float64) y = np.asarray(y, dtype=np.float64) u = np.asarray(u, dtype=np.float64) v = np.asarray(v, dtype=np.float64) # If X, Y are 1D and U, V are 2D, expand via meshgrid if x.ndim == 1 and y.ndim == 1 and u.ndim == 2: x, y = np.meshgrid(x, y) # Flatten all arrays for uniform processing self._x = x.ravel() self._y = y.ravel() self._u = u.ravel() self._v = v.ravel() if not (self._x.size == self._y.size == self._u.size == self._v.size): raise ValueError("x, y, u, v arrays must have the same number of elements")
[docs] def set_data( self, x: np.ndarray, y: np.ndarray, u: np.ndarray, v: np.ndarray, ) -> None: """Set vector field data and trigger replot. Args: x: Arrow X positions (1D or 2D) y: Arrow Y positions (1D or 2D) u: Arrow X direction components (1D or 2D) v: Arrow Y direction components (1D or 2D) """ self._set_data(x, y, u, v) plot = self.plot() if plot is not None: plot.replot()
[docs] def get_data(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Return the vector field data. Returns: Tuple of (x, y, u, v) arrays """ return self._x, self._y, self._u, self._v
# ---- Style accessors -------------------------------------------------------
[docs] def set_color(self, color: str | QG.QColor) -> None: """Set arrow color. Args: color: Color name string or QColor """ self._color = QG.QColor(color)
[docs] def get_color(self) -> QG.QColor: """Return arrow color. Returns: Arrow color """ return self._color
# ---- QwtPlotItem interface -------------------------------------------------
[docs] def boundingRect(self) -> QC.QRectF: """Return the bounding rectangle of the data. The bounding rectangle is expanded by a margin to ensure arrows at the edges of the field remain visible when the plot auto-scales. Returns: Bounding rectangle in data coordinates """ if self._x.size == 0: return QC.QRectF() xmin, xmax = float(self._x.min()), float(self._x.max()) ymin, ymax = float(self._y.min()), float(self._y.max()) # Add margin to account for arrow length extending beyond base points. # Use 10% of the data range, or fall back to a fraction of the min # absolute coordinate value for degenerate cases (e.g., single column). dx = xmax - xmin dy = ymax - ymin margin_x = dx * 0.1 if dx > 0 else abs(xmin) * 0.1 + 1.0 margin_y = dy * 0.1 if dy > 0 else abs(ymin) * 0.1 + 1.0 return QC.QRectF( xmin - margin_x, ymin - margin_y, dx + 2 * margin_x, dy + 2 * margin_y, )
[docs] def is_empty(self) -> bool: """Return True if the item has no data. Returns: True if the item is empty, False otherwise """ return self._x.size == 0
[docs] def draw( self, painter: QPainter, xMap: qwt.scale_map.QwtScaleMap, yMap: qwt.scale_map.QwtScaleMap, canvasRect: QRectF, ) -> None: """Draw the vector field. Args: painter: QPainter instance xMap: X axis scale map (data -> pixel) yMap: Y axis scale map (data -> pixel) canvasRect: Canvas rectangle in pixel coordinates """ if self._x.size == 0: return painter.save() painter.setRenderHint(QG.QPainter.Antialiasing) # Compute magnitudes and normalize direction vectors mag = np.sqrt(self._u**2 + self._v**2) max_mag = mag.max() if max_mag == 0: painter.restore() return # Normalize to [0, 1] range norm_mag = mag / max_mag pen_width = 1.5 if not self.selected else 2.5 color = self._color if self.selected: color = QG.QColor("red") pen = QG.QPen(color, pen_width) pen.setCapStyle(QC.Qt.RoundCap) pen.setJoinStyle(QC.Qt.RoundJoin) painter.setPen(pen) painter.setBrush(QG.QBrush(color)) arrow_head_size = self._arrow_head_size head_angle = math.pi / 6 # 30 degrees headwidth = self._headwidth for i in range(self._x.size): if mag[i] < 1e-12: continue # Transform base point to canvas coordinates px = xMap.transform(self._x[i]) py = yMap.transform(self._y[i]) # Arrow direction (in pixel space, note Y is inverted) dx = self._u[i] / max_mag dy = -self._v[i] / max_mag # Negate for screen coordinates # Scale arrow length length = norm_mag[i] * self._arrow_scale end_x = px + dx * length end_y = py + dy * length # Draw shaft painter.drawLine(QC.QPointF(px, py), QC.QPointF(end_x, end_y)) # Draw arrow head if length < 2.0: continue # Too small for a head angle = math.atan2(end_y - py, end_x - px) h1x = end_x - arrow_head_size * math.cos(angle - head_angle * headwidth) h1y = end_y - arrow_head_size * math.sin(angle - head_angle * headwidth) h2x = end_x - arrow_head_size * math.cos(angle + head_angle * headwidth) h2y = end_y - arrow_head_size * math.sin(angle + head_angle * headwidth) head = QG.QPolygonF( [ QC.QPointF(end_x, end_y), QC.QPointF(h1x, h1y), QC.QPointF(h2x, h2y), ] ) painter.drawPolygon(head) painter.restore()
# ---- IBasePlotItem interface -----------------------------------------------
[docs] def types(self) -> tuple[type[IItemType], ...]: """Returns a group or category for this item. Returns: Tuple of class objects inheriting from IItemType """ return (IDecoratorItemType,)
[docs] def set_readonly(self, state: bool) -> None: """Set object readonly state. Args: state: True if object is readonly, False otherwise """ self._readonly = state
[docs] def is_readonly(self) -> bool: """Return object readonly state. Returns: True if object is readonly, False otherwise """ return self._readonly
[docs] def set_private(self, state: bool) -> None: """Set object as private. Args: state: True if object is private, False otherwise """ self._private = state
[docs] def is_private(self) -> bool: """Return True if object is private. Returns: True if object is private, False otherwise """ return self._private
[docs] def get_icon_name(self) -> str: """Return the icon name. Returns: Icon name """ return self._icon_name
[docs] def set_icon_name(self, icon_name: str) -> None: """Set the icon name. Args: icon_name: Icon name """ self._icon_name = icon_name
[docs] def set_selectable(self, state: bool) -> None: """Set item selectable state. Args: state: True if item is selectable, False otherwise """ self._can_select = state
[docs] def set_resizable(self, state: bool) -> None: """Set item resizable state. Args: state: True if item is resizable, False otherwise """ self._can_resize = state
[docs] def set_movable(self, state: bool) -> None: """Set item movable state. Args: state: True if item is movable, False otherwise """ self._can_move = state
[docs] def set_rotatable(self, state: bool) -> None: """Set item rotatable state. Args: state: True if item is rotatable, False otherwise """ self._can_rotate = state
[docs] def can_select(self) -> bool: """Returns True if this item can be selected. Returns: True if item can be selected, False otherwise """ return self._can_select
[docs] def can_resize(self) -> bool: """Returns True if this item can be resized. Returns: True if item can be resized, False otherwise """ return self._can_resize
[docs] def can_rotate(self) -> bool: """Returns True if this item can be rotated. Returns: True if item can be rotated, False otherwise """ return self._can_rotate
[docs] def can_move(self) -> bool: """Returns True if this item can be moved. Returns: True if item can be moved, False otherwise """ return self._can_move
[docs] def select(self) -> None: """Select the object and highlight it.""" self.selected = True plot = self.plot() if plot is not None: plot.replot()
[docs] def unselect(self) -> None: """Unselect the object and restore its appearance.""" self.selected = False plot = self.plot() if plot is not None: plot.replot()
[docs] def hit_test(self, pos: QPointF) -> tuple[float, float, bool, None]: """Return a tuple (distance, attach point, inside, other_object). Args: pos: Position in canvas coordinates Returns: Tuple with four elements (distance, attach point, inside, other_object) """ plot = self.plot() if plot is None or self._x.size == 0: return sys.maxsize, 0, False, None # Convert click position to data coordinates cx, cy = pos.x(), pos.y() # Find the closest arrow base point dist = sys.maxsize for i in range(self._x.size): px = plot.transform(self.xAxis(), self._x[i]) py = plot.transform(self.yAxis(), self._y[i]) d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2) if d < dist: dist = d # Consider "inside" if within the bounding rect on canvas rect = self.boundingRect() x0 = plot.transform(self.xAxis(), rect.left()) y0 = plot.transform(self.yAxis(), rect.top()) x1 = plot.transform(self.xAxis(), rect.right()) y1 = plot.transform(self.yAxis(), rect.bottom()) canvas_rect = QC.QRectF( QC.QPointF(min(x0, x1), min(y0, y1)), QC.QPointF(max(x0, x1), max(y0, y1)), ) inside = canvas_rect.contains(QC.QPointF(cx, cy)) return dist, 0, inside, None
[docs] def update_item_parameters(self) -> None: """Update item parameters (dataset) from object properties."""
[docs] def get_item_parameters(self, itemparams: ItemParameters) -> None: """Appends datasets to the list of DataSets describing the parameters. Args: itemparams: Item parameters """
[docs] def set_item_parameters(self, itemparams: ItemParameters) -> None: """Change the appearance of this item according to the parameter set. Args: itemparams: Item parameters """
[docs] def move_local_point_to(self, handle: int, pos: QPointF, ctrl: bool = None) -> None: """Move a handle as returned by hit_test to the new position. Args: handle: Handle pos: Position ctrl: True if <Ctrl> button is being pressed, False otherwise """
[docs] def move_local_shape(self, old_pos: QPointF, new_pos: QPointF) -> None: """Translate the shape such that old_pos becomes new_pos. Args: old_pos: Old position new_pos: New position """
[docs] def move_with_selection(self, delta_x: float, delta_y: float) -> None: """Translate the item together with other selected items. Args: delta_x: Translation in plot coordinates along x-axis delta_y: Translation in plot coordinates along y-axis """
assert_interfaces_valid(QuiverItem)