diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 8eaaf7bddc48..32ec0f0f8f05 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -53,8 +53,7 @@ class ShapeTypeNode : public TypeNode { class ShapeType : public Type { public: - // TODO(relax-team): remove the default value later. - TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); + TVM_DLL ShapeType(int ndim, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeType, Type, ShapeTypeNode); }; diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index d7b619bf8d9b..afa25d0dd003 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -31,12 +31,11 @@ class ShapeType(Type): Parameters ---------- - ndim : Optional[int] - The size of the shape. + ndim : int + The number of dimensions of the shape. Use -1 for unknown ndim. """ - # TODO(relax-team): consider make ndim mandatory - def __init__(self, ndim: int = -1, span: Span = None) -> None: + def __init__(self, ndim: int, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index d6e496c3f14d..7c943e5d3951 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -256,7 +256,7 @@ def test_shape_expr(): def test_types(): printer = ASTPrinter() - assert strip_whitespace(printer.visit_type_(rx.ShapeType())) == "ShapeType(ndim=-1)" + assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=-1))) == "ShapeType(ndim=-1)" assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) == "ShapeType(ndim=1)" object_type = rx.ObjectType() assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()" @@ -266,7 +266,7 @@ def test_types(): assert strip_whitespace(printer.visit_type_(tensor_type)) == "TensorType(ndim=2,dtype=int32)" unit_type = rx.TupleType([]) assert strip_whitespace(printer.visit_type_(unit_type)) == "TupleType(fields=[])" - tuple_type = rx.TupleType([rx.ShapeType(), object_type]) + tuple_type = rx.TupleType([rx.ShapeType(ndim=-1), object_type]) assert_fields( "TupleType", {"fields": "[ShapeType(ndim=-1),ObjectType()]"}, diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py index 45599d198cda..cf3202ddc915 100644 --- a/tests/python/relax/test_struct_info.py +++ b/tests/python/relax/test_struct_info.py @@ -52,8 +52,8 @@ def test_object_struct_info(): def test_shape_type(): - t0 = rx.ShapeType() - t1 = rx.ShapeType() + t0 = rx.ShapeType(ndim=-1) + t1 = rx.ShapeType(ndim=-1) assert t0 == t1