From 4c66f2e8f2b1c2dea25f6ae96bce9d5b6df183ca Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:11:25 +0200 Subject: [PATCH 01/13] feat: add print options for precision --- mlx/utils.cpp | 53 +++++++++++++++++++++++++++++++++++--------- mlx/utils.h | 3 +++ python/mlx/utils.py | 3 ++- python/src/array.cpp | 53 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 12 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index cf0e0f38db..f4188898a4 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include #include @@ -39,7 +40,7 @@ void PrintFormatter::print(std::ostream& os, bool val) { } } inline void PrintFormatter::print(std::ostream& os, int16_t val) { - os << val; + os << val; } inline void PrintFormatter::print(std::ostream& os, uint16_t val) { os << val; @@ -57,24 +58,49 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - os << val; + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; - } else { - os << "-" << -val.imag() << "j"; - } + if (precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } + } else { + os << std::fixed << std::setprecision(precision) << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(precision) << val.imag() << "j"; + } else { + os << "-" << std::fixed << std::setprecision(precision) << -val.imag() << "j"; + } + } } PrintFormatter& get_global_formatter() { @@ -82,6 +108,11 @@ PrintFormatter& get_global_formatter() { return formatter; } +void set_printoptions(int precision) { + auto &formatter = get_global_formatter(); + formatter.precision = precision; +} + void abort_with_exception(const std::exception& error) { std::ostringstream msg; msg << "Terminating due to uncaught exception: " << error.what(); diff --git a/mlx/utils.h b/mlx/utils.h index 62aa82b658..6be5f09ca2 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -53,8 +53,11 @@ struct PrintFormatter { inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; + int precision{-1}; }; +MLX_API void set_printoptions(int precision); + MLX_API PrintFormatter& get_global_formatter(); /** Print the exception and then abort. */ diff --git a/python/mlx/utils.py b/python/mlx/utils.py index f4aafe1e3d..35a8829485 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,8 +1,9 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict from itertools import zip_longest +from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +from contextlib import contextmanager def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/src/array.cpp b/python/src/array.cpp index 838a33a47d..fc2c6c2b84 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -96,10 +96,63 @@ class ArrayPythonIterator { std::vector splits_; }; +struct PrintOptionsContext { + int old_precision; + int new_precision; + PrintOptionsContext(int p) : new_precision(p) {} + PrintOptionsContext& __enter__() { + old_precision = mx::get_global_formatter().precision; + mx::set_printoptions(new_precision); + return *this; + } + void __exit__(nb::args) { + mx::set_printoptions(old_precision); + } +}; + void init_array(nb::module_& m) { // Set Python print formatting options mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + m.def( + "set_printoptions", + &mx::set_printoptions, + "precision"_a, + R"pbdoc( + Set global printing precision for array formatting. + + Args: + precision (int): Number of decimal places to use when printing + floating point numbers in arrays. + )pbdoc"); + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().precision; }, + R"pbdoc( + Get global printing precision for array formatting. + + Returns: + int: The number of decimal places used when printing floating point + numbers in arrays. + )pbdoc"); + + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::__enter__) + .def("__exit__", &PrintOptionsContext::__exit__); + + m.def( + "printoptions", + [](int precision) { return PrintOptionsContext(precision); }, + "precision"_a, + R"pbdoc( + Context manager for setting print options temporarily. + + Args: + precision (int): Number of decimal places. + )pbdoc"); + // Types nb::class_( m, From 8401eae95369f58cea61ba9634152640ade68cff Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:11:42 +0200 Subject: [PATCH 02/13] test: add test for print options --- python/tests/test_array.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86328c2a1b..acf70e0e07 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -597,6 +597,28 @@ def test_array_repr(self): x = mx.array([1 - 1j], dtype=mx.complex64) expected = "array([1-1j], dtype=complex64)" + def test_array_repr_precision(self): + x = mx.array([1.123456789], dtype=mx.float32) + expected = "array([1.12346], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([1.1235], dtype=float32)" + self.assertEqual(str(x), expected) + + mx.set_printoptions(precision=2) + expected = "array([1.12], dtype=float32)" + self.assertEqual(str(x), expected) + + x = mx.sin(x) + expected = "array([0.90], dtype=float32)" + self.assertEqual(str(x), expected) + + with mx.printoptions(precision=4): + expected = "array([0.9016], dtype=float32)" + self.assertEqual(str(x), expected) + + def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: From 45abfe37f32e298d8ae097e136a44a46312b6d7a Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:28:18 +0200 Subject: [PATCH 03/13] style: reformat code --- mlx/utils.cpp | 72 ++++++++++++++++++++------------------ python/mlx/utils.py | 3 +- python/tests/test_array.py | 1 - 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index f4188898a4..56b69d932a 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -40,7 +40,7 @@ void PrintFormatter::print(std::ostream& os, bool val) { } } inline void PrintFormatter::print(std::ostream& os, int16_t val) { - os << val; + os << val; } inline void PrintFormatter::print(std::ostream& os, uint16_t val) { os << val; @@ -58,49 +58,51 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, float val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, double val) { - if (precision == -1) { - os << val; - } else { - os << std::fixed << std::setprecision(precision) << val; - } + if (precision == -1) { + os << val; + } else { + os << std::fixed << std::setprecision(precision) << val; + } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - if (precision == -1) { - os << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << val.imag() << "j"; - } else { - os << "-" << -val.imag() << "j"; - } + if (precision == -1) { + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } + } else { + os << std::fixed << std::setprecision(precision) << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << std::fixed << std::setprecision(precision) << val.imag() + << "j"; } else { - os << std::fixed << std::setprecision(precision) << val.real(); - if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << std::fixed << std::setprecision(precision) << val.imag() << "j"; - } else { - os << "-" << std::fixed << std::setprecision(precision) << -val.imag() << "j"; - } + os << "-" << std::fixed << std::setprecision(precision) << -val.imag() + << "j"; } + } } PrintFormatter& get_global_formatter() { @@ -109,7 +111,7 @@ PrintFormatter& get_global_formatter() { } void set_printoptions(int precision) { - auto &formatter = get_global_formatter(); + auto& formatter = get_global_formatter(); formatter.precision = precision; } diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 35a8829485..c5fed71429 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,9 +1,10 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict +from contextlib import contextmanager from itertools import zip_longest from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from contextlib import contextmanager + def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/tests/test_array.py b/python/tests/test_array.py index acf70e0e07..6fdb27c61f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -618,7 +618,6 @@ def test_array_repr_precision(self): expected = "array([0.9016], dtype=float32)" self.assertEqual(str(x), expected) - def test_array_to_list(self): types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] for t in types: From e817a55b7d2744f2f063ebcae8b19b9dc318fde7 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:44:29 +0200 Subject: [PATCH 04/13] docs: add documentation for printoptions --- docs/src/index.rst | 3 ++- docs/src/python/printoptions.rst | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/src/python/printoptions.rst diff --git a/docs/src/index.rst b/docs/src/index.rst index 74c52aaa2b..46d069929f 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -32,7 +32,7 @@ are the CPU and GPU. install .. toctree:: - :caption: Usage + :caption: Usage :maxdepth: 1 usage/quick_start @@ -78,6 +78,7 @@ are the CPU and GPU. python/optimizers python/distributed python/tree_utils + python/printoptions .. toctree:: :caption: C++ API Reference diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst new file mode 100644 index 0000000000..ee7d3c191a --- /dev/null +++ b/docs/src/python/printoptions.rst @@ -0,0 +1,11 @@ +Print Options +============ + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + set_printoptions + printoptions + get_printoptions From ac125b0bce763e5597bc9c919b7917453bf8bd0c Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Sun, 29 Mar 2026 16:50:06 +0200 Subject: [PATCH 05/13] fix: remove unused deps --- python/mlx/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index c5fed71429..f4aafe1e3d 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,8 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict -from contextlib import contextmanager from itertools import zip_longest -from multiprocessing import context from typing import Any, Callable, Dict, List, Optional, Tuple, Union From 0ebc9cad37228ef58dcbf3065ce1e7d1086fc453 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Tue, 31 Mar 2026 11:39:54 +0200 Subject: [PATCH 06/13] docs: make title underline long enough --- docs/src/python/printoptions.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst index ee7d3c191a..95a574ff6b 100644 --- a/docs/src/python/printoptions.rst +++ b/docs/src/python/printoptions.rst @@ -1,5 +1,5 @@ Print Options -============ +=============== .. currentmodule:: mlx.core From 31bf56b0d7f3bfcb5d03199d105ea9fd9693551e Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Tue, 31 Mar 2026 16:06:42 +0200 Subject: [PATCH 07/13] refactor: use struct for printoptions --- mlx/utils.cpp | 28 ++++++++--------- mlx/utils.h | 8 +++-- python/src/array.cpp | 64 +++++++++++++++++++++++++------------- python/tests/test_array.py | 2 +- try_prec.py | 10 ++++++ 5 files changed, 73 insertions(+), 39 deletions(-) create mode 100644 try_prec.py diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 56b69d932a..b4734216ab 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -58,35 +58,35 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } inline void PrintFormatter::print(std::ostream& os, float16_t val) { - if (precision == -1) { + if (format_options.precision == -1) { os << val; } else { - os << std::fixed << std::setprecision(precision) << val; + os << std::fixed << std::setprecision(format_options.precision) << val; } } inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { - if (precision == -1) { + if (format_options.precision == -1) { os << val; } else { - os << std::fixed << std::setprecision(precision) << val; + os << std::fixed << std::setprecision(format_options.precision) << val; } } inline void PrintFormatter::print(std::ostream& os, float val) { - if (precision == -1) { + if (format_options.precision == -1) { os << val; } else { - os << std::fixed << std::setprecision(precision) << val; + os << std::fixed << std::setprecision(format_options.precision) << val; } } inline void PrintFormatter::print(std::ostream& os, double val) { - if (precision == -1) { + if (format_options.precision == -1) { os << val; } else { - os << std::fixed << std::setprecision(precision) << val; + os << std::fixed << std::setprecision(format_options.precision) << val; } } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - if (precision == -1) { + if (format_options.precision == -1) { os << val.real(); if (val.imag() >= 0 || std::isnan(val.imag())) { os << "+" << val.imag() << "j"; @@ -94,12 +94,12 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << "-" << -val.imag() << "j"; } } else { - os << std::fixed << std::setprecision(precision) << val.real(); + os << std::fixed << std::setprecision(format_options.precision) << val.real(); if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << std::fixed << std::setprecision(precision) << val.imag() + os << "+" << std::fixed << std::setprecision(format_options.precision) << val.imag() << "j"; } else { - os << "-" << std::fixed << std::setprecision(precision) << -val.imag() + os << "-" << std::fixed << std::setprecision(format_options.precision) << -val.imag() << "j"; } } @@ -110,9 +110,9 @@ PrintFormatter& get_global_formatter() { return formatter; } -void set_printoptions(int precision) { +void set_printoptions(PrintOptions options) { auto& formatter = get_global_formatter(); - formatter.precision = precision; + formatter.format_options = options; } void abort_with_exception(const std::exception& error) { diff --git a/mlx/utils.h b/mlx/utils.h index 6be5f09ca2..134f2efcc7 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -38,6 +38,10 @@ struct StreamContext { Stream _stream; }; +struct MLX_API PrintOptions { + int precision{-1}; +}; + struct PrintFormatter { inline void print(std::ostream& os, bool val); inline void print(std::ostream& os, int16_t val); @@ -53,10 +57,10 @@ struct PrintFormatter { inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; - int precision{-1}; + PrintOptions format_options; }; -MLX_API void set_printoptions(int precision); +MLX_API void set_printoptions(PrintOptions options); MLX_API PrintFormatter& get_global_formatter(); diff --git a/python/src/array.cpp b/python/src/array.cpp index fc2c6c2b84..7408e8fbc2 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -12,6 +12,7 @@ #include #include "mlx/backend/metal/metal.h" +#include "mlx/utils.h" #include "python/src/buffer.h" #include "python/src/convert.h" #include "python/src/indexing.h" @@ -97,16 +98,16 @@ class ArrayPythonIterator { }; struct PrintOptionsContext { - int old_precision; - int new_precision; - PrintOptionsContext(int p) : new_precision(p) {} - PrintOptionsContext& __enter__() { - old_precision = mx::get_global_formatter().precision; - mx::set_printoptions(new_precision); + mx::PrintOptions old_options; + mx::PrintOptions new_options; + PrintOptionsContext(mx::PrintOptions p) : new_options(p) {} + PrintOptionsContext& enter() { + old_options = mx::get_global_formatter().format_options; + mx::set_printoptions(new_options); return *this; } - void __exit__(nb::args) { - mx::set_printoptions(old_precision); + void exit(nb::args) { + mx::set_printoptions(old_options); } }; @@ -115,42 +116,61 @@ void init_array(nb::module_& m) { mx::get_global_formatter().capitalize_bool = true; // Expose printing options to Python: allow setting global precision. + nb::class_(m, "PrintOptions") + .def(nb::init(), "precision"_a = -1) + .def_rw("precision", &mx::PrintOptions::precision); + m.def( "set_printoptions", - &mx::set_printoptions, - "precision"_a, + [](int precision) { + mx::set_printoptions({precision}); + }, + "precision"_a = mx::get_global_formatter().format_options.precision, R"pbdoc( Set global printing precision for array formatting. + Example: + >>> print(x) # Uses default precision + >>> mx.set_printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Uses precision of 3 (again) + Args: - precision (int): Number of decimal places to use when printing - floating point numbers in arrays. + precision (int): Number of decimal places. )pbdoc"); m.def( "get_printoptions", - []() { return mx::get_global_formatter().precision; }, + []() { return mx::get_global_formatter().format_options; }, R"pbdoc( Get global printing precision for array formatting. Returns: - int: The number of decimal places used when printing floating point - numbers in arrays. + PrintOptions: The format options used for printing arrays. )pbdoc"); nb::class_(m, "_PrintOptionsContext") - .def(nb::init()) - .def("__enter__", &PrintOptionsContext::__enter__) - .def("__exit__", &PrintOptionsContext::__exit__); + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::enter) + .def("__exit__", &PrintOptionsContext::exit); m.def( "printoptions", - [](int precision) { return PrintOptionsContext(precision); }, - "precision"_a, + [](int precision) { + return PrintOptionsContext({precision}); + }, + "precision"_a = mx::get_global_formatter().format_options.precision, R"pbdoc( Context manager for setting print options temporarily. + Example: + >>> print(x) # Uses default precision + >>> with mx.printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Back to default precision + + Args: - precision (int): Number of decimal places. + precision (int): Number of decimal places. Use -1 for default )pbdoc"); // Types @@ -498,7 +518,7 @@ void init_array(nb::module_& m) { * - ``x = x.at[idx].minimum(y)`` - ``x[idx] = mx.minimum(x[idx], y)`` - Example: + Example: >>> a = mx.array([0, 0]) >>> idx = mx.array([0, 1, 0, 1]) >>> a[idx] += 1 diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 6fdb27c61f..893794052d 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -602,10 +602,10 @@ def test_array_repr_precision(self): expected = "array([1.12346], dtype=float32)" self.assertEqual(str(x), expected) + with mx.printoptions(precision=4): expected = "array([1.1235], dtype=float32)" self.assertEqual(str(x), expected) - mx.set_printoptions(precision=2) expected = "array([1.12], dtype=float32)" self.assertEqual(str(x), expected) diff --git a/try_prec.py b/try_prec.py new file mode 100644 index 0000000000..2a33e5fb79 --- /dev/null +++ b/try_prec.py @@ -0,0 +1,10 @@ +import mlx.core as mx + + +print(mx.array([1.23456789])) +with mx.printoptions(precision=4): + print(mx.array([1.23456789])) + + +mx.set_printoptions(precision=2) +print(mx.array([1.23456789])) From d000eff683b96ee80bb86314a0d4b2dbaf753373 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Tue, 31 Mar 2026 17:51:12 +0200 Subject: [PATCH 08/13] remove file used for testing --- try_prec.py | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 try_prec.py diff --git a/try_prec.py b/try_prec.py deleted file mode 100644 index 2a33e5fb79..0000000000 --- a/try_prec.py +++ /dev/null @@ -1,10 +0,0 @@ -import mlx.core as mx - - -print(mx.array([1.23456789])) -with mx.printoptions(precision=4): - print(mx.array([1.23456789])) - - -mx.set_printoptions(precision=2) -print(mx.array([1.23456789])) From a0bb54687224e65aa2bd4b32d2d74ba547652270 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Tue, 31 Mar 2026 18:15:06 +0200 Subject: [PATCH 09/13] refactor: use print.cpp instead of bloated array.cpp --- python/src/CMakeLists.txt | 3 +- python/src/array.cpp | 74 -------------------------------- python/src/mlx.cpp | 3 ++ python/src/print.cpp | 90 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 75 deletions(-) create mode 100644 python/src/print.cpp diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 69152f5020..447271500b 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -27,7 +27,8 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/print.cpp) if(MLX_BUILD_PYTHON_STUBS) nanobind_add_stub( diff --git a/python/src/array.cpp b/python/src/array.cpp index 7408e8fbc2..7622bd468c 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -97,82 +97,8 @@ class ArrayPythonIterator { std::vector splits_; }; -struct PrintOptionsContext { - mx::PrintOptions old_options; - mx::PrintOptions new_options; - PrintOptionsContext(mx::PrintOptions p) : new_options(p) {} - PrintOptionsContext& enter() { - old_options = mx::get_global_formatter().format_options; - mx::set_printoptions(new_options); - return *this; - } - void exit(nb::args) { - mx::set_printoptions(old_options); - } -}; void init_array(nb::module_& m) { - // Set Python print formatting options - mx::get_global_formatter().capitalize_bool = true; - - // Expose printing options to Python: allow setting global precision. - nb::class_(m, "PrintOptions") - .def(nb::init(), "precision"_a = -1) - .def_rw("precision", &mx::PrintOptions::precision); - - m.def( - "set_printoptions", - [](int precision) { - mx::set_printoptions({precision}); - }, - "precision"_a = mx::get_global_formatter().format_options.precision, - R"pbdoc( - Set global printing precision for array formatting. - - Example: - >>> print(x) # Uses default precision - >>> mx.set_printoptions(precision=3): - >>> print(x) # Uses precision of 3 - >>> print(x) # Uses precision of 3 (again) - - Args: - precision (int): Number of decimal places. - )pbdoc"); - m.def( - "get_printoptions", - []() { return mx::get_global_formatter().format_options; }, - R"pbdoc( - Get global printing precision for array formatting. - - Returns: - PrintOptions: The format options used for printing arrays. - )pbdoc"); - - nb::class_(m, "_PrintOptionsContext") - .def(nb::init()) - .def("__enter__", &PrintOptionsContext::enter) - .def("__exit__", &PrintOptionsContext::exit); - - m.def( - "printoptions", - [](int precision) { - return PrintOptionsContext({precision}); - }, - "precision"_a = mx::get_global_formatter().format_options.precision, - R"pbdoc( - Context manager for setting print options temporarily. - - Example: - >>> print(x) # Uses default precision - >>> with mx.printoptions(precision=3): - >>> print(x) # Uses precision of 3 - >>> print(x) # Back to default precision - - - Args: - precision (int): Number of decimal places. Use -1 for default - )pbdoc"); - // Types nb::class_( m, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 2829b32199..eb79c5fec1 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -23,6 +23,8 @@ void init_constants(nb::module_&); void init_fast(nb::module_&); void init_distributed(nb::module_&); void init_export(nb::module_&); +void init_print(nb::module_&); + NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -46,6 +48,7 @@ NB_MODULE(core, m) { init_fast(m); init_distributed(m); init_export(m); + init_print(m); m.attr("__version__") = mx::version(); } diff --git a/python/src/print.cpp b/python/src/print.cpp new file mode 100644 index 0000000000..7684eac010 --- /dev/null +++ b/python/src/print.cpp @@ -0,0 +1,90 @@ +#include +#include +#include + +#include + +#include "mlx/utils.h" +#include "python/src/utils.h" + +#include "mlx/mlx.h" + +namespace mx = mlx::core; +namespace nb = nanobind; +using namespace nb::literals; + +struct PrintOptionsContext { + mx::PrintOptions old_options; + mx::PrintOptions new_options; + PrintOptionsContext(mx::PrintOptions p) : new_options(p) {} + PrintOptionsContext& enter() { + old_options = mx::get_global_formatter().format_options; + mx::set_printoptions(new_options); + return *this; + } + void exit(nb::args) { + mx::set_printoptions(old_options); + } +}; + +void init_print(nb::module_& m) { + // Set Python print formatting options + mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + nb::class_(m, "PrintOptions") + .def(nb::init(), "precision"_a = -1) + .def_rw("precision", &mx::PrintOptions::precision); + + m.def( + "set_printoptions", + [](int precision) { + mx::set_printoptions({precision}); + }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( + Set global printing precision for array formatting. + + Example: + >>> print(x) # Uses default precision + >>> mx.set_printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Uses precision of 3 (again) + + Args: + precision (int): Number of decimal places. + )pbdoc"); + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().format_options; }, + R"pbdoc( + Get global printing precision for array formatting. + + Returns: + PrintOptions: The format options used for printing arrays. + )pbdoc"); + + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::enter) + .def("__exit__", &PrintOptionsContext::exit); + + m.def( + "printoptions", + [](int precision) { + return PrintOptionsContext({precision}); + }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( + Context manager for setting print options temporarily. + + Example: + >>> print(x) # Uses default precision + >>> with mx.printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Back to default precision + + + Args: + precision (int): Number of decimal places. Use -1 for default + )pbdoc"); +} From 17b7d9e46ef2adcf37bcaeb3eaf4add89154305e Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Tue, 31 Mar 2026 18:16:14 +0200 Subject: [PATCH 10/13] docs: add documentation for type PrintOptions --- docs/src/python/printoptions.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/python/printoptions.rst b/docs/src/python/printoptions.rst index 95a574ff6b..9a4211c9dd 100644 --- a/docs/src/python/printoptions.rst +++ b/docs/src/python/printoptions.rst @@ -6,6 +6,7 @@ Print Options .. autosummary:: :toctree: _autosummary + PrintOptions set_printoptions printoptions get_printoptions From a353dd280109ba0e16e0934a1d23e88c216b6ae7 Mon Sep 17 00:00:00 2001 From: Christophe Prat Date: Wed, 1 Apr 2026 09:35:53 +0200 Subject: [PATCH 11/13] style: reformat code --- mlx/utils.cpp | 11 ++++---- python/src/array.cpp | 1 - python/src/mlx.cpp | 1 - python/src/print.cpp | 52 ++++++++++++++++++-------------------- python/tests/test_array.py | 1 - 5 files changed, 30 insertions(+), 36 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index b4734216ab..c257961d70 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -94,13 +94,14 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << "-" << -val.imag() << "j"; } } else { - os << std::fixed << std::setprecision(format_options.precision) << val.real(); + os << std::fixed << std::setprecision(format_options.precision) + << val.real(); if (val.imag() >= 0 || std::isnan(val.imag())) { - os << "+" << std::fixed << std::setprecision(format_options.precision) << val.imag() - << "j"; + os << "+" << std::fixed << std::setprecision(format_options.precision) + << val.imag() << "j"; } else { - os << "-" << std::fixed << std::setprecision(format_options.precision) << -val.imag() - << "j"; + os << "-" << std::fixed << std::setprecision(format_options.precision) + << -val.imag() << "j"; } } } diff --git a/python/src/array.cpp b/python/src/array.cpp index 7622bd468c..81fd711c61 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -97,7 +97,6 @@ class ArrayPythonIterator { std::vector splits_; }; - void init_array(nb::module_& m) { // Types nb::class_( diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index eb79c5fec1..cb031cf78c 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -25,7 +25,6 @@ void init_distributed(nb::module_&); void init_export(nb::module_&); void init_print(nb::module_&); - NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; diff --git a/python/src/print.cpp b/python/src/print.cpp index 7684eac010..f7e549385f 100644 --- a/python/src/print.cpp +++ b/python/src/print.cpp @@ -28,20 +28,18 @@ struct PrintOptionsContext { }; void init_print(nb::module_& m) { - // Set Python print formatting options - mx::get_global_formatter().capitalize_bool = true; - // Expose printing options to Python: allow setting global precision. - nb::class_(m, "PrintOptions") - .def(nb::init(), "precision"_a = -1) - .def_rw("precision", &mx::PrintOptions::precision); + // Set Python print formatting options + mx::get_global_formatter().capitalize_bool = true; + // Expose printing options to Python: allow setting global precision. + nb::class_(m, "PrintOptions") + .def(nb::init(), "precision"_a = -1) + .def_rw("precision", &mx::PrintOptions::precision); - m.def( - "set_printoptions", - [](int precision) { - mx::set_printoptions({precision}); - }, - "precision"_a = mx::get_global_formatter().format_options.precision, - R"pbdoc( + m.def( + "set_printoptions", + [](int precision) { mx::set_printoptions({precision}); }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( Set global printing precision for array formatting. Example: @@ -53,28 +51,26 @@ void init_print(nb::module_& m) { Args: precision (int): Number of decimal places. )pbdoc"); - m.def( - "get_printoptions", - []() { return mx::get_global_formatter().format_options; }, - R"pbdoc( + m.def( + "get_printoptions", + []() { return mx::get_global_formatter().format_options; }, + R"pbdoc( Get global printing precision for array formatting. Returns: PrintOptions: The format options used for printing arrays. )pbdoc"); - nb::class_(m, "_PrintOptionsContext") - .def(nb::init()) - .def("__enter__", &PrintOptionsContext::enter) - .def("__exit__", &PrintOptionsContext::exit); + nb::class_(m, "_PrintOptionsContext") + .def(nb::init()) + .def("__enter__", &PrintOptionsContext::enter) + .def("__exit__", &PrintOptionsContext::exit); - m.def( - "printoptions", - [](int precision) { - return PrintOptionsContext({precision}); - }, - "precision"_a = mx::get_global_formatter().format_options.precision, - R"pbdoc( + m.def( + "printoptions", + [](int precision) { return PrintOptionsContext({precision}); }, + "precision"_a = mx::get_global_formatter().format_options.precision, + R"pbdoc( Context manager for setting print options temporarily. Example: diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 893794052d..594692f178 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -602,7 +602,6 @@ def test_array_repr_precision(self): expected = "array([1.12346], dtype=float32)" self.assertEqual(str(x), expected) - with mx.printoptions(precision=4): expected = "array([1.1235], dtype=float32)" self.assertEqual(str(x), expected) From 066392cbef9f542e248b796447239563ae190f17 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 1 Apr 2026 16:56:21 -0700 Subject: [PATCH 12/13] Fix formatting of example in array.cpp documentation --- python/src/array.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 81fd711c61..a5e6dbe23d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -443,7 +443,7 @@ void init_array(nb::module_& m) { * - ``x = x.at[idx].minimum(y)`` - ``x[idx] = mx.minimum(x[idx], y)`` - Example: + Example: >>> a = mx.array([0, 0]) >>> idx = mx.array([0, 1, 0, 1]) >>> a[idx] += 1 From a42452f1146c1a63b4fe827726004bca26576708 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 1 Apr 2026 16:59:56 -0700 Subject: [PATCH 13/13] Fix formatting in print.cpp documentation examples --- python/src/print.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/src/print.cpp b/python/src/print.cpp index f7e549385f..03c3fdaa79 100644 --- a/python/src/print.cpp +++ b/python/src/print.cpp @@ -43,13 +43,13 @@ void init_print(nb::module_& m) { Set global printing precision for array formatting. Example: - >>> print(x) # Uses default precision - >>> mx.set_printoptions(precision=3): - >>> print(x) # Uses precision of 3 - >>> print(x) # Uses precision of 3 (again) + >>> print(x) # Uses default precision + >>> mx.set_printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Uses precision of 3 (again) Args: - precision (int): Number of decimal places. + precision (int): Number of decimal places. )pbdoc"); m.def( "get_printoptions", @@ -74,13 +74,13 @@ void init_print(nb::module_& m) { Context manager for setting print options temporarily. Example: - >>> print(x) # Uses default precision - >>> with mx.printoptions(precision=3): - >>> print(x) # Uses precision of 3 - >>> print(x) # Back to default precision + >>> print(x) # Uses default precision + >>> with mx.printoptions(precision=3): + >>> print(x) # Uses precision of 3 + >>> print(x) # Back to default precision Args: - precision (int): Number of decimal places. Use -1 for default + precision (int): Number of decimal places. Use -1 for default )pbdoc"); }