Skip to content

Commit

Permalink
Using the same StructureData for scf, nscf, bands.
Browse files Browse the repository at this point in the history
This make provenance more linear and robust
  • Loading branch information
mikibonacci committed Dec 4, 2024
1 parent ad368a7 commit bf0a97e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
9 changes: 6 additions & 3 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ def run(self, step: Step):
return

self.step_data['steps'][step.uid] = {} # maybe not needed
builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
print(f"Running workchain {running.pk} for step {step.uid}")
# running = aiidawrapperwchain.submit(builder) # in the non-blocking case.

# The below will be passed to the context, so we will need to store also the instance of the submitted workchain, if in KoopmansWorkChain.
self.step_data['steps'][step.uid] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder}

self.set_status(step, Status.RUNNING)
Expand Down Expand Up @@ -154,8 +157,8 @@ def get_pseudopotential(self, library: str, element: str):
pseudo_data = None
for pseudo in qb.all():
with tempfile.TemporaryDirectory() as dirpath:
temp_file = pathlib.Path(dirpath) / (pseudo[0].attributes['element'] + '.upf')
with pseudo[0].open(pseudo[0].attributes['element'] + '.upf', 'rb') as handle:
temp_file = pathlib.Path(dirpath) / (pseudo[0].base.attributes.all['element'] + '.upf')
with pseudo[0].open(pseudo[0].base.attributes.all['element'] + '.upf', 'rb') as handle:
temp_file.write_bytes(handle.read())
pseudo_data = read_pseudo_file(temp_file)

Expand Down
33 changes: 18 additions & 15 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,19 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
aiida_inputs = step_data['configuration']
calc_params = pw_calculator._parameters

if isinstance(pw_calculator.atoms, AtomsKoopmans):
ase_atoms = Atoms.fromdict(pw_calculator.atoms.todict())

# WE NEED TO USE THE INPUT STRUCTURE OF SCF, WHEN WE DO NSCF
structure = orm.StructureData(ase=ase_atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input.
structure = None
parent_folder = None
for step, val in step_data['steps'].items():
if "scf" in str(step) and ("nscf" in pw_calculator.uid or "bands" in pw_calculator.uid):
scf = orm.load_node(val["workchain"])
structure = scf.inputs.pw.structure
parent_folder = scf.outputs.remote_folder
break

if not structure:
if isinstance(pw_calculator.atoms, AtomsKoopmans):
ase_atoms = Atoms.fromdict(pw_calculator.atoms.todict())
structure = orm.StructureData(ase=ase_atoms)

pw_overrides = {
"CONTROL": {},
Expand Down Expand Up @@ -82,14 +90,10 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
# here we need explicit kpoints
builder.kpoints.set_kpoints(calc_params["kpts"].kpts,cartesian=False) # TODO: check cartesian false is correct.

parent_calculators = [f[0].uid for f in pw_calculator.linked_files.values() if f[0] is not None]
if len(set(parent_calculators)) > 1:
raise ValueError("More than one parent calculator found.")
elif len(set(parent_calculators)) == 1:
if "remote_folder" in step_data['steps'][parent_calculators[0]]:
builder.pw.parent_folder = orm.load_node(step_data['steps'][parent_calculators[0]]["remote_folder"])

return builder
if parent_folder:
builder.pw.parent_folder = parent_folder

return builder, step_data

def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None):
# get the builder from WannierizeWorkflow, but after we already initialized a Wannier90Calculator.
Expand Down Expand Up @@ -232,8 +236,7 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
if nscf.inputs.pw.parameters.get_dict()["SYSTEM"]["nspin"]>1: params_pw2wannier90['inputpp']["spin_component"] = "up"
builder.pw2wannier90.pw2wannier90.parameters = orm.Dict(dict=params_pw2wannier90)


return builder
return builder, step_data

## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`.
mapping_calculators = {
Expand Down

0 comments on commit bf0a97e

Please sign in to comment.