diff --git a/src/launchpad/kafka.py b/src/launchpad/kafka.py index c41741f3..9bf2ca5f 100644 --- a/src/launchpad/kafka.py +++ b/src/launchpad/kafka.py @@ -56,14 +56,14 @@ def _process_in_subprocess(decoded_message: Any, log_queue: multiprocessing.Queu def _kill_process(process: multiprocessing.Process, artifact_id: str) -> None: """Gracefully terminate, then force kill a subprocess.""" process.terminate() - process.join(timeout=5) + process.join(timeout=1) if process.is_alive(): logger.warning( "Process did not terminate gracefully, force killing", extra={"artifact_id": artifact_id, "pid": process.pid}, ) process.kill() - process.join(timeout=1) # Brief timeout to reap zombie, avoid infinite block + process.join(timeout=0.5) if process.is_alive(): logger.error( "Process could not be killed, may become zombie", @@ -79,6 +79,9 @@ def process_kafka_message_with_service( factory: LaunchpadStrategyFactory, ) -> Any: """Process a Kafka message by spawning a fresh subprocess with timeout protection.""" + if factory._is_shutting_down: + raise TimeoutError("Skipping message processing - shutdown in progress") + timeout = int(os.getenv("KAFKA_TASK_TIMEOUT_SECONDS", "720")) # 12 minutes default try: @@ -101,21 +104,12 @@ def process_kafka_message_with_service( process.join(timeout=timeout) # Check if killed during rebalance - pid = process.pid - if pid is not None: - with registry_lock: - was_killed_by_rebalance = pid in factory._killed_during_rebalance - if was_killed_by_rebalance: - factory._killed_during_rebalance.discard(pid) - - if was_killed_by_rebalance: - # Wait for kill to complete, then don't commit offset - process.join(timeout=10) # Give kill_active_processes time to finish - logger.warning( - "Process killed during rebalance, message will be reprocessed", - extra={"artifact_id": artifact_id}, - ) - raise TimeoutError("Subprocess killed during rebalance") + if factory._is_shutting_down: + logger.warning( + "Process killed during rebalance, message will be reprocessed", + extra={"artifact_id": artifact_id}, + ) + raise TimeoutError("Subprocess killed during rebalance") # Handle timeout (process still alive after full timeout) if process.is_alive(): @@ -201,7 +195,7 @@ def poll(self) -> None: self._inner.poll() def close(self) -> None: - # Kill all active subprocesses BEFORE closing inner strategy + self._factory._is_shutting_down = True self._factory.kill_active_processes() self._inner.close() @@ -264,7 +258,7 @@ def __init__( self._active_processes: dict[int, tuple[multiprocessing.Process, str]] = {} self._processes_lock = threading.Lock() - self._killed_during_rebalance: set[int] = set() + self._is_shutting_down = False self.concurrency = concurrency self.max_pending_futures = max_pending_futures @@ -286,7 +280,6 @@ def kill_active_processes(self) -> None: ) for pid, (process, artifact_id) in list(self._active_processes.items()): if process.is_alive(): - self._killed_during_rebalance.add(pid) logger.info("Terminating subprocess with PID %d", pid) _kill_process(process, artifact_id) self._active_processes.clear() @@ -297,6 +290,9 @@ def create_with_partitions( partitions: Mapping[Partition, int], ) -> ProcessingStrategy[KafkaPayload]: """Create the processing strategy chain.""" + # Reset shutdown flag when creating new strategy after rebalance + self._is_shutting_down = False + next_step: ProcessingStrategy[Any] = CommitOffsets(commit) processing_function = partial(