4444#include < sys/types.h>
4545
4646#include < chrono>
47+ #include < csignal>
4748#include < iostream>
4849#include < memory>
4950#include < optional>
@@ -592,11 +593,24 @@ void py_unsetup_example(VW::workspace& ws, std::vector<VW::example*>& ex)
592593 for (auto & example : ex) { py_unsetup_example (ws, *example); }
593594}
594595
596+ // Because of the GIL we can use globals here.
597+ static bool SIGINT_CALLED = false ;
598+ static VW::workspace* CLI_DRIVER_WORKSPACE = nullptr ;
599+
595600// return type is an optional error information (nullopt if success), driver output, list of log messages
596601// stdin is not supported
597602std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> run_cli_driver (
598603 const std::vector<std::string>& args, bool onethread)
599604{
605+ SIGINT_CALLED = false ;
606+ CLI_DRIVER_WORKSPACE = nullptr ;
607+ std::signal (SIGINT,
608+ [](int )
609+ {
610+ if (CLI_DRIVER_WORKSPACE != nullptr ) { VW::details::set_done (*CLI_DRIVER_WORKSPACE); }
611+ SIGINT_CALLED = true ;
612+ });
613+
600614 auto args_copy = args;
601615 args_copy.push_back (" --no_stdin" );
602616 auto options = VW::make_unique<VW::config::options_cli>(args_copy);
@@ -620,18 +634,23 @@ std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> ru
620634 {
621635 auto all = VW::initialize_experimental (std::move (options), nullptr , driver_logger, &driver_log, &logger);
622636 all->vw_is_main = true ;
637+ CLI_DRIVER_WORKSPACE = all.get ();
623638
624- if (onethread) { VW::LEARNER::generic_driver_onethread (*all); }
625- else
639+ // If sigint was called before we got here, we should avoid running the driver.
640+ if (!SIGINT_CALLED)
626641 {
627- VW::start_parser (*all);
628- VW::LEARNER::generic_driver (*all);
629- VW::end_parser (*all);
642+ if (onethread) { VW::LEARNER::generic_driver_onethread (*all); }
643+ else
644+ {
645+ VW::start_parser (*all);
646+ VW::LEARNER::generic_driver (*all);
647+ VW::end_parser (*all);
648+ }
649+
650+ if (all->example_parser ->exc_ptr ) { std::rethrow_exception (all->example_parser ->exc_ptr ); }
651+ VW::sync_stats (*all);
652+ all->finish ();
630653 }
631-
632- if (all->example_parser ->exc_ptr ) { std::rethrow_exception (all->example_parser ->exc_ptr ); }
633- VW::sync_stats (*all);
634- all->finish ();
635654 }
636655 catch (const std::exception& ex)
637656 {
@@ -642,6 +661,8 @@ std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> ru
642661 return std::make_tuple (" Unknown exception occurred" , driver_log.str (), log_log);
643662 }
644663
664+ SIGINT_CALLED = false ;
665+ CLI_DRIVER_WORKSPACE = nullptr ;
645666 return std::make_tuple (std::nullopt , driver_log.str (), log_log);
646667}
647668
0 commit comments