Source code for hari_plotter.interface

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any, Optional, Type, Union

import networkx as nx
import numpy as np

from .cluster import Clustering
from .dynamics import Dynamics
from .graph import Graph
from .group import Group
from .simulation import Simulation


[docs] class Interface(ABC): """Abstract base class to define interface behaviors. Attributes: REQUIRED_TYPE: Expected type for the data attribute. available_classes: dictionary mapping REQUIRED_TYPEs to their corresponding classes. """ REQUIRED_TYPE: Optional[Type[Any]] = None available_classes: dict[Type[Any], Type['Interface']] = {} def __init__(self, data: Any, group_length: int = 0): """ Initialize the Interface instance. Args: data: The underlying data object to which the interface applies. group_length[int]: length of groups Raises: ValueError: If the data is not an instance of REQUIRED_TYPE. """ if not isinstance(data, self.REQUIRED_TYPE): raise ValueError( f"data must be an instance of {self.REQUIRED_TYPE}") self.data = data self._group_cache = [None]*group_length self.groups = self.GroupIterable(self) self.static_data_cache = self.StaticDataCache(self) self.dynamic_data_cache = [self.DynamicDataCache( self, i) for i in range(len(self.groups))] self.cluster_tracker = self.ClusterTracker(self) self._nodes = None self._node_parameters = None @property def node_parameters(self): if not self._node_parameters: self._node_parameters = self.groups[0].node_parameters return self._node_parameters @property def nodes(self): if self._nodes: return self._nodes self._nodes = set().union(*[group.nodes for group in self.groups]) return self._nodes
[docs] class StaticDataCache: def __init__(self, interface_instance): self._interface = interface_instance self.cache = {} def __getitem__(self, request): method_name = request['method'] settings = request['settings'] data_key = Interface.request_to_tuple(request) if data_key not in self.cache: data_for_all_groups = [] for group in self._interface.groups: data = getattr(group, method_name)(**settings) data_for_all_groups.append(data) self.cache[data_key] = data_for_all_groups return self.cache[data_key]
[docs] def clean(self): self.cache = {}
[docs] class DynamicDataCache: def __init__(self, interface_instance, i: int): self._interface = interface_instance self.i = i self.cache = {} def __getitem__(self, request): method_name = request['method'] settings = request['settings'] data_key = Interface.request_to_tuple(request) if data_key not in self.cache: self.cache[data_key] = getattr( self._interface.groups[self.i], method_name)(**settings) return self.cache[data_key]
[docs] def clean(self): self.cache = {}
[docs] def clean_cache(self): self.cluster_tracker.clean() self.static_data_cache.clean() for c in self.dynamic_data_cache: c.clean()
[docs] class GroupIterable: def __init__(self, interface_instance: Interface): self._interface: Interface = interface_instance def __getitem__(self, index): # Handle slices if isinstance(index, slice): return [self._get_group(i) for i in range(*index.indices(len(self._interface._group_cache)))] # Handle lists of integers elif isinstance(index, list): # Ensure all elements in the list are integers if not all(isinstance(i, int) for i in index): raise TypeError("All indices must be integers") return [self._get_group(i) for i in index] # Handle single integer elif isinstance(index, int): return self._get_group(index) else: raise TypeError( "Invalid index type. Must be an integer, slice, or list of integers.") def _get_group(self, i) -> Group: # Adjust index for negative values if i < 0: i += len(self._interface._group_cache) if 0 <= i < len(self._interface._group_cache): if self._interface._group_cache[i] is None: self._interface._group_cache[i] = self._interface._initialize_group( i) return self._interface._group_cache[i] else: raise IndexError("Group index out of range") def __iter__(self): for i in range(len(self._interface._group_cache)): yield self[i] def __len__(self) -> int: return len(self._interface._group_cache)
@property @abstractmethod def time_range(self) -> list[float]: raise NotImplementedError( "This method must be implemented in subclasses") @classmethod def __init_subclass__(cls, **kwargs): """Auto-register subclasses in available_classes based on their REQUIRED_TYPE.""" super().__init_subclass__(**kwargs) if cls.REQUIRED_TYPE: Interface.available_classes[cls.REQUIRED_TYPE] = cls
[docs] @classmethod def create_interface(cls, data: Any) -> Interface: """Create an interface for the given data. Args: data: The underlying data object to which the interface applies. Returns: Instance of a subclass of Interface based on data's type. Raises: ValueError: If no matching interface is found for the data type. """ interface_class = cls.available_classes.get(type(data)) if interface_class: return interface_class(data) else: raise ValueError("Invalid data type, no matching interface found.")
@abstractmethod def _initialize_group(self, i: int) -> Group: """Abstract method to define the behavior for data grouping.""" raise NotImplementedError( "This method must be implemented in subclasses") @abstractmethod def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1): """Abstract method to regroup dynamics.""" raise NotImplementedError( "This method must be implemented in subclasses")
[docs] def regroup(self, num_intervals: int, interval_size: int = 1, offset: int = 0): self.clean_cache() self._group_cache = [None]*num_intervals self._regroup_dynamics(num_intervals, interval_size, offset)
[docs] @classmethod def info(cls) -> str: """Return a string representation of the available classes and their mapping. Returns: A string detailing the available classes and their mappings. """ mappings = ', '.join( [f"{key.__name__} -> {value.__name__}" for key, value in cls.available_classes.items()]) return f"Available Classes: {mappings}"
@property @abstractmethod def available_parameters(self) -> list: raise NotImplementedError( "This method must be implemented in subclasses")
[docs] class ClusterTracker: def __init__(self, interface): # Initialize with an instance of the enclosing class if needed to access its attributes and methods self._interface: Interface = interface self._track_clusters_cache = {}
[docs] def clean(self): self._track_clusters_cache = {}
[docs] def get_clustering(self, clusterization_settings: Union[dict, list[dict]]) -> list[list[Clustering]]: """ Retrieves the clusterization for the given settings. :param clusterization_settings: The settings for clusterization, either a single dictionary for all frames or a list of dictionaries for each frame. :return: A list of Clustering objects representing the clusterization for each frame. """ clusterization_settings_list = [clusterization_settings] if isinstance( clusterization_settings, dict) else clusterization_settings clusterizations = [] for clust_settings in clusterization_settings_list: clusterizations.append([group.clustering( **clust_settings) for group in self._interface.groups]) return clusterizations
[docs] def is_tracked(self, clusterization_settings: Union[dict, list[dict]]) -> list[bool]: """ Tracks the clusters across frames based on the provided clusterization settings. :param clusterization_settings: The settings for clusterization, either a single dictionary for all frames or a list of dictionaries for each frame. :return: A list of dictionaries where each dictionary maps frame indices to lists of cluster labels. """ clusterization_settings_list = [clusterization_settings] if isinstance( clusterization_settings, dict) else clusterization_settings is_tracked_list = [] for clust_settings in clusterization_settings_list: # Extract the dynamics of clusters across frames cluster_tuple = Interface.request_to_tuple(clust_settings) is_tracked_list.append( cluster_tuple in self._track_clusters_cache) return is_tracked_list
[docs] @staticmethod def generate_node_id(frame_index, cluster_identifier): return (frame_index, cluster_identifier)
[docs] def cluster_graph(self, cluster_settings: dict[str, Any], clusters_dynamics: list[list[Clustering]] = None) -> nx.DiGraph: """ Constructs a graph representing the cluster dynamics, optionally using provided cluster dynamics. :param cluster_settings: Settings for clusterization. :param clusters_dynamics: A list of lists where each sub-list represents the clusters in a frame. :return: A directed graph where nodes represent clusters and edges represent the temporal evolution of clusters. """ clusters_dynamics = clusters_dynamics or [group.clustering( **cluster_settings).labels_nodes_dict() for group in self._interface.groups] G = nx.DiGraph() # Build the graph with connections based on the highest overlap for frame_index, frame in enumerate(clusters_dynamics): for cluster_name in frame.keys(): node_id = self.generate_node_id(frame_index, cluster_name) G.add_node(node_id, frame=frame_index, cluster=cluster_name, Label=cluster_name) if frame_index > 0: # There are previous frame clusters to connect to current_frame_clusters = frame prev_frame_clusters = clusters_dynamics[frame_index - 1] is_more_clusters_in_current_frame = len( current_frame_clusters) >= len(prev_frame_clusters) # Determine comparison direction based on the number of clusters if is_more_clusters_in_current_frame: # More clusters in the current frame, or the number is equal, match each current cluster to one from the previous frame for current_cluster_name, current_cluster in current_frame_clusters.items(): current_node_id = self.generate_node_id( frame_index, current_cluster_name) current_set = set( map(tuple, current_cluster)) best_match = None max_overlap = 0 for prev_cluster_name, prev_cluster in prev_frame_clusters.items(): prev_set = set(map(tuple, prev_cluster)) prev_node_id = self.generate_node_id( frame_index - 1, prev_cluster_name) overlap = len( current_set.intersection(prev_set)) if overlap > max_overlap: best_match = prev_node_id max_overlap = overlap if best_match: G.add_edge(best_match, current_node_id) else: # More clusters in the previous frame, match each previous cluster to one from the current frame for prev_cluster_name, prev_cluster in prev_frame_clusters.items(): prev_set = set(map(tuple, prev_cluster)) prev_node_id = self.generate_node_id( frame_index - 1, prev_cluster_name) best_match = None max_overlap = 0 for current_cluster_name, current_cluster in current_frame_clusters.items(): current_node_id = self.generate_node_id( frame_index, current_cluster_name) current_set = set(map(tuple, current_cluster)) overlap = len( prev_set.intersection(current_set)) if overlap > max_overlap: best_match = current_node_id max_overlap = overlap if best_match: G.add_edge(prev_node_id, best_match) def propagate_labels_with_splits(G: nx.DiGraph): # Find all nodes with no predecessors (root nodes) root_nodes = [node for node in G.nodes() if len( list(G.successors(node))) == 0] # Check for root nodes with duplicate names and issue a warning if found root_names = [G.nodes[node]['cluster'] for node in root_nodes] if len(root_names) != len(set(root_names)): warnings.warn( f"Warning: Multiple root nodes have the same name: {root_names}") def propagate_label(node, current_label): """Recursively propagate labels through the graph.""" G.nodes[node]['Updated label'] = current_label predecessors = list(G.predecessors(node)) if len(predecessors) == 1: # Single successor inherits the label propagate_label(predecessors[0], current_label) elif len(predecessors) > 1: # Multiple predecessors indicate a split; assign new labels for i, succ in enumerate(predecessors): new_label = f"{current_label}.{i+1}" propagate_label(succ, new_label) # Start label propagation from each root node for root_node in root_nodes: initial_label = G.nodes[root_node]['Label'] propagate_label(root_node, initial_label) # Call the function to propagate labels with the handling of splits propagate_labels_with_splits(G) return G
[docs] def track_clusters(self, clusterization_settings: Union[dict, list[dict]]) -> list[dict[int, dict[str, str]]]: """ Tracks the clusters across frames based on the provided clusterization settings. :param clusterization_settings: The settings for clusterization, either a single dictionary for all frames or a list of dictionaries for each frame. :return: A list of dictionaries where each dictionary maps frame indices to lists of cluster labels. """ clusterization_settings_list = [clusterization_settings] if isinstance( clusterization_settings, dict) else clusterization_settings updated_labels_list = [] for clust_settings in clusterization_settings_list: # Extract the dynamics of clusters across frames cluster_tuple = Interface.request_to_tuple(clust_settings) if cluster_tuple not in self._track_clusters_cache: clusters_dynamics = [group.clustering( **clust_settings).labels_nodes_dict() for group in self._interface.groups] G = self.cluster_graph( clust_settings, clusters_dynamics=clusters_dynamics) # Extract updated labels updated_labels = {} for frame_index in range(len(clusters_dynamics)): # Initialize the dictionary for this frame's label mappings frame_labels = {} # Find all nodes belonging to the current frame nodes_in_frame = [node for node, attrs in G.nodes( data=True) if attrs.get('frame') == frame_index] # Update the frame's label mappings based on the nodes' current and updated labels for node_id in nodes_in_frame: original_label = G.nodes[node_id]['Label'] # Fallback to original if no update updated_label = G.nodes[node_id].get( 'Updated label', original_label) frame_labels[original_label] = updated_label # Store the updated label mappings for the current frame updated_labels[frame_index] = frame_labels # Apply updated labels for frame_index, labels in updated_labels.items(): group: Group = self._interface.groups[frame_index] group.clustering( **clust_settings).cluster_labels = list(labels.values()) if group.is_clustering_graph_initialized(**clust_settings): warnings.warn( 'Clustering tracked after clustering graph was initialized. This is unexpected behavior, trying to handle.') group.clustering_graph( reinitialize=True, **clust_settings) self._track_clusters_cache[cluster_tuple] = updated_labels updated_labels_list.append( self._track_clusters_cache[cluster_tuple]) return updated_labels_list
[docs] def get_unique_clusters(self, clusterization_settings: Union[dict, list[dict]]) -> list[list[str]]: """ Retrieves a list of lists, each containing unique clusters based on the given clusterization settings. Each inner list corresponds to one clustering type if multiple types are provided. :param clusterization_settings: The settings used for clusterization, either a single dictionary or a list of dictionaries for multiple types. :return: A list of lists, each containing unique cluster names for one type of clustering. """ tracked_clusters_list = self.track_clusters( clusterization_settings) unique_clusters_list = [] for tracked_clusters in tracked_clusters_list: unique_clusters = set() for clusters in tracked_clusters.values(): unique_clusters.update(clusters.values()) unique_clusters_list.append(list(unique_clusters)) return unique_clusters_list
[docs] def get_cluster_presence(self, clusterization_settings: Union[dict, list[dict]]) -> list[dict[str, list[int]]]: """ Retrieves a list of dictionaries, each mapping unique clusters to the frames in which they appear, based on the given clusterization settings. Each dictionary corresponds to one clustering type if multiple types are provided. :param clusterization_settings: The settings used for clusterization, either a single dictionary or a list of dictionaries for multiple types. :return: A list of dictionaries, each mapping unique cluster names to frame presence for one type of clustering. :example: Output .. code-block:: python [{'Cluster 0': [0, 1, 2], 'Cluster 1': [0, 1, 2], 'Cluster 2': [0, 1]}] means that Cluster 0, Cluster 1, and Cluster 2 are present in frame 0 and 1, and only Cluster 0 and 1 are present in frame 2. """ tracked_clusters_list = self.track_clusters( clusterization_settings) cluster_presence_list = [] for tracked_clusters in tracked_clusters_list: cluster_presence = {} for frame_index, clusters in tracked_clusters.items(): for cluster in clusters.values(): if cluster not in cluster_presence: cluster_presence[cluster] = [] cluster_presence[cluster].append(frame_index) cluster_presence_list.append(cluster_presence) return cluster_presence_list
[docs] def get_final_value(self, clusterization_settings: Union[dict, list[dict]], parameter) -> dict[str, float]: ''' Returns the value of the parameter in the last group cluster appeared ''' presence = self.get_cluster_presence(clusterization_settings)[0] final_image = {cluster: images[-1] for cluster, images in presence.items()} final_values = {} for cluster, group_number in final_image.items(): data = self._interface.groups[group_number].clustering_graph_values( parameters=(parameter, 'Label'), clustering_settings=clusterization_settings) final_values[cluster] = data[parameter][data['Label'].index( cluster)] return final_values
@abstractmethod def __len__(self): pass
[docs] @staticmethod def request_to_tuple(request): def convert(item): if isinstance(item, dict): return tuple(sorted((k, convert(v)) for k, v in item.items())) elif isinstance(item, list): return tuple(convert(v) for v in item) else: return item return convert(request)
[docs] def group_time_range(self) -> list[float]: return [self.groups[0].mean_time(), self.groups[-1].mean_time()]
def __repr__(self) -> str: return (f"<{self.__class__.__name__}(" f"REQUIRED_TYPE={self.REQUIRED_TYPE.__name__}, " f"data={repr(self.data)}, " f"group_length={len(self.groups)}, " f"time_range={self.time_range})>") def __str__(self) -> str: return f"Interface(type={self.REQUIRED_TYPE.__name__}, data={self.data}, groups={len(self.groups)}, time_range={self.time_range})"
[docs] class HariGraphInterface(Interface): """Interface specifically designed for the HariGraph class.""" REQUIRED_TYPE = Graph def __init__(self, data): super().__init__(data=data, group_length=1) self.data: Graph self._group_size = 1 def __len__(self): return 1 def _initialize_group(self, i: int = 0) -> Group: if i != 0 or self._group_size != 1: Warning.warn( 'An attempt to create the dynamics from a singe image') return Group([self.data]*self._group_size, time=np.array([0])) def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset: int = 0): if num_intervals != 1 and interval_size != 1: Warning.warn( 'An attempt to create the dynamics from a singe image') self._group_size = interval_size @property def available_parameters(self) -> list: """ Retrieves the list of available parameters/methods from the data gatherer. Returns: list: A list of available parameters or methods. """ return self.data.gatherer.node_parameters @property def time_range(self) -> list[float]: return [0., 0.]
[docs] class HariDynamicsInterface(Interface): """Interface specifically designed for the HariDynamics class.""" REQUIRED_TYPE = Dynamics def __init__(self, data): super().__init__(data=data, group_length=len(data.groups)) self.data: Dynamics def __len__(self): return len(self.data.groups) def _initialize_group(self, i: int) -> Group: group = self.data.groups[i] return Group([self.data[j] for j in group], time=np.array(group)) def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset: int = 0): self.data.group(num_intervals, interval_size, offset) @property def available_parameters(self) -> list: """ Retrieves the list of available parameters/methods from the data gatherer. Returns: list: A list of available parameters or methods. """ return self.data[0].gatherer.node_parameters @property def time_range(self) -> list[float]: return [0., float(len(self.data)-1)]
[docs] class SimulationInterface(Interface): """Interface specifically designed for the Simulation class.""" REQUIRED_TYPE = Simulation def __init__(self, data): super().__init__(data=data, group_length=len(data.dynamics.groups)) self.data: Simulation def __len__(self): return len(self.data.dynamics.groups) def _initialize_group(self, i: int) -> Group: group = self.data.dynamics.groups[i] return Group([self.data.dynamics[j] for j in group], time=np.array(group) * self.data.dt) def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset: int = 0): self.data.dynamics.group(num_intervals, interval_size, offset) @property def available_parameters(self) -> list: """ Retrieves the list of available parameters/methods from the data gatherer. Returns: list: A list of available parameters or methods. """ return self.data.dynamics[0].gatherer.node_parameters @property def time_range(self) -> list[float]: return [0., float(len(self.data)-1) * self.data.dt]