Source code for matador.crystal.network

""" This file implements turning matador Crystal objects
into CrystalGraph objects.
"""
import networkx as nx
import numpy as np
import itertools

EPS = 1e-12


[docs]class CrystalGraph(nx.MultiDiGraph): def __init__( self, structure=None, graph=None, coordination_cutoff=1.1, bond_tolerance=1e20, num_images=1, debug=False, separate_images=False, delete_one_way_bonds=False, max_bond_length=5, ): """Create networkx.MultiDiGraph object with extra functionality for atomic networks. Keyword Arguments: structure (matador.Crystal): crystal structure to network-ify graph (nx.MultiDiGraph): initialise from graph coordination_cutoff (float) : max multiplier of first coordination sphere for edge drawing num_images (int): number of periodic images to include in each direction separate_images (bool): whether or not to include image atoms as new nodes """ super().__init__() if graph is None and structure is None: raise RuntimeError("No structure or graph to initialise network from.") if structure is not None: atoms = structure.sites num_atoms = len(atoms) element_bonds = {} images = list( itertools.product(range(-num_images, num_images + 1), repeat=3) ) image_number = 0 # now loop over pairs of atoms and decide whether to draw an edge from matador.utils.cell_utils import calc_pairwise_distances_pbc distances = calc_pairwise_distances_pbc( structure.positions_abs, images, structure.lattice_cart, max_bond_length, compress=False, debug=True, ) # first over loop all pairs to find the minimum distance between all species pairs # and the minimum distance for each atom for i, atom in enumerate(atoms): self.add_node(i, species=atom.species) min_dists = [1e20 for atom in atoms] for index in np.where(~distances.mask)[0]: image_index = int(index / num_atoms**2) i = int((index - image_index * num_atoms**2) / num_atoms) j = int((index - image_index * num_atoms**2) % num_atoms) atom = atoms[i] other_atom = atoms[j] if i == j and np.linalg.norm(images[image_index]) <= EPS: continue dist = distances[index] pair_key = tuple(sorted([atom.species, other_atom.species])) if pair_key not in element_bonds or element_bonds[pair_key] > dist: element_bonds[pair_key] = dist # find the closest image of an atom to i if dist < min_dists[i]: min_dists[i] = dist if debug: print(min_dists) print(element_bonds) for index in np.where(~distances.mask)[0]: image_index = int(index / num_atoms**2) i = int((index - image_index * num_atoms**2) / num_atoms) j = int((index - image_index * num_atoms**2) % num_atoms) atom = atoms[i] other_atom = atoms[j] min_dist = min_dists[i] if i == j and np.linalg.norm(images[image_index]) <= EPS: continue dist = distances[index] pair_key = tuple(sorted([atom.species, other_atom.species])) if ( dist <= min_dist * coordination_cutoff and dist <= element_bonds[pair_key] * bond_tolerance ): if separate_images and all( [val <= 0 + 1e-8 for val in images[image_index]] ): image_number += 1 self.add_node(j + image_number, species=atoms[j].species) self.add_edge(i, j + image_number, dist=dist) self.add_edge(j + image_number, i, dist=dist) else: is_image = np.linalg.norm(images[image_index]) > EPS self.add_edge(i, j, dist=dist, image=is_image) elif graph is not None: for node, data in graph.nodes.data(): self.add_node( node, species=data["species"], image=data.get("image", False) ) for node_in, node_out, data in graph.edges.data(): if not delete_one_way_bonds or (node_out, node_in) in graph.edges(): self.add_edge( node_in, node_out, dist=data.get("dist", 0), image=data.get("image", False), )
[docs] def get_strongly_connected_component_subgraphs(self, delete_one_way_bonds=True): """Return generator of strongly-connected subgraphs in CrystalGraph format.""" return ( CrystalGraph( graph=self.subgraph(c), delete_one_way_bonds=delete_one_way_bonds ) for c in nx.strongly_connected_components(self) )
[docs] def get_communities(self, graph=None, **louvain_kwargs): """Return list of community subgraphs in CrystalGraph format.""" import community as louvain if graph is None: graph = self if graph.is_directed(): undirected_graph = self.remove_directionality(graph=graph) partition = louvain.best_partition(undirected_graph, **louvain_kwargs) size = len(set(partition.values())) subgraphs = [nx.MultiGraph() for i in range(size)] for node in partition: subgraphs[partition[node]].add_node( node, species=list(self.nodes(data=True))[list(self.nodes()).index(node)][1][ "species" ], ) for edge in self.edges(): if partition[edge[0]] == partition[edge[1]]: subgraphs[partition[edge[0]]].add_edge(edge[0], edge[1]) subgraphs = [CrystalGraph(graph=sg) for sg in subgraphs] return subgraphs, partition
[docs] def remove_directionality(self, graph=None): if graph is None: graph = self import networkx as nx undirected_graph = nx.MultiGraph() for node in graph.nodes(data=True): undirected_graph.add_node(node[0], species=node[1]["species"]) for edge in graph.edges(data=True): if (edge[1], edge[0]) not in undirected_graph.edges(): undirected_graph.add_edge( edge[0], edge[1], dist=edge[2].get("dist", 0), image=edge[2].get("image", False), ) return undirected_graph
[docs] def set_unique_subgraphs(self, method="community"): """Filter strongly connected component subgraphs for isomorphism with others inside CrystalGraph. Sets self.unique_subgraph to a set of such subgraphs. """ if method == "community": self.unique_subgraphs = get_unique_subgraphs(self.get_communities()) elif method == "strongly_connected": self.unique_subgraphs = get_unique_subgraphs( self.get_strongly_connected_component_subgraphs() ) elif method == "both": strong_subgraphs = self.get_strongly_connected_component_subgraphs() community_subgraphs = [] for sg in strong_subgraphs: community_subgraphs.extend(sg.get_communities()) self.unique_subgraphs = get_unique_subgraphs(community_subgraphs)
[docs] def get_bonds_per_atom(self): num_bonds = 0 for node_in in self.nodes(): for node_out in self.nodes(): if node_in == node_out: continue if self.has_edge(node_in, node_out) and self.has_edge( node_out, node_in ): num_bonds += 1 return num_bonds / self.number_of_nodes()
[docs]def node_match(n1, n2): return n1["species"] == n2["species"]
[docs]def get_unique_subgraphs(subgraphs): """Filter strongly connected component subgraphs for isomorphism with others. Input: | subgraphs: list(CrystalGraph), list of subgraph objects to filter Returns: | unique_subgraphs: set(CrystalGraph), set of unique subgraphs """ unique_subgraphs = set() for subgraph in subgraphs: if not any( [ are_graphs_the_same(subgraph, other_subgraph) for other_subgraph in unique_subgraphs ] ): unique_subgraphs.add(subgraph) return unique_subgraphs
[docs]def are_graphs_the_same(g1, g2, edge_match=None): if edge_match is None: def edge_match(e1, e2): atol = 0.1 rtol = 0.05 return ( abs(e1[0]["dist"] - e2[0]["dist"]) <= atol + rtol * e2[0]["dist"] and e1[0]["image"] == e2[0]["image"] ) return nx.is_isomorphic( g1, g2, node_match=lambda n1, n2: n1["species"] == n2["species"], edge_match=edge_match, )
[docs]def draw_network( structure, layout=None, edge_labels=False, node_index=False, curved_edges=True, node_colour="elem", partition=None, ax=None, ): import networkx as nx from matador.utils.viz_utils import get_element_colours import matplotlib.pyplot as plt element_colours = get_element_colours() try: network = structure.network except Exception: network = structure if layout is None: pos = nx.spring_layout(network) else: pos = layout if ax is None: fig, ax = plt.subplots() if node_colour == "degree": coords = list(set(dict(network.degree).values())) cmap = plt.cm.get_cmap("Dark2", len(coords)).colors colours = [cmap[coords.index(network.degree[node])] for node in network.nodes()] elif node_colour == "partition" and partition is not None: num_partitions = len(set(partition.values())) cmap = plt.cm.get_cmap("Dark2", num_partitions).colors colours = [cmap[partition[node]] for node in network.nodes()] else: elem_map = element_colours colours = [elem_map.get(data["species"]) for node, data in network.nodes.data()] if node_index: labels = { node: "{} \\#{}".format(data["species"], node) for node, data in network.nodes.data() } else: labels = {node: str(data["species"]) for node, data in network.nodes.data()} edge_colours = [] for edge in network.edges(data=True): if edge[2].get("image", True): edge_colours.append("grey") else: edge_colours.append("black") nx.draw_networkx_nodes( network, pos, node_color=colours, edgecolors="black", linewidths=2, node_size=1000, ax=ax, ) nx.draw_networkx_edges( network, pos, edge_color=edge_colours, width=2, node_size=1000, ax=ax ) if edge_labels: edge_weight = dict() for edge in network.edges(data=True): # data = edge[2] edge = (edge[0], edge[1]) if edge not in edge_weight and (edge[1], edge[0]) not in edge_weight: edge_weight[edge] = 1 else: if edge in edge_weight: edge_weight[edge] += 1 else: edge_weight[(edge[1], edge[0])] += 1 edge_label_dict = edge_weight nx.draw_networkx_edge_labels(network, pos, edge_labels=edge_label_dict, ax=ax) nx.draw_networkx_labels(network, pos, labels=labels, ax=ax) plt.axis("off")