""" This file implements turning matador Crystal objects
into CrystalGraph objects.
"""
import networkx as nx
import numpy as np
import itertools
EPS = 1e-12
[docs]class CrystalGraph(nx.MultiDiGraph):
def __init__(
self,
structure=None,
graph=None,
coordination_cutoff=1.1,
bond_tolerance=1e20,
num_images=1,
debug=False,
separate_images=False,
delete_one_way_bonds=False,
max_bond_length=5,
):
"""Create networkx.MultiDiGraph object with extra functionality for atomic networks.
Keyword Arguments:
structure (matador.Crystal): crystal structure to network-ify
graph (nx.MultiDiGraph): initialise from graph
coordination_cutoff (float) : max multiplier of first
coordination sphere for edge drawing
num_images (int): number of periodic images to include in
each direction
separate_images (bool): whether or not to include image
atoms as new nodes
"""
super().__init__()
if graph is None and structure is None:
raise RuntimeError("No structure or graph to initialise network from.")
if structure is not None:
atoms = structure.sites
num_atoms = len(atoms)
element_bonds = {}
images = list(
itertools.product(range(-num_images, num_images + 1), repeat=3)
)
image_number = 0
# now loop over pairs of atoms and decide whether to draw an edge
from matador.utils.cell_utils import calc_pairwise_distances_pbc
distances = calc_pairwise_distances_pbc(
structure.positions_abs,
images,
structure.lattice_cart,
max_bond_length,
compress=False,
debug=True,
)
# first over loop all pairs to find the minimum distance between all species pairs
# and the minimum distance for each atom
for i, atom in enumerate(atoms):
self.add_node(i, species=atom.species)
min_dists = [1e20 for atom in atoms]
for index in np.where(~distances.mask)[0]:
image_index = int(index / num_atoms**2)
i = int((index - image_index * num_atoms**2) / num_atoms)
j = int((index - image_index * num_atoms**2) % num_atoms)
atom = atoms[i]
other_atom = atoms[j]
if i == j and np.linalg.norm(images[image_index]) <= EPS:
continue
dist = distances[index]
pair_key = tuple(sorted([atom.species, other_atom.species]))
if pair_key not in element_bonds or element_bonds[pair_key] > dist:
element_bonds[pair_key] = dist
# find the closest image of an atom to i
if dist < min_dists[i]:
min_dists[i] = dist
if debug:
print(min_dists)
print(element_bonds)
for index in np.where(~distances.mask)[0]:
image_index = int(index / num_atoms**2)
i = int((index - image_index * num_atoms**2) / num_atoms)
j = int((index - image_index * num_atoms**2) % num_atoms)
atom = atoms[i]
other_atom = atoms[j]
min_dist = min_dists[i]
if i == j and np.linalg.norm(images[image_index]) <= EPS:
continue
dist = distances[index]
pair_key = tuple(sorted([atom.species, other_atom.species]))
if (
dist <= min_dist * coordination_cutoff
and dist <= element_bonds[pair_key] * bond_tolerance
):
if separate_images and all(
[val <= 0 + 1e-8 for val in images[image_index]]
):
image_number += 1
self.add_node(j + image_number, species=atoms[j].species)
self.add_edge(i, j + image_number, dist=dist)
self.add_edge(j + image_number, i, dist=dist)
else:
is_image = np.linalg.norm(images[image_index]) > EPS
self.add_edge(i, j, dist=dist, image=is_image)
elif graph is not None:
for node, data in graph.nodes.data():
self.add_node(
node, species=data["species"], image=data.get("image", False)
)
for node_in, node_out, data in graph.edges.data():
if not delete_one_way_bonds or (node_out, node_in) in graph.edges():
self.add_edge(
node_in,
node_out,
dist=data.get("dist", 0),
image=data.get("image", False),
)
[docs] def get_strongly_connected_component_subgraphs(self, delete_one_way_bonds=True):
"""Return generator of strongly-connected subgraphs in CrystalGraph format."""
return (
CrystalGraph(
graph=self.subgraph(c), delete_one_way_bonds=delete_one_way_bonds
)
for c in nx.strongly_connected_components(self)
)
[docs] def get_communities(self, graph=None, **louvain_kwargs):
"""Return list of community subgraphs in CrystalGraph format."""
import community as louvain
if graph is None:
graph = self
if graph.is_directed():
undirected_graph = self.remove_directionality(graph=graph)
partition = louvain.best_partition(undirected_graph, **louvain_kwargs)
size = len(set(partition.values()))
subgraphs = [nx.MultiGraph() for i in range(size)]
for node in partition:
subgraphs[partition[node]].add_node(
node,
species=list(self.nodes(data=True))[list(self.nodes()).index(node)][1][
"species"
],
)
for edge in self.edges():
if partition[edge[0]] == partition[edge[1]]:
subgraphs[partition[edge[0]]].add_edge(edge[0], edge[1])
subgraphs = [CrystalGraph(graph=sg) for sg in subgraphs]
return subgraphs, partition
[docs] def remove_directionality(self, graph=None):
if graph is None:
graph = self
import networkx as nx
undirected_graph = nx.MultiGraph()
for node in graph.nodes(data=True):
undirected_graph.add_node(node[0], species=node[1]["species"])
for edge in graph.edges(data=True):
if (edge[1], edge[0]) not in undirected_graph.edges():
undirected_graph.add_edge(
edge[0],
edge[1],
dist=edge[2].get("dist", 0),
image=edge[2].get("image", False),
)
return undirected_graph
[docs] def set_unique_subgraphs(self, method="community"):
"""Filter strongly connected component subgraphs for isomorphism with others inside
CrystalGraph. Sets self.unique_subgraph to a set of such subgraphs.
"""
if method == "community":
self.unique_subgraphs = get_unique_subgraphs(self.get_communities())
elif method == "strongly_connected":
self.unique_subgraphs = get_unique_subgraphs(
self.get_strongly_connected_component_subgraphs()
)
elif method == "both":
strong_subgraphs = self.get_strongly_connected_component_subgraphs()
community_subgraphs = []
for sg in strong_subgraphs:
community_subgraphs.extend(sg.get_communities())
self.unique_subgraphs = get_unique_subgraphs(community_subgraphs)
[docs] def get_bonds_per_atom(self):
num_bonds = 0
for node_in in self.nodes():
for node_out in self.nodes():
if node_in == node_out:
continue
if self.has_edge(node_in, node_out) and self.has_edge(
node_out, node_in
):
num_bonds += 1
return num_bonds / self.number_of_nodes()
[docs]def node_match(n1, n2):
return n1["species"] == n2["species"]
[docs]def get_unique_subgraphs(subgraphs):
"""Filter strongly connected component subgraphs for isomorphism with others.
Input:
| subgraphs: list(CrystalGraph), list of subgraph objects to filter
Returns:
| unique_subgraphs: set(CrystalGraph), set of unique subgraphs
"""
unique_subgraphs = set()
for subgraph in subgraphs:
if not any(
[
are_graphs_the_same(subgraph, other_subgraph)
for other_subgraph in unique_subgraphs
]
):
unique_subgraphs.add(subgraph)
return unique_subgraphs
[docs]def are_graphs_the_same(g1, g2, edge_match=None):
if edge_match is None:
def edge_match(e1, e2):
atol = 0.1
rtol = 0.05
return (
abs(e1[0]["dist"] - e2[0]["dist"]) <= atol + rtol * e2[0]["dist"]
and e1[0]["image"] == e2[0]["image"]
)
return nx.is_isomorphic(
g1,
g2,
node_match=lambda n1, n2: n1["species"] == n2["species"],
edge_match=edge_match,
)
[docs]def draw_network(
structure,
layout=None,
edge_labels=False,
node_index=False,
curved_edges=True,
node_colour="elem",
partition=None,
ax=None,
):
import networkx as nx
from matador.utils.viz_utils import get_element_colours
import matplotlib.pyplot as plt
element_colours = get_element_colours()
try:
network = structure.network
except Exception:
network = structure
if layout is None:
pos = nx.spring_layout(network)
else:
pos = layout
if ax is None:
fig, ax = plt.subplots()
if node_colour == "degree":
coords = list(set(dict(network.degree).values()))
cmap = plt.cm.get_cmap("Dark2", len(coords)).colors
colours = [cmap[coords.index(network.degree[node])] for node in network.nodes()]
elif node_colour == "partition" and partition is not None:
num_partitions = len(set(partition.values()))
cmap = plt.cm.get_cmap("Dark2", num_partitions).colors
colours = [cmap[partition[node]] for node in network.nodes()]
else:
elem_map = element_colours
colours = [elem_map.get(data["species"]) for node, data in network.nodes.data()]
if node_index:
labels = {
node: "{} \\#{}".format(data["species"], node)
for node, data in network.nodes.data()
}
else:
labels = {node: str(data["species"]) for node, data in network.nodes.data()}
edge_colours = []
for edge in network.edges(data=True):
if edge[2].get("image", True):
edge_colours.append("grey")
else:
edge_colours.append("black")
nx.draw_networkx_nodes(
network,
pos,
node_color=colours,
edgecolors="black",
linewidths=2,
node_size=1000,
ax=ax,
)
nx.draw_networkx_edges(
network, pos, edge_color=edge_colours, width=2, node_size=1000, ax=ax
)
if edge_labels:
edge_weight = dict()
for edge in network.edges(data=True):
# data = edge[2]
edge = (edge[0], edge[1])
if edge not in edge_weight and (edge[1], edge[0]) not in edge_weight:
edge_weight[edge] = 1
else:
if edge in edge_weight:
edge_weight[edge] += 1
else:
edge_weight[(edge[1], edge[0])] += 1
edge_label_dict = edge_weight
nx.draw_networkx_edge_labels(network, pos, edge_labels=edge_label_dict, ax=ax)
nx.draw_networkx_labels(network, pos, labels=labels, ax=ax)
plt.axis("off")