diff --git a/src/CoreServer.cpp b/src/CoreServer.cpp index b4acf4cba..4b690ad3b 100644 --- a/src/CoreServer.cpp +++ b/src/CoreServer.cpp @@ -277,11 +277,7 @@ void CoreServer::_setup_routes(const PrometheusConfig &prom_config) return; } try { - auto [policy, lock] = _registry->policy_manager()->module_get_locked(name); - policy->stop(); - lock.unlock(); - // TODO chance of race here - _registry->policy_manager()->module_remove(name); + _registry->policy_manager()->remove_policy(name); res.set_content(j.dump(), "text/json"); } catch (const std::exception &e) { res.status = 500; diff --git a/src/Policies.cpp b/src/Policies.cpp index e47f19d3f..3bc630ee0 100644 --- a/src/Policies.cpp +++ b/src/Policies.cpp @@ -276,6 +276,28 @@ std::vector PolicyManager::load(const YAML::Node &policy_yaml) return result; } +void PolicyManager::remove_policy(const std::string &name) +{ + std::unique_lock lock(_map_mutex); + if (_map.count(name) == 0) { + throw ModuleException(name, fmt::format("module name '{}' does not exist", name)); + } + + auto policy = _map[name].get(); + auto input_name = policy->input_stream()->name(); + std::vector module_names; + for (const auto &mod : policy->modules()) { + module_names.push_back(mod->name()); + } + policy->stop(); + + for (const auto &name : module_names) { + _registry->handler_manager()->module_remove(name); + } + _registry->input_manager()->module_remove(input_name); + + _map.erase(name); +} void Policy::info_json(json &j) const { _input_stream->info_json(j["input"][_input_stream->name()]); diff --git a/src/Policies.h b/src/Policies.h index ea6cb13ac..96991fee9 100644 --- a/src/Policies.h +++ b/src/Policies.h @@ -91,6 +91,7 @@ class PolicyManager : public AbstractManager std::vector load_from_str(const std::string &str); std::vector load(const YAML::Node &tap_yaml); + void remove_policy(const std::string &name); }; } \ No newline at end of file diff --git a/src/handlers/pcap/PcapStreamHandler.cpp b/src/handlers/pcap/PcapStreamHandler.cpp index 356a1f10e..ca59309c8 100644 --- a/src/handlers/pcap/PcapStreamHandler.cpp +++ b/src/handlers/pcap/PcapStreamHandler.cpp @@ -31,11 +31,13 @@ void PcapStreamHandler::start() _metrics->set_recorded_stream(); } - _start_tstamp_connection = _pcap_stream->start_tstamp_signal.connect(&PcapStreamHandler::set_start_tstamp, this); - _end_tstamp_connection = _pcap_stream->end_tstamp_signal.connect(&PcapStreamHandler::set_end_tstamp, this); + if (_pcap_stream) { + _start_tstamp_connection = _pcap_stream->start_tstamp_signal.connect(&PcapStreamHandler::set_start_tstamp, this); + _end_tstamp_connection = _pcap_stream->end_tstamp_signal.connect(&PcapStreamHandler::set_end_tstamp, this); - _pcap_tcp_reassembly_errors_connection = _pcap_stream->tcp_reassembly_error_signal.connect(&PcapStreamHandler::process_pcap_tcp_reassembly_error, this); - _pcap_stats_connection = _pcap_stream->pcap_stats_signal.connect(&PcapStreamHandler::process_pcap_stats, this); + _pcap_tcp_reassembly_errors_connection = _pcap_stream->tcp_reassembly_error_signal.connect(&PcapStreamHandler::process_pcap_tcp_reassembly_error, this); + _pcap_stats_connection = _pcap_stream->pcap_stats_signal.connect(&PcapStreamHandler::process_pcap_stats, this); + } _running = true; } @@ -46,9 +48,12 @@ void PcapStreamHandler::stop() return; } - _start_tstamp_connection.disconnect(); - _end_tstamp_connection.disconnect(); - _pcap_tcp_reassembly_errors_connection.disconnect(); + if (_pcap_stream) { + _start_tstamp_connection.disconnect(); + _end_tstamp_connection.disconnect(); + _pcap_tcp_reassembly_errors_connection.disconnect(); + _pcap_stats_connection.disconnect(); + } _running = false; } diff --git a/src/tests/test_policies.cpp b/src/tests/test_policies.cpp index f2cddf1da..9c4459bbb 100644 --- a/src/tests/test_policies.cpp +++ b/src/tests/test_policies.cpp @@ -550,4 +550,39 @@ TEST_CASE("Policies", "[policies]") lock.unlock(); REQUIRE_NOTHROW(registry.policy_manager()->module_remove("default_view")); } + + SECTION("Good Config, test remove policy and add again") + { + CoreRegistry registry; + registry.start(nullptr); + YAML::Node config_file = YAML::Load(policies_config); + + CHECK(config_file["visor"]["policies"]); + CHECK(config_file["visor"]["policies"].IsMap()); + + REQUIRE_NOTHROW(registry.tap_manager()->load(config_file["visor"]["taps"], true)); + REQUIRE_NOTHROW(registry.policy_manager()->load(config_file["visor"]["policies"])); + + REQUIRE(registry.policy_manager()->module_exists("default_view")); + auto [policy, lock] = registry.policy_manager()->module_get_locked("default_view"); + CHECK(policy->name() == "default_view"); + CHECK(policy->input_stream()->running()); + CHECK(policy->modules()[0]->running()); + CHECK(policy->modules()[1]->running()); + CHECK(policy->modules()[2]->running()); + lock.unlock(); + + REQUIRE_NOTHROW(registry.policy_manager()->remove_policy("default_view")); + + REQUIRE_NOTHROW(registry.policy_manager()->load(config_file["visor"]["policies"])); + REQUIRE(registry.policy_manager()->module_exists("default_view")); + auto [new_policy, new_lock] = registry.policy_manager()->module_get_locked("default_view"); + CHECK(new_policy->name() == "default_view"); + CHECK(new_policy->input_stream()->running()); + CHECK(new_policy->modules()[0]->running()); + CHECK(new_policy->modules()[1]->running()); + CHECK(new_policy->modules()[2]->running()); + new_lock.unlock(); + REQUIRE_NOTHROW(registry.policy_manager()->remove_policy("default_view")); + } }