Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions src/aiida_quantumespresso/workflows/protocols/pw/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,23 @@ default_inputs:
conv_thr_per_atom: 0.2e-9
etot_conv_thr_per_atom: 1.e-5
pseudo_family: 'SSSP/1.3/PBEsol/efficiency'
pw:
metadata:
options:
max_wallclock_seconds: 43200 # Twelve hours
withmpi: True
parameters:
CONTROL:
calculation: scf
forc_conv_thr: 1.e-4
tprnfor: True
tstress: True
SYSTEM:
nosym: False
occupations: smearing
smearing: cold
degauss: 0.02
ELECTRONS:
electron_maxstep: 80
mixing_beta: 0.4
options:
max_wallclock_seconds: 43200 # Twelve hours
withmpi: True
parameters:
CONTROL:
calculation: scf
forc_conv_thr: 1.e-4
tprnfor: True
tstress: True
SYSTEM:
nosym: False
occupations: smearing
smearing: cold
degauss: 0.02
ELECTRONS:
electron_maxstep: 80
mixing_beta: 0.4
default_protocol: balanced
protocols:
balanced:
Expand Down
60 changes: 30 additions & 30 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,14 @@ def define(cls, spec):
"""Define the process specification."""

super().define(spec)
spec.expose_inputs(PwCalculation, namespace='pw', exclude=('kpoints',))
spec.input(
'kpoints',
valid_type=orm.KpointsData,
required=False,
help='An explicit k-points list or mesh. Either this or `kpoints_distance` has to be provided.',

spec.expose_inputs(PwCalculation, exclude=('kpoints', 'metadata'))
spec.inputs.create_port_namespace('options').absorb(
PwCalculation.spec().inputs.get_port('metadata.options')
)
spec.input(
'kpoints_distance',
valid_type=orm.Float,
required=False,
spec.input('kpoints', valid_type=orm.KpointsData, required=False,
help='An explicit k-points list or mesh. Either this or `kpoints_distance` has to be provided.')
spec.input('kpoints_distance', valid_type=orm.Float, required=False,
help='The minimum desired distance in 1/Å between k-points in reciprocal space. The explicit k-points will '
'be generated automatically by a calculation function based on the input structure.',
)
Expand Down Expand Up @@ -207,15 +204,16 @@ def get_builder_from_protocol(
natoms = len(structure.sites)

# Update the parameters based on the protocol inputs
parameters = inputs['pw']['parameters']
parameters = inputs['parameters']

if overrides and 'pseudos' in overrides:

if overrides and 'pseudos' in overrides.get('pw', {}):
pseudos = overrides['pw']['pseudos']
pseudos = overrides['pseudos']

if sorted(pseudos.keys()) != sorted(structure.get_kind_names()):
raise ValueError(f'`pseudos` override needs one value for each of the {len(structure.kinds)} kinds.')

system_overrides = overrides['pw'].get('parameters', {}).get('SYSTEM', {})
system_overrides = overrides.get('parameters', {}).get('SYSTEM', {})

if not all(key in system_overrides for key in ('ecutwfc', 'ecutrho')):
raise ValueError(
Expand Down Expand Up @@ -280,30 +278,29 @@ def get_builder_from_protocol(

# If overrides are provided, they are considered absolute
if overrides:
parameter_overrides = overrides.get('pw', {}).get('parameters', {})
parameter_overrides = overrides.get('parameters', {})
parameters = recursive_merge(parameters, parameter_overrides)

# if tot_magnetization in overrides , remove starting_magnetization from parameters
if parameters.get('SYSTEM', {}).get('tot_magnetization') is not None:
parameters.setdefault('SYSTEM', {}).pop('starting_magnetization', None)

metadata = inputs['pw']['metadata']
inputs_options = inputs['options']
inputs_options = cls.set_default_resources(inputs_options, code.computer.scheduler_type)

if options:
metadata['options'] = recursive_merge(metadata['options'], options)

metadata['options'] = cls.set_default_resources(metadata['options'], code.computer.scheduler_type)
inputs_options = recursive_merge(inputs['options'], options)

builder = cls.get_builder()
builder.pw['code'] = code
builder.pw['pseudos'] = pseudos
builder.pw['structure'] = structure
builder.pw['parameters'] = orm.Dict(parameters)
builder.pw['metadata'] = metadata
if 'settings' in inputs['pw']:
builder.pw['settings'] = orm.Dict(inputs['pw']['settings'])
if 'parallelization' in inputs['pw']:
builder.pw['parallelization'] = orm.Dict(inputs['pw']['parallelization'])
builder['code'] = code
builder['pseudos'] = pseudos
builder['structure'] = structure
builder['parameters'] = orm.Dict(parameters)
builder['options'] = inputs_options
if 'settings' in inputs:
builder['settings'] = orm.Dict(inputs['settings'])
if 'parallelization' in inputs:
builder['parallelization'] = orm.Dict(inputs['parallelization'])
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
if 'kpoints' in inputs:
builder.kpoints = inputs['kpoints']
Expand All @@ -324,7 +321,10 @@ def setup(self):
default namelists for the ``parameters`` are set to empty dictionaries if not specified.
"""
super().setup()
self.ctx.inputs = AttributeDict(self.exposed_inputs(PwCalculation, 'pw'))
self.ctx.inputs = AttributeDict(self.exposed_inputs(PwCalculation))
self.ctx.inputs.metadata = AttributeDict({'options': self.inputs.options})
if 'disable_cache' in self.inputs.metadata:
self.ctx.inputs.metadata.disable_cache = self.inputs.metadata.disable_cache

self.ctx.inputs.parameters = self.ctx.inputs.parameters.get_dict()
self.ctx.inputs.parameters.setdefault('CONTROL', {})
Expand Down Expand Up @@ -354,7 +354,7 @@ def validate_kpoints(self):
kpoints = self.inputs.kpoints
except AttributeError:
inputs = {
'structure': self.inputs.pw.structure,
'structure': self.inputs.structure,
'distance': self.inputs.kpoints_distance,
'force_parity': self.inputs.get('kpoints_force_parity', orm.Bool(False)),
'metadata': {'call_link_label': 'create_kpoints_from_distance'},
Expand Down
Loading