diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a4c8dea1c..e5a4d12ab 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,11 @@ The cache will be cleared on setting ``cell_vectors`` so direct changes to the ``_cell_vectors`` attribute may cause desynchronisation. + - ``euphonic.cli.utils`` has been broken up into submodules. All the + appropriate functions are re-exported to ``__all__`` so this + should not break API in practice, but e.g. Quantity can no longer + be imported from ``euphonic.cli.utils``. + - Features - Spectrum1DCollection and Spectrum2DCollection can be indexed with diff --git a/euphonic/cli/intensity_map.py b/euphonic/cli/intensity_map.py index f32c2cf23..cab12dfb4 100644 --- a/euphonic/cli/intensity_map.py +++ b/euphonic/cli/intensity_map.py @@ -2,7 +2,6 @@ import matplotlib.style -import euphonic from euphonic import ForceConstants, QpointPhononModes, ureg import euphonic.plot from euphonic.styles import base_style diff --git a/euphonic/cli/utils/__init__.py b/euphonic/cli/utils/__init__.py new file mode 100644 index 000000000..f5aa5172a --- /dev/null +++ b/euphonic/cli/utils/__init__.py @@ -0,0 +1,41 @@ +from ._band_structure import ( + _bands_from_force_constants, + _convert_labels_to_fractions, + _get_break_points, + _get_tick_labels, + _insert_gamma, +) +from ._cli_parser import _get_cli_parser +from ._dw import _get_debye_waller +from ._grids import _get_energy_bins, _get_q_distance, _grid_spec_from_args +from ._kwargs import ( + _brille_calc_modes_kwargs, + _calc_modes_kwargs, + _plot_label_kwargs, +) +from ._loaders import load_data_from_file +from ._pdos import _arrange_pdos_groups, _get_pdos_weighting +from ._plotting import _compose_style, _get_title, matplotlib_save_or_show + +__all__ = [ + '_arrange_pdos_groups', + '_bands_from_force_constants', + '_brille_calc_modes_kwargs', + '_calc_modes_kwargs', + '_compose_style', + '_convert_labels_to_fractions', + '_get_break_points', + '_get_cli_parser', + '_get_debye_waller', + '_get_energy_bins', + '_get_pdos_weighting', + '_get_q_distance', + '_get_tick_labels', + '_get_title', + '_grid_spec_from_args', + '_insert_gamma', + '_plot_label_kwargs', + 'load_data_from_file', + 'matplotlib_save_or_show', +] + diff --git a/euphonic/cli/utils/_band_structure.py b/euphonic/cli/utils/_band_structure.py new file mode 100644 index 000000000..7d94f0c2b --- /dev/null +++ b/euphonic/cli/utils/_band_structure.py @@ -0,0 +1,164 @@ +"""Band structure utilities""" + +from collections.abc import Iterable, Sequence +from fractions import Fraction +from typing import Any, TypedDict + +import numpy as np +import seekpath + +from euphonic import ( + ForceConstants, + QpointFrequencies, + QpointPhononModes, + Quantity, +) +from euphonic.util import ( + spglib_new_errors, +) + + +def _get_tick_labels(bandpath: dict) -> list[tuple[int, str]]: + """Convert x-axis labels from seekpath format to euphonic format + + i.e.:: + + ['L', '', '', 'X', '', 'GAMMA'] --> + + [(0, 'L'), (3, 'X'), (5, '$\\Gamma$')] + """ + label_indices = np.where(bandpath['explicit_kpoints_labels'])[0] + labels = (r'$\Gamma$' if label == 'GAMMA' else label + for label in + np.take(bandpath['explicit_kpoints_labels'], label_indices)) + return list(zip(label_indices, labels, strict=True)) + + +def _get_break_points(bandpath: dict) -> list[int]: + """Get information about band path labels and break points + + Parameters + ---------- + bandpath + Bandpath dictionary from Seekpath + + Returns + ------- + break_points + Indices at which the spectrum should be split into subplots + """ + # Find break points between continuous spectra: wherever there are two + # adjacent labels + labels = np.array(bandpath['explicit_kpoints_labels']) + + special_point_bools = np.fromiter( + map(bool, labels), dtype=bool) + + # [T F F T T F T] -> [F F T T F T] AND [T F F T T F] = [F F F T F F] -> 3, + adjacent_non_empty_labels = ( + special_point_bools[:-1] & special_point_bools[1:] + ) + + adjacent_different_labels = (labels[:-1] != labels[1:]) + + break_points = np.where( + adjacent_non_empty_labels & adjacent_different_labels, + )[0] + return (break_points + 1).tolist() + + +def _insert_gamma(bandpath: dict) -> None: + """Modify seekpath.get_explicit_k_path() results; duplicate Gamma + + This enables LO-TO splitting to be included + """ + import numpy as np + gamma_indices = np.where( + np.array(bandpath['explicit_kpoints_labels'][1:-1]) == 'GAMMA', + )[0] + 1 + + rel_kpts = bandpath['explicit_kpoints_rel'].tolist() + labels = bandpath['explicit_kpoints_labels'] + for i in reversed(gamma_indices.tolist()): + rel_kpts.insert(i, [0., 0., 0.]) + labels.insert(i, 'GAMMA') + + bandpath['explicit_kpoints_rel'] = np.array(rel_kpts) + bandpath['explicit_kpoints_labels'] = labels + + # These unused properties have been invalidated: safer + # to leave None than incorrect values + bandpath['explicit_kpoints_abs'] = None + bandpath['explicit_kpoints_linearcoord'] = None + bandpath['explicit_segments'] = None + + +XTickLabels = list[tuple[int, str]] +SplitArgs = dict[str, Any] + + +# Dictionary returned by seekpath.get_explicit_k_path_orig_cell +# Not a complete specification, but these are the parts we care about. +class BandpathDict(TypedDict, total=False): + explicit_kpoints_labels: Sequence[str] + explicit_kpoints_rel: Iterable[float] + is_supercell: bool + + +def _convert_labels_to_fractions( + bandpath: BandpathDict, *, limit: int = 32) -> None: + """Replace high-symmetry labels in seekpath data with simple fractions + + bandpath: + dict from seekpath.get_explicit_k_path_orig_cell + + limit: + maximum numerator value for float rounded to fraction + """ + for i, (label, qpt) in enumerate(zip(bandpath['explicit_kpoints_labels'], + bandpath['explicit_kpoints_rel'], + strict=True)): + if label: + qpt_label = ' '.join(str(Fraction(x).limit_denominator(limit)) + for x in qpt) + bandpath['explicit_kpoints_labels'][i] = qpt_label + + +def _bands_from_force_constants(data: ForceConstants, + q_distance: Quantity, + insert_gamma: bool = True, + frequencies_only: bool = False, + **calc_modes_kwargs, +) -> tuple[QpointPhononModes | QpointFrequencies, XTickLabels, SplitArgs]: + structure = data.crystal.to_spglib_cell() + with spglib_new_errors(): + bandpath = seekpath.get_explicit_k_path_orig_cell( + structure, + reference_distance=q_distance.to('1 / angstrom').magnitude) + + if insert_gamma: + _insert_gamma(bandpath) + + # If input structure was not primitive, the high-symmetry points are not + # really meaningful. Indicate this by converting to numerical form. + if bandpath.get('is_supercell'): + _convert_labels_to_fractions(bandpath, limit=32) + + x_tick_labels = _get_tick_labels(bandpath) + split_args = {'indices': _get_break_points(bandpath)} + + print( + 'Computing phonon modes: {n_modes} modes across {n_qpts} q-points' + .format(n_modes=(data.crystal.n_atoms * 3), + n_qpts=len(bandpath['explicit_kpoints_rel']))) + qpts = bandpath['explicit_kpoints_rel'] + + if frequencies_only: + modes = data.calculate_qpoint_frequencies(qpts, + reduce_qpts=False, + **calc_modes_kwargs) + else: + modes = data.calculate_qpoint_phonon_modes(qpts, + reduce_qpts=False, + **calc_modes_kwargs) + return modes, x_tick_labels, split_args diff --git a/euphonic/cli/utils.py b/euphonic/cli/utils/_cli_parser.py similarity index 52% rename from euphonic/cli/utils.py rename to euphonic/cli/utils/_cli_parser.py index 5dc7e7e5e..f76225833 100644 --- a/euphonic/cli/utils.py +++ b/euphonic/cli/utils/_cli_parser.py @@ -1,504 +1,16 @@ from argparse import ( ArgumentDefaultsHelpFormatter, ArgumentParser, - Namespace, _ArgumentGroup, ) -from collections.abc import Collection, Iterable, Sequence -from contextlib import suppress -from fractions import Fraction -import json -import os -from pathlib import Path -import re -from typing import Any, TypedDict - -import numpy as np -from pint import UndefinedUnitError -import seekpath - -from euphonic import ( - Crystal, - DebyeWaller, - ForceConstants, - QpointFrequencies, - QpointPhononModes, - Quantity, - Spectrum1D, - Spectrum1DCollection, - ureg, -) +from collections.abc import Collection + from euphonic.util import ( dedent_and_fill, format_error, - mp_grid, - spglib_new_errors, ) -def _load_euphonic_json(filename: str | os.PathLike, - frequencies_only: bool = False, -) -> QpointPhononModes | QpointFrequencies | ForceConstants: - with open(filename) as f: - data = json.load(f) - - if 'force_constants' in data: - return ForceConstants.from_json_file(filename) - if 'frequencies' in data: - if 'eigenvectors' in data and not frequencies_only: - return QpointPhononModes.from_json_file(filename) - return QpointFrequencies.from_json_file(filename) - - msg = format_error( - f'Could not identify Euphonic data in JSON file ({filename}).', - fix='Ensure JSON file contains "force_constants" or "frequencies".', - ) - raise ValueError(msg) - - -def _load_phonopy_file(filename: str | os.PathLike, - frequencies_only: bool = False, -) -> QpointPhononModes | QpointFrequencies | ForceConstants: - path = Path(filename) - loaded_data = None - if not frequencies_only: - with suppress(KeyError, RuntimeError): - # KeyError will be raised if it is actually a force - # constants file, RuntimeError will be raised if - # it only contains q-point frequencies (no eigenvectors) - - loaded_data = QpointPhononModes.from_phonopy( - path=path.parent, phonon_name=path.name) - - # Try to read QpointFrequencies if loading QpointPhononModes has - # failed, or has been specifically requested with frequencies_only - if frequencies_only or loaded_data is None: - with suppress(KeyError): - loaded_data = QpointFrequencies.from_phonopy( - path=path.parent, phonon_name=path.name) - - if loaded_data is None: - phonopy_kwargs: dict[str, str | os.PathLike] = {} - phonopy_kwargs['path'] = path.parent - if (path.parent / 'BORN').is_file(): - phonopy_kwargs['born_name'] = 'BORN' - # Set summary_name and fc_name depending on input file - if path.suffix == '.hdf5': - if (path.parent / 'phonopy.yaml').is_file(): - phonopy_kwargs['summary_name'] = 'phonopy.yaml' - phonopy_kwargs['fc_name'] = path.name - else: - msg = format_error( - 'Missing phonopy.yaml.', - reason = ( - 'Phonopy force_constants.hdf5 file ' - 'must be accompanied by information ' - 'about atomic masses, supercell, etc.' - ), - fix='Ensure phonopy.yaml provided.', - ) - raise ValueError(msg) - elif path.suffix in ('.yaml', '.yml'): - phonopy_kwargs['summary_name'] = path.name - # Assume this is a (renamed?) phonopy.yaml file - if (janus_fc := _janus_fc_filename(path)).is_file(): - phonopy_kwargs['fc_name'] = janus_fc.name - elif (path.parent / 'force_constants.hdf5').is_file(): - phonopy_kwargs['fc_name'] = 'force_constants.hdf5' - else: - phonopy_kwargs['fc_name'] = 'FORCE_CONSTANTS' - loaded_data = ForceConstants.from_phonopy(**phonopy_kwargs) - - return loaded_data - - -def _janus_fc_filename(phonopy_file: Path) -> Path: - """Get corresponding force_constants filename following Janus convention - - If the filename follows the pattern "seedname-phonopy.yml" this will be - "seedname-force_constants.hdf5" in the same directory. - - Otherwise, return Path.cwd(), which will fail an .is_file() check. - """ - - re_match = re.match(r'(?P.+)-phonopy\.(?Pya?ml)', - phonopy_file.name) - if re_match: - seedname = re_match.group('seedname') - return Path(phonopy_file.parent / f'{seedname}-force_constants.hdf5') - return Path.cwd() - - -def load_data_from_file(filename: str | os.PathLike, - frequencies_only: bool = False, - verbose: bool = False, -) -> QpointPhononModes | QpointFrequencies | ForceConstants: - """ - Load phonon mode or force constants data from file - - Parameters - ---------- - filename - The file with a path - frequencies_only - If true only reads frequencies (not eigenvectors) from the - file. Only applies if the file is not a force constants - file. - - Returns - ------- - file_data - """ - castep_qpm_suffixes = ('.phonon',) - castep_fc_suffixes = ('.castep_bin', '.check') - phonopy_suffixes = ('.hdf5', '.yaml', '.yml') - - path = Path(filename) - if path.suffix in castep_qpm_suffixes: - if frequencies_only: - data = QpointFrequencies.from_castep(path) - else: - data = QpointPhononModes.from_castep(path) - elif path.suffix in castep_fc_suffixes: - data = ForceConstants.from_castep(path) - elif path.suffix == '.json': - data = _load_euphonic_json(path, frequencies_only) - elif path.suffix in phonopy_suffixes: - data = _load_phonopy_file(path, frequencies_only) - else: - msg = format_error( - f'File format ({path.suffix}) not recognised.', - reason=f""" - CASTEP force constants data for - import should have extension from {castep_fc_suffixes}, CASTEP - phonon mode data for import should have extension - '{castep_qpm_suffixes}', data from Phonopy should have extension - from {phonopy_suffixes}, data from Euphonic should have extension - '.json'.""", - fix='Ensure file format in known formats.', - ) - raise ValueError(msg) - if verbose: - print(f'{data.__class__.__name__} data was loaded') - return data - - -def matplotlib_save_or_show(save_filename: Path | str | None = None) -> None: - """ - Save or show the current matplotlib plot. - Show if save_filename is not None which by default it is. - - Parameters - ---------- - save_filename - The file to save the plot in - """ - import matplotlib.pyplot as plt - if save_filename is not None: - plt.savefig(save_filename) - print(f'Saved plot to {Path(save_filename).resolve()}') - else: - plt.show() - - -def _get_q_distance(length_unit_string: str, q_distance: float) -> Quantity: - """ - Parse user arguments to obtain reciprocal-length spacing Quantity - """ - try: - length_units = ureg(length_unit_string) - except UndefinedUnitError as err: - msg = format_error( - 'Length unit not known', - reason='Euphonic uses Pint for units.', - fix=("Try 'angstrom' or 'bohr'. " - "Metric prefixes are also allowed, e.g 'nm'."), - ) - raise ValueError(msg) from err - recip_length_units = 1 / length_units - return q_distance * recip_length_units - - -def _get_energy_bins( - modes: QpointPhononModes | QpointFrequencies, - n_ebins: int, emin: float | None = None, - emax: float | None = None, - headroom: float = 1.05) -> Quantity: - """ - Gets recommended energy bins, in same units as modes.frequencies. - emin and emax are assumed to be in the same units as - modes.frequencies, if not provided the min/max values of - modes.frequencies are used to find the bin limits - """ - if emin is None: - # Subtract small amount from min frequency - otherwise due to unit - # conversions binning of this frequency can vary with different - # architectures/lib versions, making it difficult to test - emin_room = 1e-5*ureg('meV').to(modes.frequencies.units).magnitude - emin = min(np.min(modes.frequencies.magnitude - emin_room), 0.) - if emax is None: - emax = np.max(modes.frequencies.magnitude) * headroom - if emin >= emax: - msg = format_error( - 'Maximum energy should be greater than minimum.', - fix='Check --e-min and --e-max arguments.', - ) - raise ValueError(msg) - return np.linspace(emin, emax, n_ebins) * modes.frequencies.units - - -def _get_tick_labels(bandpath: dict) -> list[tuple[int, str]]: - """Convert x-axis labels from seekpath format to euphonic format - - i.e.:: - - ['L', '', '', 'X', '', 'GAMMA'] --> - - [(0, 'L'), (3, 'X'), (5, '$\\Gamma$')] - """ - - label_indices = np.where(bandpath['explicit_kpoints_labels'])[0] - labels = [bandpath['explicit_kpoints_labels'][i] for i in label_indices] - - for i, label in enumerate(labels): - if label == 'GAMMA': - labels[i] = r'$\Gamma$' - - return list(zip(label_indices, labels, strict=True)) - - -def _get_break_points(bandpath: dict) -> list[int]: - """Get information about band path labels and break points - - Parameters - ---------- - bandpath - Bandpath dictionary from Seekpath - - Returns - ------- - break_points - Indices at which the spectrum should be split into subplots - """ - # Find break points between continuous spectra: wherever there are two - # adjacent labels - labels = np.array(bandpath['explicit_kpoints_labels']) - - special_point_bools = np.fromiter( - map(bool, labels), dtype=bool) - - # [T F F T T F T] -> [F F T T F T] AND [T F F T T F] = [F F F T F F] -> 3, - adjacent_non_empty_labels = np.logical_and(special_point_bools[:-1], - special_point_bools[1:]) - - adjacent_different_labels = (labels[:-1] != labels[1:]) - - break_points = np.where(np.logical_and(adjacent_non_empty_labels, - adjacent_different_labels))[0] - return (break_points + 1).tolist() - - -def _insert_gamma(bandpath: dict) -> None: - """Modify seekpath.get_explicit_k_path() results; duplicate Gamma - - This enables LO-TO splitting to be included - """ - import numpy as np - gamma_indices = np.where( - np.array(bandpath['explicit_kpoints_labels'][1:-1]) == 'GAMMA', - )[0] + 1 - - rel_kpts = bandpath['explicit_kpoints_rel'].tolist() - labels = bandpath['explicit_kpoints_labels'] - for i in reversed(gamma_indices.tolist()): - rel_kpts.insert(i, [0., 0., 0.]) - labels.insert(i, 'GAMMA') - - bandpath['explicit_kpoints_rel'] = np.array(rel_kpts) - bandpath['explicit_kpoints_labels'] = labels - - # These unused properties have been invalidated: safer - # to leave None than incorrect values - bandpath['explicit_kpoints_abs'] = None - bandpath['explicit_kpoints_linearcoord'] = None - bandpath['explicit_segments'] = None - - -XTickLabels = list[tuple[int, str]] -SplitArgs = dict[str, Any] - - -# Dictionary returned by seekpath.get_explicit_k_path_orig_cell -# Not a complete specification, but these are the parts we care about. -class BandpathDict(TypedDict, total=False): - explicit_kpoints_labels: Sequence[str] - explicit_kpoints_rel: Iterable[float] - is_supercell: bool - - -def _convert_labels_to_fractions( - bandpath: BandpathDict, *, limit: int = 32) -> None: - """Replace high-symmetry labels in seekpath data with simple fractions - - bandpath: - dict from seekpath.get_explicit_k_path_orig_cell - - limit: - maximum numerator value for float rounded to fraction - """ - for i, (label, qpt) in enumerate(zip(bandpath['explicit_kpoints_labels'], - bandpath['explicit_kpoints_rel'], - strict=True)): - if label: - qpt_label = ' '.join(str(Fraction(x).limit_denominator(limit)) - for x in qpt) - bandpath['explicit_kpoints_labels'][i] = qpt_label - - -def _bands_from_force_constants(data: ForceConstants, - q_distance: Quantity, - insert_gamma: bool = True, - frequencies_only: bool = False, - **calc_modes_kwargs, -) -> tuple[QpointPhononModes | QpointFrequencies, XTickLabels, SplitArgs]: - structure = data.crystal.to_spglib_cell() - with spglib_new_errors(): - bandpath = seekpath.get_explicit_k_path_orig_cell( - structure, - reference_distance=q_distance.to('1 / angstrom').magnitude) - - if insert_gamma: - _insert_gamma(bandpath) - - # If input structure was not primitive, the high-symmetry points are not - # really meaningful. Indicate this by converting to numerical form. - if bandpath.get('is_supercell'): - _convert_labels_to_fractions(bandpath, limit=32) - - x_tick_labels = _get_tick_labels(bandpath) - split_args = {'indices': _get_break_points(bandpath)} - - print( - 'Computing phonon modes: {n_modes} modes across {n_qpts} q-points' - .format(n_modes=(data.crystal.n_atoms * 3), - n_qpts=len(bandpath['explicit_kpoints_rel']))) - qpts = bandpath['explicit_kpoints_rel'] - - if frequencies_only: - modes = data.calculate_qpoint_frequencies(qpts, - reduce_qpts=False, - **calc_modes_kwargs) - else: - modes = data.calculate_qpoint_phonon_modes(qpts, - reduce_qpts=False, - **calc_modes_kwargs) - return modes, x_tick_labels, split_args - - -def _grid_spec_from_args(crystal: Crystal, - grid: Sequence[int] | None = None, - grid_spacing: Quantity = 0.1 * ureg('1/angstrom'), - ) -> tuple[int, int, int]: - """Get Monkorst-Pack mesh divisions from user arguments""" - if grid: - grid_spec = tuple(grid) - else: - grid_spec = crystal.get_mp_grid_spec(spacing=grid_spacing) - return grid_spec - - -def _get_debye_waller(temperature: Quantity, - fc: ForceConstants, - grid: Sequence[int] | None = None, - grid_spacing: Quantity = 0.1 * ureg('1/angstrom'), - **calc_modes_kwargs, - ) -> DebyeWaller: - """Generate Debye-Waller data from force constants and grid specification - """ - mp_grid_spec = _grid_spec_from_args(fc.crystal, grid=grid, - grid_spacing=grid_spacing) - print('Calculating Debye-Waller factor on {} q-point grid' - .format(' x '.join(map(str, mp_grid_spec)))) - dw_phonons = fc.calculate_qpoint_phonon_modes( - mp_grid(mp_grid_spec), **calc_modes_kwargs) - return dw_phonons.calculate_debye_waller(temperature) - - -def _get_pdos_weighting(cl_arg_weighting: str) -> str | None: - """ - Convert CL --weighting to weighting for calculate_pdos - e.g. --weighting coherent-dos to weighting=coherent - """ - if cl_arg_weighting == 'dos': - pdos_weighting = None - else: - idx = cl_arg_weighting.rfind('-') - if idx == -1: - msg = format_error( - f'Unexpected weighting "{cl_arg_weighting}"', - fix='Check weighting argument. Should be e.g. "coherent-dos".', - ) - raise ValueError(msg) - pdos_weighting = cl_arg_weighting[:idx] - return pdos_weighting - - -def _arrange_pdos_groups(pdos: Spectrum1DCollection, - cl_arg_pdos: Sequence[str], - ) -> Spectrum1D | Spectrum1DCollection: - """ - Convert PDOS returned by calculate_pdos to PDOS/DOS - wanted as CL output according to --pdos - """ - dos = pdos.sum() - if cl_arg_pdos is not None: - # Only label total DOS if there are other lines on the plot - dos.metadata['label'] = 'Total' - pdos = pdos.group_by('species') - for line_metadata in pdos.metadata['line_data']: - line_metadata['label'] = line_metadata['species'] - if len(cl_arg_pdos) > 0: - pdos = pdos.select(species=cl_arg_pdos) - dos = pdos - else: - dos = Spectrum1DCollection.from_spectra([dos, *pdos]) - return dos - - -def _plot_label_kwargs(args: Namespace, default_xlabel: str = '', - default_ylabel: str = '') -> dict[str, str]: - """Collect title/label arguments that can be passed to plot_nd - """ - plot_kwargs = {'title': args.title, - 'xlabel': default_xlabel, - 'ylabel': default_ylabel} - if args.ylabel is not None: - plot_kwargs['ylabel'] = args.ylabel - if args.xlabel is not None: - plot_kwargs['xlabel'] = args.xlabel - return plot_kwargs - - -def _calc_modes_kwargs(args: Namespace) -> dict[str, Any]: - """ - Collect arguments that can be passed to - ForceConstants.calculate_qpoint_phonon_modes() - """ - return {'asr': args.asr, 'dipole_parameter': args.dipole_parameter, - 'use_c': args.use_c, 'n_threads': args.n_threads} - -def _brille_calc_modes_kwargs(args: Namespace) -> dict[str, Any]: - """ - Collect arguments that can be passed to - BrilleInterpolator.calculate_qpoint_phonon_modes() - """ - if args.n_threads is None: - # Nothing specified, allow defaults - return {} - - return {'useparallel': args.n_threads > 1, 'threads': args.n_threads} - - def _get_cli_parser(features: Collection[str] = {}, # noqa: C901 conflict_handler: str = 'error', ) -> tuple[ArgumentParser, @@ -921,55 +433,3 @@ def _get_cli_parser(features: Collection[str] = {}, # noqa: C901 ) return parser, sections - - -MplStyle = str | dict[str, str] - - -def _compose_style( - *, user_args: Namespace, base: list[MplStyle] | None, - ) -> list[MplStyle]: - """Combine user-specified style options with default stylesheets - - Args: - user_args: from _get_cli_parser().parse_args() - base: Euphonic default styles for this plot - - N.B. matplotlib applies styles from left to right, so the right-most - elements of the list take the highest priority. This function builds a - list in the order: - - [base style(s), user style(s), CLI arguments] - """ - - style = base if not user_args.no_base_style and base is not None else [] - - if user_args.style: - style += user_args.style - - # Explicit args take priority over any other - explicit_args = {} - for user_arg, mpl_property in {'cmap': 'image.cmap', - 'fontsize': 'font.size', - 'font': 'font.sans-serif', - 'linewidth': 'lines.linewidth', - 'figsize': 'figure.figsize'}.items(): - if getattr(user_args, user_arg, None): - explicit_args.update({mpl_property: getattr(user_args, user_arg)}) - - if 'font.sans-serif' in explicit_args: - explicit_args.update({'font.family': 'sans-serif'}) - - if 'figure.figsize' in explicit_args: - dimensioned_figsize = [dim * ureg(user_args.figsize_unit) - for dim in explicit_args['figure.figsize']] - explicit_args['figure.figsize'] = [dim.to('inches').magnitude - for dim in dimensioned_figsize] - - style.append(explicit_args) - return style - - -def _get_title(filename: str, title: str | None = None) -> str: - """Get a plot title: either user-provided string, or from filename""" - return title if title is not None else Path(filename).stem diff --git a/euphonic/cli/utils/_dw.py b/euphonic/cli/utils/_dw.py new file mode 100644 index 000000000..eec955bd0 --- /dev/null +++ b/euphonic/cli/utils/_dw.py @@ -0,0 +1,30 @@ +from collections.abc import Sequence + +from euphonic import ( + DebyeWaller, + ForceConstants, + Quantity, + ureg, +) +from euphonic.util import ( + mp_grid, +) + +from ._grids import _grid_spec_from_args + + +def _get_debye_waller(temperature: Quantity, + fc: ForceConstants, + grid: Sequence[int] | None = None, + grid_spacing: Quantity = 0.1 * ureg('1/angstrom'), + **calc_modes_kwargs, + ) -> DebyeWaller: + """Generate Debye-Waller data from force constants and grid specification + """ + mp_grid_spec = _grid_spec_from_args(fc.crystal, grid=grid, + grid_spacing=grid_spacing) + print('Calculating Debye-Waller factor on {} q-point grid' + .format(' x '.join(map(str, mp_grid_spec)))) + dw_phonons = fc.calculate_qpoint_phonon_modes( + mp_grid(mp_grid_spec), **calc_modes_kwargs) + return dw_phonons.calculate_debye_waller(temperature) diff --git a/euphonic/cli/utils/_grids.py b/euphonic/cli/utils/_grids.py new file mode 100644 index 000000000..14ba39607 --- /dev/null +++ b/euphonic/cli/utils/_grids.py @@ -0,0 +1,73 @@ +"""Get sensible bins/grids from minimal user input""" + + +import numpy as np +from pint import UndefinedUnitError + +from euphonic import ( + Crystal, + QpointFrequencies, + QpointPhononModes, + Quantity, + ureg, +) +from euphonic.util import ( + format_error, +) + + +def _get_q_distance(length_unit_string: str, q_distance: float) -> Quantity: + """ + Parse user arguments to obtain reciprocal-length spacing Quantity + """ + try: + length_units = ureg(length_unit_string) + except UndefinedUnitError as err: + msg = format_error( + 'Length unit not known', + reason='Euphonic uses Pint for units.', + fix=("Try 'angstrom' or 'bohr'. " + "Metric prefixes are also allowed, e.g 'nm'."), + ) + raise ValueError(msg) from err + recip_length_units = 1 / length_units + return q_distance * recip_length_units + + +def _get_energy_bins( + modes: QpointPhononModes | QpointFrequencies, + n_ebins: int, emin: float | None = None, + emax: float | None = None, + headroom: float = 1.05) -> Quantity: + """ + Gets recommended energy bins, in same units as modes.frequencies. + emin and emax are assumed to be in the same units as + modes.frequencies, if not provided the min/max values of + modes.frequencies are used to find the bin limits + """ + if emin is None: + # Subtract small amount from min frequency - otherwise due to unit + # conversions binning of this frequency can vary with different + # architectures/lib versions, making it difficult to test + emin_room = 1e-5*ureg('meV').to(modes.frequencies.units).magnitude + emin = min(np.min(modes.frequencies.magnitude - emin_room), 0.) + if emax is None: + emax = np.max(modes.frequencies.magnitude) * headroom + if emin >= emax: + msg = format_error( + 'Maximum energy should be greater than minimum.', + fix='Check --e-min and --e-max arguments.', + ) + raise ValueError(msg) + return np.linspace(emin, emax, n_ebins) * modes.frequencies.units + + +def _grid_spec_from_args(crystal: Crystal, + grid: list[int] | None = None, + grid_spacing: Quantity = 0.1 * ureg('1/angstrom'), + ) -> tuple[int, int, int]: + """Get Monkorst-Pack mesh divisions from user arguments""" + if grid: + return tuple(grid) + + return crystal.get_mp_grid_spec(spacing=grid_spacing) diff --git a/euphonic/cli/utils/_kwargs.py b/euphonic/cli/utils/_kwargs.py new file mode 100644 index 000000000..a4dcd2a49 --- /dev/null +++ b/euphonic/cli/utils/_kwargs.py @@ -0,0 +1,35 @@ +"""Functions extracting useful groups of kwargs from the argparse Namespace""" + +from argparse import Namespace +from typing import Any + + +def _plot_label_kwargs(args: Namespace, default_xlabel: str = '', + default_ylabel: str = '') -> dict[str, str]: + """Collect title/label arguments that can be passed to plot_nd + """ + return {'title': args.title, + 'xlabel': getattr(args, 'xlabel', None) or default_xlabel, + 'ylabel': getattr(args, 'ylabel', None) or default_ylabel} + + +def _calc_modes_kwargs(args: Namespace) -> dict[str, Any]: + """ + Collect arguments that can be passed to + ForceConstants.calculate_qpoint_phonon_modes() + """ + return {'asr': args.asr, 'dipole_parameter': args.dipole_parameter, + 'use_c': args.use_c, 'n_threads': args.n_threads} + + +def _brille_calc_modes_kwargs(args: Namespace) -> dict[str, Any]: + """ + Collect arguments that can be passed to + BrilleInterpolator.calculate_qpoint_phonon_modes() + """ + if args.n_threads is None: + # Nothing specified, allow defaults + return {} + + return {'useparallel': args.n_threads > 1, 'threads': args.n_threads} + diff --git a/euphonic/cli/utils/_loaders.py b/euphonic/cli/utils/_loaders.py new file mode 100644 index 000000000..05b454453 --- /dev/null +++ b/euphonic/cli/utils/_loaders.py @@ -0,0 +1,163 @@ +from contextlib import suppress +import json +import os +from pathlib import Path +import re + +from euphonic import ( + ForceConstants, + QpointFrequencies, + QpointPhononModes, +) +from euphonic.util import ( + format_error, +) + + +def _load_euphonic_json(filename: str | os.PathLike, + frequencies_only: bool = False, +) -> QpointPhononModes | QpointFrequencies | ForceConstants: + with open(filename) as f: + data = json.load(f) + + match data: + case {'force_constants': _}: + return ForceConstants.from_json_file(filename) + case {'frequencies': _, 'eigenvectors': _} if not frequencies_only: + return QpointPhononModes.from_json_file(filename) + case {'frequencies': _}: + return QpointFrequencies.from_json_file(filename) + case _: + msg = format_error( + 'Could not identify Euphonic data in ' + f'JSON file ({filename}).', + fix=('Ensure JSON file contains "force_constants" ' + 'or "frequencies".'), + ) + raise ValueError(msg) + + +def _load_phonopy_file(filename: str | os.PathLike, + frequencies_only: bool = False, +) -> QpointPhononModes | QpointFrequencies | ForceConstants: + path = Path(filename) + loaded_data = None + if not frequencies_only: + with suppress(KeyError, RuntimeError): + # KeyError will be raised if it is actually a force + # constants file, RuntimeError will be raised if + # it only contains q-point frequencies (no eigenvectors) + + loaded_data = QpointPhononModes.from_phonopy( + path=path.parent, phonon_name=path.name) + + # Try to read QpointFrequencies if loading QpointPhononModes has + # failed, or has been specifically requested with frequencies_only + if frequencies_only or loaded_data is None: + with suppress(KeyError): + loaded_data = QpointFrequencies.from_phonopy( + path=path.parent, phonon_name=path.name) + + if loaded_data is None: + phonopy_kwargs: dict[str, str | os.PathLike] = {} + phonopy_kwargs['path'] = path.parent + if path.with_name('BORN').is_file(): + phonopy_kwargs['born_name'] = 'BORN' + # Set summary_name and fc_name depending on input file + if path.suffix == '.hdf5': + if path.with_name('phonopy.yaml').is_file(): + phonopy_kwargs['summary_name'] = 'phonopy.yaml' + phonopy_kwargs['fc_name'] = path.name + else: + msg = format_error( + 'Missing phonopy.yaml.', + reason = ( + 'Phonopy force_constants.hdf5 file ' + 'must be accompanied by information ' + 'about atomic masses, supercell, etc.' + ), + fix='Ensure phonopy.yaml provided.', + ) + raise ValueError(msg) + elif path.suffix in ('.yaml', '.yml'): + phonopy_kwargs['summary_name'] = path.name + # Assume this is a (renamed?) phonopy.yaml file + if (janus_fc := _janus_fc_filename(path)).is_file(): + phonopy_kwargs['fc_name'] = janus_fc.name + elif path.with_name('force_constants.hdf5').is_file(): + phonopy_kwargs['fc_name'] = 'force_constants.hdf5' + else: + phonopy_kwargs['fc_name'] = 'FORCE_CONSTANTS' + loaded_data = ForceConstants.from_phonopy(**phonopy_kwargs) + + return loaded_data + + +def _janus_fc_filename(phonopy_file: Path) -> Path: + """Get corresponding force_constants filename following Janus convention + + If the filename follows the pattern "seedname-phonopy.yml" this will be + "seedname-force_constants.hdf5" in the same directory. + + Otherwise, return Path.cwd(), which will fail an .is_file() check. + """ + + if re_match := re.match( + r'(?P.+)-phonopy\.(?Pya?ml)', phonopy_file.name): + seedname = re_match.group('seedname') + return phonopy_file.with_name(f'{seedname}-force_constants.hdf5') + return Path.cwd() + + +def load_data_from_file(filename: str | os.PathLike, + frequencies_only: bool = False, + verbose: bool = False, +) -> QpointPhononModes | QpointFrequencies | ForceConstants: + """ + Load phonon mode or force constants data from file + + Parameters + ---------- + filename + The file with a path + frequencies_only + If true only reads frequencies (not eigenvectors) from the + file. Only applies if the file is not a force constants + file. + + Returns + ------- + file_data + """ + castep_qpm_suffixes = ('.phonon',) + castep_fc_suffixes = ('.castep_bin', '.check') + phonopy_suffixes = ('.hdf5', '.yaml', '.yml') + + path = Path(filename) + if path.suffix in castep_qpm_suffixes: + if frequencies_only: + data = QpointFrequencies.from_castep(path) + else: + data = QpointPhononModes.from_castep(path) + elif path.suffix in castep_fc_suffixes: + data = ForceConstants.from_castep(path) + elif path.suffix == '.json': + data = _load_euphonic_json(path, frequencies_only) + elif path.suffix in phonopy_suffixes: + data = _load_phonopy_file(path, frequencies_only) + else: + msg = format_error( + f'File format ({path.suffix}) not recognised.', + reason=f""" + CASTEP force constants data for + import should have extension from {castep_fc_suffixes}, CASTEP + phonon mode data for import should have extension + '{castep_qpm_suffixes}', data from Phonopy should have extension + from {phonopy_suffixes}, data from Euphonic should have extension + '.json'.""", + fix='Ensure file format in known formats.', + ) + raise ValueError(msg) + if verbose: + print(f'{data.__class__.__name__} data was loaded') + return data diff --git a/euphonic/cli/utils/_pdos.py b/euphonic/cli/utils/_pdos.py new file mode 100644 index 000000000..7ff1e86b6 --- /dev/null +++ b/euphonic/cli/utils/_pdos.py @@ -0,0 +1,50 @@ +from collections.abc import Sequence + +from euphonic import ( + Spectrum1D, + Spectrum1DCollection, +) +from euphonic.util import ( + format_error, +) + + +def _arrange_pdos_groups(pdos: Spectrum1DCollection, + cl_arg_pdos: Sequence[str], + ) -> Spectrum1D | Spectrum1DCollection: + """ + Convert PDOS returned by calculate_pdos to PDOS/DOS + wanted as CL output according to --pdos + """ + dos = pdos.sum() + if cl_arg_pdos is not None: + # Only label total DOS if there are other lines on the plot + dos.metadata['label'] = 'Total' + pdos = pdos.group_by('species') + for line_metadata in pdos.metadata['line_data']: + line_metadata['label'] = line_metadata['species'] + if len(cl_arg_pdos) > 0: + pdos = pdos.select(species=cl_arg_pdos) + dos = pdos + else: + dos = Spectrum1DCollection.from_spectra([dos, *pdos]) + return dos + + +def _get_pdos_weighting(cl_arg_weighting: str) -> str | None: + """ + Convert CL --weighting to weighting for calculate_pdos + e.g. --weighting coherent-dos to weighting=coherent + """ + if cl_arg_weighting == 'dos': + return None + + idx = cl_arg_weighting.rfind('-') + if idx == -1: + msg = format_error( + f'Unexpected weighting "{cl_arg_weighting}"', + fix='Check weighting argument. Should be e.g. "coherent-dos".', + ) + raise ValueError(msg) + + return cl_arg_weighting[:idx] diff --git a/euphonic/cli/utils/_plotting.py b/euphonic/cli/utils/_plotting.py new file mode 100644 index 000000000..201605c2d --- /dev/null +++ b/euphonic/cli/utils/_plotting.py @@ -0,0 +1,77 @@ +from argparse import ( + Namespace, +) +from pathlib import Path + +from euphonic import ( + ureg, +) + +MplStyle = str | dict[str, str] + + +def matplotlib_save_or_show(save_filename: Path | str | None = None) -> None: + """ + Save or show the current matplotlib plot. + Show if save_filename is not None which by default it is. + + Parameters + ---------- + save_filename + The file to save the plot in + """ + import matplotlib.pyplot as plt + if save_filename is not None: + plt.savefig(save_filename) + print(f'Saved plot to {Path(save_filename).resolve()}') + else: + plt.show() + + +def _compose_style( + *, user_args: Namespace, base: list[MplStyle] | None, + ) -> list[MplStyle]: + """Combine user-specified style options with default stylesheets + + Args: + user_args: from _get_cli_parser().parse_args() + base: Euphonic default styles for this plot + + N.B. matplotlib applies styles from left to right, so the right-most + elements of the list take the highest priority. This function builds a + list in the order: + + [base style(s), user style(s), CLI arguments] + """ + + style = base if not user_args.no_base_style and base is not None else [] + + if user_args.style: + style += user_args.style + + # Explicit args take priority over any other + explicit_args = {} + for user_arg, mpl_property in {'cmap': 'image.cmap', + 'fontsize': 'font.size', + 'font': 'font.sans-serif', + 'linewidth': 'lines.linewidth', + 'figsize': 'figure.figsize'}.items(): + if getattr(user_args, user_arg, None): + explicit_args.update({mpl_property: getattr(user_args, user_arg)}) + + if 'font.sans-serif' in explicit_args: + explicit_args.update({'font.family': 'sans-serif'}) + + if 'figure.figsize' in explicit_args: + dimensioned_figsize = [dim * ureg(user_args.figsize_unit) + for dim in explicit_args['figure.figsize']] + explicit_args['figure.figsize'] = [dim.to('inches').magnitude + for dim in dimensioned_figsize] + + style.append(explicit_args) + return style + + +def _get_title(filename: str, title: str | None = None) -> str: + """Get a plot title: either user-provided string, or from filename""" + return title if title is not None else Path(filename).stem diff --git a/meson.build b/meson.build index 53d2afab5..866f67e26 100644 --- a/meson.build +++ b/meson.build @@ -20,7 +20,11 @@ py_src = { 'euphonic/cli': ['__init__.py', 'brille_convergence.py', 'dispersion.py', 'dos.py', 'intensity_map.py', 'optimise_dipole_parameter.py', 'powder_map.py', - 'show_sampling.py', 'utils.py'], + 'show_sampling.py'], + 'euphonic/cli/utils': [ + '__init__.py', '_band_structure.py', '_cli_parser.py', '_dw.py', + '_grids.py', '_kwargs.py', '_loaders.py', '_pdos.py', '_plotting.py' + ], 'euphonic/data': ['__init__.py', 'bluebook.json', 'sears-1992.json', 'reciprocal_spectroscopy_definitions.txt'], 'euphonic/readers': ['__init__.py', 'castep.py', 'phonopy.py'], diff --git a/tests_and_analysis/test/euphonic_test/test_cli_utils.py b/tests_and_analysis/test/euphonic_test/test_cli_utils.py index 2811e0cd7..ddc02a128 100644 --- a/tests_and_analysis/test/euphonic_test/test_cli_utils.py +++ b/tests_and_analysis/test/euphonic_test/test_cli_utils.py @@ -8,9 +8,9 @@ _get_cli_parser, _get_energy_bins, _get_q_distance, - _load_phonopy_file, load_data_from_file, ) +from euphonic.cli.utils._loaders import _load_phonopy_file from euphonic.ureg import Quantity from tests_and_analysis.test.utils import get_data_path @@ -60,7 +60,7 @@ def test_load_data_extension_error(): @pytest.fixture def mocked_fc_from_phonopy(mocker): - from euphonic.cli.utils import ForceConstants + from euphonic.cli.utils._loaders import ForceConstants mocked_method = mocker.patch.object(ForceConstants, 'from_phonopy') mocked_method.return_value = None @@ -75,8 +75,6 @@ def test_find_force_constants(mocked_fc_from_phonopy): Rather than add a whole script test case, here we use mocking to check a particular internal method path """ - from euphonic.cli.utils import _load_phonopy_file - phonopy_file = get_data_path( 'phonopy_files', 'NaCl', 'phonopy_nofc.yaml', )