diff --git a/docs/src/index.rst b/docs/src/index.rst index 74c52aaa2b..46d069929f 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -32,7 +32,7 @@ are the CPU and GPU. install .. toctree:: - :caption: Usage + :caption: Usage :maxdepth: 1 usage/quick_start @@ -78,6 +78,7 @@ are the CPU and GPU. python/optimizers python/distributed python/tree_utils + python/printoptions .. toctree:: :caption: C++ API Reference diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst new file mode 100644 index 0000000000..9a4211c9dd --- /dev/null +++ b/docs/src/python/printoptions.rst @@ -0,0 +1,12 @@ +Print Options +=============== + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + PrintOptions + set_printoptions + printoptions + get_printoptions diff --git a/mlx/utils.cpp b/mlx/utils.cpp index cf0e0f38db..c257961d70 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include #include @@ -57,23 +58,51 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - os << val; + if (format_options.precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(format_options.precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - os << val; + if (format_options.precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(format_options.precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - os << val; + if (format_options.precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(format_options.precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - os << val; + if (format_options.precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(format_options.precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; + if (format_options.precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } else { - os << "-" << -val.imag() << "j"; + os << std::fixed << std::setprecision(format_options.precision) + << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(format_options.precision) + << val.imag() << "j"; + } else { + os << "-" << std::fixed << std::setprecision(format_options.precision) + << -val.imag() << "j"; + } } } @@ -82,6 +111,11 @@ PrintFormatter& get_global_formatter() { return formatter; } +void set_printoptions(PrintOptions options) { + auto& formatter = get_global_formatter(); + formatter.format_options = options; +} + void abort_with_exception(const std::exception& error) { std::ostringstream msg; msg << "Terminating due to uncaught exception: " << error.what(); diff --git a/mlx/utils.h b/mlx/utils.h index 62aa82b658..134f2efcc7 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -38,6 +38,10 @@ struct StreamContext { Stream _stream; }; +struct MLX_API PrintOptions { + int precision{-1}; +}; + struct PrintFormatter { inline void print(std::ostream& os, bool val); inline void print(std::ostream& os, int16_t val); @@ -53,8 +57,11 @@ struct PrintFormatter { inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; + PrintOptions format_options; }; +MLX_API void set_printoptions(PrintOptions options); + MLX_API PrintFormatter& get_global_formatter(); /** Print the exception and then abort. */ diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 69152f5020..447271500b 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -27,7 +27,8 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/print.cpp) if(MLX_BUILD_PYTHON_STUBS) nanobind_add_stub( diff --git a/python/src/array.cpp b/python/src/array.cpp index 838a33a47d..a5e6dbe23d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -12,6 +12,7 @@ #include #include "mlx/backend/metal/metal.h" +#include "mlx/utils.h" #include "python/src/buffer.h" #include "python/src/convert.h" #include "python/src/indexing.h" @@ -97,9 +98,6 @@ class ArrayPythonIterator { }; void init_array(nb::module_& m) { - // Set Python print formatting options - mx::get_global_formatter().capitalize_bool = true; - // Types nb::class_( m, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 2829b32199..cb031cf78c 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -23,6 +23,7 @@ void init_constants(nb::module_&); void init_fast(nb::module_&); void init_distributed(nb::module_&); void init_export(nb::module_&); +void init_print(nb::module_&); NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -46,6 +47,7 @@ NB_MODULE(core, m) { init_fast(m); init_distributed(m); init_export(m); + init_print(m); m.attr("__version__") = mx::version(); } diff --git a/python/src/print.cpp b/python/src/print.cpp new file mode 100644 index 0000000000..03c3fdaa79 --- /dev/null +++ b/python/src/print.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +#include + +#include "mlx/utils.h" +#include "python/src/utils.h" + +#include "mlx/mlx.h" + +namespace mx = mlx::core; +namespace nb = nanobind; +using namespace nb::literals; + +struct PrintOptionsContext { + mx::PrintOptions old_options; + mx::PrintOptions new_options; + PrintOptionsContext(mx::PrintOptions p) : new_options(p) {} + PrintOptionsContext& enter() { + old_options = mx::get_global_formatter().format_options; + mx::set_printoptions(new_options); + return *this; + } + void exit(nb::args) { + mx::set_printoptions(old_options); + } +}; + +void init_print(nb::module_& m) { + // Set Python print formatting options + mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + nb::class_(m, "PrintOptions") + .def(nb::init(), "precision"_a = -1) + .def_rw("precision", &mx::PrintOptions::precision); + + m.def( + "set_printoptions", + [](int precision) { mx::set_printoptions({precision}); }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( + Set global printing precision for array formatting. + + Example: + >>> print(x) # Uses default precision + >>> mx.set_printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Uses precision of 3 (again) + + Args: + precision (int): Number of decimal places. + )pbdoc"); + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().format_options; }, + R"pbdoc( + Get global printing precision for array formatting. + + Returns: + PrintOptions: The format options used for printing arrays. + )pbdoc"); + + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::enter) + .def("__exit__", &PrintOptionsContext::exit); + + m.def( + "printoptions", + [](int precision) { return PrintOptionsContext({precision}); }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( + Context manager for setting print options temporarily. + + Example: + >>> print(x) # Uses default precision + >>> with mx.printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Back to default precision + + + Args: + precision (int): Number of decimal places. Use -1 for default + )pbdoc"); +} diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86328c2a1b..594692f178 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -597,6 +597,26 @@ def test_array_repr(self): x = mx.array([1 - 1j], dtype=mx.complex64) expected = "array([1-1j], dtype=complex64)" + def test_array_repr_precision(self): + x = mx.array([1.123456789], dtype=mx.float32) + expected = "array([1.12346], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([1.1235], dtype=float32)" + self.assertEqual(str(x), expected) + mx.set_printoptions(precision=2) + expected = "array([1.12], dtype=float32)" + self.assertEqual(str(x), expected) + + x = mx.sin(x) + expected = "array([0.90], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([0.9016], dtype=float32)" + self.assertEqual(str(x), expected) + def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: