Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion source/qdk_package/qdk/qre/application/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from ._cirq import CirqApplication
from ._cirq import CirqApplication, CirqApplicationParams
from ._qir import QIRApplication
from ._qsharp import QSharpApplication
from ._openqasm import OpenQASMApplication

__all__ = [
"CirqApplication",
"CirqApplicationParams",
"QIRApplication",
"QSharpApplication",
"OpenQASMApplication",
Expand Down
34 changes: 30 additions & 4 deletions source/qdk_package/qdk/qre/application/_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field

import cirq

Expand All @@ -15,7 +15,22 @@


@dataclass
class CirqApplication(Application[None]):
class CirqApplicationParams:
"""Application parameters that control how the resource estimation trace is generated from a Cirq circuit.

Args:
track_memory_qubits (bool): When True, memory qubits are tracked
separately from compute qubits. When False, all qubits are treated
as compute qubits. Also, if True, read-from-memory and
write-to-memory instructions are preserved in the trace, otherwise,
they are decompsed into SWAP and RESET instructions. Defaults to
True.
"""
track_memory_qubits: bool = field(default=True, metadata={"domain": [True]})


@dataclass
class CirqApplication(Application[CirqApplicationParams]):
"""Application that produces a resource estimation trace from a Cirq circuit.

Accepts either a Cirq ``Circuit`` object or an OpenQASM string. When a
Expand All @@ -26,10 +41,16 @@ class CirqApplication(Application[None]):
circuit_or_qasm: A Cirq Circuit or an OpenQASM string.
classical_control_probability: Probability that a classically
controlled operation is included in the trace. Defaults to 0.5.
rotation_threshold: Rotation exponents with absolute value below
this threshold are treated as identity and omitted from the
trace. This applies to single-qubit rotations (RX, RY, RZ) as
well as to the rotation components of controlled-Z
decompositions. Defaults to 1e-6.
"""

circuit_or_qasm: str | cirq.CIRCUIT_LIKE
classical_control_probability: float = 0.5
rotation_threshold: float = 1e-6

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sufficient? I would think this should be tied to the rotation synthesis error rate (although that is computed lower in the stack).


def __post_init__(self):
telemetry_events.on_qre_application_created("CirqApplication")
Expand All @@ -46,7 +67,7 @@ def __post_init__(self):
else:
self._circuit = self.circuit_or_qasm

def get_trace(self, parameters: None = None) -> Trace:
def get_trace(self, parameters: CirqApplicationParams = CirqApplicationParams()) -> Trace:
"""Return the resource estimation trace for the Cirq circuit.

Args:
Expand All @@ -55,4 +76,9 @@ def get_trace(self, parameters: None = None) -> Trace:
Returns:
Trace: The resource estimation trace.
"""
return trace_from_cirq(self._circuit)
return trace_from_cirq(
self._circuit,
classical_control_probability=self.classical_control_probability,
rotation_threshold=self.rotation_threshold,
track_memory_qubits=parameters.track_memory_qubits,
)
26 changes: 16 additions & 10 deletions source/qdk_package/tests/qre/test_cirq_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
cirq = pytest.importorskip("cirq")

from qdk.qre import PSSPC
from qdk.qre.application import CirqApplication
from qdk.qre.interop import trace_from_cirq
from qdk.qre.application import CirqApplication, CirqApplicationParams
from qdk.qre.interop._cirq import (
TypedQubit,
QubitType,
Expand Down Expand Up @@ -124,7 +123,8 @@ def _make_memory_circuit(*ops):
def test_write_to_memory_memory_compute_true():
"""Test WriteToMemoryGate produces WRITE_TO_MEMORY instructions when memory_compute is True."""
circuit = _make_memory_circuit(write_to_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=True)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True))

assert trace.compute_qubits == 2
assert trace.memory_qubits == 2
Expand All @@ -141,7 +141,8 @@ def test_write_to_memory_memory_compute_true():
def test_write_to_memory_memory_compute_false():
"""Test WriteToMemoryGate decomposes into SWAPs when memory_compute is False."""
circuit = _make_memory_circuit(write_to_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=False)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False))

assert trace.compute_qubits == 4
assert trace.memory_qubits is None
Expand All @@ -158,7 +159,8 @@ def test_write_to_memory_memory_compute_false():
def test_read_from_memory_memory_compute_true():
"""Test ReadFromMemoryGate produces READ_FROM_MEMORY instructions when memory_compute is True."""
circuit = _make_memory_circuit(read_from_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=True)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True))

assert trace.compute_qubits == 2
assert trace.memory_qubits == 2
Expand All @@ -175,7 +177,8 @@ def test_read_from_memory_memory_compute_true():
def test_read_from_memory_memory_compute_false():
"""Test ReadFromMemoryGate decomposes into SWAPs when memory_compute is False."""
circuit = _make_memory_circuit(read_from_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=False)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False))

assert trace.compute_qubits == 4
assert trace.memory_qubits is None
Expand All @@ -192,7 +195,8 @@ def test_read_from_memory_memory_compute_false():
def test_read_write_memory_round_trip_memory_compute_true():
"""Test a write followed by a read produces both instruction types with memory_compute True."""
circuit = _make_memory_circuit(write_to_memory, read_from_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=True)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True))

assert trace.compute_qubits == 2
assert trace.memory_qubits == 2
Expand All @@ -209,7 +213,8 @@ def test_read_write_memory_round_trip_memory_compute_true():
def test_read_write_memory_round_trip_memory_compute_false():
"""Test a write followed by a read decomposes fully with memory_compute False."""
circuit = _make_memory_circuit(write_to_memory, read_from_memory)
trace = trace_from_cirq(circuit, track_memory_qubits=False)
app = CirqApplication(circuit)
trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False))

assert trace.compute_qubits == 4
assert trace.memory_qubits is None
Expand All @@ -227,8 +232,9 @@ def test_plain_circuit_unaffected_by_memory_compute():
"""Test that memory_compute has no effect on circuits without memory qubits."""
circuit = cirq.H.on_each(*cirq.LineQubit.range(3))

trace_true = trace_from_cirq(circuit, track_memory_qubits=True)
trace_false = trace_from_cirq(circuit, track_memory_qubits=False)
app = CirqApplication(circuit)
trace_true = app.get_trace(CirqApplicationParams(track_memory_qubits=True))
trace_false = app.get_trace(CirqApplicationParams(track_memory_qubits=False))

assert trace_true.compute_qubits == trace_false.compute_qubits == 3
assert trace_true.memory_qubits is None
Expand Down
14 changes: 14 additions & 0 deletions source/qdk_package/tests/qre/test_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ class BoolConfig:
assert instances[1].flag is False


def test_enumerate_instances_bool_with_domain():
"""Test that boolean fields with a domain only enumerate specified values."""
from qdk.qre._enumeration import _enumerate_instances

@dataclass
class BoolConfig:
_: KW_ONLY
flag: bool = field(default=True, metadata={"domain": [True]})

instances = list(_enumerate_instances(BoolConfig))
assert len(instances) == 1
assert instances[0].flag is True


def test_enumerate_instances_enum():
"""Test that Enum dataclass fields enumerate all members."""
from qdk.qre._enumeration import _enumerate_instances
Expand Down
Loading