Source code for hari_plotter.group

from __future__ import annotations

import numpy as np

from .cluster import Clustering
from .graph import Graph


[docs] class Group: """ Represents a group of HariGraph images, allowing for operations like calculating mean graphs, performing clustering, and extracting node values. This facilitates analyzing similarities and differences among a collection of graph-based images, typically used in complex network analysis or similar domains. Attributes: images (list[HariGraph]): A list of HariGraph instances representing individual images in the group. time (Optional[list[float]]): Optional time values associated with each image, used for time-based analyses. model (Optional[Any]): An optional model associated with the group, which can be used for further analysis or processing. clusterings (dict): A dictionary to store clustering results with their settings as keys to avoid recomputation. _mean_graph (Optional[HariGraph]): Cached mean graph of the group, calculated when needed to optimize performance. _nodes (Optional[set]): Cached set of nodes present in the mean graph, used to speed up node-related computations. _node_parameters (Optional[dict]): Cached node parameters from the first image, ensuring consistency across the group. Common Functions: Group class provides a dictionary mapping statistical function names (e.g., 'Mean', 'Sum') to their corresponding numpy function objects, allowing for flexible data aggregation and analysis. """ # dictionary mapping function names to actual function objects for common statistical operations common_functions = { 'Mean': np.mean, 'Sum': np.sum, 'Min': np.min, 'Max': np.max, 'Median': np.median, 'Standard Deviation': np.std, 'Variance': np.var, 'Peak to Peak': np.ptp, # Peak to Peak (Max - Min) # Additional functions can be added here as needed } def __init__(self, images: list[Graph], time=None, model=None): """ Initializes a Group instance with a list of images, optional time values, and an optional model. Parameters: images (list[HariGraph]): A list of HariGraph instances to include in the group. time (Optional[list[float] | float]): Time values associated with each image, can be a list of floats, a single float (applied to all images), or None if time is not applicable. model (Optional[Any]): An optional model to associate with the group for advanced analyses. Raises: ValueError: If the length of the time list does not match the number of images or an invalid type is provided for time. """ self.images: list[Graph] = images if time is None: self.time = None elif isinstance(time, (int, float)): self.time = [time] * len(images) elif hasattr(time, '__len__'): assert len(time) == len( images), "Length of time and images must be the same" self.time = time else: raise ValueError("Invalid input for time") self.model = model self._mean_graph = None self.clusterings = dict() self._nodes = None self._node_parameters = None @property def node_parameters(self): if not self._node_parameters: self._node_parameters = self.images[0].node_parameters return self._node_parameters
[docs] @staticmethod def request_to_tuple(request): """ Converts a request dictionary or list into a sorted, nested tuple to ensure consistent key representation, especially useful for caching and retrieving results based on unique request settings. Parameters: request (Union[dict, list, Any]): The request to convert, which may include nested dictionaries and lists. Returns: tuple: A nested tuple representation of the request, providing a hashable and consistent key for caching. """ def convert(item): if isinstance(item, dict): return tuple(sorted((k, convert(v)) for k, v in item.items())) elif isinstance(item, list): return tuple(convert(v) for v in item) else: return item return convert(request)
@property def mean_graph(self) -> Graph: """ Lazily computes or retrieves the cached mean graph of the group. The mean graph is a single HariGraph instance that represents the average structure and attributes of all graphs in the group. Returns: HariGraph: The mean graph of the group. """ if self._mean_graph is None: self.initialize_mean_graph() return self._mean_graph
[docs] def get_mean_graph(self, **settings) -> Graph: return self.mean_graph
[docs] def initialize_mean_graph(self): self._mean_graph = self.images[0].mean_graph( [image.get_graph() for image in self.images])
[docs] def clustering(self, **clustering_settings) -> Clustering: # Convert settings to a sorted tuple of pairs to ensure consistent ordering clustering_key = Group.request_to_tuple(clustering_settings) if clustering_key not in self.clusterings: # Create a new clustering instance clustering = Clustering.create_clustering( G=self.mean_graph, **clustering_settings) # Cache the newly created clustering self.clusterings[clustering_key] = {'clustering': clustering} return self.clusterings[clustering_key]['clustering']
[docs] def get_clustering(self, **settings) -> Clustering: if 'clustering_settings' in settings: return self.clustering(**(settings['clustering_settings'])) else: return self.clustering()
[docs] def is_clustering_graph_initialized(self, **clustering_settings) -> bool: clustering_key = self.request_to_tuple(clustering_settings) return clustering_key in self.clusterings and 'graph' in self.clusterings[clustering_key] and self.clusterings[clustering_key]['graph']
[docs] def clustering_graph(self, merge_remaining: bool = False, reinitialize: bool = False, **clustering_settings) -> Graph: clustering_key = self.request_to_tuple(clustering_settings) # initialize the graph if reinitialize or (clustering_key not in self.clusterings or 'graph' not in self.clusterings[clustering_key]): clustering = self.clustering(**clustering_settings) clustering_nodes = [ list([tuple(node) for node in nodes]) for nodes in clustering.labels_nodes_dict().values()] cluster_labels = clustering.cluster_labels clustering_graph = self.mean_graph.copy() clustering_graph.merge_clusters( clustering_nodes, labels=cluster_labels, merge_remaining=merge_remaining) self.clusterings[clustering_key]['graph'] = clustering_graph return self.clusterings[clustering_key]['graph']
[docs] def clustering_graph_values(self, parameters: tuple[str], clustering_settings: tuple, **settings) -> dict[str, np.ndarray]: graph = self.clustering_graph(**clustering_settings) params_no_time = [param for param in parameters if param != 'Time'] results = graph.gatherer.gather(params_no_time) if 'Time' in parameters: results['Time'] = self.mean_time() return results
def __len__(self): return len(self.images) def __iter__(self): return iter(self.images) def __getitem__(self, key): return self.images[key] def __setitem__(self, key, value): self.images[key] = value def __delitem__(self, key): del self.images[key]
[docs] def append(self, image): self.images.append(image)
def __repr__(self): return f"Group(id={id(self)}, time={self.time}, images={self.images})" @property def nodes(self) -> set: if self._nodes: return self._nodes self._nodes = set(self.mean_graph.nodes) return self._nodes
[docs] def calculate_node_values(self, parameters: tuple[str], **settings) -> dict: """ Calculate the node values based on parameters. Args: parameters (list[str]): list of parameter names. Returns: dict: A dictionary containing mean node values. """ params_no_time = [param for param in parameters if param != 'Time'] results = self.mean_graph.gatherer.gather(params_no_time) if 'Time' in parameters: results['Time'] = self.mean_time() return results
[docs] def calculate_function_of_node_values(self, parameters: tuple[str], function='Mean', **settings) -> dict: """ Calculate the function of mean node values based on parameters, treating 'Time' specially. Args: parameters (list[str]): list of parameter names. function (str): The name of the function to be applied to the mean values of the nodes, except for 'Time'. Returns: dict: A dictionary containing the results of the function applied to mean node values. """ # Get the actual function object from the common_functions dictionary if function not in self.common_functions: raise ValueError( f'{function} is not in a list of common functions: {list(self.common_functions.keys())}') func = self.common_functions[function] # Separate 'Time' from other parameters if it's present params_no_time = [param for param in parameters if param != 'Time'] node_data = self.calculate_node_values(params_no_time) results = {} for param in parameters: if param == 'Time': # Handle 'Time' as a special case results['Time'] = self.mean_time() elif param == 'Nodes': continue elif param in node_data: # Apply the function to other parameters results[param] = func(node_data[param]) return results
[docs] def mean_time(self, **settings): return np.mean(self.time)