diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index 3eaeab59..a5a56ca1 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -306,6 +306,10 @@ def execute(self, circuit: QuantumTape, compute_gradient=False, **run_kwargs) -> tracking_data = self._tracking_data(self._task) self.tracker.update(executions=1, shots=self.shots, **tracking_data) self.tracker.record() + + # increment counter for number of executions of device + self._num_executions += 1 + return self._braket_to_pl_result(braket_result, circuit) def apply( @@ -499,6 +503,7 @@ def batch_execute(self, circuits, **run_kwargs): # Update the tracker before raising an exception further if some circuits do not complete. finally: + self._num_executions += len(task_batch.tasks) if self.tracker.active: for task in task_batch.tasks: tracking_data = self._tracking_data(task) diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index edb878d8..26d80e96 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -501,6 +501,21 @@ def test_execute_with_gradient( assert (results[0][1] == expected_pl_result[0][1]).all() +@patch.object(AwsDevice, "run") +def test_number_executions(mock_run): + """Asserts tracker stores information during execute when active""" + mock_run.side_effect = [TASK, SIM_TASK, SIM_TASK, TASK] + dev = _aws_device(wires=4, foo="bar") + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.probs(wires=(0,)) + dev.execute(circuit) + dev.execute(circuit) + + assert dev.num_executions == 2 + + @patch.object(AwsDevice, "run") def test_execute_tracker(mock_run): """Asserts tracker stores information during execute when active"""