-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinstall_practical_rife_weights.py
More file actions
130 lines (107 loc) · 4.15 KB
/
install_practical_rife_weights.py
File metadata and controls
130 lines (107 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import hashlib
import zipfile
from pathlib import Path, PurePosixPath
BUNDLE_ROOT = Path(__file__).resolve().parent
DEFAULT_INSTALL_PATH = BUNDLE_ROOT / "Practical-RIFE" / "train_log" / "flownet.pkl"
EXPECTED_FLOWNET_SHA256 = (
"6615790EFD627772917205DB291F51CD392528A157ECBB2ECAEEC3BFF8EB6DE2"
)
def _sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest().upper()
def _read_weights_from_archive(archive_path: Path) -> tuple[bytes, str]:
with zipfile.ZipFile(archive_path) as archive:
preferred_matches = []
fallback_matches = []
for name in archive.namelist():
pure_path = PurePosixPath(name)
if pure_path.name != "flownet.pkl":
continue
fallback_matches.append(name)
if pure_path.as_posix().endswith("train_log/flownet.pkl"):
preferred_matches.append(name)
matches = preferred_matches or fallback_matches
if not matches:
raise FileNotFoundError(
f"Archive does not contain flownet.pkl: {archive_path}"
)
if len(matches) > 1:
joined = ", ".join(matches)
raise RuntimeError(
f"Archive contains multiple flownet.pkl candidates: {joined}"
)
chosen = matches[0]
return archive.read(chosen), chosen
def _validated_weights_bytes(
*,
weights_file: Path | None,
weights_archive: Path | None,
) -> tuple[bytes, str]:
if weights_file is not None:
data = weights_file.read_bytes()
source = str(weights_file)
elif weights_archive is not None:
data, archive_member = _read_weights_from_archive(weights_archive)
source = f"{weights_archive}!{archive_member}"
else:
raise RuntimeError("Expected either weights_file or weights_archive.")
digest = _sha256_bytes(data)
if digest != EXPECTED_FLOWNET_SHA256:
raise RuntimeError(
"Unexpected Practical-RIFE weights hash. "
f"Expected {EXPECTED_FLOWNET_SHA256}, got {digest} from {source}"
)
return data, source
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
source_group = parser.add_mutually_exclusive_group(required=True)
source_group.add_argument(
"--weights-file",
type=Path,
help="Path to a downloaded Practical-RIFE flownet.pkl file.",
)
source_group.add_argument(
"--weights-archive",
type=Path,
help="Path to a downloaded Practical-RIFE archive containing train_log/flownet.pkl.",
)
parser.add_argument(
"--install-path",
type=Path,
default=DEFAULT_INSTALL_PATH,
help=(
"Destination for the installed Practical-RIFE weights. "
f"Default: {DEFAULT_INSTALL_PATH}"
),
)
parser.add_argument(
"--force",
action="store_true",
help="Overwrite an existing installed flownet.pkl.",
)
return parser.parse_args()
def main() -> None:
args = _parse_args()
install_path = Path(args.install_path)
if args.weights_file is not None and not args.weights_file.exists():
raise FileNotFoundError(f"Missing weights file: {args.weights_file}")
if args.weights_archive is not None and not args.weights_archive.exists():
raise FileNotFoundError(f"Missing weights archive: {args.weights_archive}")
if install_path.exists() and not args.force:
raise FileExistsError(
f"Weights already installed at {install_path}. Use --force to overwrite."
)
data, source = _validated_weights_bytes(
weights_file=args.weights_file,
weights_archive=args.weights_archive,
)
install_path.parent.mkdir(parents=True, exist_ok=True)
install_path.write_bytes(data)
print(f"Installed Practical-RIFE weights from: {source}")
print(f"Installed to: {install_path}")
print(f"SHA256: {EXPECTED_FLOWNET_SHA256}")
if __name__ == "__main__":
main()