Skip to content

Commit 8a08bd5

Browse files
committed
Changed the how unown_tensor attribute is set on TRT mod
1 parent 6446085 commit 8a08bd5

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void setup_input_tensors(
9696
std::vector<at::Tensor> inputs,
9797
c10::intrusive_ptr<TRTEngine> compiled_engine,
9898
bool cudagraphs_enabled,
99-
bool need_cudagraphs_record) {
99+
bool shape_changed) {
100100
// this is a buffer to store shape tensor input addresses throughout the runtime scope
101101
std::list<std::vector<int64_t>> inputShapeTensorValues;
102102
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);
@@ -140,12 +140,14 @@ void setup_input_tensors(
140140
} else {
141141
at::Tensor contig_input = inputs[i].view(shape).contiguous();
142142
formatted_inputs.emplace_back(std::move(contig_input));
143-
143+
bool need_cudagraphs_record = cudagraphs_enabled and
144+
(not compiled_engine->runtime_states.old_cudagraphs or shape_changed or
145+
compiled_engine->runtime_states.context_changed);
144146
if (need_cudagraphs_record) {
145147
// Create a new persistent input buffer
146148
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
147149
}
148-
if (need_cudagraphs_record or compiled_engine->allocated_outputs.size() == 0) {
150+
if (shape_changed or compiled_engine->allocated_outputs.size() == 0) {
149151
TORCHTRT_CHECK(
150152
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
151153
}
@@ -226,7 +228,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
226228
input_profiler_guard =
227229
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
228230
}
229-
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
231+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, shape_changed);
230232
// Check if input shapes can be inferred.
231233
int32_t const io_size{compiled_engine->io_size};
232234
std::vector<char const*> names(io_size);

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def preserve_module_specs(
949949
for attr in dir(gm):
950950
if attr.startswith("_frozen_param"):
951951
delattr(gm, attr)
952-
trt_module = None
952+
953953
for name, _ in partitioned_module.named_children():
954954
submodule = getattr(partitioned_module, name)
955955
# filter on the GraphModule
@@ -1082,8 +1082,12 @@ def preserve_module_specs(
10821082
trt_module = getattr(partitioned_module, name)
10831083
trt_module.setup_engine()
10841084

1085-
if trt_module:
1086-
trt_module.set_output_tensors_as_unowned(True)
1085+
output_node = list(partitioned_module.graph.nodes)[-1]
1086+
for arg in output_node.args:
1087+
target = arg[0].target
1088+
if "_run_on_acc" not in str(target):
1089+
continue
1090+
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
10871091

10881092
# Reset settings object to user specification after fallback to global partitioning mode
10891093
if fast_partitioner_failed:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,11 @@ def setup_input_tensors(
383383
self,
384384
contiguous_inputs: List[torch.Tensor],
385385
cudagraphs_enabled: bool,
386-
need_cudagraphs_record: bool,
386+
shape_changed: bool = True,
387387
) -> None:
388+
need_cudagraphs_record = cudagraphs_enabled and (
389+
not self.old_cudagraphs or shape_changed or self.context_changed
390+
)
388391
for i, input_name in enumerate(self.input_names):
389392
if not contiguous_inputs[i].is_cuda:
390393
logger.warning(
@@ -417,9 +420,7 @@ def setup_input_tensors(
417420
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
418421
self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data)
419422
else:
420-
if (
421-
need_cudagraphs_record or self.output_tensors is None
422-
): # First time execution:
423+
if shape_changed or self.output_tensors is None:
423424
self.context.set_input_shape(
424425
input_name, tuple(contiguous_inputs[i].shape)
425426
)
@@ -490,9 +491,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
490491
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
491492

492493
self.setup_input_tensors(
493-
contiguous_inputs,
494-
self.cudagraphs_enabled,
495-
need_cudagraphs_record,
494+
contiguous_inputs, self.cudagraphs_enabled, shape_changed
496495
)
497496

498497
if shape_changed:
@@ -807,3 +806,6 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
807806
return True
808807

809808
return False
809+
810+
def are_output_tensors_unowned(self) -> bool:
811+
return self.output_tensors_are_unowned

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ def build_libtorchtrt_cxx11_abi(
195195
else:
196196
cmd.append("//:libtorchtrt")
197197

198-
# if develop:
199-
# cmd.append("--compilation_mode=dbg")
200-
# else:
201-
cmd.append("--compilation_mode=opt")
198+
if develop:
199+
cmd.append("--compilation_mode=dbg")
200+
else:
201+
cmd.append("--compilation_mode=opt")
202202
if use_dist_dir:
203203
if IS_AARCH64:
204204
cmd.append("--distdir=third_party/dist_dir/aarch64-linux-gnu")

0 commit comments

Comments
 (0)