@@ -408,9 +408,9 @@ def fit(
408408
409409 self ._jax_state_synced = True
410410 with epoch_iterator .catch_stop_iteration ():
411- for step , iterator in epoch_iterator :
411+ for begin_step , end_step , iterator in epoch_iterator :
412412 # Callbacks
413- callbacks .on_train_batch_begin (step )
413+ callbacks .on_train_batch_begin (begin_step )
414414
415415 # Train step
416416 if self ._jax_state_synced :
@@ -441,7 +441,7 @@ def fit(
441441 "metrics_variables" : metrics_variables ,
442442 }
443443 # Dispatch callbacks. This takes care of async dispatch.
444- callbacks .on_train_batch_end (step , logs )
444+ callbacks .on_train_batch_end (end_step , logs )
445445
446446 if self .stop_training :
447447 # Stop training if a callback has set
@@ -569,8 +569,8 @@ def evaluate(
569569
570570 self ._jax_state_synced = True
571571 with epoch_iterator .catch_stop_iteration ():
572- for step , iterator in epoch_iterator :
573- callbacks .on_test_batch_begin (step )
572+ for begin_step , end_step , iterator in epoch_iterator :
573+ callbacks .on_test_batch_begin (begin_step )
574574
575575 if self ._jax_state_synced :
576576 # The state may have been synced by a callback.
@@ -600,7 +600,7 @@ def evaluate(
600600 }
601601
602602 # Dispatch callbacks. This takes care of async dispatch.
603- callbacks .on_test_batch_end (step , logs )
603+ callbacks .on_test_batch_end (end_step , logs )
604604
605605 if self .stop_evaluating :
606606 break
@@ -633,7 +633,7 @@ def predict(
633633
634634 if not all (layer .built for layer in self ._flatten_layers ()):
635635 # Build the model on one batch of data.
636- for _ , iterator in epoch_iterator :
636+ for _ , _ , iterator in epoch_iterator :
637637 # Build model
638638 x , _ , _ = data_adapter_utils .unpack_x_y_sample_weight (
639639 next (iterator )
@@ -677,8 +677,8 @@ def append_to_outputs(batch_outputs, outputs):
677677 outputs = None
678678 non_trainable_variables = None
679679 with epoch_iterator .catch_stop_iteration ():
680- for step , iterator in epoch_iterator :
681- callbacks .on_predict_batch_begin (step )
680+ for begin_step , end_step , iterator in epoch_iterator :
681+ callbacks .on_predict_batch_begin (begin_step )
682682 if self ._jax_state_synced :
683683 # The state may have been synced by a callback.
684684 state = self ._get_jax_state (
@@ -701,7 +701,9 @@ def append_to_outputs(batch_outputs, outputs):
701701 outputs = append_to_outputs (batch_outputs , outputs )
702702
703703 # Dispatch callbacks. This takes care of async dispatch.
704- callbacks .on_predict_batch_end (step , {"outputs" : batch_outputs })
704+ callbacks .on_predict_batch_end (
705+ end_step , {"outputs" : batch_outputs }
706+ )
705707
706708 if self .stop_predicting :
707709 break
0 commit comments