-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
72 lines (58 loc) · 2.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import pprint
from functools import reduce
from typing import Optional
from loguru import logger
from contrastyou.configure import dictionary_merge_by_hierachy, extract_dictionary_from_anchor, \
extract_params_with_key_prefix, ConfigManager
def separate_pretrain_finetune_configs(config_manager):
input_params = config_manager.parsed_config
base_config = config_manager.base_config
opt_params = reduce(dictionary_merge_by_hierachy, config_manager.optional_configs)
pretrain_config = dictionary_merge_by_hierachy(base_config, opt_params)
# extract the input_params for both settings
pretrain_config = dictionary_merge_by_hierachy(
pretrain_config,
extract_dictionary_from_anchor(target_dictionary=input_params,
anchor_dictionary=pretrain_config,
prune_anchor=True))
# extract input_params for pre_
pretrain_config = dictionary_merge_by_hierachy(
pretrain_config,
extract_params_with_key_prefix(input_params, prefix="pre_"))
base_config = dictionary_merge_by_hierachy(
base_config,
extract_dictionary_from_anchor(target_dictionary=input_params,
anchor_dictionary=base_config,
prune_anchor=True))
base_config = dictionary_merge_by_hierachy(
base_config,
extract_params_with_key_prefix(input_params, prefix="ft_"))
return pretrain_config, base_config
def logging_configs(manager: ConfigManager, logger: logger):
unmerged_dictionaries = manager.unmerged_configs
parsed_params = manager.parsed_config
config_dictionary = manager.config
for i, od in enumerate(unmerged_dictionaries):
logger.info(f"optional configs {i}")
logger.info("\n" + pprint.pformat(od))
logger.info(f"parsed params")
logger.info("\n" + pprint.pformat(parsed_params))
logger.info("merged params")
logger.info("\n" + pprint.pformat(config_dictionary))
def find_checkpoint(trainer_folder, name="last.pth") -> Optional[str]:
ckpt_path = os.path.join(trainer_folder, name)
if os.path.exists(ckpt_path):
logger.info(f"Found existing checkpoint from folder {trainer_folder}")
return ckpt_path
return None
def grouper(array_list, group_num):
num_samples = len(array_list) // group_num
batch = []
for item in array_list:
if len(batch) == num_samples:
yield batch
batch = []
batch.append(item)
if len(batch) > 0:
yield batch