diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index baa6282fd5b..84883d2fd19 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -104,6 +104,41 @@ void TfLiteVarArrayFree(T* a) { #ifndef TF_LITE_STATIC_MEMORY +TfLiteSparsity TfLiteSparsityClone(const TfLiteSparsity& src) { + TfLiteSparsity dst = src; + dst.traversal_order = TfLiteIntArrayCopy(src.traversal_order); + dst.block_map = TfLiteIntArrayCopy(src.block_map); + if (src.dim_metadata) { + dst.dim_metadata = reinterpret_cast( + calloc(1, sizeof(TfLiteDimensionMetadata) * src.dim_metadata_size)); + for (int i = 0; i < src.dim_metadata_size; ++i) { + dst.dim_metadata[i] = src.dim_metadata[i]; + dst.dim_metadata[i].array_segments = + TfLiteIntArrayCopy(src.dim_metadata[i].array_segments); + dst.dim_metadata[i].array_indices = + TfLiteIntArrayCopy(src.dim_metadata[i].array_indices); + } + } + return dst; +} + +// Clones the source sparsity to a newly allocated object. +TfLiteSparsity* TfLiteSparsityClone(const TfLiteSparsity* const src) { + if (!src) { + return nullptr; + } + TfLiteSparsity* dst = + reinterpret_cast(calloc(1, sizeof(TfLiteSparsity))); + *dst = TfLiteSparsityClone(*src); + return dst; +} + +#endif // TF_LITE_STATIC_MEMORY + +} // namespace + +#ifndef TF_LITE_STATIC_MEMORY + TfLiteQuantization TfLiteQuantizationClone(const TfLiteQuantization& src) { TfLiteQuantization dst; dst.type = src.type; @@ -136,39 +171,8 @@ TfLiteQuantization TfLiteQuantizationClone(const TfLiteQuantization& src) { return dst; } -TfLiteSparsity TfLiteSparsityClone(const TfLiteSparsity& src) { - TfLiteSparsity dst = src; - dst.traversal_order = TfLiteIntArrayCopy(src.traversal_order); - dst.block_map = TfLiteIntArrayCopy(src.block_map); - if (src.dim_metadata) { - dst.dim_metadata = reinterpret_cast( - calloc(1, sizeof(TfLiteDimensionMetadata) * src.dim_metadata_size)); - for (int i = 0; i < src.dim_metadata_size; ++i) { - dst.dim_metadata[i] = src.dim_metadata[i]; - dst.dim_metadata[i].array_segments = - TfLiteIntArrayCopy(src.dim_metadata[i].array_segments); - dst.dim_metadata[i].array_indices = - TfLiteIntArrayCopy(src.dim_metadata[i].array_indices); - } - } - return dst; -} - -// Clones the source sparsity to a newly allocated object. -TfLiteSparsity* TfLiteSparsityClone(const TfLiteSparsity* const src) { - if (!src) { - return nullptr; - } - TfLiteSparsity* dst = - reinterpret_cast(calloc(1, sizeof(TfLiteSparsity))); - *dst = TfLiteSparsityClone(*src); - return dst; -} - #endif // TF_LITE_STATIC_MEMORY -} // namespace - extern "C" { size_t TfLiteIntArrayGetSizeInBytes(int size) { @@ -247,6 +251,11 @@ void TfLiteQuantizationFree(TfLiteQuantization* quantization) { } free(q_params); } + if (quantization->type == kTfLiteBlockwiseQuantization) { + TfLiteBlockwiseQuantization* q_params = + reinterpret_cast(quantization->params); + free(q_params); + } quantization->params = nullptr; quantization->type = kTfLiteNoQuantization; } diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 3f1fe32b8b4..c3e00cc0972 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -788,6 +788,7 @@ TfLiteStatus TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); /// If all dimensions are known, this is the same as `t->dims`. /// (`dims_signature` is NULL or empty if all dimensions are known.) const TfLiteIntArray* TfLiteTensorGetDimsSignature(const TfLiteTensor* t); + #endif // TF_LITE_STATIC_MEMORY /// WARNING: This is an experimental interface that is subject to change. @@ -1633,5 +1634,8 @@ TfLiteStatus TfLiteTensorVariantRealloc(TfLiteTensor* t, return kTfLiteOk; } +// Returns a copy of the quantization parameters of the tensor. +TfLiteQuantization TfLiteQuantizationClone(const TfLiteQuantization& src); + #endif // __cplusplus #endif // TENSORFLOW_LITE_CORE_C_COMMON_H_ diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py index 15077a6c62d..de7ef820079 100644 --- a/tensorflow/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -43,6 +43,16 @@ body {font-family: sans-serif; background-color: #fa0;} table {background-color: #eca;} th {background-color: black; color: white;} +/* Constrain table cells to a max size and make them scrollable. */ +.data-table td { + max-width: 900px; +} +.data-table .cell-content { + max-height: 200px; + overflow: auto; + white-space: pre-wrap; + word-break: break-all; +} h1 { background-color: ffaa00; padding:5px; @@ -284,6 +294,27 @@ def __call__(self, x): return html +def QuantizationMapper(q): + """Pretty-print the quantization dictionary, truncating large arrays.""" + if not q: + return "" + + items_str = [] + for key, value in q.items(): + key_str = repr(key) + # In TFLite, quantization arrays can be large. + if isinstance(value, list) and len(value) > 20: + head = value[:10] + tail = value[-10:] + value_str = (f"[{', '.join(map(repr, head))}, ..., " + f"{', '.join(map(repr, tail))}]") + else: + value_str = repr(value) + items_str.append(f"{key_str}: {value_str}") + + return f"{{{', '.join(items_str)}}}" + + def GenerateGraph(subgraph_idx, g, opcode_mapper): """Produces the HTML required to have a d3 visualization of the dag.""" @@ -359,8 +390,8 @@ def GenerateTableHtml(items, keys_to_print, display_index=True): An html table. """ html = "" - # Print the list of items - html += "\n" + # Print the list of items + html += "
\n" html += "\n" if display_index: html += "" @@ -375,7 +406,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True): for h, mapper in keys_to_print: val = tensor[h] if h in tensor else None val = val if mapper is None else mapper(val) - html += "\n" % val + html += "\n" % val html += "\n" html += "
index%s
%s
\n" @@ -465,11 +496,13 @@ def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invali toplevel_stuff = [("filename", None), ("version", None), ("description", None)] - html += "\n" + html += "
\n" for key, mapping in toplevel_stuff: if not mapping: mapping = lambda x: x - html += "\n" % (key, mapping(data.get(key))) + val = mapping(data.get(key)) + html += ("\n" + % (key, val)) html += "
%s%s
%s
%s
\n" # Spec on what keys to display @@ -493,7 +526,7 @@ def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invali tensor_keys_to_display = [("name", NameListToString), ("type", TensorTypeToName), ("shape", None), ("shape_signature", None), ("buffer", None), - ("quantization", None)] + ("quantization", QuantizationMapper)] html += "

Subgraph %d

\n" % subgraph_idx