Source code for hari_plotter.plotter

from __future__ import annotations

import os
import shutil
import tempfile
import warnings
from typing import TYPE_CHECKING, Any, Optional

    from .plot import Plot

import imageio
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from .color_scheme import ColorScheme
from .interface import Interface

plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0

[docs] class Plotter: _parameter_dict = {'Time': 'Time', 'Opinion': 'Node Opinion', 'Cluster size': 'Cluster size', 'Importance': 'Node Importance', 'Label': 'Node Label', 'Neighbor mean opinion': 'Node Neighbor Mean Opinion', 'Activity': 'Node Activity', 'Inner opinions': 'Node Inner Opinions', 'Max opinion': 'Node Max Opinion', 'Min opinion': 'Node Min Opinion'} _plot_types: dict[str, type[Plot]] = {}
[docs] class PlotSaver: """ A utility class to handle the saving and display of plots. It provides functionality to save individual plots, display them, and even create GIFs from a sequence of plots. """ def __init__(self, mode: str | list[str] = 'show', save_path: Optional[str] = None, save_format: Optional[str] = 'image_{}', animation_path: Optional[str] = None) -> None: """ Initialize the PlotSaver instance. Args: mode (str | list[str]): The mode(s) in which to operate. It can be a list or a single string, e.g. ['show', 'save'] or 'gif'. Available modes: ["show", "save", "gif", "mp4"] save_path (Optional[str]): Path to save individual plots (used if 'save' is in mode) save_format (Optional[str]): string with {} for formatting in the number animation_path (Optional[str]): Path to save gif (used if 'gif' is in mode). """ # Ensure mode is a list even if a single mode string is provided self.mode = mode if isinstance(mode, list) else [mode] if not os.path.exists(save_path): warnings.warn(f"Path {save_path} does not exist. Creating it.") # Create the directory, including any necessary parent directories os.makedirs(save_path, exist_ok=True) self.save_path = os.Path( save_path) if save_path[-1] == '/' else save_path+'/' self.save_format = save_format self.animation_path = animation_path self.saved_images = [] self.temp_dir = None
[docs] @staticmethod def is_inside_jupyter() -> bool: """ Determine if the current environment is Jupyter Notebook. Returns: bool: True if inside Jupyter Notebook, False otherwise. """ try: get_ipython return True except NameError: return False
def __enter__(self) -> Plotter.PlotSaver: """ Entry point for the context manager. Returns: Plotter.PlotSaver: The current instance of the PlotSaver. """ return self
[docs] def save(self, fig: matplotlib.figure.Figure) -> None: """ Save and/or display the provided figure based on the specified mode. Args: fig (matplotlib.figure.Figure): The figure to be saved or displayed. """ plt.tight_layout() # Save the figure if 'save' mode is active and save_path is provided if 'save' in self.mode and self.save_path: path = self.save_path + \ self.save_format.format(len(self.saved_images)) fig.savefig(path) self.saved_images.append(path) # If only 'gif' or 'mp4' mode is selected, save figure to a temp directory elif ('gif' in self.mode or 'mp4' in self.mode) and not self.save_path: if not self.temp_dir: self.temp_dir = tempfile.mkdtemp() temp_path = os.path.join( self.temp_dir, "tmp_plot_{}.png".format(len(self.saved_images))) fig.savefig(temp_path) self.saved_images.append(temp_path) # Show the figure if 'show' mode is active if 'show' in self.mode: if self.is_inside_jupyter(): # In Jupyter, let the figure be displayed automatically display(fig) else: # Outside Jupyter, use to display the figure # Close the figure after processing plt.close(fig)
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[object]) -> None: """ Exit point for the context manager. Args: exc_type (Optional[type]): The exception type if raised inside the context. exc_val (Optional[Exception]): The exception instance if raised inside the context. exc_tb (Optional[object]): The traceback if an exception was raised inside the context. """ # If 'gif' mode is active and animation_path is provided, create a GIF from the saved images if 'gif' in self.mode and self.animation_path and self.saved_images: with imageio.get_writer(self.animation_path+'.gif', mode='I') as writer: for img_path in self.saved_images: image = imageio.imread(img_path) writer.append_data(image) # # Create MP4 animation if mode is selected if 'mp4' in self.mode and self.animation_path and self.saved_images: with imageio.get_writer(self.animation_path+'.mp4', mode='I', fps=5, codec='libx264') as writer: for img_path in self.saved_images: image = imageio.imread(img_path) writer.append_data(image) # Cleanup temporary directory if it was used if self.temp_dir: shutil.rmtree(self.temp_dir)
[docs] class PlotLattice: def __init__(self, size_ratios: tuple[list[float], list[float]] = ([1.], [1.]), figsize=None) -> None: """ Initialize the PlotLattice class. Parameters: size_ratios (tuple[list[float], list[float]], optional): The size ratios for the rows and columns of the lattice. Defaults to ([1.], [1.]). figsize (tuple[float, float], optional): The size of the figure. Defaults to None. """ self._fig = None self._axs = None self._figsize = figsize self._size_ratios: tuple[list[float], list[float]] = size_ratios @property def size_ratios(self) -> tuple[list[float], list[float]]: """ Get the size ratios for the rows and columns of the lattice. Returns: tuple[list[float], list[float]]: The size ratios. """ return self._size_ratios @size_ratios.setter def size_ratios(self, value: tuple[list[float], list[float]]): """ Set the size ratios for the rows and columns of the lattice. Parameters: value (tuple[list[float], list[float]]): The size ratios. """ print(f'{value=}') if not len(value) == 2: raise ValueError('Size ratios must be a tuple of two lists') if not all(isinstance(ratio, (tuple, list)) for ratio in value): raise ValueError( 'Size ratios must be a tuple of two tuples/lists') if not len(value[0]) == len(self._size_ratios[0]) or not len(value[1]) == len(self._size_ratios[1]): raise ValueError( 'Size ratios must have the same length as the current size ratios') self._size_ratios = value
[docs] def get_figsize(self) -> tuple[float, float]: """ Get the figure size. If a size was set during initialization, it returns that size. Otherwise, it calculates the size based on the sum of size_ratios. Returns: tuple[float, float]: The size of the figure as (height, width). """ if self._figsize is not None: return self._figsize else: # Calculate size based on the sum of size ratios width = np.sum(self.size_ratios[1]) height = np.sum(self.size_ratios[0]) # Adjust the width and height to ensure the smallest dimension is higher than 3 if width < height: width = max(width, 4*len(self.size_ratios[1])) height = width * \ (np.sum(self.size_ratios[0])/np.sum(self.size_ratios[1])) else: height = max(height, 4*len(self.size_ratios[0])) width = height * \ (np.sum(self.size_ratios[1])/np.sum(self.size_ratios[0])) self._figsize = (width, height) return self._figsize
[docs] def set_figsize(self, value): """ Set the figure size. Parameters: value (tuple[float, float]): The size of the figure as (height, width). """ if not isinstance(value, (list, tuple)) or len(value) != 2: raise ValueError( "fig_size must be a list or tuple of two elements: [height, width]") self._figsize = value
@property def num_rows(self) -> int: """ Get the number of rows in the lattice. Returns: int: The number of rows. """ return len(self.size_ratios[0]) @property def num_cols(self) -> int: """ Get the number of columns in the lattice. Returns: int: The number of columns. """ return len(self.size_ratios[1])
[docs] def fig(self) -> Figure: """ Get the figure object. Returns: Figure: The figure object. """ if self._fig is None: self.create_fig_and_axs() return self._fig
[docs] def axs(self) -> list[list[Axes]]: """ Get the axes objects. Returns: list[list[Axes]]: The axes objects. """ if self._fig is None: self.create_fig_and_axs() return self._axs
[docs] def fig_axs(self) -> tuple[Figure, list[list[Axes]]]: """ Get the figure and axes objects. Returns: tuple[Figure, list[list[Axes]]]: The figure and axes objects. """ if self._fig is None: self.create_fig_and_axs() return self._fig, self._axs
[docs] def update_size_ratios(self, row: int, column: int): """ Update the size ratios based on the given row and column indices. Parameters: row (int): The row index. column (int): The column index. """ if row >= self.num_rows: self._size_ratios[0].extend([1.]*(row-self.num_rows+1)) if column >= self.num_cols: self._size_ratios[1].extend([1.]*(column-self.num_cols+1))
[docs] def convert_parameters_to_index(self, column: int, row: int) -> tuple[int, int]: """ Convert the column and row indices to the corresponding index in the axes list. Parameters: column (int): The column index. row (int): The row index. Returns: tuple[int, int]: The corresponding index in the axes list. """ self.update_size_ratios(row, column) # Fixed the argument order return (row, column)
[docs] def get_ax_by_index(self, index: tuple[int, int]) -> Axes: """ Get the axes object based on the given index. Parameters: index (tuple[int, int]): The index in the axes list. Returns: Axes: The axes object. """ self.update_size_ratios(index[0], index[1]) return self.axs()[index[0]][index[1]]
[docs] def create_fig_and_axs(self) -> tuple[Figure, Axes]: """ Create the figure and axes objects. Returns: tuple[Figure, Axes]: The figure and axes objects. """ size_ratios = self.size_ratios figsize = self.get_figsize() self._fig, self._axs = plt.subplots(self.num_rows, self.num_cols, figsize=figsize, gridspec_kw={ 'width_ratios': size_ratios[1], 'height_ratios': size_ratios[0]}) # Ensure axs is a 2D array for consistency if self.num_rows == 1 and self.num_cols == 1: self._axs = [[self._axs]] # Single plot elif self.num_rows == 1: self._axs = [self._axs] # Single row, multiple columns elif self.num_cols == 1: # Multiple rows, single column self._axs = [[ax] for ax in self._axs]
def __init__(self, interfaces: Interface | list[Interface] | None = None): """ Initialize the Plotter object with the given Interface instance. Parameters: ----------- interface : Interface Interface instance to be used for plotting. """ self._interfaces: list[Interface] | None = [interfaces] if isinstance( interfaces, Interface) else interfaces self.default_color_scheme: ColorScheme = ColorScheme() self.color_schemes: dict[Interface, ColorScheme] = {} self.plots: dict[tuple[int, int, Interface], list[Plot]] = {} self.plot_grid: Plotter.PlotLattice = self.PlotLattice() @property def number_of_interfaces(self) -> int: if self._interfaces is None: return 0 return len(self._interfaces)
[docs] def update_interface(self, new_interface): self._interfaces = new_interface self.color_scheme = ColorScheme(new_interface)
@property def is_initialized(self) -> bool: return self._interfaces is not None
[docs] @classmethod def plot_type(cls, plot_name): """ Decorator to register a plot method. Parameters: plot_name (str): Name of the plot type. instructions (dict): Provides the information about the plot type and function that is used for creating the plot. """ def decorator(plot_func): if plot_name in cls._plot_types: raise ValueError(f"Plot type {plot_name} is already defined.") cls._plot_types[plot_name] = plot_func return plot_func return decorator
[docs] def add_plot_for_interface(self, interface: Interface, plot_type: str, plot_arguments: dict[str, str | list[str] | dict[str, Any]], row: int = 0, col: int = 0): if plot_type not in self.available_plot_types: raise ValueError(f"Plot type '{plot_type}' not recognized. " f"Available methods: {self.available_plot_types}") is_available, comment = self._plot_types[plot_type].is_available( interface) if not is_available: raise ValueError(f"Plot type '{plot_type}' is not available for the given interface {interface}." f"Reason: {comment}") self.plot_grid.update_size_ratios(row, col) if 'color_scheme' not in plot_arguments: if interface not in self.color_schemes: self.color_schemes[interface] = self.default_color_scheme.new_interface( interface) # Add self.color_scheme to plot_arguments with the key 'color_scheme' plot_arguments['color_scheme'] = self.color_schemes[interface] plot_arguments['interface'] = interface plot = self._plot_types[plot_type](**plot_arguments) if (row, col, interface) in self.plots: self.plots[(row, col, interface)].append(plot) else: self.plots[(row, col, interface)] = [plot]
[docs] def number_of_groups(self) -> int: return max(len(interface.groups) for interface in self._interfaces)
[docs] def regroup(self, num_intervals: int, interval_size: int = 1, offset: int = 0): for interface in self._interfaces: interface.regroup(num_intervals, interval_size)
[docs] def add_plot(self, plot_type: str, plot_arguments: dict[str, str | list[str] | dict[str, Any]], row: int = 0, col: int = 0): """ Example: plotter.add_plot( "Scatter", { "parameters": ["Opinion", "Activity"], "scale": ["Linear", "Linear"], "color": { "mode": "Parameter Colormap", "settings": { "parameter": "Opinion density", "scale": ("Linear", ) }, }, }, row=2, col=0, ) """ for interface in self._interfaces: self.add_plot_for_interface( interface, plot_type, plot_arguments, row, col)
def _plot(self, fig, group_number: int) -> Figure: if self._interfaces is None: warnings.warn("Interface not initialized. Cannot plot.") return fig # Data fetching for static plots track_clusters_requests = [request for cell in self.plots.values( ) for plot in cell if cell for request in plot.get_track_clusterings_requests() if request] for interface in self._interfaces: interface.cluster_tracker.track_clusters(track_clusters_requests) # Determine the rendering order render_order = self._determine_plot_order() axis_limits = dict() for index in render_order: ax = self.plot_grid.get_ax_by_index(index) for plot in self.plots[index]: plot.plot(ax=ax, group_number=group_number, axis_limits=axis_limits) # Store axis limits for future reference axis_limits[index] = (ax.get_xlim(), ax.get_ylim()) if not self.plots[index]: ax.grid(False) ax.axis('off') ax.set_visible(False) return fig
[docs] def plot(self, group_number: int) -> tuple[Figure, Axes]: fig = self.plot_grid.fig() self._plot(fig, group_number) return fig
[docs] def plot_dynamics(self, mode: str | list[str] = 'show', save_dir: Optional[str] = None, animation_path: Optional[str] = None, name: str = 'opinion_histogram', preview: bool = False) -> None: """ Create and display the plots based on the stored configurations. mode : ["show", "save", "gif", "mp4"] """ if self._interfaces is None: raise ValueError("Interface not initialized. Cannot plot.") with Plotter.PlotSaver(mode=mode, save_path=save_dir, save_format=f"{name}_" + "{}.png", animation_path=animation_path) as saver: for group_number in range(self.number_of_groups()): fig = self.plot(group_number) if preview: break
def _determine_plot_order(self) -> list[tuple]: """ Determine the order in which plots should be rendered based on dependencies. Returns: list of tuples: Each tuple represents the row and column indices of a plot, sorted in the order they should be rendered. """ dependency_graph = nx.DiGraph() for index, plots in self.plots.items(): dependency_graph.add_node(index) if plots is not None: for plot in plots: dependencies = plot.plot_dependencies() for dependency in dependencies['after']: dependency_graph.add_edge(dependency, index) sorted_order = self._topological_sort(dependency_graph) return sorted_order
[docs] def clear_plot(self, row: int, col: int): ''' Clears axis on a given position in a grid ''' cell = self.plots.get((row, col)) if cell: cell.clear()
def _topological_sort(self, dependency_graph: nx.DiGraph): """ Perform a topological sort on the dependency graph. Args: dependency_graph (networkx.DiGraph): The dependency graph. Returns: list of tuples: Sorted order of plots. """ # Check for cycles if not nx.is_directed_acyclic_graph(dependency_graph): raise ValueError("A cycle was detected in the plot dependencies.") # Perform topological sort return list(nx.topological_sort(dependency_graph)) @property def node_parameters(self) -> list: """ Retrieves the list of available parameters/methods from the interface. Returns: list: A list of available parameters or methods. """ if self._interfaces is None: return None parameters = [ interface.node_parameters for interface in self._interfaces] # TODO figure out what does in return after the [interface] return parameters @property def existing_plot_types(self) -> list: """ Retrieves the list of available parameters/methods from the interface. Returns: list: A list of available parameters or methods. """ return list(self._plot_types.keys()) @property def available_plot_types(self) -> list: """ Retrieves the list of available parameters/methods from the interface. Returns: list: A list of available parameters or methods. """ return [method_name for method_name, method in self._plot_types.items() if all(method.is_available(interface)[0] for interface in self._interfaces)]
[docs] @classmethod def get_plot_class(cls, plot_type: str) -> type[Plot]: ''' Converts name of the class to the class''' return cls._plot_types[plot_type]
[docs] @classmethod def get_plot_name(cls, plot_class: type[Plot]) -> str: ''' Converts the class to its name''' for key, value in cls._plot_types.items(): if value == plot_class: return key warnings.warn(f'{plot_class} was not found in plot types') return ''
@property def available_plot_types_hint(self) -> str: """ Retrieves the list of available parameters/methods from the interface. Returns: str: A list of available parameters or methods. """ info_string = '' for method_name, method in self._plot_types.items(): is_available_final = True comments = [] for interface in self._interfaces: is_available, comment = method.is_available(interface) is_available_final &= is_available comments.append(comment) final_comment = ', '.join(comments) info_string += method_name + ': ' info_string += '+ ' if is_available else '- ' info_string += final_comment info_string += '\n' return info_string
[docs] @classmethod def create_plotter(cls, data) -> Plotter: """ Class method to create a Plotter instance given data. Parameters: ----------- data : Any Data used to create the Interface instance. Returns: -------- Plotter Initialized Plotter object. """ interface = Interface.create_interface(data) return cls(interface)
[docs] def info(self) -> str: ''' String with the information about the plotter setup ''' info_str = '' if len(self.plots) == 0: return 'Plot is empty' for (i, j, interface), cell in self.plots.items(): info_str += f'\n{i, j} {interface}\n' if len(cell) > 0: for plot in cell: info_str += str(plot) + '\n' else: info_str += 'Empty cell' return info_str
[docs] def to_code(self) -> str: ''' Generates a string that can be used to create the current plotter setup. Can be used in GUI as well. ''' if len(self.plots) == 0: return '' # TODO: Add the case when plots are different for different interfaces base_interface = self._interfaces[0] final_code = '' for (i, j, interface), cell in self.plots.items(): if interface == base_interface: if len(cell) > 0: for plot in cell: final_code += 'plotter.add_plot(' final_code += f'\'{Plotter.get_plot_name(type(plot))}\'' final_code += ',{' + plot.settings_to_code() + '},' final_code += f'row={i},col={j},)\n' return final_code