Skip to content

Commit

Permalink
Support using variables in base config directly as normal variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 10, 2022
1 parent d30e37d commit 0193d0a
Showing 1 changed file with 58 additions and 44 deletions.
102 changes: 58 additions & 44 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module

from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
Expand Down Expand Up @@ -175,9 +173,35 @@ def _substitute_base_vars(cfg, base_var_dict, base_cfg):

return cfg

@staticmethod
def parse_base_files(filename):
if filename.endswith('.py'):
Config._validate_py_syntax(filename)
with open(filename) as f:
codes = ast.parse(f.read()).body

def is_base_line(c):
return (isinstance(c, ast.Assign)
and c.targets[0].id == BASE_KEY)

base_code = next((c for c in codes if is_base_line(c)), None)
if base_code is not None:
base_code = ast.Expression(body=base_code.value)
base_files = eval(compile(base_code, '', mode='eval'))
else:
base_files = []
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
base_files = cfg_dict.get(BASE_KEY, [])
base_files = base_files if isinstance(base_files,
list) else [base_files]
return base_files

@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
cfg_dir = osp.dirname(filename)
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
Expand All @@ -188,7 +212,6 @@ def _file2dict(filename, use_predefined_variables=True):
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename,
Expand All @@ -199,19 +222,27 @@ def _file2dict(filename, use_predefined_variables=True):
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)

# Handle base files
# base_filename = cfg_dict.pop(BASE_KEY)
# base_filename = base_filename if isinstance(
# base_filename, list) else [base_filename]
base_cfg_dict = dict()
cfg_text_list = list()
for f in Config.parse_base_files(temp_config_file.name):
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_text_list.append(_cfg_text)
duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(_cfg_dict)

if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules[temp_module_name]
cfg_dict = {}
with open(temp_config_file.name, 'r') as f:
content = f.read()
codeobj = compile(content, '', mode='exec')
eval(codeobj, base_cfg_dict, cfg_dict)
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(temp_config_file.name)
Expand All @@ -236,37 +267,20 @@ def _file2dict(filename, use_predefined_variables=True):
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()

if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]

cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)

base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(c)

# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
cfg_dict.pop(BASE_KEY, None)

base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = {
k: v
for k, v in cfg_dict.items() if not k.startswith('__')
}

# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)

return cfg_dict, cfg_text

Expand Down

0 comments on commit 0193d0a

Please sign in to comment.