Source code for pyfibre.model.tools.fire_algorithm

import logging
import time
import copy

import networkx as nx
import numpy as np

from skimage.morphology import local_maxima

from pyfibre.utilities import ring, numpy_remove

from .fibre_utilities import (
    branch_angles, reduce_coord, new_branches, transfer_edges,
    check_2D_arrays
)

logger = logging.getLogger(__name__)


[docs]class FIREAlgorithm: """Class that extracts a complete fibre network from a provided image as a single nx.Graph object""" def __init__(self, nuc_thresh=2, lmp_thresh=0.15, angle_thresh=70, r_thresh=7, nuc_radius=10): """Initialise FibreNetwork object Parameters --------- nuc_thresh : float, optional Minimum distance pixel threshold to be classed as nucleation point lmp_thresh : float, optional Minimum distance pixel threshold to be classed as lmp point angle_thresh : float, optional Maximum radian deviation of new lmp from fibre trajectory r_thresh : float, optional Maximum length of edges between nodes nuc_radius : float, optional Minimum radial distance between nucleation points """ self._graph = None self.fibres = [] self.grow_list = [] self.nuc_thresh = nuc_thresh self.lmp_thresh = lmp_thresh self.angle_thresh = angle_thresh self.r_thresh = r_thresh self.nuc_radius = nuc_radius @property def theta_thresh(self): """Conversion of angle_thresh to radians""" return 1 + np.cos( (180 - self.angle_thresh) * np.pi / 180) def _assign_graph(self, graph=None): """Assign graph to self.graph""" assert isinstance(graph, nx.Graph), ( f"Argument `graph` must be an object " f"of type {nx.Graph}" ) self._graph = graph def _reset_graph(self): """Reset attribute `graph` to empty nx.Graph object""" self._assign_graph(nx.Graph()) self.grow_list = [] self.fibres = [] def _get_connected_nodes(self, node): """Get nodes connected to input node""" return np.array(list(self._graph.adj[node])) def _get_nucleation_points(self, image): """Set distance and angle thresholds for fibre iterator""" # Get global maxima for smoothed distance matrix maxima = local_maxima( image, connectivity=self.nuc_radius, allow_borders=True ) nuc_node_coord = reduce_coord( np.argwhere(maxima * image >= self.nuc_thresh), image[np.where(maxima * image >= self.nuc_thresh)], self.r_thresh) return nuc_node_coord def _initialise_graph(self, image, nuc_node_coord): """Initialise graph with nucleation nodes""" n_nuc = nuc_node_coord.shape[0] self._graph.add_nodes_from(np.arange(n_nuc)) self.grow_list = list(range(n_nuc)) n_nodes = n_nuc for nuc, nuc_coord in enumerate(nuc_node_coord): self._graph.nodes[nuc]['xy'] = nuc_coord self._graph.nodes[nuc]['nuc'] = nuc self.grow_list.remove(nuc) ring_filter = ring( np.zeros(image.shape), nuc_coord, [self.r_thresh // 2], 1 ) lmp_coord, lmp_vectors, lmp_r = new_branches( image, nuc_coord, ring_filter, self.lmp_thresh ) n_lmp = lmp_coord.shape[0] self._graph.add_nodes_from(n_nodes + np.arange(n_lmp)) self._graph.add_edges_from( [*zip(nuc * np.ones(n_lmp, dtype=int), n_nodes + np.arange(n_lmp))] ) generator = zip( lmp_coord, lmp_vectors, lmp_r, n_nodes + np.arange(n_lmp) ) for xy, vec, r, lmp in generator: self._graph.nodes[lmp]['xy'] = xy self._graph[nuc][lmp]['r'] = r self._graph.nodes[lmp]['nuc'] = nuc self.grow_list.append(lmp) self._graph.nodes[lmp]['direction'] = -vec / r n_nodes += n_lmp
[docs] def grow_lmp(self, index, image, tot_node_coord): """ Grow fibre object along network Parameters ---------- index: int Index of node to grow on the graph image: array_like, (float); shape=(nx, ny) Image to perform FIRE upon tot_node_coord: array_like Array of full coordinates (x, y) of nodes in graph network """ # Get nodes: end_node (end of fibre), nuc_node (start of fibre) # and prior_node (node connected to end) end_node = self._graph.nodes[index] nuc_node = self._graph.nodes[end_node['nuc']] # Get list of connected nodes in fibre connected_nodes = self._get_connected_nodes(index) prior = connected_nodes[0] prior_node = self._graph.nodes[prior] # Get edge between end and prior nodes edge = self._graph[index][prior] ring_filter = ring( np.zeros(image.shape), end_node['xy'], np.arange(2, 3), 1 ) branch_coord, branch_vector, branch_r = new_branches( image, end_node['xy'], ring_filter, self.lmp_thresh ) cos_the = branch_angles( end_node['direction'], branch_vector, branch_r ) indices = np.argwhere(abs(cos_the + 1) <= self.theta_thresh) if indices.size == 0: self.grow_list.remove(index) if edge['r'] <= self.r_thresh / 10: transfer_edges(self._graph, index, prior) return branch_coord = branch_coord[indices] branch_r = branch_r[indices] close_nodes, _ = check_2D_arrays(tot_node_coord, branch_coord, 1) close_nodes = numpy_remove(close_nodes, connected_nodes) if close_nodes.size != 0: new_end = close_nodes.min() transfer_edges(self._graph, index, new_end) self.grow_list.remove(index) else: new_index = branch_r.argmax() new_end_coord = branch_coord[new_index].flatten() new_end_vector = new_end_coord - prior_node['xy'] new_end_r = np.sqrt((new_end_vector**2).sum()) new_dir_vector = new_end_coord - nuc_node['xy'] new_dir_r = np.sqrt((new_dir_vector**2).sum()) if new_end_r >= self.r_thresh: new_end = self._graph.number_of_nodes() self._graph.add_node(new_end) new_node = self._graph.nodes[new_end] self._graph.add_edge(index, new_end) new_edge = self._graph[index][new_end] new_node['xy'] = new_end_coord new_node['nuc'] = end_node['nuc'] new_edge['r'] = np.sqrt( ((new_end_coord - end_node['xy'])**2).sum()) new_node['direction'] = (new_dir_vector / new_dir_r) self.grow_list.remove(index) self.grow_list.append(new_end) else: end_node['xy'] = new_end_coord edge['r'] = new_end_r end_node['direction'] = (new_dir_vector / new_dir_r)
[docs] def create_network(self, image): """Initialise network from n_nucleation sites""" self._reset_graph() nuc_node_coord = self._get_nucleation_points(image) self._initialise_graph(image, nuc_node_coord) n_nuc = nuc_node_coord.shape[0] n_node = self._graph.number_of_nodes() fibre_grow = [] fibre_grow[:] = self.grow_list n_fibres = len(fibre_grow) logger.debug("No. nucleation nodes = {}".format(n_nuc)) logger.debug("No. nodes created = {}".format(n_node)) logger.debug("No. fibres to grow = {}".format(n_fibres)) it = 0 total_time = 0 while len(fibre_grow) > 0: start = time.time() tot_node_coord = [self._graph.nodes[node]['xy'] for node in self._graph] tot_node_coord = np.stack(tot_node_coord) for fibre in fibre_grow: self.grow_lmp( fibre, image, tot_node_coord ) n_node = self._graph.number_of_nodes() fibre_grow[:] = self.grow_list it += 1 end = time.time() total_time += end - start logger.debug( f"Iteration {it} time = {round(end - start, 3)} s," f" {n_node} nodes {len(fibre_grow)}/{n_fibres} " f"fibres left to grow") return copy.copy(self._graph)