Skip to content

Commit 7177cc8

Browse files
committed
Make Parameter classes generic in the instrument type
1 parent e120911 commit 7177cc8

File tree

5 files changed

+36
-21
lines changed

5 files changed

+36
-21
lines changed

src/qcodes/instrument_drivers/AlazarTech/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
:mod:`.AlazarTech.helpers` module).
66
"""
77

8-
from typing import TYPE_CHECKING, Any, cast
8+
from typing import TYPE_CHECKING, Any
99

1010
from qcodes.parameters import Parameter, ParamRawDataType
1111

1212
if TYPE_CHECKING:
13-
from .ATS import AlazarTechATS
13+
# ruff does not detect that this is used as a generic parameter below
14+
from .ATS import AlazarTechATS # noqa: F401
1415

1516

16-
class TraceParameter(Parameter):
17+
class TraceParameter(Parameter["AlazarTechATS"]):
1718
"""
1819
A parameter that keeps track of if its value has been synced to
1920
the ``Instrument``. To achieve that, this parameter sets
@@ -38,6 +39,5 @@ def synced_to_card(self) -> bool:
3839
return self._synced_to_card
3940

4041
def set_raw(self, value: ParamRawDataType) -> None:
41-
instrument = cast("AlazarTechATS", self.instrument)
42-
instrument._parameters_synced = False
42+
self.instrument._parameters_synced = False
4343
self._synced_to_card = False

src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
import textwrap
3-
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast, overload
3+
from typing import TYPE_CHECKING, Any, Literal, NotRequired, overload
44

55
import numpy as np
66
import numpy.typing as npt
@@ -701,7 +701,7 @@ def _get_sweep_steps_parser(response: str) -> SweepSteps:
701701
"""
702702

703703

704-
class _ParameterWithStatus(Parameter):
704+
class _ParameterWithStatus(Parameter["KeysightB1517A"]):
705705
def __init__(self, *args: Any, **kwargs: Any):
706706
super().__init__(*args, **kwargs)
707707

@@ -728,7 +728,7 @@ def snapshot_base(
728728

729729
class _SpotMeasurementVoltageParameter(_ParameterWithStatus):
730730
def set_raw(self, value: ParamRawDataType) -> None:
731-
smu = cast("KeysightB1517A", self.instrument)
731+
smu = self.instrument
732732

733733
if smu._source_config["output_range"] is None:
734734
smu._source_config["output_range"] = constants.VOutputRange.AUTO
@@ -752,7 +752,7 @@ def set_raw(self, value: ParamRawDataType) -> None:
752752
)
753753

754754
def get_raw(self) -> ParamRawDataType:
755-
smu = cast("KeysightB1517A", self.instrument)
755+
smu = self.instrument
756756

757757
msg = MessageBuilder().tv(
758758
chnum=smu.channels[0],
@@ -769,7 +769,7 @@ def get_raw(self) -> ParamRawDataType:
769769

770770
class _SpotMeasurementCurrentParameter(_ParameterWithStatus):
771771
def set_raw(self, value: ParamRawDataType) -> None:
772-
smu = cast("KeysightB1517A", self.instrument)
772+
smu = self.instrument
773773

774774
if smu._source_config["output_range"] is None:
775775
smu._source_config["output_range"] = constants.IOutputRange.AUTO
@@ -793,7 +793,7 @@ def set_raw(self, value: ParamRawDataType) -> None:
793793
)
794794

795795
def get_raw(self) -> ParamRawDataType:
796-
smu = cast("KeysightB1517A", self.instrument)
796+
smu = self.instrument
797797

798798
msg = MessageBuilder().ti(
799799
chnum=smu.channels[0],

src/qcodes/instrument_drivers/tektronix/DPO7200xx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import textwrap
88
import time
99
from functools import partial
10-
from typing import TYPE_CHECKING, Any, ClassVar, cast
10+
from typing import TYPE_CHECKING, Any, ClassVar
1111

1212
import numpy as np
1313
import numpy.typing as npt
@@ -767,7 +767,7 @@ def _trigger_type(self, value: str) -> None:
767767
self.write(f"TRIGger:{self._identifier}:TYPE {value}")
768768

769769

770-
class TektronixDPOMeasurementParameter(Parameter):
770+
class TektronixDPOMeasurementParameter(Parameter["TektronixDPOMeasurement"]):
771771
"""
772772
A measurement parameter does not only return the instantaneous value
773773
of a measurement, but can also return some statistics. The accumulation
@@ -778,7 +778,7 @@ class TektronixDPOMeasurementParameter(Parameter):
778778
"""
779779

780780
def _get(self, metric: str) -> float:
781-
measurement_channel = cast("TektronixDPOMeasurement", self.instrument)
781+
measurement_channel = self.instrument
782782
if measurement_channel.type.get_latest() != self.name:
783783
measurement_channel.type(self.name)
784784

src/qcodes/parameters/parameter.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
from types import MethodType
9-
from typing import TYPE_CHECKING, Any, Literal
9+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
1010

1111
from .command import Command
1212
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
@@ -24,7 +24,15 @@
2424
log = logging.getLogger(__name__)
2525

2626

27-
class Parameter(ParameterBase):
27+
_InstrumentType_co = TypeVar(
28+
"_InstrumentType_co",
29+
bound="InstrumentBase | None",
30+
default="InstrumentBase | None",
31+
covariant=True,
32+
)
33+
34+
35+
class Parameter(ParameterBase[_InstrumentType_co], Generic[_InstrumentType_co]):
2836
"""
2937
A parameter represents a single degree of freedom. Most often,
3038
this is the standard parameter for Instruments, though it can also be
@@ -172,7 +180,7 @@ class Parameter(ParameterBase):
172180
def __init__(
173181
self,
174182
name: str,
175-
instrument: InstrumentBase | None = None,
183+
instrument: _InstrumentType_co = None,
176184
label: str | None = None,
177185
unit: str | None = None,
178186
get_cmd: str | Callable[..., Any] | Literal[False] | None = None,

src/qcodes/parameters/parameter_base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
from qcodes.instrument import InstrumentBase
4343
from qcodes.logger.instrument_logger import InstrumentLoggerAdapter
4444

45+
_InstrumentType_co = TypeVar(
46+
"_InstrumentType_co",
47+
bound="InstrumentBase | None",
48+
default="InstrumentBase | None",
49+
covariant=True,
50+
)
51+
4552
LOG = logging.getLogger(__name__)
4653

4754

@@ -109,7 +116,7 @@ def invert_val_mapping(val_mapping: Mapping[Any, Any]) -> dict[Any, Any]:
109116
return {v: k for k, v in val_mapping.items()}
110117

111118

112-
class ParameterBase(MetadatableWithName):
119+
class ParameterBase(MetadatableWithName, Generic[_InstrumentType_co]):
113120
"""
114121
Shared behavior for all parameters. Not intended to be used
115122
directly, normally you should use ``Parameter``, ``ArrayParameter``,
@@ -212,7 +219,7 @@ class ParameterBase(MetadatableWithName):
212219
def __init__(
213220
self,
214221
name: str,
215-
instrument: InstrumentBase | None,
222+
instrument: _InstrumentType_co = None,
216223
snapshot_get: bool = True,
217224
metadata: Mapping[Any, Any] | None = None,
218225
step: float | None = None,
@@ -606,7 +613,7 @@ def snapshot_base(
606613
state["ts"] = dttime.strftime("%Y-%m-%d %H:%M:%S")
607614

608615
for attr in set(self._meta_attrs):
609-
if attr == "instrument" and self._instrument:
616+
if attr == "instrument" and self._instrument is not None:
610617
state.update(
611618
{
612619
"instrument": full_class(self._instrument),
@@ -1033,7 +1040,7 @@ def register_name(self) -> str:
10331040
return self._register_name or self.full_name
10341041

10351042
@property
1036-
def instrument(self) -> InstrumentBase | None:
1043+
def instrument(self) -> _InstrumentType_co:
10371044
"""
10381045
Return the first instrument that this parameter is bound to.
10391046
E.g if this is bound to a channel it will return the channel

0 commit comments

Comments
 (0)