Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,13 @@ def convert_op_to_relax(self):

assert isinstance(op, Operator)
ret = self.convert_map[op_code_str](op=op)
ret = self.bb.normalize(ret)
# print("Op Code:", op_code_str, " Shape:", ret.struct_info)

# In case the Op can be prefetched, the output can be optimized out
if ret is None:
continue

ret = self.bb.normalize(ret)

if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
Expand Down Expand Up @@ -1898,15 +1898,8 @@ def convert_fully_connected(self, op):
TensorType.UINT8,
TensorType.FLOAT32,
)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

if self.has_expr(weight_tensor.tensor_idx):
weight_expr = self.get_expr(weight_tensor.tensor_idx)
else:
weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(
weight_value, dtype=weight_tensor_type_str, source_name=weight_tensor.tensor.Name()
)
weight_expr = self.get_tensor_expr(weight_tensor)
weight_shape = weight_expr.struct_info.shape
weight_expr = relax.op.permute_dims(weight_expr, [1, 0])

Expand Down Expand Up @@ -3142,7 +3135,7 @@ def convert_transpose_conv(self, op):
weight_expr_iohw = self.get_expr(weights_tensor.tensor_idx)
weight_expr_iohw = relax.op.permute_dims(weight_expr_iohw, axes=(3, 0, 1, 2))
else:
weight_value_ohwi = self.get_tensor_value(weights_tensor)
weight_value_ohwi = self.get_tensor_value_or_prefetched(weights_tensor)
# Relax kernel_layout should be OIHW
# Relax weights layout should be different from kernel_layout - it should be IOHW
weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
Expand Down Expand Up @@ -3878,18 +3871,21 @@ def set_prefetched_node(self, input_tensor_idx, value):
def get_prefetched_node(self, input_tensor_idx):
return self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)]

def get_tensor_value_or_prefetched(self, tensor, is_sparse=False):
if self.is_prefetched(tensor.tensor_idx):
return self.get_prefetched_node(tensor.tensor_idx)
return self.get_tensor_value(tensor, is_sparse)

def get_tensor_expr(self, tensor, is_sparse=False):
"""Return the Relax expr for tensor."""
if self.has_expr(tensor.tensor_idx):
expr = self.get_expr(tensor.tensor_idx)
else:
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(
self.get_tensor_value(tensor, is_sparse),
dtype=type_str,
source_name=tensor.tensor.Name(),
)
return expr
return self.get_expr(tensor.tensor_idx)

type_str = self.get_tensor_type_str(tensor.tensor.Type())
value = self.get_tensor_value_or_prefetched(tensor, is_sparse)
return self.exp_tab.new_const(
value, dtype=type_str, source_name=tensor.tensor.Name()
)

def get_tensor_shape(self, tensor_wrapper):
"""Returns tensor shape. Infers shape if the shape is empty."""
Expand Down
Loading
Loading