Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions tensorflow/lite/kernels/internal/portable_tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
}

void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
int bit_width, int8_t* dst_buffer) {
int bit_width, int8_t* dst_buffer,
bool unpack_unsigned) {
assert(bit_width == 2 || bit_width == 4);
if (bit_width == 4) {
// num_elements means the number of elements regardless of packed or
Expand All @@ -105,39 +106,58 @@ void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
for (int i = 0; i < num_elements / 2; i++) {
int8_t byte = src_buffer[i];
// Shift left first so that sign is properly extended when shifted right
int8_t lower = static_cast<int8_t>(byte << 4) >> 4;
int8_t higher = byte >> 4;
int8_t lower, higher;
if (unpack_unsigned) {
lower = byte & 0x0F;
higher = (byte >> 4) & 0x0F;
} else {
// Shift left first so that sign is properly extended when shifted right
lower = static_cast<int8_t>(byte << 4) >> 4;
higher = byte >> 4;
}
dst_buffer[2 * i] = lower;
dst_buffer[2 * i + 1] = higher;
}

// If the buffer size is odd, extract the final lower nibble.
if (num_elements % 2 != 0) {
int8_t byte = src_buffer[num_elements / 2];
dst_buffer[num_elements - 1] =
static_cast<int8_t>(src_buffer[num_elements / 2] << 4) >> 4;
unpack_unsigned ? (byte & 0x0F) : static_cast<int8_t>(byte << 4) >> 4;
}
} else if (bit_width == 2) {
for (int i = 0; i < num_elements / 4; i++) {
int8_t byte = src_buffer[i];
// Shift left first so that sign is properly extended when shifted right
int8_t val1 = static_cast<int8_t>(byte << 6) >> 6;
int8_t val2 = static_cast<int8_t>((byte << 4) & 0xFF) >> 6;
int8_t val3 = static_cast<int8_t>((byte << 2) & 0xFF) >> 6;
int8_t val4 = byte >> 6;
dst_buffer[4 * i] = val1;
dst_buffer[4 * i + 1] = val2;
dst_buffer[4 * i + 2] = val3;
dst_buffer[4 * i + 3] = val4;
if (unpack_unsigned) {
dst_buffer[4 * i] = byte & 0x03;
dst_buffer[4 * i + 1] = (byte >> 2) & 0x03;
dst_buffer[4 * i + 2] = (byte >> 4) & 0x03;
dst_buffer[4 * i + 3] = (byte >> 6) & 0x03;
} else {
// Shift left first so that sign is properly extended when shifted right
int8_t val1 = static_cast<int8_t>(byte << 6) >> 6;
int8_t val2 = static_cast<int8_t>((byte << 4) & 0xFF) >> 6;
int8_t val3 = static_cast<int8_t>((byte << 2) & 0xFF) >> 6;
int8_t val4 = byte >> 6;
dst_buffer[4 * i] = val1;
dst_buffer[4 * i + 1] = val2;
dst_buffer[4 * i + 2] = val3;
dst_buffer[4 * i + 3] = val4;
}
}

// Handle the remaining elements.
int remaining_elements = num_elements % 4;
if (remaining_elements > 0) {
int8_t byte = src_buffer[num_elements / 4];
for (int i = 0; i < remaining_elements; i++) {
dst_buffer[num_elements - remaining_elements + i] =
static_cast<int8_t>((byte << (6 - 2 * i)) & 0xFF) >> 6;
if (unpack_unsigned) {
dst_buffer[num_elements - remaining_elements + i] =
(byte >> (2 * i)) & 0x03;
} else {
dst_buffer[num_elements - remaining_elements + i] =
static_cast<int8_t>((byte << (6 - 2 * i)) & 0xFF) >> 6;
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/kernels/internal/portable_tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,8 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
// For 2-bit unpacking: e.g., `src_buffer = {0x12};` (num_elements = 4)
// will return `dst_buffer = {0x02, 0x00, 0x01, 0x00}` (sign extended).
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
int bit_width, int8_t* dst_buffer);
int bit_width, int8_t* dst_buffer,
bool unpack_unsigned = false);

// Pack `src_buffer` into a densely packed buffer of int2 or int4 values.
// Parameters:
Expand Down