Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument;
the device will retry circuits that do not complete within the timeout.
A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits.

Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device

>> print(remote_device.circuits[0])

Device options
~~~~~~~~~~~~~~

Expand Down
21 changes: 20 additions & 1 deletion src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def __init__(
self._parallel = parallel
self._max_parallel = max_parallel
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []
self._noise_model = noise_model
self._parametrize_differentiable = parametrize_differentiable
self._run_kwargs = run_kwargs
Expand All @@ -179,7 +181,9 @@ def __init__(
def reset(self):
super().reset()
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []

@property
def operations(self) -> frozenset[str]:
Expand All @@ -195,11 +199,21 @@ def circuit(self) -> Circuit:
"""Circuit: The last circuit run on this device."""
return self._circuit

@property
def circuits(self) -> list[Circuit]:
"""Circuit: The circuits run on this device."""
return self._circuits

@property
def task(self) -> QuantumTask:
"""QuantumTask: The task corresponding to the last run circuit."""
return self._task

@property
def tasks(self) -> list[QuantumTask]:
"""The tasks corresponding to the circuits run on this device."""
return self._tasks

@property
def parallel(self) -> bool:
"""bool: Whether the device supports parallel execution of batches."""
Expand Down Expand Up @@ -686,6 +700,8 @@ def __init__(
self._poll_interval_seconds = poll_interval_seconds
self._max_connections = max_connections
self._max_retries = max_retries
self._circuits = []
self._tasks = []

@property
def use_grouping(self) -> bool:
Expand All @@ -698,6 +714,8 @@ def use_grouping(self) -> bool:
return not ("provides_jacobian" in caps and caps["provides_jacobian"])

def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs):
self._circuits = braket_circuits
batch_shots = 0 if self.analytic else self.shots
if self._supports_program_sets:
program_set = (
ProgramSet.zip(braket_circuits, input_sets=inputs)
Expand All @@ -712,6 +730,7 @@ def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs
poll_interval_seconds=self._poll_interval_seconds,
**self._run_kwargs,
)
self._tasks = [task]
return self._braket_program_set_to_pl_result(task.result(), pl_circuits)
task_batch = self._device.run_batch(
braket_circuits,
Expand All @@ -724,7 +743,7 @@ def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs
inputs=inputs,
**self._run_kwargs,
)

self._tasks = task_batch.tasks
# Call results() to retrieve the Braket results in parallel.
try:
braket_results_batch = task_batch.results(
Expand Down
24 changes: 24 additions & 0 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,15 @@ def test_reset():
"""Tests that the members of the device are cleared on reset."""
dev = _aws_device(wires=2)
dev._circuit = CIRCUIT
dev._circuits = [CIRCUIT, CIRCUIT]
dev._task = TASK
dev._tasks = [TASK, TASK]

dev.reset()
assert dev.circuit is None
assert dev.circuits == []
assert dev.task is None
assert dev.tasks == []


def test_apply():
Expand Down Expand Up @@ -1115,6 +1119,24 @@ def test_batch_execute_program_set_noncommuting():

@patch.object(AwsDevice, "properties", new_callable=mock.PropertyMock)
@patch.object(AwsDevice, "run_batch")
def test_aws_device_batch_execute_parallel_circuits_persistance(mock_run_batch):
mock_run_batch.return_value = TASK_BATCH
dev = _aws_device(wires=4, foo="bar", parallel=True)
assert dev.parallel is True

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.probs(wires=[0])
qml.expval(qml.PauliX(1))
qml.var(qml.PauliY(2))
qml.sample(qml.PauliZ(3))

circuits = [circuit, circuit]
dev.batch_execute(circuits)
assert dev.circuits[1]


def test_aws_device_batch_execute_parallel(mock_run_batch, mock_properties):
"""Test batch_execute(parallel=True) correctly calls batch
execution methods for AwsDevices in Braket SDK"""
Expand All @@ -1135,6 +1157,8 @@ def test_aws_device_batch_execute_parallel(mock_run_batch, mock_properties):

circuits = [circuit, circuit]
batch_results = dev.batch_execute(circuits)

assert dev.tasks[0]
for results in batch_results:
assert np.allclose(
results[0],
Expand Down
Loading