Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/actions/build-macos/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ runs:
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)

- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests

run: |
./build/tests/tests
./build/tests/test_teardown

- name: Build small binary with JIT
shell: bash -l {0}
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/actions/test-windows/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ runs:
run: |
echo "::group::CPP tests - CPU"
./build/tests.exe -tce="*gguf*,test random uniform"
./build/test_teardown.exe
echo "::endgroup::"
42 changes: 36 additions & 6 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <cstdlib>
#include <sstream>

#include <fmt/format.h>

#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
Expand Down Expand Up @@ -265,6 +267,14 @@ CommandEncoder::CommandEncoder(
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
}

CommandEncoder::~CommandEncoder() {
exiting_ = true;
synchronize();
auto pool = new_scoped_memory_pool();
buffer_.reset();
queue_.reset();
}

void CommandEncoder::set_buffer(
const MTL::Buffer* buf,
int idx,
Expand Down Expand Up @@ -433,6 +443,22 @@ void CommandEncoder::commit() {
buffer_sizes_ = 0;
}

void CommandEncoder::synchronize() {
auto pool = new_scoped_memory_pool();
auto cb = NS::RetainPtr(get_command_buffer());
end_encoding();
commit();
cb->waitUntilCompleted();
if (!exiting_) {
if (cb->status() == MTL::CommandBufferStatusError) {
throw std::runtime_error(
fmt::format(
"[METAL] Command buffer execution failed: {}.",
cb->error()->localizedDescription()->utf8String()));
}
}
}

MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {
if (!encoder_) {
encoder_ = NS::RetainPtr(
Expand Down Expand Up @@ -770,16 +796,20 @@ Device& device(mlx::core::Device) {
}

CommandEncoder& get_command_encoder(Stream s) {
// Leak the command encoders for the same reason with device.
static auto* encoders = new std::unordered_map<int, CommandEncoder>;
auto it = encoders->find(s.index);
if (it == encoders->end()) {
auto& d = device(s.device);
it = encoders->try_emplace(s.index, d, s.index, d.residency_set()).first;
auto& encoders = get_command_encoders();
auto it = encoders.find(s.index);
if (it == encoders.end()) {
throw std::runtime_error(
fmt::format("There is no Stream(gpu, {}) in current thread.", s.index));
}
return it->second;
}

std::unordered_map<int, CommandEncoder>& get_command_encoders() {
static thread_local std::unordered_map<int, CommandEncoder> encoders;
return encoders;
}

NS::SharedPtr<NS::AutoreleasePool> new_scoped_memory_pool() {
return NS::TransferPtr(NS::AutoreleasePool::alloc()->init());
}
Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Device;
class MLX_API CommandEncoder {
public:
CommandEncoder(Device& d, int index, ResidencySet& residency_set);
~CommandEncoder();

CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;

Expand Down Expand Up @@ -90,6 +92,7 @@ class MLX_API CommandEncoder {
void end_encoding();
bool needs_commit() const;
void commit();
void synchronize();

MTL::CommandQueue* get_command_queue() const {
return queue_.get();
Expand All @@ -102,6 +105,7 @@ class MLX_API CommandEncoder {
MTL::ComputeCommandEncoder* get_command_encoder();

Device& device_;
bool exiting_{false};

// Buffer that stores encoded commands.
NS::SharedPtr<MTL::CommandQueue> queue_;
Expand Down Expand Up @@ -226,6 +230,7 @@ class MLX_API Device {
MLX_API Device& device(mlx::core::Device);
MLX_API CommandEncoder& get_command_encoder(Stream s);

std::unordered_map<int, CommandEncoder>& get_command_encoders();
NS::SharedPtr<NS::AutoreleasePool> new_scoped_memory_pool();

bool is_nax_available();
Expand Down
19 changes: 6 additions & 13 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ namespace mlx::core::gpu {

void init() {}

void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::get_command_encoder(stream);
}
void new_stream(Stream s) {
assert(s.device == Device::gpu);
auto& encoders = metal::get_command_encoders();
auto& d = metal::device(s.device);
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

inline void check_error(MTL::CommandBuffer* cbuf) {
Expand Down Expand Up @@ -83,15 +84,7 @@ void finalize(Stream s) {
}

void synchronize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& encoder = metal::get_command_encoder(s);
auto* cb = encoder.get_command_buffer();
cb->retain();
encoder.end_encoding();
encoder.commit();
cb->waitUntilCompleted();
check_error(cb);
cb->release();
metal::get_command_encoder(s).synchronize();
}

} // namespace mlx::core::gpu
10 changes: 1 addition & 9 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,7 @@ Scheduler::Scheduler() {
gpu::init();
}

Scheduler::~Scheduler() {
for (auto& s : get_streams()) {
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
}
}
Scheduler::~Scheduler() = default;

void Scheduler::new_thread(Device::DeviceType type) {
if (type == Device::gpu) {
Expand Down
21 changes: 21 additions & 0 deletions tests/scheduler_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ TEST_CASE("test default stream in threads") {
CHECK_EQ(new_streams, thread_streams);
}

TEST_CASE("test access stream in other thread") {
if (!metal::is_available()) {
return;
}

auto main_thread_stream = new_stream(Device::gpu);
eval(arange(10, main_thread_stream));

bool error_caught = false;
std::thread t([&] {
try {
eval(arange(10, main_thread_stream));
} catch (const std::runtime_error&) {
error_caught = true;
}
});
t.join();

CHECK(error_caught);
}

TEST_CASE("test get streams") {
auto streams = get_streams();

Expand Down
Loading