Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d251037
Support PSNR algorithm in checkbox-support
baconYao May 2, 2024
23b187e
Added opencv_python to checkbox support tox.ini
fernando79513 May 22, 2024
ef8596c
Merge remote-tracking branch 'origin/main' into implement-psnr-algori…
fernando79513 May 22, 2024
629e1c5
Update the unittest of PSNR
baconYao May 29, 2024
828cc4c
Handle the divide zero case and improve unittest
baconYao May 29, 2024
6400e9f
Update checkbox-support/checkbox_support/scripts/psnr.py
baconYao Jun 14, 2024
b04f980
Update checkbox-support/checkbox_support/scripts/psnr.py
baconYao Jun 14, 2024
42bc6eb
Update checkbox-support/checkbox_support/scripts/psnr.py
baconYao Jun 14, 2024
51415d1
Update checkbox-support/checkbox_support/tests/test_psnr.py
baconYao Jun 14, 2024
8c29ce3
Update checkbox-support/checkbox_support/tests/test_psnr.py
baconYao Jun 14, 2024
a3bcda7
Update checkbox-support/checkbox_support/tests/test_psnr.py
baconYao Jun 14, 2024
2e5dc2e
Update checkbox-support/checkbox_support/tests/test_psnr.py
baconYao Jun 14, 2024
85ad54a
Update checkbox-support/checkbox_support/tests/test_psnr.py
baconYao Jun 14, 2024
a9b47cc
Remove psnr from toml file
baconYao Jun 14, 2024
92f27db
Refactor
baconYao Jun 14, 2024
72af0bc
Patch sys.argv to mock args
baconYao Jul 18, 2024
605e50d
Fixed numpy version in python 10
fernando79513 Jul 18, 2024
c6b12bf
Fixed extension of files in test
fernando79513 Jul 18, 2024
74c7c93
Add test_similar_images case
baconYao Jul 18, 2024
9283f05
Fix black issue
baconYao Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions checkbox-support/checkbox_support/scripts/psnr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# This file is part of Checkbox.
#
# Copyright 2024 Canonical Ltd.
# Written by:
# Patrick Chang <patrick.chang@canonical.com>
#
# Checkbox is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3,
# as published by the Free Software Foundation.
#
# Checkbox is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Checkbox. If not, see <http://www.gnu.org/licenses/>.
#
# Reference the PSNR algorithm below
# - https://docs.opencv.org/3.4/d5/dc4/tutorial_video_input_psnr_ssim.html

import cv2
import numpy as np
import argparse
from typing import Tuple, List


def psnr_args() -> argparse.ArgumentParser:
"""
Create and configure the argument parser for the PSNR calculation script.

Returns:
ArgumentParser: The configured argument parser
"""
parser = argparse.ArgumentParser(
description=(
"Calculate PSNR between two files." " File can be image or video."
)
)
parser.add_argument(
"reference_file", type=str, help="Path to the reference file"
)
parser.add_argument("test_file", type=str, help="Path to the test file")
parser.add_argument(
"-s",
"--show_psnr_each_frame",
action="store_true",
default=False,
help="Absolutely always show command output",
)
return parser.parse_args()


def _get_psnr(I1: np.ndarray, I2: np.ndarray) -> float:
"""
Calculate the Peak Signal-to-Noise Ratio (PSNR) between two frames.

Args:
I1 (np.ndarray): Reference frame.
I2 (np.ndarray): Frame to be compared with the reference.

Returns:
float: PSNR value indicating the quality of I2 compared to I1.
"""
# Calculate the absolute difference
s1 = cv2.absdiff(I1, I2)
# cannot make a square on 8 bits
s1 = np.float32(s1)
# Calculate squared differences
s1 = s1 * s1
# Sum of squared differences per channel
sse = s1.sum()
# sum channels
if sse <= 1e-10:
# for small values return zero
return 0.0
else:
shape = I1.shape
mse = 1.0 * sse / (shape[0] * shape[1] * shape[2])
psnr = 10.0 * np.log10((255 * 255) / mse)
return psnr


def _get_frame_resolution(capture) -> Tuple[int, int]:
return (
int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)),
)


def get_average_psnr(
reference_file_path: str, test_file_path: str
) -> Tuple[float, List[float]]:
"""
Calculate the average PSNR and PSNR for each frame between two files.
Files can be image or video.

Args:
reference_file_path (str): Path to the reference file.
test_file_path (str): Path to the test file.

Returns:
Tuple[float, List[float]]: A tuple containing the average PSNR value
and a list of PSNR values for each frame.
"""
capt_refrnc = cv2.VideoCapture(reference_file_path)
capt_undTst = cv2.VideoCapture(test_file_path)

if not capt_refrnc.isOpened() or not capt_undTst.isOpened():
raise SystemExit("Error: Could not open reference or test file.")

ref_size = _get_frame_resolution(capt_refrnc)
test_size = _get_frame_resolution(capt_undTst)

if ref_size != test_size:
raise SystemExit("Error: Files have different dimensions.")

total_frame_count = int(capt_refrnc.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frame_count == 0:
raise SystemExit("Error: The count of frame should not be 0")

psnr_each_frame = []
for _ in range(total_frame_count):
_, frameReference = capt_refrnc.read()
_, frameUnderTest = capt_undTst.read()
psnr = _get_psnr(frameReference, frameUnderTest)
psnr_each_frame.append(psnr)

psnr_array = np.array(psnr_each_frame)
avg_psnr = np.mean(psnr)
return avg_psnr, psnr_array


def main() -> None:
args = psnr_args()
avg_psnr, psnr_each_frame = get_average_psnr(
args.reference_file, args.test_file
)
print("Average PSNR: ", avg_psnr)
if args.show_psnr_each_frame:
print("PSNR each frame: ", psnr_each_frame)


if __name__ == "__main__":
main()
204 changes: 204 additions & 0 deletions checkbox-support/checkbox_support/tests/test_psnr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import unittest
import numpy as np
from unittest.mock import patch, MagicMock
from argparse import Namespace
from io import StringIO

from checkbox_support.scripts.psnr import (
main,
psnr_args,
_get_psnr,
get_average_psnr,
_get_frame_resolution,
)


def create_image_helper_function(width, height, color):
"""Creates an image with the specified color."""
return np.full((height, width, 3), color, dtype=np.uint8)


class TestPSNRArgs(unittest.TestCase):

def test_psnr_args_with_defaults(self):
with patch(
"sys.argv",
["psnr.py", "ref.mp4", "test.mp4"],
):
args = psnr_args()
self.assertEqual(args.reference_file, "ref.mp4")
self.assertEqual(args.test_file, "test.mp4")
self.assertFalse(args.show_psnr_each_frame)

def test_psnr_args_with_custom_args(self):
with patch(
"sys.argv",
["psnr.py", "ref.mp4", "test.mp4", "-s"],
):
args = psnr_args()
self.assertEqual(args.reference_file, "ref.mp4")
self.assertEqual(args.test_file, "test.mp4")
self.assertTrue(args.show_psnr_each_frame)


class TestGetFrameResolution(unittest.TestCase):
@patch("checkbox_support.scripts.psnr.cv2.VideoCapture")
def test_get_frame_resolution(self, mock_videocapture):
mock_capture = mock_videocapture.return_value
mock_capture.get.side_effect = [100, 200]
width, height = _get_frame_resolution(mock_capture)
self.assertEqual(width, 100)
self.assertEqual(height, 200)


class TestGetPSNR(unittest.TestCase):

def create_image(self, width, height, color):
"""Creates an image with the specified color."""
return np.full((height, width, 3), color, dtype=np.uint8)

def test_identical_images(self):
img1 = self.create_image(100, 100, 255)
img2 = self.create_image(100, 100, 255)
self.assertEqual(_get_psnr(img1, img2), 0.0)

def test_different_images(self):
img1 = self.create_image(100, 100, 255)
img2 = self.create_image(100, 100, 0)
self.assertNotEqual(_get_psnr(img1, img2), 0.0)
self.assertLessEqual(_get_psnr(img1, img2), 50.0)

def test_similar_images(self):
img1 = self.create_image(100, 100, 125)
img2 = self.create_image(100, 100, 125)
img2[0:10, 0:10] = [120, 120, 120]
self.assertGreaterEqual(_get_psnr(img1, img2), 50.0)


class TestGetAveragePSNR(unittest.TestCase):
@patch("checkbox_support.scripts.psnr.cv2.VideoCapture")
def test_get_average_psnr_file_not_found(self, mock_vc):
mock_vc.return_value.isOpened.return_value = False
with self.assertRaises(SystemExit):
get_average_psnr("nonexistent_file.mp4", "test_file.mp4")

@patch("checkbox_support.scripts.psnr.cv2.VideoCapture")
def test_zero_frames(self, mock_vc):
mock_vc.return_value.isOpened.return_value = True
mock_vc.return_value.get.return_value = 0 # Zero frames in the video
with self.assertRaises(SystemExit):
get_average_psnr("ref.mp4", "test.mp4")

with self.assertRaises(SystemExit):
get_average_psnr("ref.mp4", "test.mp4")

@patch("checkbox_support.scripts.psnr._get_frame_resolution")
@patch("checkbox_support.scripts.psnr.cv2.VideoCapture")
def test_get_average_psnr_different_dimensions(
self, mock_vc, mock_get_frame_resolution
):
mock_vc.return_value.isOpened.return_value = True
mock_get_frame_resolution.side_effect = [(100, 100), (100, 150)]

with self.assertRaises(SystemExit):
get_average_psnr("ref_file.mp4", "test_file.mp4")

@patch("checkbox_support.scripts.psnr._get_psnr")
@patch("checkbox_support.scripts.psnr._get_frame_resolution")
@patch("checkbox_support.scripts.psnr.cv2.VideoCapture")
def test_get_average_psnr(
self, mock_VideoCapture, mock_get_frame_resolution, mock_get_psnr
):
# Setup
reference_file_path = "reference.mp4"
test_file_path = "test.mp4"
total_frame_count = 5
mock_capt_refrnc = MagicMock()
mock_capt_undTst = MagicMock()

mock_VideoCapture.side_effect = [mock_capt_refrnc, mock_capt_undTst]

mock_capt_refrnc.isOpened.return_value = True
mock_capt_undTst.isOpened.return_value = True

mock_get_frame_resolution.return_value = (1920, 1080)

mock_capt_refrnc.get.return_value = total_frame_count

mock_capt_refrnc.read.return_value = (True, "frameReference")
mock_capt_undTst.read.return_value = (True, "frameUnderTest")

mock_get_psnr.return_value = 30

# Code under test
avg_psnr, psnr_array = get_average_psnr(
reference_file_path, test_file_path
)

# Assertions
expected_psnr_array = np.array([30] * total_frame_count)
expected_avg_psnr = np.mean(expected_psnr_array)

self.assertEqual(len(psnr_array), total_frame_count)
self.assertTrue(np.array_equal(psnr_array, expected_psnr_array))
self.assertEqual(avg_psnr, expected_avg_psnr)

# Ensure mocks were called correctly
mock_VideoCapture.assert_any_call(reference_file_path)
mock_VideoCapture.assert_any_call(test_file_path)
self.assertEqual(mock_capt_refrnc.isOpened.call_count, 1)
self.assertEqual(mock_capt_undTst.isOpened.call_count, 1)
self.assertEqual(mock_get_frame_resolution.call_count, 2)
self.assertEqual(mock_capt_refrnc.get.call_count, 1)
self.assertEqual(mock_capt_refrnc.read.call_count, total_frame_count)
self.assertEqual(mock_capt_undTst.read.call_count, total_frame_count)
self.assertEqual(mock_get_psnr.call_count, total_frame_count)


class TestMainFunction(unittest.TestCase):

@patch("sys.stdout", new_callable=StringIO)
@patch("checkbox_support.scripts.psnr.get_average_psnr")
@patch("checkbox_support.scripts.psnr.argparse.ArgumentParser.parse_args")
def test_main_prints_avg_psnr(
self, mock_parse_args, mock_get_average_psnr, mock_stdout
):
mock_parse_args.return_value = Namespace(
reference_file="ref.mp4",
test_file="test.mp4",
show_psnr_each_frame=False,
)

mock_get_average_psnr.return_value = (30.0, [28.5, 31.2, 29.8])

main()

expected_output = "Average PSNR: 30.0\n"
self.assertEqual(mock_stdout.getvalue(), expected_output)
mock_get_average_psnr.assert_called_once_with("ref.mp4", "test.mp4")

@patch("sys.stdout", new_callable=StringIO)
@patch("checkbox_support.scripts.psnr.get_average_psnr")
@patch("checkbox_support.scripts.psnr.argparse.ArgumentParser.parse_args")
def test_main_prints_psnr_each_frame(
self, mock_parse_args, mock_get_average_psnr, mock_stdout
):
mock_parse_args.return_value = Namespace(
reference_file="ref_file",
test_file="test_file",
show_psnr_each_frame=True,
)

mock_get_average_psnr.return_value = (30.0, [28.5, 31.2, 29.8])

main()

expected_output = (
"Average PSNR: 30.0\nPSNR each frame: [28.5, 31.2, 29.8]\n"
)
self.assertEqual(mock_stdout.getvalue(), expected_output)
mock_get_average_psnr.assert_called_once_with("ref_file", "test_file")


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions checkbox-support/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ deps =
coverage == 5.5
pytest-cov == 2.12.1
requests == 2.9.1
opencv_python == 4.4.0.42
pyparsing == 2.0.3
distro == 1.0.1
PyYAML == 3.11
Expand All @@ -32,6 +33,7 @@ deps =
coverage == 5.5
pytest-cov == 3.0.0
requests == 2.18.4
opencv_python == 4.4.0.42
pyparsing == 2.2.0
distro == 1.0.1
PyYAML == 3.12
Expand All @@ -43,6 +45,7 @@ deps =
coverage == 7.3.0
pytest-cov == 4.1.0
requests == 2.22.0
opencv_python == 4.8.1.78
pyparsing == 2.4.6
distro == 1.4.0
PyYAML == 5.3.1
Expand All @@ -53,6 +56,8 @@ deps =
coverage == 7.3.0
pytest-cov == 4.1.0
requests == 2.25.1
opencv_python == 4.8.1.78
numpy == 1.26.4
pyparsing == 2.4.7
distro == 1.7.0
PyYAML == 6.0.1
Expand Down