Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 149 additions & 37 deletions pycvvdp/run_cvvdp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Command-line interface for ColorVideoVDP.

import os, sys
import csv
import os.path
import argparse
import logging
Expand All @@ -12,7 +13,6 @@
import imageio.v2 as imageio
import re
import inspect
import traceback

from pycvvdp.vq_metric import vq_metric_dict

Expand Down Expand Up @@ -77,6 +77,107 @@ def np2img(np_srgb, imgfile):

imageio.imwrite( imgfile, (np.clip(np_srgb,0.0,1.0)[0,...]*255.0).astype(np.uint8) )


# -----------------------------------
# Per-frame CSV helpers
# -----------------------------------

def _to_numpy_1d(value):
if value is None:
return None

if torch.is_tensor(value):
arr = value.detach().cpu().numpy()
elif isinstance(value, np.ndarray):
arr = value
elif isinstance(value, (list, tuple)):
arr = np.asarray(value)
else:
return None

if arr.ndim == 1:
return arr.astype(np.float64, copy=False)

return None


def extract_per_frame_jod_series(stats):
if stats is None:
return None

preferred_keys = [
"Q_per_frame",
"q_per_frame",
"JOD_per_frame",
"jod_per_frame",
"Q_frames",
"q_frames",
]

for key in preferred_keys:
if key in stats:
arr = _to_numpy_1d(stats[key])
if arr is not None:
return arr

if "Q_per_ch" in stats:
value = stats["Q_per_ch"]

if torch.is_tensor(value):
arr = value.detach().cpu().numpy()
elif isinstance(value, np.ndarray):
arr = value
else:
arr = np.asarray(value)

rho = stats.get("rho_band")
if torch.is_tensor(rho):
rho = rho.detach().cpu().numpy()
elif rho is not None and not isinstance(rho, np.ndarray):
rho = np.asarray(rho)

# Expected shape in your build: (1, channels, frames, bands)
if arr.ndim == 4:
if rho is not None and rho.ndim == 1 and rho.shape[0] == arr.shape[3]:
weighted = arr * rho.reshape(1, 1, 1, -1)
denom = float(np.sum(rho))
if denom > 0:
return (weighted.sum(axis=3) / denom).mean(axis=(0, 1)).astype(np.float64, copy=False)
return arr.mean(axis=(0, 1, 3)).astype(np.float64, copy=False)

# Fallbacks for possible future variants
if arr.ndim == 3:
if rho is not None and rho.ndim == 1 and rho.shape[0] == arr.shape[2]:
weighted = arr * rho.reshape(1, 1, -1)
denom = float(np.sum(rho))
if denom > 0:
return (weighted.sum(axis=2) / denom).mean(axis=0).astype(np.float64, copy=False)
return arr.mean(axis=(0, 2)).astype(np.float64, copy=False)

if arr.ndim == 2:
return arr.mean(axis=0).astype(np.float64, copy=False)

return None


def per_frame_csv_path(base_path, pair_index, multi_pair):
if not multi_pair:
return base_path

root, ext = os.path.splitext(base_path)
suffix = f"_{pair_index:04d}"
if ext:
return root + suffix + ext
return base_path + suffix + ".csv"


def write_per_frame_csv(csv_path, per_frame_values):
with open(csv_path, "w", newline="") as fh:
writer = csv.writer(fh)
writer.writerow(["frame", "jod"])
for idx, value in enumerate(per_frame_values):
writer.writerow([idx, float(value)])

# -----------------------------------
# Command-line Arguments
# -----------------------------------
Expand All @@ -93,22 +194,21 @@ def parse_args(arg_list=None):
parser.add_argument("-x", "--features", action='store_true', default=False, help="generate JSON files with extracted features. Useful for retraining the metric.")
parser.add_argument("-o", "--output-dir", type=str, default=None, help="in which directory heatmaps and feature files should be stored (the default is the current directory)")
parser.add_argument("--result", type=str, default=None, help="write metric prediction results to a CSV file passed as an argument.")
parser.add_argument("--per-frame-csv", type=str, default=None, help="write per-frame JOD to a CSV file when the metric exposes a per-frame series.")
parser.add_argument("-c", "--config-paths", type=str, nargs='+', default=[], help="One or more paths to configuration files or directories. The main configurations files are `display_models.json`, `color_spaces.json` and `cvvdp_parameters.json`. The file name must start as the name of the original config file.")
parser.add_argument("-d", "--display", type=str, default="standard_4k", help="display name, e.g. 'HTC Vive', or ? to print the list of models.")
parser.add_argument("-n", "--nframes", type=int, default=-1, help="the number of video frames you want to compare")
parser.add_argument("--count-frames", action='store_true', default=False, help="Use accurate method to count frames in a video. Slower but accurate. Use if you see frame read errors.")
parser.add_argument("-f", "--full-screen-resize", choices=['bilinear', 'bicubic', 'nearest', 'area'], default=None, help="Both test and reference videos will be resized to match the full resolution of the display. Currently works only with videos.")
parser.add_argument("-m", "--metric", choices=available_metrics, nargs='+', default=['cvvdp'], help='Select which metric(s) to run')
parser.add_argument("--temp-padding", choices=['replicate', 'symmetric', 'valid'], default='symmetric', help='How to pad the video in the time domain (for the temporal filters). "replicate" - repeat the first frame. "symmetric" - mirror the first frames. "valid" - skip initial frames until buffer is fully populated (no padding artifacts).')
parser.add_argument("--temp-padding", choices=['replicate', 'circular', 'pingpong'], default='replicate', help='How to pad the video in the time domain (for the temporal filters). "replicate" - repeat the first frame. "pingpong" - mirror the first frames. "circular" - take the last frames.')
parser.add_argument("--pix-per-deg", type=float, default=None, help='Overwrite display geometry and use the provided pixels per degree value.')
parser.add_argument("--fps", type=float, default=None, help='Frames per second. It will overwrite frame rate stores in the video file. Required when passing an array of image files.')
parser.add_argument("--frames", type=str, default=None, help='Range of frames specified as first:step:last, first:last, or first: (Matlab notation). Currently works only with frames provided as images.')
parser.add_argument("--gpu-mem", type=float, default=None, help='How much GPU memory can we use in GB. Use if CUDA reports out of mem errors, or you want to run multiple instances at the same time.')
parser.add_argument("-q", "--quiet", action='store_true', default=False, help="Do not print any information but the final JOD value. Warning message will be still printed.")
parser.add_argument("-v", "--verbose", action='store_true', default=False, help="Print out extra information.")
parser.add_argument("--debug", action='store_true', default=False, help="Prints full stack trace when error is encountered.")
parser.add_argument("--ffmpeg-cc", action='store_true', default=False, help="Use ffmpeg for upsampling and color conversion. Use custom pytorch code by default (faster and less memory).")
parser.add_argument("--temp-resample", type=float, nargs="?", default=-1, const=0, help="Resample test and reference video to a common frame rate. Allows to compare videos of different frame rates. An optional argument - the maximum frame rate used when resampling.")
parser.add_argument("--temp-resample", action='store_true', default=False, help="Resample test and reference video to a common frame rate. Allows to compare videos of different frame rates.")
parser.add_argument("-i", "--interactive", action='store_true', default=False, help="Run in an interactive mode, in which command line arguments are provided to the standard input, line by line. Saves on start-up time when running a large number of comparisons.")
parser.add_argument("--dump-channels", nargs='+', choices=['temporal', 'lpyr', 'difference'], default=None, help="Output video/images with intermediate processing stages (for debugging and visualization).")
if arg_list is not None:
Expand Down Expand Up @@ -265,8 +365,6 @@ def run_on_args(args):
met_args['gpu_mem'] = args.gpu_mem
if 'dump_channels' in constructor_args:
met_args['dump_channels'] = dump_channels
if 'quiet' in constructor_args:
met_args['quiet'] = args.quiet
fv = metric_class(**met_args)
fv.train(False)
metrics.append( fv )
Expand Down Expand Up @@ -294,20 +392,15 @@ def run_on_args(args):
logging.info(f"Predicting the quality of '{test_file}' compared to '{ref_file}'")
for mm in metrics:
preload = False if args.temp_padding == 'replicate' else True

nframes = -2 if args.count_frames else args.nframes

with torch.no_grad():

if args.temp_resample>=0:
if args.temp_resample>0:
pycvvdp.video_source_temp_resample_file.max_fps = args.temp_resample
if args.temp_resample:
vs = pycvvdp.video_source_temp_resample_file( test_file, ref_file,
display_photometry=display_photometry,
config_paths=args.config_paths,
full_screen_resize=args.full_screen_resize,
resize_resolution=display_geometry.resolution,
frames=nframes,
frames=args.nframes,
ffmpeg_cc=args.ffmpeg_cc,
verbose=args.verbose )
else:
Expand All @@ -316,7 +409,7 @@ def run_on_args(args):
config_paths=args.config_paths,
full_screen_resize=args.full_screen_resize,
resize_resolution=display_geometry.resolution,
frames=nframes,
frames=args.nframes,
fps=args.fps,
frame_range=frame_range,
preload=preload,
Expand All @@ -328,14 +421,34 @@ def run_on_args(args):
mm.set_base_fname(base_fname)

Q_pred, stats = mm.predict_video_source(vs)
Q_pred_scalar = Q_pred.item()
if not stats is None:
print("STATS KEYS:", list(stats.keys()))
for k, v in stats.items():
try:
shape = tuple(v.shape)
except Exception:
shape = type(v).__name__
print(f"STAT {k}: {shape}")
if args.quiet:
print( "{Q:0.4f}".format(Q=Q_pred_scalar) )
print(f"JOD: {Q_pred:0.4f}")
else:
units_str = f" [{mm.quality_unit()}]"
print( "{met_name}={Q:0.4f}{units}".format(met_name=mm.short_name(), Q=Q_pred_scalar, units=units_str) )
print("{met_name}={Q:0.4f}{units}".format(
met_name=mm.short_name(),
Q=Q_pred,
units=units_str
))
if not res_fh is None:
res_fh.write( f", {Q_pred_scalar}" )
res_fh.write( f", {Q_pred}" )

if args.per_frame_csv and not stats is None:
per_frame_values = extract_per_frame_jod_series(stats)
if per_frame_values is None:
logging.warning("Per-frame JOD series is not available in this metric version; skipping per-frame CSV export.")
else:
csv_path = per_frame_csv_path(args.per_frame_csv, kk, max(N_test, N_ref) > 1)
logging.info("Writing per-frame JOD CSV '" + csv_path + "' ...")
write_per_frame_csv(csv_path, per_frame_values)


if args.features and not stats is None:
Expand All @@ -356,11 +469,14 @@ def run_on_args(args):
logging.info("Writing heat map '" + dest_name + "' ...")
np2img(torch.squeeze(stats["heatmap"].permute((2,3,4,1,0)), dim=4).cpu().numpy(), dest_name)

if args.distogram != -1:
if args.distogram != -1 and not stats is None:
dest_name = os.path.join(out_dir, base + "_distogram.png")
logging.info("Writing distogram '" + dest_name + "' ...")
jod_max = args.distogram
mm.export_distogram( stats, dest_name, jod_max=jod_max )
try:
mm.export_distogram( stats, dest_name, jod_max=jod_max )
except NotImplementedError as e:
logging.warning( f'Metric {mm.short_name()} cannot generate distograms' )

del stats

Expand All @@ -376,23 +492,19 @@ def run_on_args(args):
def main():
args = parse_args()

try:
if args.interactive:
#print( "Running in an interactive mode" )
while True:
line = sys.stdin.readline()
if not line:
break

#print( shlex.split(line) )
args = parse_args(shlex.split(line))
run_on_args(args)
else:
if args.interactive:
#print( "Running in an interactive mode" )
while True:
line = sys.stdin.readline()
if not line:
break

#print( shlex.split(line) )
args = parse_args(shlex.split(line))
run_on_args(args)
except pycvvdp.vq_exception as ex:
logging.error( str(ex) )
if args.debug:
traceback.print_exc()
else:
run_on_args(args)


if __name__ == '__main__':
main()
Binary file added scripting_build_macOS/macOS/Install cvvdp.pdf
Binary file not shown.
Loading