diff --git a/source/qdk_package/qdk/qre/application/__init__.py b/source/qdk_package/qdk/qre/application/__init__.py index f6ee4c9f08..1141711fdc 100644 --- a/source/qdk_package/qdk/qre/application/__init__.py +++ b/source/qdk_package/qdk/qre/application/__init__.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ._cirq import CirqApplication +from ._cirq import CirqApplication, CirqApplicationParams from ._qir import QIRApplication from ._qsharp import QSharpApplication from ._openqasm import OpenQASMApplication __all__ = [ "CirqApplication", + "CirqApplicationParams", "QIRApplication", "QSharpApplication", "OpenQASMApplication", diff --git a/source/qdk_package/qdk/qre/application/_cirq.py b/source/qdk_package/qdk/qre/application/_cirq.py index a49c58e317..ffbd4106ed 100644 --- a/source/qdk_package/qdk/qre/application/_cirq.py +++ b/source/qdk_package/qdk/qre/application/_cirq.py @@ -4,7 +4,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import cirq @@ -15,7 +15,22 @@ @dataclass -class CirqApplication(Application[None]): +class CirqApplicationParams: + """Application parameters that control how the resource estimation trace is generated from a Cirq circuit. + + Args: + track_memory_qubits (bool): When True, memory qubits are tracked + separately from compute qubits. When False, all qubits are treated + as compute qubits. Also, if True, read-from-memory and + write-to-memory instructions are preserved in the trace, otherwise, + they are decompsed into SWAP and RESET instructions. Defaults to + True. + """ + track_memory_qubits: bool = field(default=True, metadata={"domain": [True]}) + + +@dataclass +class CirqApplication(Application[CirqApplicationParams]): """Application that produces a resource estimation trace from a Cirq circuit. Accepts either a Cirq ``Circuit`` object or an OpenQASM string. When a @@ -26,10 +41,16 @@ class CirqApplication(Application[None]): circuit_or_qasm: A Cirq Circuit or an OpenQASM string. classical_control_probability: Probability that a classically controlled operation is included in the trace. Defaults to 0.5. + rotation_threshold: Rotation exponents with absolute value below + this threshold are treated as identity and omitted from the + trace. This applies to single-qubit rotations (RX, RY, RZ) as + well as to the rotation components of controlled-Z + decompositions. Defaults to 1e-6. """ circuit_or_qasm: str | cirq.CIRCUIT_LIKE classical_control_probability: float = 0.5 + rotation_threshold: float = 1e-6 def __post_init__(self): telemetry_events.on_qre_application_created("CirqApplication") @@ -46,7 +67,7 @@ def __post_init__(self): else: self._circuit = self.circuit_or_qasm - def get_trace(self, parameters: None = None) -> Trace: + def get_trace(self, parameters: CirqApplicationParams = CirqApplicationParams()) -> Trace: """Return the resource estimation trace for the Cirq circuit. Args: @@ -55,4 +76,9 @@ def get_trace(self, parameters: None = None) -> Trace: Returns: Trace: The resource estimation trace. """ - return trace_from_cirq(self._circuit) + return trace_from_cirq( + self._circuit, + classical_control_probability=self.classical_control_probability, + rotation_threshold=self.rotation_threshold, + track_memory_qubits=parameters.track_memory_qubits, + ) diff --git a/source/qdk_package/tests/qre/test_cirq_interop.py b/source/qdk_package/tests/qre/test_cirq_interop.py index 95bcf8fe04..ce61b982d1 100644 --- a/source/qdk_package/tests/qre/test_cirq_interop.py +++ b/source/qdk_package/tests/qre/test_cirq_interop.py @@ -6,8 +6,7 @@ cirq = pytest.importorskip("cirq") from qdk.qre import PSSPC -from qdk.qre.application import CirqApplication -from qdk.qre.interop import trace_from_cirq +from qdk.qre.application import CirqApplication, CirqApplicationParams from qdk.qre.interop._cirq import ( TypedQubit, QubitType, @@ -124,7 +123,8 @@ def _make_memory_circuit(*ops): def test_write_to_memory_memory_compute_true(): """Test WriteToMemoryGate produces WRITE_TO_MEMORY instructions when memory_compute is True.""" circuit = _make_memory_circuit(write_to_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=True) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True)) assert trace.compute_qubits == 2 assert trace.memory_qubits == 2 @@ -141,7 +141,8 @@ def test_write_to_memory_memory_compute_true(): def test_write_to_memory_memory_compute_false(): """Test WriteToMemoryGate decomposes into SWAPs when memory_compute is False.""" circuit = _make_memory_circuit(write_to_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=False) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False)) assert trace.compute_qubits == 4 assert trace.memory_qubits is None @@ -158,7 +159,8 @@ def test_write_to_memory_memory_compute_false(): def test_read_from_memory_memory_compute_true(): """Test ReadFromMemoryGate produces READ_FROM_MEMORY instructions when memory_compute is True.""" circuit = _make_memory_circuit(read_from_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=True) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True)) assert trace.compute_qubits == 2 assert trace.memory_qubits == 2 @@ -175,7 +177,8 @@ def test_read_from_memory_memory_compute_true(): def test_read_from_memory_memory_compute_false(): """Test ReadFromMemoryGate decomposes into SWAPs when memory_compute is False.""" circuit = _make_memory_circuit(read_from_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=False) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False)) assert trace.compute_qubits == 4 assert trace.memory_qubits is None @@ -192,7 +195,8 @@ def test_read_from_memory_memory_compute_false(): def test_read_write_memory_round_trip_memory_compute_true(): """Test a write followed by a read produces both instruction types with memory_compute True.""" circuit = _make_memory_circuit(write_to_memory, read_from_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=True) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=True)) assert trace.compute_qubits == 2 assert trace.memory_qubits == 2 @@ -209,7 +213,8 @@ def test_read_write_memory_round_trip_memory_compute_true(): def test_read_write_memory_round_trip_memory_compute_false(): """Test a write followed by a read decomposes fully with memory_compute False.""" circuit = _make_memory_circuit(write_to_memory, read_from_memory) - trace = trace_from_cirq(circuit, track_memory_qubits=False) + app = CirqApplication(circuit) + trace = app.get_trace(CirqApplicationParams(track_memory_qubits=False)) assert trace.compute_qubits == 4 assert trace.memory_qubits is None @@ -227,8 +232,9 @@ def test_plain_circuit_unaffected_by_memory_compute(): """Test that memory_compute has no effect on circuits without memory qubits.""" circuit = cirq.H.on_each(*cirq.LineQubit.range(3)) - trace_true = trace_from_cirq(circuit, track_memory_qubits=True) - trace_false = trace_from_cirq(circuit, track_memory_qubits=False) + app = CirqApplication(circuit) + trace_true = app.get_trace(CirqApplicationParams(track_memory_qubits=True)) + trace_false = app.get_trace(CirqApplicationParams(track_memory_qubits=False)) assert trace_true.compute_qubits == trace_false.compute_qubits == 3 assert trace_true.memory_qubits is None diff --git a/source/qdk_package/tests/qre/test_enumeration.py b/source/qdk_package/tests/qre/test_enumeration.py index d8c987f90c..3b4e156bf9 100644 --- a/source/qdk_package/tests/qre/test_enumeration.py +++ b/source/qdk_package/tests/qre/test_enumeration.py @@ -60,6 +60,20 @@ class BoolConfig: assert instances[1].flag is False +def test_enumerate_instances_bool_with_domain(): + """Test that boolean fields with a domain only enumerate specified values.""" + from qdk.qre._enumeration import _enumerate_instances + + @dataclass + class BoolConfig: + _: KW_ONLY + flag: bool = field(default=True, metadata={"domain": [True]}) + + instances = list(_enumerate_instances(BoolConfig)) + assert len(instances) == 1 + assert instances[0].flag is True + + def test_enumerate_instances_enum(): """Test that Enum dataclass fields enumerate all members.""" from qdk.qre._enumeration import _enumerate_instances