Source code for matador.utils.cursor_utils

# coding: utf-8
# Distributed under the terms of the MIT License.

""" This submodule defines some useful generic cursor methods for
displaying, extracting and refining results from a Mongo cursor/list.

"""

from time import strftime

import numpy as np
import pymongo as pm
from matador.utils.chem_utils import (
    get_formula_from_stoich,
    get_root_source,
    get_stoich_from_formula,
    get_subscripted_formula,
)
from matador import __version__

EPS = 1e-12


[docs]def recursive_get(data, keys, _top=True): """Recursively slice a nested dictionary by a list of keys. Parameters: data (dict): nested dictionary to get from. keys (list): list of keys/indices to delve into. Raises: KeyError: if any in chain keys are missing, IndexError: if any element of a sublist is missing. """ if not isinstance(keys, (list, tuple)): return data[keys] if isinstance(keys, (list, tuple)) and len(keys) == 1: return data[keys[0]] try: return recursive_get(data[keys[0]], keys[1:], _top=False) except (KeyError, IndexError) as exc: if _top: raise type(exc)("Recursive keys {} missing".format(keys)) raise exc
[docs]def recursive_set(data, keys, value): """Recursively slice a nested dictionary by a list of keys and set the value. Parameters: data (dict): nested dictionary to get from. keys (list): list of keys/indices to delve into. value: value to store under key. Raises: KeyError: if any intermediate keys are missing. """ if isinstance(keys, (list, tuple)): if len(keys) == 1: data[keys[0]] = value else: return recursive_set(data[keys[0]], keys[1:], value) else: data[keys] = value
[docs]def display_results( cursor, energy_key="enthalpy_per_atom", summary=False, args=None, argstr=None, additions=None, deletions=None, sort=True, hull=False, markdown=False, latex=False, colour=True, return_str=False, use_source=True, details=False, per_atom=False, eform=False, source=False, **kwargs, ): """Print query results in a table, with many options for customisability. TODO: this function has gotten out of control and should be rewritten. Parameters: cursor (list of dict or pm.cursor.Cursor): list of matador documents Keyword arguments: summary (bool): print a summary per stoichiometry, that uses the lowest energy phase (requires `sort=True`). argstr (str): string to store matador initialisation command eform (bool): prepend energy key with "formation_". sort (bool): sort input cursor by the value of energy key. return_str (bool): return string instead of printing. details (bool): print extra details as an extra line per structure. per_atom (bool): print quantities per atom, rather than per fu. source (bool): print all source files associated with the structure. use_source (bool): use the source instead of the text id when displaying a structure. hull (bool): whether or not to print hull-style (True) or query-style energy_key (str or list): key (or recursive key) to print as energy (per atom) markdown (bool): whether or not to write a markdown file containing results latex (bool): whether or not to create a LaTeX table colour (bool): colour on-hull structures additions (list): list of string text_ids to be coloured green with a (+) or, list of indices referring to those structures in the cursor. deletions (list): list of string text_ids to be coloured red with a (-) or, list of indices referring to those structures in the cursor. kwargs (dict): any extra args are ignored. Returns: str or None: markdown or latex string, if markdown or latex is True, else None. """ add_index_mode = False del_index_mode = False if additions: if isinstance(additions[0], int): add_index_mode = True if deletions: if isinstance(deletions[0], int): del_index_mode = True if add_index_mode: assert max(additions) <= len(cursor) and min(additions) >= 0 if del_index_mode: assert max(deletions) <= len(cursor) and min(deletions) >= 0 if markdown and latex: latex = False # lists in which to accumulate the table struct_strings = [] detail_strings = [] detail_substrings = [] source_strings = [] formulae = [] # tracking the last formula last_formula = "" if not cursor: raise RuntimeError("No structures found in cursor.") if markdown: markdown_string = "Date: {} \n".format(strftime("%H:%M %d/%m/%Y")) if argstr is not None: markdown_string += "Command: matador {} \n".format(" ".join(argstr)) markdown_string += "Version: {} \n\n".format(__version__) if latex: latex_string = ( "\\begin{tabular}{l r r c l l}\n" "\\rowcolor{gray!20}\n" "formula & " "\\thead{$\\Delta E$\\\\(meV/atom)} & " "spacegroup & " "provenance & " "description \\\\ \n\n" ) latex_struct_strings = [] latex_sub_style = r"\mathrm" else: latex_sub_style = "" header_string, units_string = _construct_header_string( markdown, use_source, per_atom, eform, hull, summary, energy_key ) if summary and isinstance(cursor, pm.cursor.Cursor): raise RuntimeError("Unable to provide summary when displaying cursor object.") # ensure cursor is sorted by enthalpy if sort and isinstance(cursor, pm.cursor.Cursor): print("Unable to check sorting of cursor, assuming it is already sorted.") elif sort: sorted_inds = sorted( enumerate(cursor), key=lambda element: recursive_get(element[1], energy_key) ) cursor = [ind[1] for ind in sorted_inds] sorted_inds = [ind[0] for ind in sorted_inds] if additions is not None and add_index_mode: additions = [sorted_inds.index(ind) for ind in additions] if deletions is not None and del_index_mode: deletions = [sorted_inds.index(ind) for ind in deletions] # loop over structures and create pretty output for ind, doc in enumerate(cursor): # use the formula to see if we need to update gs_enthalpy for this formula formula_substring = get_formula_from_stoich( doc["stoichiometry"], tex=latex, latex_sub_style=latex_sub_style ) if not latex: formula_substring = get_subscripted_formula(formula_substring) if "encapsulated" in doc: formula_substring += "+CNT" if last_formula != formula_substring: gs_enthalpy = 0.0 formulae.append(formula_substring) struct_strings.append( _construct_structure_string( doc, ind, formula_substring, gs_enthalpy, use_source, colour, hull, additions, deletions, add_index_mode, del_index_mode, energy_key, per_atom, eform, markdown, latex, ) ) if latex: latex_struct_strings.append( "{:^30} {:^10} & ".format( formula_substring, "$\\star$" if doc.get("hull_distance") == 0 else "", ) ) latex_struct_strings[-1] += ( "{:^20.0f} & ".format(doc.get("hull_distance") * 1000) if doc.get("hull_distance", 0) > 0 else "{:^20} &".format("-") ) latex_struct_strings[-1] += "{:^20} & ".format( doc.get("space_group", "xxx") ) prov = get_guess_doc_provenance(doc["source"], doc.get("icsd")) if doc.get("icsd"): prov += " {}".format(doc["icsd"]) latex_struct_strings[-1] += "{:^30} & ".format(prov) latex_struct_strings[-1] += "{:^30} \\\\".format("") if last_formula != formula_substring: if per_atom: gs_enthalpy = recursive_get(doc, energy_key) else: gs_enthalpy = ( recursive_get(doc, energy_key) * doc["num_atoms"] / doc["num_fu"] ) last_formula = formula_substring if details: detail_string, detail_substring = _construct_detail_strings( doc, padding_length=len(header_string), source=source ) detail_strings.append(detail_string) detail_substrings.append(detail_substring) if source: source_strings.append(_construct_source_string(doc["source"])) total_string = "" total_string += len(header_string) * "─" + "\n" total_string += header_string + "\n" total_string += units_string + "\n" total_string += len(header_string) * "─" + "\n" if markdown: markdown_string += len(header_string) * "-" + "\n" markdown_string += header_string + "\n" markdown_string += units_string + "\n" markdown_string += len(header_string) * "-" + "\n" summary_inds = [] # filter for lowest energy phase per stoichiometry if summary: current_formula = "" formula_list = {} for ind, substring in enumerate(formulae): if substring != current_formula and substring not in formula_list: current_formula = substring formula_list[substring] = 0 summary_inds.append(ind) formula_list[substring] += 1 else: summary_inds = range(len(struct_strings)) # construct final string containing table if markdown: markdown_string += "\n".join(struct_strings[ind] for ind in summary_inds) elif latex: latex_string += "\n".join(latex_struct_strings[ind] for ind in summary_inds) else: for ind in summary_inds: total_string += struct_strings[ind] + "\n" if details: total_string += detail_strings[ind] + "\n" total_string += detail_substrings[ind] + "\n" if source: total_string += source_strings[ind] + "\n" if details or source: total_string += len(header_string) * "─" + "\n" if markdown: markdown_string += "```" return markdown_string if latex: latex_string += "\\end{tabular}" return latex_string if return_str: return total_string print(total_string)
[docs]def loading_bar(iterable, width=80, verbosity=0): """Checks if tqdm exists and makes a loading bar, otherwise just returns initial iterable. Parameters: iterable (iterable): the thing to be iterated over. Keyword arguments: width (int): maximum number of columns to use on screen. Returns: iterable: the decorated iterator. """ try: import tqdm if verbosity < 1: raise RuntimeError return tqdm.tqdm(iterable, ncols=width) except (ImportError, RuntimeError): return iterable
[docs]def set_cursor_from_array(cursor, array, key): """Updates the key-value pair for documents in internal cursor from a numpy array. """ if len(array) != len(cursor): raise RuntimeError( "Trying to fit array of shape {} into cursor of length {}".format( np.shape(array), len(cursor) ) ) for ind, _ in enumerate(cursor): recursive_set(cursor[ind], key, array[ind])
[docs]def get_array_from_cursor(cursor, key, pad_missing=False): """Returns a numpy array of the values of a key in a cursor, where the key can be defined as list of keys to use with `recursive_get`. Parameters: cursor (list): list of matador dictionaries. key (str or list): the key to extract, or list of keys/subkeys/indices to extract with recursive_get. Keyword arguments: pad_missing (bool): whether to fill array with NaN's where data is missing. Raises: KeyError: if any document is missing that key, unless pad_missing is True. Returns: np.ndarray: numpy array containing results, padded with np.nan if key is missing and pad_missing is True. """ array = [] for ind, doc in enumerate(cursor): try: if isinstance(key, (tuple, list)): array.append(recursive_get(doc, key)) else: array.append(doc[key]) except KeyError as exc: print( "{} missing in entry {}, with source {}".format( key, ind, doc.get("source") ) ) if pad_missing: array.append(np.NaN) else: raise exc array = np.asarray(array) return array
[docs]def get_guess_doc_provenance(sources, icsd=None): """Returns a guess at the provenance of a structure from its source list. Return possiblities are 'ICSD', 'SWAP', 'OQMD' or 'AIRSS', 'MP' or 'PF'. """ prov = "AIRSS" if isinstance(sources, dict): sources = sources["source"] elif isinstance(sources, str): sources = [sources] for fname in sources: fname_with_folder = fname fname = fname.split("/")[-1].lower() if ( fname.endswith(".castep") or fname.endswith(".res") or fname.endswith(".history") or fname.endswith(".phonon") or fname.count(".") == 0 ): if any(substr in fname for substr in ["collcode", "colcode", "collo"]): if fname.count("-") == 2 + fname.count("oqmd") or "swap" in fname: prov = "SWAPS" else: prov = "ICSD" elif "swap" in fname_with_folder: prov = "SWAPS" elif "-ga-" in fname: prov = "GA" elif icsd is not None: prov = "ICSD" elif "oqmd" in fname: prov = "OQMD" elif "-icsd" in fname: prov = "ICSD" elif "pf-" in fname and prov is None: prov = "PF" elif any(s in fname for s in ["mp-", "mp_"]) and prov != "PF": prov = "MP" elif "-sm-" in fname: prov = "SM" elif "-doi-" in fname: prov = "DOI" elif "-config_enum" in fname: prov = "ENUM" return prov
[docs]def filter_unique_structures(cursor, quiet=False, **kwargs): """Wrapper for `matador.fingerprints.similarity.get_uniq_cursor` that displays the results and returns the filtered cursor. """ from matador.fingerprints.similarity import get_uniq_cursor uniq_inds, dupe_dict, _, _ = get_uniq_cursor(cursor, **kwargs) filtered_cursor = [cursor[ind] for ind in uniq_inds] if not quiet: display_cursor = [] additions = [] deletions = [] for key in dupe_dict: additions.append(len(display_cursor)) display_cursor.append(cursor[key]) if dupe_dict[key]: for _, jnd in enumerate(dupe_dict[key]): deletions.append(len(display_cursor)) display_cursor.append(cursor[jnd]) if not display_cursor: display_cursor = filtered_cursor display_results( display_cursor, additions=additions, deletions=deletions, sort=True, use_source=True, **kwargs, ) print("Filtered {} down to {}".format(len(cursor), len(uniq_inds))) return filtered_cursor
[docs]def filter_cursor(cursor, key, vals, verbosity=0): """Returns a cursor obeying the filter on the given key. Any documents that are missing the key will not be returned. Any documents with values that cannot be compared to floats will also not be returned. Parameters: cursor (list): list of dictionaries to filter. key (str): key to filter. vals (list): either 1 value to 2 values to use as a range. The values are interpreted as floats for comparison. Returns: list: list of dictionaries that pass the filter. """ filtered_cursor = list() orig_cursor_len = len(cursor) if not isinstance(vals, list): vals = [vals] if len(vals) == 2: min_val = float(vals[0]) max_val = float(vals[1]) if verbosity > 0: print("Filtering {} <= {} < {}".format(min_val, key, max_val)) for doc in cursor: try: if doc[key] < max_val and doc[key] >= min_val: filtered_cursor.append(doc) except (TypeError, ValueError, KeyError): pass else: min_val = float(vals[0]) if verbosity > 0: print("Filtering {} >= {}".format(key, min_val)) for doc in cursor: try: if doc[key] >= min_val: filtered_cursor.append(doc) except (TypeError, ValueError, KeyError): pass if verbosity > 0: print(orig_cursor_len, "filtered to", len(filtered_cursor), "documents.") return filtered_cursor
[docs]def filter_cursor_by_chempots(species, cursor): """For the desired chemical potentials, remove any incompatible structures from cursor. Parameters: species (list): list of chemical potential formulae. cursor (list): list of matador documents to filter. Returns: list: the filtered cursor. """ from matador.utils.chem_utils import get_number_of_chempots # filter out structures with any elements with missing chem pots chempot_stoichiometries = [] for label in species: chempot_stoichiometries.append(get_stoich_from_formula(label)) inds_to_remove = set() for ind, doc in enumerate(cursor): try: cursor[ind]["num_chempots"] = get_number_of_chempots( doc, chempot_stoichiometries ) except RuntimeError: inds_to_remove.add(ind) else: cursor[ind]["concentration"] = ( cursor[ind]["num_chempots"][:-1] / np.sum(cursor[ind]["num_chempots"]) ).tolist() for idx, conc in enumerate(cursor[ind]["concentration"]): if conc < 0 + EPS: cursor[ind]["concentration"][idx] = 0.0 elif conc > 1 - EPS: cursor[ind]["concentration"][idx] = 1.0 return [doc for ind, doc in enumerate(cursor) if ind not in inds_to_remove]
def _construct_structure_string( doc, ind, formula_substring, gs_enthalpy, use_source, colour, hull, additions, deletions, add_index_mode, del_index_mode, energy_key, per_atom, eform, markdown, latex, ): """Construct the pretty output for an individual structure. Options passed from `matador.utils.cursor_utils.display_results.` Returns: str: the pretty output. """ # start with two spaces, replaced by the prefix from hull/add/del this_struct_string = " " prefix = "" suffix = "" # apply appropriate prefices and suffices to structure if hull and np.abs(doc.get("hull_distance")) <= 0.0 + 1e-12: if colour: prefix = "\033[92m" suffix = "\033[0m" this_struct_string = "* " if additions is not None: if (add_index_mode and ind in additions) or doc.get( "text_id", "_" ) in additions: this_struct_string = "+ " if colour: prefix = "\033[92m" suffix = "\033[0m" if deletions is not None: if (del_index_mode and ind in deletions) or doc.get( "text_id", "_" ) in deletions: this_struct_string = "- " if colour: prefix = "\033[91m" suffix = "\033[0m" # display the canonical name for the structure if use_source: src = get_root_source(doc["source"]) max_len = 34 this_struct_string += "{:<36.{max_len}}".format( src if len(src) <= max_len else src[: max_len - 4] + "[..]", max_len=max_len ) else: this_struct_string += "{:^24.22}".format( " ".join(doc.get("text_id", ["xxx", "yyy"])) ) # again, if we're not outputting to markdown, then flag warnings in the quality column try: if doc.get("prototype"): this_struct_string += "{:^5}".format("*p*") elif doc.get("quality", 5) == 0: this_struct_string += "{:^5}".format("!!!") else: this_struct_string += "{:^5}".format((5 - doc.get("quality", 5)) * "?") except KeyError: this_struct_string += "{:^5}".format(" ") # loop over header names and print the appropriate values if "pressure" in doc and doc["pressure"] != "xxx": this_struct_string += "{: >9.2f} ".format(doc["pressure"]) else: this_struct_string += "{:^9} ".format("xxx") try: if per_atom and "cell_volume" in doc and "num_atoms" in doc: this_struct_string += "{:>12.1f} ".format( doc["cell_volume"] / doc["num_atoms"] ) elif "cell_volume" in doc and "num_fu" in doc: this_struct_string += "{:>12.1f} ".format( doc["cell_volume"] / doc["num_fu"] ) else: this_struct_string += "{:^12} ".format("xxx") except Exception: this_struct_string += "{:^10} ".format("xxx") try: if hull and eform: this_struct_string += "{:>12.3f} ".format( doc["formation_" + energy_key] ) elif hull: this_struct_string += "{:>12.1f} ".format(1000 * doc["hull_distance"]) elif per_atom: this_struct_string += "{:>16.4f} ".format( recursive_get(doc, energy_key) - gs_enthalpy ) else: this_struct_string += "{:>16.4f} ".format( recursive_get(doc, energy_key) * doc["num_atoms"] / doc["num_fu"] - gs_enthalpy ) except KeyError: this_struct_string += "{:^18}".format("xxx") if latex: from matador.utils.cell_utils import get_space_group_label_latex this_struct_string += " {:^13} ".format( get_space_group_label_latex(doc.get("space_group", "xxx")) ) else: this_struct_string += " {:^13} ".format(doc.get("space_group", "xxx")) # now we add the formula column this_struct_string += " {:^13} ".format(formula_substring) if "num_fu" in doc: this_struct_string += " {:^6} ".format(int(doc["num_fu"])) else: this_struct_string += " {:^6} ".format("xxx") if "source" in doc: prov = get_guess_doc_provenance(doc["source"], doc.get("icsd")) this_struct_string += "{:^8}".format(prov) else: this_struct_string += "{:^8}".format("xxx") this_struct_string = prefix + this_struct_string + suffix return this_struct_string def _construct_source_string(sources): """From a list of sources, return a fancy string output displaying them as a list. """ num_sources = len(sources) if num_sources == 1: this_source_string = 11 * " " + "└──────────────────" else: this_source_string = 11 * " " + "└───────────────┬──" for num, _file in enumerate(sources): if num_sources == 1: this_source_string += "" elif num == num_sources - 1: this_source_string += (len("└────────────── ") + 11) * " " + "└──" elif num != 0: this_source_string += (len("└────────────── ") + 11) * " " + "├──" this_source_string += " " + _file if num != num_sources - 1: this_source_string += "\n" return this_source_string def _construct_detail_strings(doc, padding_length=0, source=False): """From a document, return a fancy string output displaying the desired field names and details of the structure. Parameters: doc (dict): matador document to print. Keyword arguments: source (bool): whether to allow adjust output so it can be interleaved with the source pretty output. padding_length (int): how much to pad the detail string output. Returns: (str, str): containing pretty output over two lines. """ detail_string = 11 * " " + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌ " if source: detail_substring = 11 * " " + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌ " else: detail_substring = 11 * " " + "└╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌ " if "spin_polarized" in doc: if doc["spin_polarized"]: detail_string += "S-" if "sedc_scheme" in doc: detail_string += doc["sedc_scheme"].upper() + "+" if "xc_functional" in doc: detail_string += doc["xc_functional"] else: detail_string += "xc-functional unknown!" if "cut_off_energy" in doc: detail_string += ", {:4.2f} eV".format(doc["cut_off_energy"]) else: detail_string += "cutoff unknown" if "external_pressure" in doc: detail_string += ", {:4.2f} GPa".format( sum(doc["external_pressure"][i][0] for i in range(3)) / 3 ) if "kpoints_mp_spacing" in doc: detail_string += ", ~{:0.3f} 1/A".format(doc["kpoints_mp_spacing"]) if "geom_force_tol" in doc: detail_string += ", {:.2f} eV/A, ".format(doc["geom_force_tol"]) if "species_pot" in doc: for species in doc["species_pot"]: detail_substring += "{}: {}, ".format(species, doc["species_pot"][species]) if "icsd" in doc: detail_substring += "ICSD-CollCode {}, ".format(doc["icsd"]) if "tags" in doc: if isinstance(doc["tags"], list): detail_substring += ", ".join(doc["tags"]) if "user" in doc: detail_substring += doc["user"] if "encapsulated" in doc and all( key in doc for key in ["cnt_chiral", "cnt_radius", "cnt_length"] ): detail_string += ( ", (n,m)=(" + str(doc["cnt_chiral"][0]) + "," + str(doc["cnt_chiral"][1]) + ")" ) detail_string += ", r={:4.2f} A".format(doc["cnt_radius"]) detail_string += ", z={:4.2f} A".format(doc["cnt_length"]) detail_string += " " + (padding_length - len(detail_string) - 1) * "╌" detail_substring += " " + (padding_length - len(detail_substring) - 1) * "╌" return detail_string, detail_substring def _construct_header_string( markdown, use_source, per_atom, eform, hull, summary, energy_key ): """Construct the header of the table, from the passed options. For arguments, see docstring for `matador.utils.cursor_utils.display_results`. Returns: (str, str): the header and units string. """ header_string = "" units_string = "" if not markdown: if use_source: header_string += "{:^38}".format("Source") units_string += "{:^38}".format("") else: header_string += "{:^26}".format("ID") units_string += "{:^26}".format("") header_string += "{:^5}".format("!?!") units_string += "{:^5}".format("") else: header_string += "```\n" header_string += "{:^43}".format("Root") units_string += "{:^43}".format("") header_string += "{:^10}".format("Pressure") units_string += "{:^10}".format("(GPa)") header_string += "{:^14}".format("Cell volume") if per_atom: units_string += "{:^14}".format("(ų/atom)") else: units_string += "{:^14}".format("(ų/fu)") if eform: header_string += "{:^18}".format("Formation energy") units_string += "{:^18}".format("(eV/atom)") elif hull: header_string += "{:^18}".format("Hull dist.") units_string += "{:^18}".format("(meV/atom)") elif per_atom: header_string += "{:^18}".format( " ".join(energy_key.replace("_per_atom", "").split("_")).title() ) units_string += "{:^18}".format("(eV/atom)") else: header_string += "{:^18}".format( " ".join(energy_key.replace("_per_atom", "").split("_")).title() ) units_string += "{:^18}".format("(eV/fu)") header_string += "{:^15}".format("Space group") header_string += " {:^13} ".format("Formula") header_string += "{:^8}".format("# fu") header_string += "{:^8}".format("Prov.") if summary: header_string += "{:^12}".format("Occurrences") return header_string, units_string
[docs]def index_cursors_by_structure(cursors, structure_labeller=get_root_source): """For a dictionary of lists of structures, reindex the list by the root source of each structure. Args: cursors: A dictionary of input cursors. Keys will be used as labels in the output dictionary. structure_labeller: A function called on each structure, the result of which will be that structure's key in the output dictionary. Returns: A dictionary with one key per structure, with subkeys corresponding to the elements of the initial cursors, under which structures are stored from each cursor. """ from collections import defaultdict structure_map = defaultdict(dict) for label in cursors: for s in cursors[label]: structure_map[structure_labeller(s)][label] = s return structure_map
def _compare_field(bench, other, field): """For a given field, compute absolute and relative differences between the benchmark and other structure. Args: bench: The structure to compare against. other: The structure to compare. field: The field to compare (will be accessed recursively if iterable, e.g., `(lattice_abc, 0, 0)`). Returns: A dictionary summarising the differences. """ if isinstance(field, str): field = [field] field_label = "_".join(str(_) for _ in field) try: benchmark_field = recursive_get(bench, field) except KeyError: raise KeyError(f"Benchmark structure is missing field {field}") try: other_field = recursive_get(other, field) except KeyError: raise KeyError(f"Trial structure is missing field {field}") # Normalize cell volume to per-atom so different settings can be compared if "cell_volume" in field: benchmark_field /= bench["num_atoms"] other_field /= other["num_atoms"] summary = {f"abs_{field_label}": benchmark_field - other_field} if abs(benchmark_field) > 1e-10: summary[f"rel_{field_label}"] = summary[f"abs_{field_label}"] / benchmark_field summary[field_label] = other_field return summary
[docs]def compare_structures(structures, order, fields=None): """Compare structures across various specified or default fields. Intended use is to compare crystal structures/energies of the "same" crystal when relaxed with different parameters. Args: structures: A dictionary containing the structures to compare. Keys will be used to label the output. order: The order of the input keys to use, the first of which will be treated as the 'benchmark' structure. fields: A list of fields to compare. If None, defaults to comparing the lattice parameters, cell volumes and stabilities (hull distance, formation energy). Returns: A dictionary summarising the differences. """ root_sources = set(get_root_source(structures[s]) for s in (structures)) if len(root_sources) != 1: raise RuntimeError( f"Not comparing structures with multiple root sources: {root_sources}" ) if fields is None: fields = [ "cell_volume", "formation_enthalpy_per_atom", "hull_distance", ("lattice_abc", 0, 0), ("lattice_abc", 0, 1), ("lattice_abc", 0, 2), ("lattice_abc", 1, 0), ("lattice_abc", 1, 1), ("lattice_abc", 1, 2), ] if order[0] not in structures: raise RuntimeError( f"Benchmark parameter set {order[0]} not found in entry {root_sources}" ) benchmark = structures[order[0]] results = {} for label in order[1:]: summary = {} if label in structures: for field in fields: summary.update(_compare_field(benchmark, structures[label], field)) results[label] = summary return results
[docs]def compare_structure_cursor(cursor, order, fields=None): """Compare the "same" structures across different accuracies. Args: cursor: A dict of dicts keyed by structure ID storing data for each structure at different accuracies. order: An ordered list of the subkeys for each structure; the first will be used as the benchmark. fields: A list of fields to compare. If None, defaults to comparing the lattice parameters, cell volumes and stabilities (hull distance, formation energy). Returns: A dictionary of dictionaries summarising the differences. """ import warnings structure_comparator = {} for entry in cursor: structures = cursor[entry] if len(structures) > 1: if order[0] not in structures: warnings.warn( f"Benchmark parameter set {order[0]} not found for entry {entry}" ) continue try: structure_comparator[entry] = compare_structures( structures, order=order, fields=fields ) except RuntimeError as exc: structure_comparator[entry] = exc return structure_comparator