@@ -185,6 +185,9 @@ def process(
185185
186186 wq , q_scales , q_zeros , q_g_idx , duration , avg_loss , damp_percent , nsamples = g .quantize ()
187187
188+ workspace_summary = getattr (g , "_borrow_workspace_last_summary" , None )
189+ workspace_totals = getattr (g , "_borrow_workspace_totals" , None )
190+
188191 module .stream_state_payload_to_cpu (
189192 {
190193 "q_scales" : q_scales ,
@@ -235,6 +238,25 @@ def process(
235238 PROCESS_USED_MEMORY : self .device_memory_report (),
236239 }
237240
241+ if workspace_summary :
242+ requests = int (workspace_summary .get ("requests" , 0 ) or 0 )
243+ if requests :
244+ hit_rate = float (workspace_summary .get ("hit_rate" , 0.0 ) or 0.0 )
245+ chunk_rows = workspace_summary .get ("chunk_rows" )
246+ stat ["workspace_cache_requests" ] = str (requests )
247+ stat ["workspace_cache_hit_rate" ] = f"{ hit_rate :.1%} "
248+ stat ["workspace_stage_dtype" ] = workspace_summary .get ("staging_dtype" , "" )
249+ if chunk_rows is not None :
250+ stat ["workspace_chunk_rows" ] = str (chunk_rows )
251+ if workspace_totals :
252+ total_requests = int (workspace_totals .get ("requests" , 0 ) or 0 )
253+ if total_requests :
254+ cumulative_hit_rate = (
255+ float (workspace_totals .get ("materialized_hits" , 0 ) or 0.0 ) / total_requests
256+ )
257+ stat ["workspace_total_requests" ] = str (total_requests )
258+ stat ["workspace_total_hit_rate" ] = f"{ cumulative_hit_rate :.1%} "
259+
238260 if self .qcfg .dynamic is not None :
239261 stat ["dynamic" ] = self .qcfg .dynamic_get (layer_name = module .full_name )
240262
@@ -244,6 +266,8 @@ def process(
244266 # Log the new row
245267 self .log_new_row (stat )
246268
269+ g .log_workspace_stats (context = "gptq_process" )
270+
247271 if self .calculate_w_wq_diff :
248272 # diff in float32
249273 w_wq_diff = module .weight .data .to (dtype = torch .float32 ) - wq .to (dtype = torch .float32 )
0 commit comments