-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcontext_align.py
More file actions
82 lines (67 loc) · 3.1 KB
/
context_align.py
File metadata and controls
82 lines (67 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import argparse
import torch
import torch.distributed as dist
from src.common import PATHS
from src.tasks.context_align_task import ContextAligning
from src.utils.config import Config
from src.utils.tools import control_randomness, make_dir_if_not_exists, parse_config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
NOTES = "Aligning"
def main_worker():
# ========== Distributed Environment Setup ==========
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
# ========== CLI Argument Parsing ==========
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/align.yaml")
parser.add_argument("--pretraining_run_name", type=str, default="")
parser.add_argument("--train_batch_size", type=int, default=64)
parser.add_argument("--val_batch_size", type=int, default=64)
parser.add_argument("--finetuning_mode", type=str, default="end-to-end")
parser.add_argument("--init_lr", type=float, default=0.001)
parser.add_argument("--max_epoch", type=int, default=300)
parser.add_argument("--text_encoder_name", type=str, default="bert-base-uncased")
parser.add_argument("--num_negatives", type=int, default=10)
parser.add_argument('--hard_negative_mining', action='store_true', default=False)
parser.add_argument("--cross_attend", action='store_true', default=False)
args_cmd = parser.parse_args()
# ========== Config & Reproducibility ==========
config = Config(
config_file_path=args_cmd.config,
default_config_file_path="configs/default.yaml"
).parse(run_name=args_cmd.pretraining_run_name, if_override=True)
control_randomness(config["random_seed"])
config["device"] = local_rank
config["rank"] = global_rank
config["world_size"] = world_size
config["distributed"] = True
config["checkpoint_path"] = PATHS.CHECKPOINTS_DIR + "context_align/"
config["result_dir"] = PATHS.RESULTS_DIR + "context_align/"
make_dir_if_not_exists(config["checkpoint_path"])
args = parse_config(config)
args.pretraining_run_name = args_cmd.pretraining_run_name
args.train_batch_size = args_cmd.train_batch_size
args.val_batch_size = args_cmd.val_batch_size
args.finetuning_mode = args_cmd.finetuning_mode
args.max_epoch = args_cmd.max_epoch
args.init_lr = args_cmd.init_lr
args.text_encoder_name = args_cmd.text_encoder_name
args.num_negatives = args_cmd.num_negatives
args.hard_negative_mining = args_cmd.hard_negative_mining
args.cross_attend = args_cmd.cross_attend
if global_rank == 0:
print(f"[Rank {global_rank}] Running with config:\n{args}\n")
# ========== Main Training ==========
task_obj = ContextAligning(args=args)
if global_rank == 0:
task_obj.setup_logger(notes=NOTES+f" on {args.pretraining_run_name}")
task_obj.train()
if global_rank == 0:
task_obj.end_logger()
dist.destroy_process_group()
if __name__ == "__main__":
main_worker()