from __future__ import annotations
import os
import shutil
import tempfile
import warnings
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from .plot import Plot
import imageio
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from .color_scheme import ColorScheme
from .interface import Interface
plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0
[docs]
class Plotter:
_parameter_dict = {'Time': 'Time',
'Opinion': 'Node Opinion',
'Cluster size': 'Cluster size',
'Importance': 'Node Importance',
'Label': 'Node Label',
'Neighbor mean opinion': 'Node Neighbor Mean Opinion',
'Activity': 'Node Activity',
'Inner opinions': 'Node Inner Opinions',
'Max opinion': 'Node Max Opinion',
'Min opinion': 'Node Min Opinion'}
_plot_types: dict[str, type[Plot]] = {}
[docs]
class PlotSaver:
"""
A utility class to handle the saving and display of plots.
It provides functionality to save individual plots, display them,
and even create GIFs from a sequence of plots.
"""
def __init__(self, mode: str | list[str] = 'show',
save_path: Optional[str] = None,
save_format: Optional[str] = 'image_{}',
animation_path: Optional[str] = None) -> None:
"""
Initialize the PlotSaver instance.
Args:
mode (str | list[str]): The mode(s) in which to operate.
It can be a list or a single string, e.g. ['show', 'save'] or 'gif'. Available modes: ["show", "save", "gif", "mp4"]
save_path (Optional[str]): Path to save individual plots (used if 'save' is in mode)
save_format (Optional[str]): string with {} for formatting in the number
animation_path (Optional[str]): Path to save gif (used if 'gif' is in mode).
"""
# Ensure mode is a list even if a single mode string is provided
self.mode = mode if isinstance(mode, list) else [mode]
if not os.path.exists(save_path):
warnings.warn(f"Path {save_path} does not exist. Creating it.")
# Create the directory, including any necessary parent directories
os.makedirs(save_path, exist_ok=True)
self.save_path = os.Path(
save_path) if save_path[-1] == '/' else save_path+'/'
self.save_format = save_format
self.animation_path = animation_path
self.saved_images = []
self.temp_dir = None
[docs]
@staticmethod
def is_inside_jupyter() -> bool:
"""
Determine if the current environment is Jupyter Notebook.
Returns:
bool: True if inside Jupyter Notebook, False otherwise.
"""
try:
get_ipython
return True
except NameError:
return False
def __enter__(self) -> Plotter.PlotSaver:
"""
Entry point for the context manager.
Returns:
Plotter.PlotSaver: The current instance of the PlotSaver.
"""
return self
[docs]
def save(self, fig: matplotlib.figure.Figure) -> None:
"""
Save and/or display the provided figure based on the specified mode.
Args:
fig (matplotlib.figure.Figure): The figure to be saved or displayed.
"""
plt.tight_layout()
# Save the figure if 'save' mode is active and save_path is provided
if 'save' in self.mode and self.save_path:
path = self.save_path + \
self.save_format.format(len(self.saved_images))
fig.savefig(path)
self.saved_images.append(path)
# If only 'gif' or 'mp4' mode is selected, save figure to a temp directory
elif ('gif' in self.mode or 'mp4' in self.mode) and not self.save_path:
if not self.temp_dir:
self.temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(
self.temp_dir, "tmp_plot_{}.png".format(len(self.saved_images)))
fig.savefig(temp_path)
self.saved_images.append(temp_path)
# Show the figure if 'show' mode is active
if 'show' in self.mode:
if self.is_inside_jupyter():
# In Jupyter, let the figure be displayed automatically
display(fig)
else:
# Outside Jupyter, use fig.show() to display the figure
fig.show()
# Close the figure after processing
plt.close(fig)
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[object]) -> None:
"""
Exit point for the context manager.
Args:
exc_type (Optional[type]): The exception type if raised inside the context.
exc_val (Optional[Exception]): The exception instance if raised inside the context.
exc_tb (Optional[object]): The traceback if an exception was raised inside the context.
"""
# If 'gif' mode is active and animation_path is provided, create a GIF from the saved images
if 'gif' in self.mode and self.animation_path and self.saved_images:
with imageio.get_writer(self.animation_path+'.gif', mode='I') as writer:
for img_path in self.saved_images:
image = imageio.imread(img_path)
writer.append_data(image)
# # Create MP4 animation if mode is selected
if 'mp4' in self.mode and self.animation_path and self.saved_images:
with imageio.get_writer(self.animation_path+'.mp4', mode='I', fps=5, codec='libx264') as writer:
for img_path in self.saved_images:
image = imageio.imread(img_path)
writer.append_data(image)
# Cleanup temporary directory if it was used
if self.temp_dir:
shutil.rmtree(self.temp_dir)
[docs]
class PlotLattice:
def __init__(self, size_ratios: tuple[list[float], list[float]] = ([1.], [1.]), figsize=None) -> None:
"""
Initialize the PlotLattice class.
Parameters:
size_ratios (tuple[list[float], list[float]], optional): The size ratios for the rows and columns of the lattice. Defaults to ([1.], [1.]).
figsize (tuple[float, float], optional): The size of the figure. Defaults to None.
"""
self._fig = None
self._axs = None
self._figsize = figsize
self._size_ratios: tuple[list[float], list[float]] = size_ratios
@property
def size_ratios(self) -> tuple[list[float], list[float]]:
"""
Get the size ratios for the rows and columns of the lattice.
Returns:
tuple[list[float], list[float]]: The size ratios.
"""
return self._size_ratios
@size_ratios.setter
def size_ratios(self, value: tuple[list[float], list[float]]):
"""
Set the size ratios for the rows and columns of the lattice.
Parameters:
value (tuple[list[float], list[float]]): The size ratios.
"""
print(f'{value=}')
if not len(value) == 2:
raise ValueError('Size ratios must be a tuple of two lists')
if not all(isinstance(ratio, (tuple, list)) for ratio in value):
raise ValueError(
'Size ratios must be a tuple of two tuples/lists')
if not len(value[0]) == len(self._size_ratios[0]) or not len(value[1]) == len(self._size_ratios[1]):
raise ValueError(
'Size ratios must have the same length as the current size ratios')
self._size_ratios = value
[docs]
def get_figsize(self) -> tuple[float, float]:
"""
Get the figure size. If a size was set during initialization, it returns that size.
Otherwise, it calculates the size based on the sum of size_ratios.
Returns:
tuple[float, float]: The size of the figure as (height, width).
"""
if self._figsize is not None:
return self._figsize
else:
# Calculate size based on the sum of size ratios
width = np.sum(self.size_ratios[1])
height = np.sum(self.size_ratios[0])
# Adjust the width and height to ensure the smallest dimension is higher than 3
if width < height:
width = max(width, 4*len(self.size_ratios[1]))
height = width * \
(np.sum(self.size_ratios[0])/np.sum(self.size_ratios[1]))
else:
height = max(height, 4*len(self.size_ratios[0]))
width = height * \
(np.sum(self.size_ratios[1])/np.sum(self.size_ratios[0]))
self._figsize = (width, height)
return self._figsize
[docs]
def set_figsize(self, value):
"""
Set the figure size.
Parameters:
value (tuple[float, float]): The size of the figure as (height, width).
"""
if not isinstance(value, (list, tuple)) or len(value) != 2:
raise ValueError(
"fig_size must be a list or tuple of two elements: [height, width]")
self._figsize = value
@property
def num_rows(self) -> int:
"""
Get the number of rows in the lattice.
Returns:
int: The number of rows.
"""
return len(self.size_ratios[0])
@property
def num_cols(self) -> int:
"""
Get the number of columns in the lattice.
Returns:
int: The number of columns.
"""
return len(self.size_ratios[1])
[docs]
def fig(self) -> Figure:
"""
Get the figure object.
Returns:
Figure: The figure object.
"""
if self._fig is None:
self.create_fig_and_axs()
return self._fig
[docs]
def axs(self) -> list[list[Axes]]:
"""
Get the axes objects.
Returns:
list[list[Axes]]: The axes objects.
"""
if self._fig is None:
self.create_fig_and_axs()
return self._axs
[docs]
def fig_axs(self) -> tuple[Figure, list[list[Axes]]]:
"""
Get the figure and axes objects.
Returns:
tuple[Figure, list[list[Axes]]]: The figure and axes objects.
"""
if self._fig is None:
self.create_fig_and_axs()
return self._fig, self._axs
[docs]
def update_size_ratios(self, row: int, column: int):
"""
Update the size ratios based on the given row and column indices.
Parameters:
row (int): The row index.
column (int): The column index.
"""
if row >= self.num_rows:
self._size_ratios[0].extend([1.]*(row-self.num_rows+1))
if column >= self.num_cols:
self._size_ratios[1].extend([1.]*(column-self.num_cols+1))
[docs]
def convert_parameters_to_index(self, column: int, row: int) -> tuple[int, int]:
"""
Convert the column and row indices to the corresponding index in the axes list.
Parameters:
column (int): The column index.
row (int): The row index.
Returns:
tuple[int, int]: The corresponding index in the axes list.
"""
self.update_size_ratios(row, column) # Fixed the argument order
return (row, column)
[docs]
def get_ax_by_index(self, index: tuple[int, int]) -> Axes:
"""
Get the axes object based on the given index.
Parameters:
index (tuple[int, int]): The index in the axes list.
Returns:
Axes: The axes object.
"""
self.update_size_ratios(index[0], index[1])
return self.axs()[index[0]][index[1]]
[docs]
def create_fig_and_axs(self) -> tuple[Figure, Axes]:
"""
Create the figure and axes objects.
Returns:
tuple[Figure, Axes]: The figure and axes objects.
"""
size_ratios = self.size_ratios
figsize = self.get_figsize()
self._fig, self._axs = plt.subplots(self.num_rows, self.num_cols, figsize=figsize, gridspec_kw={
'width_ratios': size_ratios[1], 'height_ratios': size_ratios[0]})
# Ensure axs is a 2D array for consistency
if self.num_rows == 1 and self.num_cols == 1:
self._axs = [[self._axs]] # Single plot
elif self.num_rows == 1:
self._axs = [self._axs] # Single row, multiple columns
elif self.num_cols == 1:
# Multiple rows, single column
self._axs = [[ax] for ax in self._axs]
def __init__(self, interfaces: Interface | list[Interface] | None = None):
"""
Initialize the Plotter object with the given Interface instance.
Parameters:
-----------
interface : Interface
Interface instance to be used for plotting.
"""
self._interfaces: list[Interface] | None = [interfaces] if isinstance(
interfaces, Interface) else interfaces
self.default_color_scheme: ColorScheme = ColorScheme()
self.color_schemes: dict[Interface, ColorScheme] = {}
self.plots: dict[tuple[int, int, Interface], list[Plot]] = {}
self.plot_grid: Plotter.PlotLattice = self.PlotLattice()
@property
def number_of_interfaces(self) -> int:
if self._interfaces is None:
return 0
return len(self._interfaces)
[docs]
def update_interface(self, new_interface):
self._interfaces = new_interface
self.color_scheme = ColorScheme(new_interface)
@property
def is_initialized(self) -> bool:
return self._interfaces is not None
[docs]
@classmethod
def plot_type(cls, plot_name):
"""
Decorator to register a plot method.
Parameters:
plot_name (str): Name of the plot type.
instructions (dict): Provides the information about the plot type and function that is used for creating the plot.
"""
def decorator(plot_func):
if plot_name in cls._plot_types:
raise ValueError(f"Plot type {plot_name} is already defined.")
cls._plot_types[plot_name] = plot_func
return plot_func
return decorator
[docs]
def add_plot_for_interface(self, interface: Interface, plot_type: str, plot_arguments: dict[str, str | list[str] | dict[str, Any]], row: int = 0, col: int = 0):
if plot_type not in self.available_plot_types:
raise ValueError(f"Plot type '{plot_type}' not recognized. "
f"Available methods: {self.available_plot_types}")
is_available, comment = self._plot_types[plot_type].is_available(
interface)
if not is_available:
raise ValueError(f"Plot type '{plot_type}' is not available for the given interface {interface}."
f"Reason: {comment}")
self.plot_grid.update_size_ratios(row, col)
if 'color_scheme' not in plot_arguments:
if interface not in self.color_schemes:
self.color_schemes[interface] = self.default_color_scheme.new_interface(
interface)
# Add self.color_scheme to plot_arguments with the key 'color_scheme'
plot_arguments['color_scheme'] = self.color_schemes[interface]
plot_arguments['interface'] = interface
plot = self._plot_types[plot_type](**plot_arguments)
if (row, col, interface) in self.plots:
self.plots[(row, col, interface)].append(plot)
else:
self.plots[(row, col, interface)] = [plot]
[docs]
def number_of_groups(self) -> int:
return max(len(interface.groups) for interface in self._interfaces)
[docs]
def regroup(self, num_intervals: int, interval_size: int = 1, offset: int = 0):
for interface in self._interfaces:
interface.regroup(num_intervals, interval_size)
[docs]
def add_plot(self, plot_type: str, plot_arguments: dict[str, str | list[str] | dict[str, Any]], row: int = 0, col: int = 0):
"""
Example:
plotter.add_plot(
"Scatter",
{
"parameters": ["Opinion", "Activity"],
"scale": ["Linear", "Linear"],
"color": {
"mode": "Parameter Colormap",
"settings": {
"parameter": "Opinion density",
"scale": ("Linear", )
},
},
},
row=2,
col=0,
)
"""
for interface in self._interfaces:
self.add_plot_for_interface(
interface, plot_type, plot_arguments, row, col)
def _plot(self, fig, group_number: int) -> Figure:
if self._interfaces is None:
warnings.warn("Interface not initialized. Cannot plot.")
return fig
# Data fetching for static plots
track_clusters_requests = [request for cell in self.plots.values(
) for plot in cell if cell for request in plot.get_track_clusterings_requests() if request]
for interface in self._interfaces:
interface.cluster_tracker.track_clusters(track_clusters_requests)
# Determine the rendering order
render_order = self._determine_plot_order()
axis_limits = dict()
for index in render_order:
ax = self.plot_grid.get_ax_by_index(index)
for plot in self.plots[index]:
plot.plot(ax=ax, group_number=group_number,
axis_limits=axis_limits)
# Store axis limits for future reference
axis_limits[index] = (ax.get_xlim(), ax.get_ylim())
if not self.plots[index]:
ax.grid(False)
ax.axis('off')
ax.set_visible(False)
return fig
[docs]
def plot(self, group_number: int) -> tuple[Figure, Axes]:
fig = self.plot_grid.fig()
self._plot(fig, group_number)
return fig
[docs]
def plot_dynamics(self, mode: str | list[str] = 'show', save_dir: Optional[str] = None, animation_path: Optional[str] = None, name: str = 'opinion_histogram', preview: bool = False) -> None:
"""
Create and display the plots based on the stored configurations.
mode : ["show", "save", "gif", "mp4"]
"""
if self._interfaces is None:
raise ValueError("Interface not initialized. Cannot plot.")
with Plotter.PlotSaver(mode=mode, save_path=save_dir, save_format=f"{name}_" + "{}.png", animation_path=animation_path) as saver:
for group_number in range(self.number_of_groups()):
fig = self.plot(group_number)
saver.save(fig)
if preview:
break
def _determine_plot_order(self) -> list[tuple]:
"""
Determine the order in which plots should be rendered based on dependencies.
Returns:
list of tuples: Each tuple represents the row and column indices of a plot,
sorted in the order they should be rendered.
"""
dependency_graph = nx.DiGraph()
for index, plots in self.plots.items():
dependency_graph.add_node(index)
if plots is not None:
for plot in plots:
dependencies = plot.plot_dependencies()
for dependency in dependencies['after']:
dependency_graph.add_edge(dependency, index)
sorted_order = self._topological_sort(dependency_graph)
return sorted_order
[docs]
def clear_plot(self, row: int, col: int):
'''
Clears axis on a given position in a grid
'''
cell = self.plots.get((row, col))
if cell:
cell.clear()
def _topological_sort(self, dependency_graph: nx.DiGraph):
"""
Perform a topological sort on the dependency graph.
Args:
dependency_graph (networkx.DiGraph): The dependency graph.
Returns:
list of tuples: Sorted order of plots.
"""
# Check for cycles
if not nx.is_directed_acyclic_graph(dependency_graph):
raise ValueError("A cycle was detected in the plot dependencies.")
# Perform topological sort
return list(nx.topological_sort(dependency_graph))
@property
def node_parameters(self) -> list:
"""
Retrieves the list of available parameters/methods from the interface.
Returns:
list: A list of available parameters or methods.
"""
if self._interfaces is None:
return None
parameters = [
interface.node_parameters for interface in self._interfaces]
# TODO figure out what does in return after the [interface]
return parameters
@property
def existing_plot_types(self) -> list:
"""
Retrieves the list of available parameters/methods from the interface.
Returns:
list: A list of available parameters or methods.
"""
return list(self._plot_types.keys())
@property
def available_plot_types(self) -> list:
"""
Retrieves the list of available parameters/methods from the interface.
Returns:
list: A list of available parameters or methods.
"""
return [method_name for method_name, method in self._plot_types.items() if all(method.is_available(interface)[0] for interface in self._interfaces)]
[docs]
@classmethod
def get_plot_class(cls, plot_type: str) -> type[Plot]:
''' Converts name of the class to the class'''
return cls._plot_types[plot_type]
[docs]
@classmethod
def get_plot_name(cls, plot_class: type[Plot]) -> str:
''' Converts the class to its name'''
for key, value in cls._plot_types.items():
if value == plot_class:
return key
warnings.warn(f'{plot_class} was not found in plot types')
return ''
@property
def available_plot_types_hint(self) -> str:
"""
Retrieves the list of available parameters/methods from the interface.
Returns:
str: A list of available parameters or methods.
"""
info_string = ''
for method_name, method in self._plot_types.items():
is_available_final = True
comments = []
for interface in self._interfaces:
is_available, comment = method.is_available(interface)
is_available_final &= is_available
comments.append(comment)
final_comment = ', '.join(comments)
info_string += method_name + ': '
info_string += '+ ' if is_available else '- '
info_string += final_comment
info_string += '\n'
return info_string
[docs]
@classmethod
def create_plotter(cls, data) -> Plotter:
"""
Class method to create a Plotter instance given data.
Parameters:
-----------
data : Any
Data used to create the Interface instance.
Returns:
--------
Plotter
Initialized Plotter object.
"""
interface = Interface.create_interface(data)
return cls(interface)
[docs]
def info(self) -> str:
'''
String with the information about the plotter setup
'''
info_str = ''
if len(self.plots) == 0:
return 'Plot is empty'
for (i, j, interface), cell in self.plots.items():
info_str += f'\n{i, j} {interface}\n'
if len(cell) > 0:
for plot in cell:
info_str += str(plot) + '\n'
else:
info_str += 'Empty cell'
return info_str
[docs]
def to_code(self) -> str:
'''
Generates a string that can be used to create the current plotter setup. Can be used in GUI as well.
'''
if len(self.plots) == 0:
return ''
# TODO: Add the case when plots are different for different interfaces
base_interface = self._interfaces[0]
final_code = ''
for (i, j, interface), cell in self.plots.items():
if interface == base_interface:
if len(cell) > 0:
for plot in cell:
final_code += 'plotter.add_plot('
final_code += f'\'{Plotter.get_plot_name(type(plot))}\''
final_code += ',{' + plot.settings_to_code() + '},'
final_code += f'row={i},col={j},)\n'
return final_code