diff --git a/.github/actions/build-macos/action.yml b/.github/actions/build-macos/action.yml index 5da840b819..d3cc3df686 100644 --- a/.github/actions/build-macos/action.yml +++ b/.github/actions/build-macos/action.yml @@ -45,15 +45,17 @@ runs: cd build cmake .. make -j $(sysctl -n hw.ncpu) - + - name: Run CPP tests shell: bash -l {0} env: DEVICE: gpu METAL_DEVICE_WRAPPER_TYPE: 1 METAL_DEBUG_ERROR_MODE: 0 - run: ./build/tests/tests - + run: | + ./build/tests/tests + ./build/tests/test_teardown + - name: Build small binary with JIT shell: bash -l {0} run: | diff --git a/.github/actions/test-windows/action.yml b/.github/actions/test-windows/action.yml index ebb886e7bd..c738688836 100644 --- a/.github/actions/test-windows/action.yml +++ b/.github/actions/test-windows/action.yml @@ -17,4 +17,5 @@ runs: run: | echo "::group::CPP tests - CPU" ./build/tests.exe -tce="*gguf*,test random uniform" + ./build/test_teardown.exe echo "::endgroup::" diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 8da66df85f..a654b16070 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -3,6 +3,8 @@ #include #include +#include + #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION @@ -265,6 +267,14 @@ CommandEncoder::CommandEncoder( buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); } +CommandEncoder::~CommandEncoder() { + exiting_ = true; + synchronize(); + auto pool = new_scoped_memory_pool(); + buffer_.reset(); + queue_.reset(); +} + void CommandEncoder::set_buffer( const MTL::Buffer* buf, int idx, @@ -433,6 +443,22 @@ void CommandEncoder::commit() { buffer_sizes_ = 0; } +void CommandEncoder::synchronize() { + auto pool = new_scoped_memory_pool(); + auto cb = NS::RetainPtr(get_command_buffer()); + end_encoding(); + commit(); + cb->waitUntilCompleted(); + if (!exiting_) { + if (cb->status() == MTL::CommandBufferStatusError) { + throw std::runtime_error( + fmt::format( + "[METAL] Command buffer execution failed: {}.", + cb->error()->localizedDescription()->utf8String())); + } + } +} + MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { if (!encoder_) { encoder_ = NS::RetainPtr( @@ -770,16 +796,20 @@ Device& device(mlx::core::Device) { } CommandEncoder& get_command_encoder(Stream s) { - // Leak the command encoders for the same reason with device. - static auto* encoders = new std::unordered_map; - auto it = encoders->find(s.index); - if (it == encoders->end()) { - auto& d = device(s.device); - it = encoders->try_emplace(s.index, d, s.index, d.residency_set()).first; + auto& encoders = get_command_encoders(); + auto it = encoders.find(s.index); + if (it == encoders.end()) { + throw std::runtime_error( + fmt::format("There is no Stream(gpu, {}) in current thread.", s.index)); } return it->second; } +std::unordered_map& get_command_encoders() { + static thread_local std::unordered_map encoders; + return encoders; +} + NS::SharedPtr new_scoped_memory_pool() { return NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 9e58b92f0c..5f2e72f915 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -24,6 +24,8 @@ class Device; class MLX_API CommandEncoder { public: CommandEncoder(Device& d, int index, ResidencySet& residency_set); + ~CommandEncoder(); + CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -90,6 +92,7 @@ class MLX_API CommandEncoder { void end_encoding(); bool needs_commit() const; void commit(); + void synchronize(); MTL::CommandQueue* get_command_queue() const { return queue_.get(); @@ -102,6 +105,7 @@ class MLX_API CommandEncoder { MTL::ComputeCommandEncoder* get_command_encoder(); Device& device_; + bool exiting_{false}; // Buffer that stores encoded commands. NS::SharedPtr queue_; @@ -226,6 +230,7 @@ class MLX_API Device { MLX_API Device& device(mlx::core::Device); MLX_API CommandEncoder& get_command_encoder(Stream s); +std::unordered_map& get_command_encoders(); NS::SharedPtr new_scoped_memory_pool(); bool is_nax_available(); diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 00ed754e5f..4fa08fc4db 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -11,10 +11,11 @@ namespace mlx::core::gpu { void init() {} -void new_stream(Stream stream) { - if (stream.device == mlx::core::Device::gpu) { - metal::get_command_encoder(stream); - } +void new_stream(Stream s) { + assert(s.device == Device::gpu); + auto& encoders = metal::get_command_encoders(); + auto& d = metal::device(s.device); + encoders.try_emplace(s.index, d, s.index, d.residency_set()); } inline void check_error(MTL::CommandBuffer* cbuf) { @@ -83,15 +84,7 @@ void finalize(Stream s) { } void synchronize(Stream s) { - auto pool = metal::new_scoped_memory_pool(); - auto& encoder = metal::get_command_encoder(s); - auto* cb = encoder.get_command_buffer(); - cb->retain(); - encoder.end_encoding(); - encoder.commit(); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); + metal::get_command_encoder(s).synchronize(); } } // namespace mlx::core::gpu diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 4a8bda3adc..cabe991926 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -26,15 +26,7 @@ Scheduler::Scheduler() { gpu::init(); } -Scheduler::~Scheduler() { - for (auto& s : get_streams()) { - try { - synchronize(s); - } catch (const std::runtime_error&) { - // ignore errors if synch fails - } - } -} +Scheduler::~Scheduler() = default; void Scheduler::new_thread(Device::DeviceType type) { if (type == Device::gpu) { diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index 532f168616..3a8400f5d0 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -65,6 +65,27 @@ TEST_CASE("test default stream in threads") { CHECK_EQ(new_streams, thread_streams); } +TEST_CASE("test access stream in other thread") { + if (!metal::is_available()) { + return; + } + + auto main_thread_stream = new_stream(Device::gpu); + eval(arange(10, main_thread_stream)); + + bool error_caught = false; + std::thread t([&] { + try { + eval(arange(10, main_thread_stream)); + } catch (const std::runtime_error&) { + error_caught = true; + } + }); + t.join(); + + CHECK(error_caught); +} + TEST_CASE("test get streams") { auto streams = get_streams();