Skip to content

Commit 0da0f92

Browse files
authored
feat: install sigint handler for run_cli_driver (#78)
Partially addresses #77
1 parent 5bbafb0 commit 0da0f92

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

src/cpp/main.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <sys/types.h>
4545

4646
#include <chrono>
47+
#include <csignal>
4748
#include <iostream>
4849
#include <memory>
4950
#include <optional>
@@ -592,11 +593,24 @@ void py_unsetup_example(VW::workspace& ws, std::vector<VW::example*>& ex)
592593
for (auto& example : ex) { py_unsetup_example(ws, *example); }
593594
}
594595

596+
// Because of the GIL we can use globals here.
597+
static bool SIGINT_CALLED = false;
598+
static VW::workspace* CLI_DRIVER_WORKSPACE = nullptr;
599+
595600
// return type is an optional error information (nullopt if success), driver output, list of log messages
596601
// stdin is not supported
597602
std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> run_cli_driver(
598603
const std::vector<std::string>& args, bool onethread)
599604
{
605+
SIGINT_CALLED = false;
606+
CLI_DRIVER_WORKSPACE = nullptr;
607+
std::signal(SIGINT,
608+
[](int)
609+
{
610+
if (CLI_DRIVER_WORKSPACE != nullptr) { VW::details::set_done(*CLI_DRIVER_WORKSPACE); }
611+
SIGINT_CALLED = true;
612+
});
613+
600614
auto args_copy = args;
601615
args_copy.push_back("--no_stdin");
602616
auto options = VW::make_unique<VW::config::options_cli>(args_copy);
@@ -620,18 +634,23 @@ std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> ru
620634
{
621635
auto all = VW::initialize_experimental(std::move(options), nullptr, driver_logger, &driver_log, &logger);
622636
all->vw_is_main = true;
637+
CLI_DRIVER_WORKSPACE = all.get();
623638

624-
if (onethread) { VW::LEARNER::generic_driver_onethread(*all); }
625-
else
639+
// If sigint was called before we got here, we should avoid running the driver.
640+
if (!SIGINT_CALLED)
626641
{
627-
VW::start_parser(*all);
628-
VW::LEARNER::generic_driver(*all);
629-
VW::end_parser(*all);
642+
if (onethread) { VW::LEARNER::generic_driver_onethread(*all); }
643+
else
644+
{
645+
VW::start_parser(*all);
646+
VW::LEARNER::generic_driver(*all);
647+
VW::end_parser(*all);
648+
}
649+
650+
if (all->example_parser->exc_ptr) { std::rethrow_exception(all->example_parser->exc_ptr); }
651+
VW::sync_stats(*all);
652+
all->finish();
630653
}
631-
632-
if (all->example_parser->exc_ptr) { std::rethrow_exception(all->example_parser->exc_ptr); }
633-
VW::sync_stats(*all);
634-
all->finish();
635654
}
636655
catch (const std::exception& ex)
637656
{
@@ -642,6 +661,8 @@ std::tuple<std::optional<std::string>, std::string, std::vector<std::string>> ru
642661
return std::make_tuple("Unknown exception occurred", driver_log.str(), log_log);
643662
}
644663

664+
SIGINT_CALLED = false;
665+
CLI_DRIVER_WORKSPACE = nullptr;
645666
return std::make_tuple(std::nullopt, driver_log.str(), log_log);
646667
}
647668

0 commit comments

Comments
 (0)