-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpractical_rife_model.py
More file actions
346 lines (299 loc) · 11.2 KB
/
practical_rife_model.py
File metadata and controls
346 lines (299 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import sys
from importlib import import_module
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from executorch.backends.arm.quantizer import get_symmetric_quantization_config
from executorch.backends.arm.quantizer.arm_quantizer import QuantizationSpec
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from PIL import Image
from torchao.quantization.pt2e.observer import FixedQParamsObserver
from export_executorch import export_model
from export_scenario import build_scenario_from_edge_program
from install_practical_rife_weights import DEFAULT_INSTALL_PATH
from utils import mode_artifact_dirs, reset_generated_artifact_dirs
BUNDLE_ROOT = Path(__file__).resolve().parent
PRACTICAL_RIFE_ROOT = BUNDLE_ROOT / "Practical-RIFE"
DEFAULT_ARTIFACTS_ROOT = BUNDLE_ROOT / "artifacts" / "model_practical_rife"
DEFAULT_INPUT0_SOURCE = PRACTICAL_RIFE_ROOT / "demo" / "I0_0.png"
DEFAULT_INPUT1_SOURCE = PRACTICAL_RIFE_ROOT / "demo" / "I0_1.png"
DEFAULT_TIMESTEP = 0.5
DEFAULT_INPUT_SIZE = (448, 256)
DEFAULT_SCENARIO_FILENAME = "scenario.json"
DEFAULT_SCALE_LIST = (16.0, 8.0, 4.0, 2.0, 1.0)
PRACTICAL_RIFE_BASE_GRID_MODULE_NAME = (
"flownet.grid_position_builder.base_grid_builder"
)
PRACTICAL_RIFE_FLOW_TO_GRID_MODULE_NAME = "flownet.grid_position_builder.flow_to_grid"
PRACTICAL_RIFE_GRID_COMBINE_MODULE_NAME = (
"flownet.grid_position_builder.grid_combine"
)
def _image_snorm_qspec() -> QuantizationSpec:
return QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 127.0,
zero_point=0,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
quant_min=-127,
quant_max=127,
),
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
)
def _unit_interval_qspec() -> QuantizationSpec:
return QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 254.0,
zero_point=-127,
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
quant_min=-127,
quant_max=127,
),
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
)
DEFAULT_QUANTIZED_INPUT_QSPECS = [
_image_snorm_qspec(),
_image_snorm_qspec(),
_unit_interval_qspec(),
]
DEFAULT_QUANTIZED_OUTPUT_QSPECS = [_image_snorm_qspec()]
def _grid_position_int16_snorm_qspec() -> QuantizationSpec:
return QuantizationSpec(
dtype=torch.int16,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 32767.0,
zero_point=0,
dtype=torch.int16,
qscheme=torch.per_tensor_symmetric,
quant_min=-32767,
quant_max=32767,
),
quant_min=-32767,
quant_max=32767,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
)
def _grid_position_module_quant_config() -> QuantizationConfig:
base_quant_config = get_symmetric_quantization_config(
is_per_channel=True,
is_qat=False,
is_dynamic=False,
act_qmin=-127,
act_qmax=127,
weight_qmin=-127,
weight_qmax=127,
)
grid_qspec = _grid_position_int16_snorm_qspec()
return QuantizationConfig(
input_activation=base_quant_config.input_activation,
output_activation=grid_qspec,
weight=None,
bias=None,
)
def _shape_only_grid_position_quant_config() -> QuantizationConfig:
grid_qspec = _grid_position_int16_snorm_qspec()
return QuantizationConfig(
input_activation=grid_qspec,
output_activation=grid_qspec,
weight=None,
bias=None,
)
def _grid_position_combine_quant_config() -> QuantizationConfig:
grid_qspec = _grid_position_int16_snorm_qspec()
return QuantizationConfig(
input_activation=grid_qspec,
output_activation=grid_qspec,
weight=None,
bias=None,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
choices=("float", "quantized"),
default="quantized",
)
parser.add_argument(
"--input0",
type=Path,
default=DEFAULT_INPUT0_SOURCE,
help=f"Input image 0. Default: {DEFAULT_INPUT0_SOURCE}",
)
parser.add_argument(
"--input1",
type=Path,
default=DEFAULT_INPUT1_SOURCE,
help=f"Input image 1. Default: {DEFAULT_INPUT1_SOURCE}",
)
parser.add_argument(
"--timestep-path",
type=Path,
help="Optional .npy file containing the interpolation timestep tensor.",
)
parser.add_argument(
"--weights-path",
type=Path,
default=DEFAULT_INSTALL_PATH,
help=f"Path to flownet.pkl. Default: {DEFAULT_INSTALL_PATH}",
)
return parser.parse_args()
def _load_rgba_image(image_path: str | Path) -> torch.Tensor:
image = Image.open(image_path).convert("RGBA")
array = np.asarray(image, dtype=np.float32) / 255.0
return torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
def _save_rgba_image(tensor: torch.Tensor, image_path: str | Path) -> None:
image_tensor = tensor.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0)
array = (image_tensor * 255.0).round().to(torch.uint8).cpu().numpy()
output_path = Path(image_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(array, mode="RGBA").save(output_path)
def _practical_rife_ifnet_class():
if not PRACTICAL_RIFE_ROOT.exists():
raise FileNotFoundError(
f"Missing vendored Practical-RIFE files at {PRACTICAL_RIFE_ROOT}"
)
practical_rife_root = str(PRACTICAL_RIFE_ROOT)
if practical_rife_root not in sys.path:
sys.path.insert(0, practical_rife_root)
ifnet_module = import_module("train_log.IFNet_HDv3")
return ifnet_module.IFNet
class ModelPracticalRife(nn.Module):
def __init__(self, weights_path: str | Path) -> None:
super().__init__()
ifnet_cls = _practical_rife_ifnet_class()
self.flownet = ifnet_cls()
state_dict = torch.load(Path(weights_path), map_location="cpu")
converted = {
key.replace("module.", ""): value for key, value in state_dict.items()
}
self.flownet.load_state_dict(converted, strict=False)
self.flownet.eval()
for parameter in self.flownet.parameters():
parameter.requires_grad_(False)
def forward(
self,
image0: torch.Tensor,
image1: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
rgb0 = image0[:, :3]
rgb1 = image1[:, :3]
images = torch.cat((rgb0, rgb1), dim=1)
_flow_list, _mask, merged = self.flownet(
images, timestep, list(DEFAULT_SCALE_LIST)
)
alpha = torch.ones_like(merged[-1][:, :1])
return torch.cat((merged[-1], alpha), dim=1)
def _prepare_inputs(
*,
input0_source: Path,
input1_source: Path,
timestep_path: Path | None,
python_dir: Path,
) -> tuple[Path, Path, Path]:
if not input0_source.exists():
raise FileNotFoundError(f"Missing input image: {input0_source}")
if not input1_source.exists():
raise FileNotFoundError(f"Missing input image: {input1_source}")
input0_path = python_dir / input0_source.name
input1_path = python_dir / input1_source.name
for src, dst in ((input0_source, input0_path), (input1_source, input1_path)):
image = Image.open(src).convert("RGBA").resize(DEFAULT_INPUT_SIZE, Image.BILINEAR)
image.save(dst)
normalized_timestep_path = python_dir / "input_timestep.npy"
if timestep_path is None:
np.save(
normalized_timestep_path,
np.array([[[[DEFAULT_TIMESTEP]]]], dtype=np.float32),
)
else:
if not timestep_path.exists():
raise FileNotFoundError(f"Missing timestep file: {timestep_path}")
timestep_values = np.load(timestep_path)
np.save(normalized_timestep_path, timestep_values)
return input0_path, input1_path, normalized_timestep_path
def _manual_install_message(weights_path: Path) -> str:
installer = BUNDLE_ROOT / "install_practical_rife_weights.py"
return (
f"Missing Practical-RIFE weights at {weights_path}. "
"Download the upstream 4.25 Practical-RIFE weights and install them with: "
f'python "{installer}" --weights-file path\\to\\flownet.pkl '
"or "
f'python "{installer}" --weights-archive path\\to\\rife-4.25.zip'
)
def main() -> None:
args = _parse_args()
quantized = args.mode == "quantized"
weights_path = Path(args.weights_path)
if not weights_path.exists():
raise FileNotFoundError(_manual_install_message(weights_path))
mode_root, python_dir, scenario_dir = mode_artifact_dirs(
DEFAULT_ARTIFACTS_ROOT, args.mode
)
reset_generated_artifact_dirs(mode_root, scenario_dir)
input0_path, input1_path, timestep_path = _prepare_inputs(
input0_source=Path(args.input0),
input1_source=Path(args.input1),
timestep_path=args.timestep_path,
python_dir=python_dir,
)
image0 = _load_rgba_image(input0_path)
image1 = _load_rgba_image(input1_path)
timestep = torch.from_numpy(np.load(timestep_path))
model = ModelPracticalRife(weights_path).eval()
output_path = python_dir / "output_model_practical_rife.png"
with torch.no_grad():
out = model(image0, image1, timestep)
_save_rgba_image(out, output_path)
model_inputs = (image0, image1, timestep)
edge_program, example_inputs, io_quant_params = export_model(
model,
model_inputs,
artifacts_root=scenario_dir,
intermediate_artifacts_root=mode_root,
quantized_input_qspecs=DEFAULT_QUANTIZED_INPUT_QSPECS if quantized else [],
quantized_output_qspecs=DEFAULT_QUANTIZED_OUTPUT_QSPECS if quantized else [],
module_name_quant_configs=(
{
PRACTICAL_RIFE_BASE_GRID_MODULE_NAME: _shape_only_grid_position_quant_config(),
PRACTICAL_RIFE_FLOW_TO_GRID_MODULE_NAME: _grid_position_module_quant_config(),
PRACTICAL_RIFE_GRID_COMBINE_MODULE_NAME: _grid_position_combine_quant_config(),
}
if quantized
else {}
),
grid_sampler_grid_input_qspec=(
_grid_position_int16_snorm_qspec() if quantized else None
),
enable_pt2e_quantization=quantized,
)
scenario_path = build_scenario_from_edge_program(
edge_program,
scenario_dir / DEFAULT_SCENARIO_FILENAME,
example_inputs,
io_quant_params=io_quant_params,
)
print()
print()
print(f"Scenario: {scenario_path}")
print("Run the scenario when ready with:")
print(
f'scenario-runner --scenario "{scenario_path}" '
f'--output "{scenario_dir}" --log-level debug'
)
if __name__ == "__main__":
main()