Skip to content

Commit 5410bf0

Browse files
committed
Make Parameter classes generic in the instrument type
1 parent 1e84daa commit 5410bf0

File tree

5 files changed

+40
-22
lines changed

5 files changed

+40
-22
lines changed

src/qcodes/instrument_drivers/AlazarTech/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
:mod:`.AlazarTech.helpers` module).
66
"""
77

8-
from typing import TYPE_CHECKING, Any, cast
8+
from typing import TYPE_CHECKING, Any
99

1010
from qcodes.parameters import Parameter, ParamRawDataType
1111

1212
if TYPE_CHECKING:
13-
from .ATS import AlazarTechATS
13+
# ruff does not detect that this is used as a generic parameter below
14+
from .ATS import AlazarTechATS # noqa: F401
1415

1516

16-
class TraceParameter(Parameter):
17+
class TraceParameter(Parameter["AlazarTechATS"]):
1718
"""
1819
A parameter that keeps track of if its value has been synced to
1920
the ``Instrument``. To achieve that, this parameter sets
@@ -38,6 +39,5 @@ def synced_to_card(self) -> bool:
3839
return self._synced_to_card
3940

4041
def set_raw(self, value: ParamRawDataType) -> None:
41-
instrument = cast("AlazarTechATS", self.instrument)
42-
instrument._parameters_synced = False
42+
self.instrument._parameters_synced = False
4343
self._synced_to_card = False

src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
import textwrap
3-
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast, overload
3+
from typing import TYPE_CHECKING, Any, Literal, NotRequired, overload
44

55
import numpy as np
66
import numpy.typing as npt
@@ -701,7 +701,7 @@ def _get_sweep_steps_parser(response: str) -> SweepSteps:
701701
"""
702702

703703

704-
class _ParameterWithStatus(Parameter):
704+
class _ParameterWithStatus(Parameter["KeysightB1517A"]):
705705
def __init__(self, *args: Any, **kwargs: Any):
706706
super().__init__(*args, **kwargs)
707707

@@ -728,7 +728,7 @@ def snapshot_base(
728728

729729
class _SpotMeasurementVoltageParameter(_ParameterWithStatus):
730730
def set_raw(self, value: ParamRawDataType) -> None:
731-
smu = cast("KeysightB1517A", self.instrument)
731+
smu = self.instrument
732732

733733
if smu._source_config["output_range"] is None:
734734
smu._source_config["output_range"] = constants.VOutputRange.AUTO
@@ -752,7 +752,7 @@ def set_raw(self, value: ParamRawDataType) -> None:
752752
)
753753

754754
def get_raw(self) -> ParamRawDataType:
755-
smu = cast("KeysightB1517A", self.instrument)
755+
smu = self.instrument
756756

757757
msg = MessageBuilder().tv(
758758
chnum=smu.channels[0],
@@ -769,7 +769,7 @@ def get_raw(self) -> ParamRawDataType:
769769

770770
class _SpotMeasurementCurrentParameter(_ParameterWithStatus):
771771
def set_raw(self, value: ParamRawDataType) -> None:
772-
smu = cast("KeysightB1517A", self.instrument)
772+
smu = self.instrument
773773

774774
if smu._source_config["output_range"] is None:
775775
smu._source_config["output_range"] = constants.IOutputRange.AUTO
@@ -793,7 +793,7 @@ def set_raw(self, value: ParamRawDataType) -> None:
793793
)
794794

795795
def get_raw(self) -> ParamRawDataType:
796-
smu = cast("KeysightB1517A", self.instrument)
796+
smu = self.instrument
797797

798798
msg = MessageBuilder().ti(
799799
chnum=smu.channels[0],

src/qcodes/instrument_drivers/tektronix/DPO7200xx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import textwrap
88
import time
99
from functools import partial
10-
from typing import TYPE_CHECKING, Any, ClassVar, cast
10+
from typing import TYPE_CHECKING, Any, ClassVar
1111

1212
import numpy as np
1313
import numpy.typing as npt
@@ -767,7 +767,7 @@ def _trigger_type(self, value: str) -> None:
767767
self.write(f"TRIGger:{self._identifier}:TYPE {value}")
768768

769769

770-
class TektronixDPOMeasurementParameter(Parameter):
770+
class TektronixDPOMeasurementParameter(Parameter["TektronixDPOMeasurement"]):
771771
"""
772772
A measurement parameter does not only return the instantaneous value
773773
of a measurement, but can also return some statistics. The accumulation
@@ -778,7 +778,7 @@ class TektronixDPOMeasurementParameter(Parameter):
778778
"""
779779

780780
def _get(self, metric: str) -> float:
781-
measurement_channel = cast("TektronixDPOMeasurement", self.instrument)
781+
measurement_channel = self.instrument
782782
if measurement_channel.type.get_latest() != self.name:
783783
measurement_channel.type(self.name)
784784

src/qcodes/parameters/parameter.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import logging
77
import os
88
from types import MethodType
9-
from typing import TYPE_CHECKING, Any, Literal
9+
from typing import TYPE_CHECKING, Any, Generic, Literal
10+
11+
from typing_extensions import TypeVar
1012

1113
from .command import Command
1214
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
@@ -24,7 +26,15 @@
2426
log = logging.getLogger(__name__)
2527

2628

27-
class Parameter(ParameterBase):
29+
_InstrumentType_co = TypeVar(
30+
"_InstrumentType_co",
31+
bound="InstrumentBase | None",
32+
default="InstrumentBase | None",
33+
covariant=True,
34+
)
35+
36+
37+
class Parameter(ParameterBase[_InstrumentType_co], Generic[_InstrumentType_co]):
2838
"""
2939
A parameter represents a single degree of freedom. Most often,
3040
this is the standard parameter for Instruments, though it can also be
@@ -172,7 +182,7 @@ class Parameter(ParameterBase):
172182
def __init__(
173183
self,
174184
name: str,
175-
instrument: InstrumentBase | None = None,
185+
instrument: _InstrumentType_co = None,
176186
label: str | None = None,
177187
unit: str | None = None,
178188
get_cmd: str | Callable[..., Any] | Literal[False] | None = None,

src/qcodes/parameters/parameter_base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from contextlib import contextmanager
99
from datetime import datetime
1010
from functools import cached_property, wraps
11-
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload
11+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, overload
1212

1313
import numpy as np
14+
from typing_extensions import TypeVar
1415

1516
from qcodes.metadatable import Metadatable, MetadatableWithName
1617
from qcodes.parameters import ParamSpecBase
@@ -42,6 +43,13 @@
4243
from qcodes.instrument import InstrumentBase
4344
from qcodes.logger.instrument_logger import InstrumentLoggerAdapter
4445

46+
_InstrumentType_co = TypeVar(
47+
"_InstrumentType_co",
48+
bound="InstrumentBase | None",
49+
default="InstrumentBase | None",
50+
covariant=True,
51+
)
52+
4553
LOG = logging.getLogger(__name__)
4654

4755

@@ -109,7 +117,7 @@ def invert_val_mapping(val_mapping: Mapping[Any, Any]) -> dict[Any, Any]:
109117
return {v: k for k, v in val_mapping.items()}
110118

111119

112-
class ParameterBase(MetadatableWithName):
120+
class ParameterBase(MetadatableWithName, Generic[_InstrumentType_co]):
113121
"""
114122
Shared behavior for all parameters. Not intended to be used
115123
directly, normally you should use ``Parameter``, ``ArrayParameter``,
@@ -212,7 +220,7 @@ class ParameterBase(MetadatableWithName):
212220
def __init__(
213221
self,
214222
name: str,
215-
instrument: InstrumentBase | None,
223+
instrument: _InstrumentType_co = None,
216224
snapshot_get: bool = True,
217225
metadata: Mapping[Any, Any] | None = None,
218226
step: float | None = None,
@@ -606,7 +614,7 @@ def snapshot_base(
606614
state["ts"] = dttime.strftime("%Y-%m-%d %H:%M:%S")
607615

608616
for attr in set(self._meta_attrs):
609-
if attr == "instrument" and self._instrument:
617+
if attr == "instrument" and self._instrument is not None:
610618
state.update(
611619
{
612620
"instrument": full_class(self._instrument),
@@ -1033,7 +1041,7 @@ def register_name(self) -> str:
10331041
return self._register_name or self.full_name
10341042

10351043
@property
1036-
def instrument(self) -> InstrumentBase | None:
1044+
def instrument(self) -> _InstrumentType_co:
10371045
"""
10381046
Return the first instrument that this parameter is bound to.
10391047
E.g if this is bound to a channel it will return the channel

0 commit comments

Comments
 (0)