2222#include " paddle/fluid/inference/tensorrt/engine.h"
2323#include " paddle/fluid/inference/tensorrt/helper.h"
2424#include " paddle/fluid/inference/tensorrt/op_teller.h"
25+ #include " paddle/fluid/inference/utils/io_utils.h"
2526
2627namespace paddle {
2728namespace inference {
@@ -197,6 +198,17 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
197198 auto opt_input_shape =
198199 Get<std::map<std::string, std::vector<int >>>(" optim_input_shape" );
199200
201+ auto allow_build_at_runtime = Get<bool >(" trt_allow_build_at_runtime" );
202+ auto shape_range_info_path = Get<std::string>(" trt_shape_range_info_path" );
203+ auto trt_tuned_dynamic_shape = Get<bool >(" trt_tuned_dynamic_shape" );
204+ int max_batch_size = Get<int >(" max_batch_size" );
205+ if (trt_tuned_dynamic_shape) {
206+ VLOG (1 ) << " trt dynamic_shape deserialize from " << shape_range_info_path;
207+ inference::DeserializeShapeRangeInfo (shape_range_info_path,
208+ &min_input_shape, &max_input_shape,
209+ &opt_input_shape);
210+ }
211+
200212 // The following procedure is used to rename all the intermediate
201213 // variables and the output variables of the subgraph.
202214 // Why we do this?
@@ -242,12 +254,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
242254
243255 op_desc->SetBlockAttr (" sub_block" , new_block);
244256 op_desc->SetAttr (" subgraph" , block_desc.Proto ()->SerializeAsString ());
245- op_desc->SetAttr (" max_batch_size" , Get< int >( " max_batch_size" ) );
257+ op_desc->SetAttr (" max_batch_size" , max_batch_size);
246258 op_desc->SetAttr (" workspace_size" , Get<int >(" workspace_size" ));
247259 op_desc->SetAttr (" gpu_id" , Get<int >(" gpu_device_id" ));
248260 op_desc->SetAttr (" output_name_mapping" , output_mapping);
249261 op_desc->SetAttr (" origin_output_dims" , renamed_output_dims);
250262 op_desc->SetAttr (" parameters" , params);
263+ op_desc->SetAttr (" allow_build_at_runtime" , allow_build_at_runtime);
264+ op_desc->SetAttr (" shape_range_info_path" , shape_range_info_path);
251265
252266 // we record all inputs' shapes in attr to check if they are consistent
253267 // with the real inputs' shapes retrieved from scope when trt runs.
@@ -259,19 +273,24 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
259273 }
260274
261275 auto use_static_engine = Get<bool >(" use_static_engine" );
276+ op_desc->SetAttr (" use_static_engine" , use_static_engine);
277+ if (use_static_engine)
278+ op_desc->SetAttr (" model_opt_cache_dir" ,
279+ Get<std::string>(" model_opt_cache_dir" ));
280+
262281 // TODO(NHZlX)
263282 // There are models with the same structure but the different parameters,
264283 // when running in the 'use_serialize' mode, there is a bug.
265284 // serialization is affected by max_batch_size, but calibration is not.
266285 // So we use seperate engine keys in serialization and calibration.
267286 auto engine_key = GenerateEngineKey (
268287 input_names_with_id, output_names_with_id, std::to_string (0 ),
269- std::to_string (Get< int >( " max_batch_size" ) ),
288+ std::to_string (max_batch_size),
270289 std::to_string (static_cast <int >(precision_mode)), false );
271- auto calibration_engine_key = GenerateEngineKey (
272- input_names_with_id, output_names_with_id, std::to_string ( 0 ) ,
273- std::to_string (Get< int >( " max_batch_size" ) ),
274- std::to_string (static_cast <int >(precision_mode)), true );
290+ auto calibration_engine_key =
291+ GenerateEngineKey ( input_names_with_id, output_names_with_id,
292+ std::to_string (0 ), std::to_string ( max_batch_size),
293+ std::to_string (static_cast <int >(precision_mode)), true );
275294 auto predictor_id = Get<int >(" predictor_id" );
276295
277296 // Get "" when there is no cached calibration table data.
@@ -345,11 +364,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
345364 bool disable_trt_plugin_fp16 = Get<bool >(" disable_trt_plugin_fp16" );
346365 tensorrt::TensorRTEngine *trt_engine =
347366 inference::Singleton<inference::tensorrt::TRTEngineManager>::Global ()
348- .Create (engine_key + std::to_string (predictor_id),
349- Get<int >(" max_batch_size" ), Get<int >(" workspace_size" ),
350- precision_mode, calibrator.get (), Get<int >(" gpu_device_id" ),
351- min_input_shape, max_input_shape, opt_input_shape,
352- disable_trt_plugin_fp16);
367+ .Create (engine_key + std::to_string (predictor_id), max_batch_size,
368+ Get<int >(" workspace_size" ), precision_mode, calibrator.get (),
369+ Get<int >(" gpu_device_id" ), min_input_shape, max_input_shape,
370+ opt_input_shape, disable_trt_plugin_fp16);
353371 trt_engine->SetUseOSS (Get<bool >(" use_oss" ));
354372 trt_engine->SetUseDLA (Get<bool >(" trt_use_dla" ));
355373 trt_engine->SetDLACore (Get<int >(" trt_dla_core" ));
0 commit comments