diff --git a/rocketpy/motors/tank_geometry.py b/rocketpy/motors/tank_geometry.py index 485f57b09..2585778c8 100644 --- a/rocketpy/motors/tank_geometry.py +++ b/rocketpy/motors/tank_geometry.py @@ -1,3 +1,4 @@ +import warnings from functools import cached_property import numpy as np @@ -384,14 +385,21 @@ class inherits from the TankGeometry class. See the TankGeometry class for more information on its attributes and methods. """ - def __init__(self, radius, height, spherical_caps=False, geometry_dict=None): + def __init__( + self, + radius_function=None, + height=None, + spherical_caps=False, + geometry_dict=None, + **kwargs, + ): """Initialize CylindricalTank class. The zero reference point of the cylinder is its center (i.e. half of its height). Therefore the its height coordinate span is (-height/2, height/2). Parameters ---------- - radius : float + radius_function : int, float Radius of the cylindrical tank, in meters. height : float Height of the cylindrical tank, in meters. @@ -401,17 +409,47 @@ def __init__(self, radius, height, spherical_caps=False, geometry_dict=None): will have flat caps at the top and bottom. Defaults to False. geometry_dict : Union[dict, None], optional Dictionary containing the geometry of the tank. See TankGeometry. - """ + + Notes + ----- + The ``radius`` keyword argument is deprecated. Use ``radius_function`` + instead. + """ + if "radius" in kwargs: + if radius_function is not None: + raise TypeError( + "Cannot specify both 'radius_function' and deprecated " + "'radius' arguments. Use 'radius_function' instead." + ) + warnings.warn( + "The 'radius' argument in CylindricalTank is deprecated. " + "Use 'radius_function' instead.", + DeprecationWarning, + stacklevel=2, + ) + radius_function = kwargs.pop("radius") + if radius_function is None: + raise TypeError( + "CylindricalTank.__init__() missing required argument: " + "'radius_function'" + ) + if height is None: + raise TypeError( + "CylindricalTank.__init__() missing required argument: 'height'" + ) geometry_dict = geometry_dict or {} super().__init__(geometry_dict) - self.__input_radius = radius + self.radius_function = radius_function self.height = height self.has_caps = False if spherical_caps: - self.add_geometry((-height / 2 + radius, height / 2 - radius), radius) + self.add_geometry( + (-height / 2 + radius_function, height / 2 - radius_function), + radius_function, + ) self.add_spherical_caps() else: - self.add_geometry((-height / 2, height / 2), radius) + self.add_geometry((-height / 2, height / 2), radius_function) def add_spherical_caps(self): """ @@ -424,11 +462,11 @@ def add_spherical_caps(self): "Warning: Adding spherical caps to the tank will not modify the " + f"total height of the tank {self.height} m. " + "Its cylindrical portion height will be reduced to " - + f"{self.height - 2 * self.__input_radius} m." + + f"{self.height - 2 * self.radius_function} m." ) if not self.has_caps: - radius = self.__input_radius + radius = self.radius_function height = self.height bottom_cap_range = (-height / 2, -height / 2 + radius) upper_cap_range = (height / 2 - radius, height / 2) @@ -447,7 +485,7 @@ def upper_cap_radius(h): def to_dict(self, **kwargs): data = { - "radius": self.__input_radius, + "radius_function": self.radius_function, "height": self.height, "spherical_caps": self.has_caps, } @@ -459,7 +497,17 @@ def to_dict(self, **kwargs): @classmethod def from_dict(cls, data): - return cls(data["radius"], data["height"], data["spherical_caps"]) + if "radius_function" in data: + radius_function = data["radius_function"] + else: + warnings.warn( + "The 'radius' key in CylindricalTank serialized data is " + "deprecated. Use 'radius_function' instead.", + DeprecationWarning, + stacklevel=2, + ) + radius_function = data["radius"] + return cls(radius_function, data["height"], data["spherical_caps"]) class SphericalTank(TankGeometry): @@ -468,25 +516,50 @@ class SphericalTank(TankGeometry): inherits from the TankGeometry class. See the TankGeometry class for more information on its attributes and methods.""" - def __init__(self, radius, geometry_dict=None): + def __init__(self, radius_function=None, geometry_dict=None, **kwargs): """Initialize SphericalTank class. The zero reference point of the sphere is its center (i.e. half of its height). Therefore, its height - coordinate ranges between (-radius, radius). + coordinate ranges between (-radius_function, radius_function). Parameters ---------- - radius : float - Radius of the spherical tank. + radius_function : int, float + Radius of the spherical tank, in meters. geometry_dict : Union[dict, None], optional Dictionary containing the geometry of the tank. See TankGeometry. - """ + + Notes + ----- + The ``radius`` keyword argument is deprecated. Use ``radius_function`` + instead. + """ + if "radius" in kwargs: + if radius_function is not None: + raise TypeError( + "Cannot specify both 'radius_function' and deprecated " + "'radius' arguments. Use 'radius_function' instead." + ) + warnings.warn( + "The 'radius' argument in SphericalTank is deprecated. " + "Use 'radius_function' instead.", + DeprecationWarning, + stacklevel=2, + ) + radius_function = kwargs.pop("radius") + if radius_function is None: + raise TypeError( + "SphericalTank.__init__() missing required argument: 'radius_function'" + ) geometry_dict = geometry_dict or {} super().__init__(geometry_dict) - self.__input_radius = radius - self.add_geometry((-radius, radius), lambda h: (radius**2 - h**2) ** 0.5) + self.radius_function = radius_function + self.add_geometry( + (-radius_function, radius_function), + lambda h: (radius_function**2 - h**2) ** 0.5, + ) def to_dict(self, **kwargs): - data = {"radius": self.__input_radius} + data = {"radius_function": self.radius_function} if kwargs.get("include_outputs", False): data.update(super().to_dict(**kwargs)) @@ -495,4 +568,14 @@ def to_dict(self, **kwargs): @classmethod def from_dict(cls, data): - return cls(data["radius"]) + if "radius_function" in data: + radius_function = data["radius_function"] + else: + warnings.warn( + "The 'radius' key in SphericalTank serialized data is " + "deprecated. Use 'radius_function' instead.", + DeprecationWarning, + stacklevel=2, + ) + radius_function = data["radius"] + return cls(radius_function) diff --git a/tests/fixtures/motor/tanks_fixtures.py b/tests/fixtures/motor/tanks_fixtures.py index 88721966d..aba5f5a7e 100644 --- a/tests/fixtures/motor/tanks_fixtures.py +++ b/tests/fixtures/motor/tanks_fixtures.py @@ -499,7 +499,7 @@ def temperature(t): return MassBasedTank( name="Variable Density N2O Tank", - geometry=CylindricalTank(height=0.8, radius=0.06, spherical_caps=True), + geometry=CylindricalTank(height=0.8, radius_function=0.06, spherical_caps=True), flux_time=7, liquid=nitrous_oxide_non_constant_fluid, gas=nitrous_oxide_non_constant_fluid, diff --git a/tests/unit/motors/test_tank_geometry.py b/tests/unit/motors/test_tank_geometry.py index ff4a525ba..abb9e8772 100644 --- a/tests/unit/motors/test_tank_geometry.py +++ b/tests/unit/motors/test_tank_geometry.py @@ -132,3 +132,139 @@ def test_tank_inertia(params, request): @patch("matplotlib.pyplot.show") def test_tank_geometry_plots_info(mock_show): # pylint: disable=unused-argument assert TankGeometry({(0, 5): 1}).plots.all() is None + + +def test_cylindrical_tank_radius_function_attribute(): + """Test that CylindricalTank stores the input radius as 'radius_function' + and that it does not conflict with the 'radius' property (a Function of + height). + """ + from rocketpy import CylindricalTank + from rocketpy.mathutils.function import Function + + r = 0.1 + tank = CylindricalTank(r, 2.0) + + # radius_function stores the raw input scalar + assert tank.radius_function == r + # radius property is a callable Function, not the scalar + assert callable(tank.radius) + assert isinstance(tank.radius, Function) + # The two must differ in type + assert not isinstance(tank.radius_function, Function) + + +def test_spherical_tank_radius_function_attribute(): + """Test that SphericalTank stores the input radius as 'radius_function' + and that it does not conflict with the 'radius' property (a Function of + height). + """ + from rocketpy import SphericalTank + from rocketpy.mathutils.function import Function + + r = 0.05 + tank = SphericalTank(r) + + # radius_function stores the raw input scalar + assert tank.radius_function == r + # radius property is a callable Function, not the scalar + assert callable(tank.radius) + assert isinstance(tank.radius, Function) + # The two must differ in type + assert not isinstance(tank.radius_function, Function) + + +def test_cylindrical_tank_deprecated_radius_kwarg(): + """Test that CylindricalTank issues a DeprecationWarning when the old + 'radius' keyword argument is used, and still works correctly. + """ + import warnings + + from rocketpy import CylindricalTank + + r = 0.1 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tank = CylindricalTank(radius=r, height=2.0) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "radius_function" in str(w[0].message) + + assert tank.radius_function == r + + +def test_spherical_tank_deprecated_radius_kwarg(): + """Test that SphericalTank issues a DeprecationWarning when the old + 'radius' keyword argument is used, and still works correctly. + """ + import warnings + + from rocketpy import SphericalTank + + r = 0.05 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tank = SphericalTank(radius=r) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "radius_function" in str(w[0].message) + + assert tank.radius_function == r + + +def test_cylindrical_tank_to_dict_uses_radius_function_key(): + """Test that CylindricalTank.to_dict() uses the 'radius_function' key.""" + from rocketpy import CylindricalTank + + tank = CylindricalTank(0.1, 2.0) + data = tank.to_dict() + assert "radius_function" in data + assert "radius" not in data + assert data["radius_function"] == 0.1 + + +def test_spherical_tank_to_dict_uses_radius_function_key(): + """Test that SphericalTank.to_dict() uses the 'radius_function' key.""" + from rocketpy import SphericalTank + + tank = SphericalTank(0.05) + data = tank.to_dict() + assert "radius_function" in data + assert "radius" not in data + assert data["radius_function"] == 0.05 + + +def test_cylindrical_tank_from_dict_deprecated_radius_key(): + """Test that CylindricalTank.from_dict() issues a DeprecationWarning when + the serialized data contains the old 'radius' key. + """ + import warnings + + from rocketpy import CylindricalTank + + old_data = {"radius": 0.1, "height": 2.0, "spherical_caps": False} + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tank = CylindricalTank.from_dict(old_data) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + + assert tank.radius_function == 0.1 + + +def test_spherical_tank_from_dict_deprecated_radius_key(): + """Test that SphericalTank.from_dict() issues a DeprecationWarning when + the serialized data contains the old 'radius' key. + """ + import warnings + + from rocketpy import SphericalTank + + old_data = {"radius": 0.05} + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tank = SphericalTank.from_dict(old_data) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + + assert tank.radius_function == 0.05