# coding: utf-8
# Distributed under the terms of the MIT License.
""" This submodule implements some useful auxiliary routines for use in
the other plotting functions, and some generic routines for plotting
simple (x, y) line data (e.g. pair distribution functions), or lists of
(x, y) data against some third parameter (e.g. powder x-ray spectrum vs
voltage).
"""
SAVE_EXTS = ["pdf", "png", "svg"]
MATADOR_STYLE = "/".join(__file__.split("/")[:-1]) + "/../config/matador.mplstyle"
[docs]def set_style(style=None):
"""Set the matplotlib style for all future plots, manually. This
will conflict with the context manager used by the `plotting_function`
wrapper.
"""
import matplotlib.pyplot as plt
if style is None or style == "matador":
style = MATADOR_STYLE
if not isinstance(style, list):
style = [style]
# apply multiple compound styles, if present
for styles in style:
plt.style.use(styles)
[docs]def plotting_function(function):
"""Wrapper for plotting functions to safely fail on X-forwarding
errors and handle the plot style context manager.
"""
from functools import wraps
from matador.utils.print_utils import print_warning, print_failure
from matador.config import load_custom_settings
@wraps(function)
def wrapped_plot_function(*args, **kwargs):
"""Wrap and return the plotting function."""
saving = False
result = None
# if we're going to be saving a figure, switch to Agg to avoid X-forwarding
try:
for arg in args:
if arg.savefig:
import matplotlib
# don't warn as backend might have been set externally by e.g. Jupyter
matplotlib.use("Agg", force=False)
saving = True
break
except AttributeError:
pass
if not saving:
if any(kwargs.get(ext) for ext in SAVE_EXTS):
import matplotlib
matplotlib.use("Agg", force=False)
saving = True
settings = load_custom_settings(
kwargs.get("config_fname"), quiet=True, no_quickstart=True
)
try:
style = settings.get("plotting", {}).get("default_style")
if kwargs.get("style"):
style = kwargs["style"]
if style is not None and not isinstance(style, list):
style = [style]
if style is None:
style = ["matador"]
if "matador" in style:
for ind, styles in enumerate(style):
if styles == "matador":
style[ind] = MATADOR_STYLE
# now actually call the function
set_style(style)
result = function(*args, **kwargs)
except Exception as exc:
if "TclError" not in type(exc).__name__:
raise exc
print_failure("Caught exception: {}".format(type(exc).__name__))
print_warning("Error message was: {}".format(exc))
print_warning("This is probably an X-forwarding error")
print_failure("Skipping plot...")
return result
return wrapped_plot_function
[docs]def get_linear_cmap(colours, num_colours=100, list_only=False):
"""Create a linear colormap from a list of colours.
Parameters:
colours (:obj:`list` of :obj:`str`): list of fractional RGB/hex
values of colours
Keyword arguments:
num_colours (int): number of colours in resulting cmap
list_only (bool): return only a list of colours
Returns:
:obj:`matplotlib.colors.LinearSegmentedColormap` or :obj:`list`:
returns list of colours if `list_only` is True, otherwise
:obj:`matplotlib.colors.LinearSegmentedColormap`.
"""
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, to_rgb
colours = [to_rgb(colour) for colour in colours]
uniq_colours = []
_colours = [tuple(colour) for colour in colours]
for colour in _colours:
if colour not in uniq_colours:
uniq_colours.append(colour)
_colours = uniq_colours
linear_cmap = []
repeat = int(num_colours / len(_colours))
for ind, colour in enumerate(_colours):
if ind == len(_colours) - 1:
break
diff = np.asarray(_colours[ind + 1]) - np.asarray(_colours[ind])
diff_norm = diff / repeat
for i in range(repeat):
linear_cmap.append(np.asarray(colour) + i * diff_norm)
if list_only:
return linear_cmap
return LinearSegmentedColormap.from_list("linear_cmap", linear_cmap, N=num_colours)
[docs]class XYvsZPlot:
"""This class wraps plotting (x, y) lines against a third
variable.
"""
def __init__(self, xys, zs, y_scale=1.0, **kwargs):
"""Construct plot from data.
Parameters:
xys (:obj:`list` of :obj:`list` or numpy.ndarray): list or
array of data to be plotted. For N lines of M samples,
this can be provided as an (N, M, 2) or (2, M, N) array,
or corresponding list/sublist format.
zs (:obj:`list`): third parameter to plot lines against. The
y-values are rescaled relative to the maximum across all
lines so that no lines overlap (this can be overridden
using offset_factor keyword).
Keyword arguments:
y_scale (float): controls the scale factor between the
arbitrary y-scale and the z-scale.
"""
import numpy as np
self.plot_kwargs = kwargs
_xys = np.asarray(xys)
shape = np.shape(_xys)
if shape[0] != 2 and shape[-1] != 2:
raise RuntimeError(
"Data of shape {} is not compatible with XYvsZPlot.".format(shape)
)
if shape[0] == 2:
_xys = _xys.T
self._xs = _xys[:, :, 0]
self._ys = _xys[:, :, 1]
self._zs = np.asarray(zs).flatten()
if len(self._zs) != np.shape(self._xs)[0]:
raise RuntimeError("x/y and z data do not match in shape!")
self._y_scale = y_scale
self.plot(**self.plot_kwargs)
@property
def y_scale(self):
return self._y_scale
@y_scale.setter
def y_scale(self, value):
"""Reset the y_scale and replot."""
self._y_scale = value
self.plot(**self.plot_kwargs)
[docs] def get_plot(self):
return self.fig, self.ax
[docs] @plotting_function
def plot(self, *args, **kwargs):
"""Actually plot the data and optionally save it."""
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
for i in range(len(self._xs)):
ax.plot(self._xs[i, :], self._y_scale * self._ys[i, :] + self._zs[i])
self.fig = fig
self.ax = ax
plt.show()