# coding: utf-8
# Distributed under the terms of the MIT License.
""" This submodule contains a dirty wrapper of ase-gui for quick
visualisation, and nglview wrapper for JupyterNotebook visualisation,
and some colour definitions scraped from VESTA configs.
"""
from typing import Tuple, Optional, Dict, List, Union, Callable, Set, TYPE_CHECKING
from pathlib import Path
from functools import lru_cache
import tqdm
import numpy as np
from matador.utils.ase_utils import doc2ase
from matador.crystal import Crystal
if TYPE_CHECKING:
import fresnel
[docs]def viz(doc):
"""Quick and dirty ase-gui visualisation from matador doc."""
try:
from ase.visualize import view
except ImportError as exc:
raise ImportError(
"Optional dependency 'ase' is missing, please install it from PyPI with `pip install ase"
) from exc
view(doc2ase(doc))
return
[docs]@lru_cache(maxsize=1)
def get_element_colours() -> Dict[str, List[float]]:
"""Return RGB element colours from VESTA file.
The colours file can be specified in the matadorrc.
If unspecified, the default `../config/vesta_elements.ini` will be used.
"""
import os
from matador.config import load_custom_settings, SETTINGS
default_colours_fname = (
Path(os.path.realpath(__file__)).parent / "../config/vesta_elements.ini"
)
# check if element_colours has been given as an absolute path
if SETTINGS:
colours_fname = SETTINGS.get("plotting", {}).get("element_colours")
else:
colours_fname = (
load_custom_settings().get("plotting", {}).get("element_colours")
)
if colours_fname is None or not os.path.isfile(colours_fname):
colours_fname = default_colours_fname
with open(colours_fname, "r") as f:
flines = f.readlines()
element_colours = dict()
for line in flines:
line = line.split()
elem = line[1]
colour = list(map(float, line[-3:]))
element_colours[elem] = colour
return element_colours
[docs]@lru_cache(maxsize=1)
def get_element_radii():
"""Read element radii from VESTA config file. The colours file can be
specified in the matadorrc. If unspecified, the default
../config/vesta_elements.ini (relative to this file) will be used.
"""
import os
from matador.config import load_custom_settings, SETTINGS
default_colours_fname = (
Path(os.path.realpath(__file__)).parent / "../config/vesta_elements.ini"
)
# check if element_colours has been given as an absolute path
if SETTINGS:
colours_fname = SETTINGS.get("plotting", {}).get("element_colours")
else:
colours_fname = (
load_custom_settings().get("plotting", {}).get("element_colours")
)
if colours_fname is None or not os.path.isfile(colours_fname):
colours_fname = default_colours_fname
with open(colours_fname, "r") as f:
flines = f.readlines()
element_radii = dict()
for line in flines:
line = line.split()
elem = line[1]
r = float(line[2])
element_radii[elem] = r
return element_radii
[docs]def nb_viz(doc, repeat=1, bonds=None):
"""Return an ipywidget for nglview visualisation in
a Jupyter notebook or otherwise.
Parameters:
doc (matador.crystal.Crystal / dict): matador document to show.
Keyword arguments:
repeat (int): number of periodic images to include.
bonds (str): custom bond selection.
"""
try:
import nglview
except ImportError as exc:
raise ImportError(
"Optional dependency 'nglview' is missing, please install it from PyPI with `pip install nglview"
) from exc
atoms = doc2ase(doc)
atoms = atoms.repeat((repeat, repeat, repeat))
view = nglview.show_ase(atoms)
view.add_unitcell()
view.remove_ball_and_stick()
view.add_spacefill(radius_type="vdw", scale=0.3)
if bonds is None:
for elements in set(doc["atom_types"]):
view.add_ball_and_stick(selection="#{}".format(elements.upper()))
else:
if not isinstance(bonds, list):
bonds = [bonds]
for bond in bonds:
view.add_ball_and_stick(selection="#{}".format(bond.upper()))
view.parameters = {"clipDist": 0}
view.camera = "orthographic"
view.background = "#FFFFFF"
view.center()
return view
[docs]def fresnel_view(
doc: Crystal,
standardize: Union[bool, float] = True,
extension: Optional[Tuple[int, int, int]] = None,
show_bonds: bool = True,
show_cell: bool = True,
show_compass: bool = True,
bond_dict: Optional[Dict[Tuple[str, str], float]] = None,
images: Union[bool, float] = True,
pad_cell: bool = True,
lights: Optional[Callable] = None,
**camera_kwargs,
) -> "fresnel.Scene":
"""Return a fresnel scene visualising the input crystal.
The scene can then be rendered/ray-traced with `fresnel.pathtrace` or
`fresnel.preview`.
Notes:
Periodic images of sites will be included in the visualisation if they are
within the desired distance tolerance of another site in the crystal, and
are only rendered if they are bonded within the definitions of the `bond_dict`.
This process requires the construction of two kd-trees, so disabling the `images`
flag will yield a significant speed-up for large crystal structures.
Parameters:
doc: The crystal to visualize.
standardize: Whether to standardize the crystal before visualising. If a float
is provided, this will be used for the symmetry tolerance.
extension: Visualize a supercell with this lattice extension.
show_bonds: Whether to show bonds in the visualisation.
show_cell: Whether to show the unit cell in the visualisation.
show_compass: Whether to show a compass in the visualisation.
bond_dict: A dictionary of bond lengths to override the default bonds, e.g.,
`{('C', 'C'): 1.5, ('C', 'H'): 1.0}`.
images: Whether to show periodic images in the visualisation (if given a float, this will be used
as the clipping distance).
pad_cell: Size to aim for all directions to be visualised, e.g. a supercell with minimum side length
`pad_cell` will be constructed.
lights: An optional callable that sets the lights for the scene.
Returns:
The fresnel scene to render.
"""
try:
import fresnel
except ImportError as exc:
raise ImportError(
"Optional dependency 'fresnel' is missing, please install it "
"from conda forge with `conda install -c conda-forge fresnel`"
) from exc
scene = fresnel.Scene()
if standardize:
doc = doc.standardized()
if show_cell:
_draw_cell(scene, doc, compass=show_compass)
if extension is not None:
if not extension == (1, 1, 1):
doc = doc.supercell(extension)
elif pad_cell:
doc = doc.supercell(target=pad_cell)
clip = None
if isinstance(images, float):
images = True
clip = images
elif images and bond_dict:
clip = max(bond_dict[d] for d in bond_dict)
species, positions = _get_sites(doc, images=images, clip=clip)
to_remove = set()
if show_bonds:
bond_sets = _draw_bonds(
scene,
list(zip(species, positions)),
bond_dict=bond_dict,
nsites=len(doc.positions_abs),
bond_images=False,
)
if images:
image_ind = len(doc.positions_abs)
for i in range(image_ind, len(positions)):
if i not in bond_sets:
to_remove.add(i)
positions = [positions[i] for i, p in enumerate(positions) if i not in to_remove]
species = [species[i] for i, p in enumerate(species) if i not in to_remove]
_draw_atoms(scene, positions, species)
a, b, c = doc.lattice_cart
fit_orthographic_camera(scene, **camera_kwargs)
if lights is None:
scene.lights = fresnel.light.lightbox()
else:
scene.lights = lights()
return scene
def _get_sites(
doc, images: True, clip: float = 3.5
) -> Tuple[List[str], List[Tuple[float, float, float]]]:
"""Return the species and positions of the sites in the crystal, including
any periodic images within the distance tolerance.
Parameters:
doc: The crystal to visualize.
images: Whether to include periodic images.
clip: The distance tolerance to other sites for periodic image inclusion.
Returns:
A list of species and a list of Cartesian positions.
"""
positions = []
species = []
clip = clip or 3.5
for ind, pos_cart in enumerate(doc.positions_abs):
positions.append(pos_cart)
species.append(doc[ind].species)
if images:
image_positions, image_species = get_image_atoms_near_cell_boundaries(
doc, tolerance=clip
)
positions.extend(image_positions)
species.extend(image_species)
return species, positions
def _draw_cell(scene, doc, thickness=0.05, compass=True) -> None:
"""Add the unit cell to a fresnel scene.
Parameters:
scene: The fresnel scene.
doc: The crystal to visualize.
thickness: The line thickness to use.
compass: Whether to draw the compass of lattice vectors.
"""
import numpy as np
import fresnel
vertices = np.zeros((8, 3), dtype=np.float64)
vertices[1:4, :] = doc.lattice_cart
vertices[4, :] = vertices[1, :] + vertices[2, :]
vertices[5, :] = vertices[1, :] + vertices[3, :]
vertices[6, :] = vertices[2, :] + vertices[3, :]
vertices[7, :] = np.sum(vertices[1:4, :], axis=0)
unit_cell = fresnel.geometry.Cylinder(scene, N=12)
unit_cell.points[:] = [
[vertices[0, :], vertices[1, :]],
[vertices[0, :], vertices[2, :]],
[vertices[0, :], vertices[3, :]],
[vertices[1, :], vertices[4, :]],
[vertices[2, :], vertices[4, :]],
[vertices[1, :], vertices[5, :]],
[vertices[3, :], vertices[5, :]],
[vertices[2, :], vertices[6, :]],
[vertices[3, :], vertices[6, :]],
[vertices[4, :], vertices[7, :]],
[vertices[5, :], vertices[7, :]],
[vertices[6, :], vertices[7, :]],
]
unit_cell.radius[:] = thickness * np.ones(12)
unit_cell.material = fresnel.material.Material(
color=fresnel.color.linear([0, 0, 0])
)
unit_cell.material.primitive_color_mix = 1
unit_cell.material.solid = 1.0
if compass:
compass = fresnel.geometry.Cylinder(scene, N=3)
compass.points[:] = [
[[0, 0, 0], doc.lattice_cart[ind] / doc.lattice_abc[0][ind]]
for ind in range(3)
]
compass.color[0] = ([1.0, 0, 0], [1, 0, 0])
compass.color[1] = ([0, 1.0, 0], [0, 1.0, 0])
compass.color[2] = ([0, 0, 1.0], [0, 0, 1.0])
compass.radius[:] = 0.1
compass.material = fresnel.material.Material(
color=fresnel.color.linear([0, 0, 0])
)
compass.material.primitive_color_mix = 1
compass.material.solid = 0.0
compass.material.roughness = 0.5
compass.material.specular = 0.7
compass.material.spec_trans = 0.0
compass.material.metal = 0.0
def _draw_bonds(
scene: "fresnel.Scene",
sites: List[Tuple[str, List[List[float]]]],
bond_dict: Dict[Tuple[str, str], float] = None,
bond_thickness: float = 0.15,
nsites: int = None,
bond_images: bool = False,
) -> Dict[int, Set[int]]:
"""Add bonds to a fresnel scene.
Parameters:
scene: The fresnel scene.
sites: A zipped list of species and positions.
bond_dict: The dictionary defining bond lengths between species.
bond_thickness: The thickness of the bonds in the visualisation.
nsites: The number of "real" sites in the structure.
bond_images: Whether to draw bonds between images.
Returns:
Indices of bonded atoms as a dictionary of sets.
"""
import fresnel
from collections import defaultdict
bond_colours = []
bonds = []
bond_sets = defaultdict(set)
bond_dict = {} if bond_dict is None else bond_dict
for k, v in list(bond_dict.items()):
bond_dict[tuple(sorted(k))] = v
colours = get_element_colours()
for i, i_atom in enumerate(sites):
for j, j_atom in enumerate(sites):
if j <= i:
continue
species = (sites[i][0], sites[j][0])
dist = np.linalg.norm(np.array(sites[i][1]) - np.array(sites[j][1]))
if bond_dict.get(tuple(sorted(species)), -1) < dist:
continue
# If these are both image atoms
if not bond_images and nsites and i >= nsites and j >= nsites:
continue
bonds.append((sites[i][1], sites[j][1]))
bond_sets[i].add(j)
bond_sets[j].add(i)
bond_colours.append(
(
fresnel.color.linear(colours[sites[i][0]]),
fresnel.color.linear(colours[sites[j][0]]),
)
)
if bonds:
bonds_geom = fresnel.geometry.Cylinder(
scene,
N=len(bonds),
color=bond_colours,
)
bonds_geom.points[:] = bonds
bonds_geom.radius[:] = bond_thickness * np.ones(len(bonds))
bonds_geom.material.primitive_color_mix = 1
bonds_geom.outline_width = 0.05
return bond_sets
def _draw_atoms(
scene: "fresnel.Scene",
positions: List[Tuple[float, float, float]],
species: List[str],
):
"""Add atoms to the scene, coloured by species.
Parameters:
scene: The fresnel scene to add the atoms to.
positions: List of positions of the atoms to draw.
species: List of species of the atoms to draw.
Returns:
atoms: The atoms geometry.
sites: The list of sites (as species, position tuples) that were added to the scene.
"""
import fresnel
colours = get_element_colours()
radii = get_element_radii()
atoms = fresnel.geometry.Sphere(
scene, position=positions, radius=1.0, outline_width=0.1
)
atoms.material = fresnel.material.Material()
atoms.material.solid = 0.0
atoms.material.primitive_color_mix = 1
atoms.material.roughness = 0.8
atoms.material.specular = 0.9
atoms.material.spec_trans = 0.0
atoms.material.metal = 0
atoms.color[:] = fresnel.color.linear([colours[s] for s in species])
atoms.radius[:] = [0.5 * radii[s] for s in species]
[docs]def rerender_scenes_to_axes(scenes, axes, renderer=None):
"""(Re)render the scenes to the axes.
Parameters:
scenes: List of fresnel scenes to render.
axes: List or grid of matplolib axes to render to.
renderer: The renderer to use (default pathtrace).
"""
import fresnel
_axes = axes.flatten() if not isinstance(axes, list) else axes
if renderer is None:
from functools import partial
renderer = partial(fresnel.pathtrace, light_samples=32)
for i, ax in tqdm.tqdm(enumerate(_axes), desc="Rendering scenes"):
old_title = ax.title.get_text()
ax.clear()
ax.axis("off")
ax.set_aspect("equal")
ax.set_title(old_title)
if i < len(scenes):
render = renderer(scenes[i])
ax.imshow(render[:])
[docs]def fit_orthographic_camera(
scene: "fresnel.Scene", scale=1.1, margin=0.05, view="isometric", vectors=None
) -> None:
"""Alternative orthographic camera fitting, based on the equivalent fresnel function.
Parameters:
scene: The fresnel scene to visualize.
scale: The scale of the camera.
margin: The margin around the viewbox.
view: Either 'front', 'isometric', 'a', 'b', 'c', 'vector' or an actual vector.
vectors: A custom dictionary of vectors with keys 'v', 'up' and 'right' that defines
the camera basis.
"""
import fresnel
if not scene.camera:
scene.camera = fresnel.camera.Orthographic()
if view == "front":
view = "c"
_default_vectors = {
"a": dict(
v=np.array([1, 0, 0]), up=np.array([0, 0, 1]), right=np.array([0, 1, 0])
),
"b": dict(
v=np.array([0, 1, 0]), up=np.array([1, 0, 0]), right=np.array([0, 0, 1])
),
"c": dict(
v=np.array([0, 0, 1]), up=np.array([0, 1, 0]), right=np.array([1, 0, 0])
),
"isometric": dict(
v=np.array([1, 1, 1]) / np.sqrt(3),
up=np.array([-1, 2, -1]) / np.sqrt(6),
right=np.array([1, 0, -1]) / np.sqrt(2),
),
}
if view == "vector" or vectors is not None:
view = "vector"
_default_vectors["vector"] = vectors
if "angle" in vectors:
angle = np.deg2rad(vectors["angle"])
vectors["v"] = np.array([np.cos(angle), np.sin(angle), 0])
vectors["right"] = np.array([-np.sin(angle), np.cos(angle), 0])
if "right" not in vectors:
vectors["right"] = np.cross(vectors["up"], vectors["v"])
for k in ("up", "v", "right"):
vectors[k] = np.array(vectors[k]) / np.linalg.norm(vectors[k])
if isinstance(view, list):
_view = np.array(view)
_view = _view / np.linalg.norm(_view)
up = [0, 0, 1]
right = np.cross(_view, up)
up = np.cross(_view, right)
vector_view = {"v": _view, "up": up, "right": right}
_default_vectors["vector"] = vector_view
view = "vector"
# raise error if the scene is empty
if len(scene.geometry) == 0:
raise ValueError("The scene is empty")
# find the center of the scene
extents = scene.get_extents()
# choose an appropriate view automatically
if view == "auto":
xw = extents[1, 0] - extents[0, 0]
yw = extents[1, 1] - extents[0, 1]
zw = extents[1, 2] - extents[0, 2]
if zw < 0.51 * max(xw, yw):
view = "c"
else:
view = "isometric"
v = _default_vectors[view]["v"]
up = _default_vectors[view]["up"]
# make a list of points of the cube surrounding the scene
points = np.array(
[
[extents[0, 0], extents[0, 1], extents[0, 2]],
[extents[0, 0], extents[0, 1], extents[1, 2]],
[extents[0, 0], extents[1, 1], extents[0, 2]],
[extents[0, 0], extents[1, 1], extents[1, 2]],
[extents[1, 0], extents[0, 1], extents[0, 2]],
[extents[1, 0], extents[0, 1], extents[1, 2]],
[extents[1, 0], extents[1, 1], extents[0, 2]],
[extents[1, 0], extents[1, 1], extents[1, 2]],
]
)
# find the center of the box
center = (extents[0, :] + extents[1, :]) / 2
points = points - center
# determine the extent of the scene box in the up direction
up_projection = np.dot(points, up)
height = (1 + margin) * np.max(np.abs(up_projection)) * 2
# determine the extent of the scene box in the view direction
view_projection = np.dot(points, v)
view_distance = np.max(view_projection) * scale
# build the camera
scene.camera.position = center + view_distance * v
scene.camera.look_at = center
scene.camera.up = up
scene.camera.height = height
[docs]def get_image_atoms_near_cell_boundaries(
doc: Crystal, tolerance: float, bond_dict: Dict[Tuple[str, str], float] = None
) -> Tuple[List[str], List[Tuple[float, float, float]]]:
"""Returns a list of atom positions and species that are within the
distance `tolerance` of the non-image atoms.
Parameters:
doc: The crystal object to consider.
tolerance: The distance tolerance in Å.
bond_dict: If provided, use asspecies-specific tolerances.
Returns:
A list of species and a list of positions in Cartesian coordinates.
"""
from itertools import product
from scipy.spatial import KDTree
n_iter = 1
tree = KDTree(doc.positions_abs, compact_nodes=False, copy_data=True)
test_displacements = [
np.sum(np.array(doc.lattice_cart) * np.array(p), axis=-1)
for p in product(range(-1, 2), repeat=3)
if p != (0, 0, 0)
]
positions = []
species = []
added = set()
for j, pos in enumerate(doc.positions_abs):
for d, disp in enumerate(test_displacements):
image_pos = pos + disp
if tree.query_ball_point(image_pos, tolerance * 1.2):
positions.append(image_pos.tolist())
species.append(doc.atom_types[j])
added.add((j, d))
# Can also "self-consistently" check for images close to the just-added images
# Not sure this is worth enabling
if n_iter == 2:
tree = KDTree(positions, compact_nodes=False, copy_data=True)
for j, pos in enumerate(doc.positions_abs):
for d, disp in enumerate(test_displacements):
if (j, d) not in added:
image_pos = pos + disp
if tree.query_ball_point(image_pos, tolerance):
positions.append(image_pos.tolist())
species.append(doc.atom_types[j])
return positions, species
[docs]def fresnel_plot(
structures: Union[Crystal, List[Crystal]],
figsize: Optional[Tuple[float, float]] = None,
fig_cols: Optional[int] = None,
fig_rows: Optional[int] = None,
labels: Union[bool, List[str]] = True,
renderer: Optional[Callable] = None,
camera_patches: Optional[List[Optional[Dict]]] = None,
**fresnel_view_kwargs,
):
"""Visualize a series of structures as a grid of matplotlib plots.
Parameters:
structures: A list or single structure to visualize.
figsize: The size of the figure (defaults to a 3x4 box for each structure).
fig_cols: The number of columns in the figure.
fig_rows: The number of rows in the figure.
labels: If true, label structures by formula, otherwise label by the passed
list of strings.
renderer: A callable to use instead of the default `fresnel.pathtrace`.
A useful alternative could be `fresnel.preview`. Can also be used to
parameterise the fresnel rendering with e.g.,
`partial(fresnel.pathtrace, samples=64)`.
fresnel_view_kwargs: Any additional kwargs will be passed down to
`matador.utils.viz_utils.fresnel_view` to control e.g., camera angles.
The contents of any keywords containing lists with the same length as
the passed structures will be applied to each structure individually.
Returns:
The matplotlib figures, axis grid and a list of fresnel scenes.
"""
import matplotlib.pyplot as plt
if isinstance(structures, Crystal):
structures = [structures]
if fig_cols is None and fig_rows is None:
fig_cols = len(structures)
fig_rows = 1
elif fig_rows is None:
fig_rows = -(len(structures) // -fig_cols)
else:
fig_cols = -(len(structures) // -fig_rows)
if figsize is None:
figsize = (2 * fig_cols, 3 * fig_rows)
fig, axes = plt.subplots(
nrows=fig_rows, ncols=fig_cols, figsize=figsize, squeeze=False
)
scenes = []
list_kwargs = []
for kw in fresnel_view_kwargs:
if isinstance(fresnel_view_kwargs[kw], list) or isinstance(
fresnel_view_kwargs[kw], tuple
):
if len(fresnel_view_kwargs[kw]) == len(structures):
list_kwargs.append(kw)
for ind, s in enumerate(structures):
kwargs = {k: fresnel_view_kwargs[k][ind] for k in list_kwargs}
kwargs.update(
{
k: fresnel_view_kwargs[k]
for k in fresnel_view_kwargs
if k not in list_kwargs
}
)
scenes.append(fresnel_view(s, **kwargs))
if camera_patches is not None:
for ind, scene in enumerate(scenes):
patch = camera_patches[ind] or {}
if "view" in patch:
fit_orthographic_camera(scene, view=patch["view"])
if "height" in patch:
scene.camera.height = scene.camera.height + patch["height"]
if labels is not None:
if isinstance(labels, bool) and labels:
labels = [s.formula_tex for s in structures]
if labels:
for ind, ax in enumerate(axes.flatten()):
try:
ax.set_title(labels[ind])
except IndexError:
pass
rerender_scenes_to_axes(scenes, axes, renderer=renderer)
fig.set_tight_layout(True)
return fig, axes, scenes
[docs]def colour_from_ternary_concentration(
conc: Union[Tuple[float, float], Tuple[float, float, float]],
species: List[str],
alpha: float = 1,
) -> List[float]:
"""Returns a colour for a ternary composition of the given species, mixing
the configured VESTA colours.
Returns:
RGBA array.
"""
if len(conc) == 1:
return get_element_colours()[species[0]] + [alpha]
elif len(conc) == 2:
x, y = conc
z = 1 - x - y
else:
x, y, z = conc
colours = [np.asarray(get_element_colours()[s]) for s in species + ["H"]]
return ((x * colours[0] + y * colours[1] + z * colours[2])).tolist() + [alpha]