Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1576,4 +1576,115 @@ INSTANTIATE_TEST_SUITE_P(
ZE_EVENT_POOL_FLAG_HOST_VISIBLE |
ZE_EVENT_POOL_FLAG_KERNEL_MAPPED_TIMESTAMP));

LZT_TEST_F(
zeMutableCommandListTests,
GivenRegularAndMutableCommandListsExecutedViaCommandListImmediateAppendCommandListsExpThenVerifyBuffersValid) {
if (!kernelArgumentsSupport) {
GTEST_SKIP() << "ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS not "
"supported";
}

auto cmd_bundle = lzt::create_command_bundle(context, device, true);
auto cmd_list = lzt::create_command_list(context, device, false);
lzt::zeEventPool ep;
ze_event_handle_t event = nullptr;
ep.create_event(event, ZE_EVENT_SCOPE_FLAG_HOST, 0);

uint32_t buffer_size = 1024;

const uint32_t fill_value_1 = 0xA0A0A0A0;
const uint32_t fill_value_2 = 0xB0B0B0B0;

const uint32_t add_value_1 = 0x0A0A0A0A;
const uint32_t add_value_2 = 0x0B0B0B0B;

uint32_t *buffer = reinterpret_cast<uint32_t *>(lzt::allocate_shared_memory(
buffer_size * sizeof(uint32_t), sizeof(uint32_t), 0, 0, device, context));
uint32_t *mutated_buffer = reinterpret_cast<uint32_t *>(
lzt::allocate_shared_memory(buffer_size * sizeof(uint32_t),
sizeof(uint32_t), 0, 0, device, context));
std::memset(buffer, 0, buffer_size * sizeof(uint32_t));
std::memset(mutated_buffer, 0, buffer_size * sizeof(uint32_t));

std::vector<ze_command_list_handle_t> tested_cmd_lists;
tested_cmd_lists.push_back(cmd_list);
tested_cmd_lists.push_back(mutableCmdList);

ze_kernel_handle_t add_kernel = lzt::create_function(module, "addValue");

uint32_t group_size_x = 0;
uint32_t group_size_y = 0;
uint32_t group_size_z = 0;
lzt::suggest_group_size(add_kernel, buffer_size, 1, 1, group_size_x,
group_size_y, group_size_z);
lzt::set_group_size(add_kernel, group_size_x, group_size_y, group_size_z);
lzt::set_argument_value(add_kernel, 0, sizeof(void *), &buffer);
lzt::set_argument_value(add_kernel, 1, sizeof(add_value_1), &add_value_1);

uint64_t command_id = 0;
commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS;
EXPECT_ZE_RESULT_SUCCESS(zeCommandListGetNextCommandIdExp(
mutableCmdList, &commandIdDesc, &command_id));

ze_group_count_t group_count{buffer_size / group_size_x, 1, 1};
lzt::append_memory_fill(cmd_list, buffer, &fill_value_1, sizeof(fill_value_1),
buffer_size * sizeof(uint32_t), event);
lzt::append_launch_function(mutableCmdList, add_kernel, &group_count, nullptr,
1, &event);
lzt::append_reset_event(mutableCmdList, event);
lzt::close_command_list(cmd_list);
lzt::close_command_list(mutableCmdList);

lzt::append_command_lists_immediate_exp(
cmd_bundle.list, 2, tested_cmd_lists.data(), nullptr, 0, nullptr);

lzt::execute_and_sync_command_bundle(cmd_bundle,
std::numeric_limits<uint64_t>::max());
lzt::reset_command_list(cmd_list);

for (size_t i = 0; i < buffer_size; i++) {
EXPECT_EQ(buffer[i], 0xAAAAAAAA);
}

ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_buffer_kernel_arg.commandId = command_id;
mutate_buffer_kernel_arg.argIndex = 0;
mutate_buffer_kernel_arg.argSize = sizeof(uint32_t *);
mutate_buffer_kernel_arg.pArgValue = &mutated_buffer;
ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_scalar_kernel_arg.pNext = &mutate_buffer_kernel_arg;
mutate_scalar_kernel_arg.commandId = command_id;
mutate_scalar_kernel_arg.argIndex = 1;
mutate_scalar_kernel_arg.argSize = sizeof(add_value_2);
mutate_scalar_kernel_arg.pArgValue = &add_value_2;
mutableCmdDesc.pNext = &mutate_scalar_kernel_arg;
EXPECT_ZE_RESULT_SUCCESS(
zeCommandListUpdateMutableCommandsExp(mutableCmdList, &mutableCmdDesc));
lzt::close_command_list(mutableCmdList);

lzt::append_memory_fill(cmd_list, mutated_buffer, &fill_value_2,
sizeof(fill_value_2), buffer_size * sizeof(uint32_t),
event);
lzt::close_command_list(cmd_list);

lzt::append_command_lists_immediate_exp(
cmd_bundle.list, 2, tested_cmd_lists.data(), nullptr, 0, nullptr);

lzt::execute_and_sync_command_bundle(cmd_bundle,
std::numeric_limits<uint64_t>::max());

for (size_t i = 0; i < buffer_size; i++) {
EXPECT_EQ(mutated_buffer[i], 0xBBBBBBBB);
}

lzt::destroy_function(add_kernel);
lzt::destroy_event(event);
lzt::destroy_command_list(cmd_list);
lzt::destroy_command_bundle(cmd_bundle);
lzt::free_memory(buffer);
lzt::free_memory(mutated_buffer);
}

} // namespace
Loading