-
Notifications
You must be signed in to change notification settings - Fork 644
Expand file tree
/
Copy pathconfig.py
More file actions
370 lines (306 loc) · 13.5 KB
/
config.py
File metadata and controls
370 lines (306 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Configuration support for custom components."""
import fnmatch
import os
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any
import yaml
from eql.utils import load_dump # type: ignore[reportMissingTypeStubs]
from .misc import discover_tests
from .utils import (
OPTIONAL_ELASTIC_VALIDATION_BYPASS_ENV,
cached,
get_etc_path,
load_etc_dump,
set_all_validation_bypass,
)
ROOT_DIR = Path(__file__).parent.parent
CUSTOM_RULES_DIR = os.getenv("CUSTOM_RULES_DIR", None)
@dataclass
class UnitTest:
"""Base object for unit tests configuration."""
bypass: list[str] | None = None
test_only: list[str] | None = None
def __post_init__(self) -> None:
if self.bypass and self.test_only:
raise ValueError("Cannot set both `test_only` and `bypass` in test_config!")
@dataclass
class RuleValidation:
"""Base object for rule validation configuration."""
bypass: list[str] | None = None
test_only: list[str] | None = None
def __post_init__(self) -> None:
if self.bypass and self.test_only:
raise ValueError("Cannot use both test_only and bypass")
@dataclass
class ConfigFile:
"""Base object for configuration files."""
@dataclass
class FilePaths:
packages_file: str
stack_schema_map_file: str
deprecated_rules_file: str | None = None
version_lock_file: str | None = None
@dataclass
class TestConfigPath:
config: str
files: FilePaths
rule_dir: list[str]
testing: TestConfigPath | None = None
@classmethod
def from_dict(cls, obj: dict[str, Any]) -> "ConfigFile":
files_data = obj.get("files", {})
files = cls.FilePaths(
deprecated_rules_file=files_data.get("deprecated_rules"),
packages_file=files_data["packages"],
stack_schema_map_file=files_data["stack_schema_map"],
version_lock_file=files_data.get("version_lock"),
)
rule_dir = obj["rule_dirs"]
testing_data = obj.get("testing")
testing = cls.TestConfigPath(config=testing_data["config"]) if testing_data else None
return cls(files=files, rule_dir=rule_dir, testing=testing)
@dataclass
class TestConfig:
"""Detection rules test config file"""
test_file: Path | None = None
unit_tests: UnitTest | None = None
rule_validation: RuleValidation | None = None
@classmethod
def from_dict(
cls,
test_file: Path | None = None,
unit_tests: dict[str, Any] | None = None,
rule_validation: dict[str, Any] | None = None,
) -> "TestConfig":
return cls(
test_file=test_file or None,
unit_tests=UnitTest(**unit_tests or {}),
rule_validation=RuleValidation(**rule_validation or {}),
)
@cached_property
def all_tests(self) -> list[str]:
"""Get the list of all test names."""
return discover_tests()
def tests_by_patterns(self, *patterns: str) -> list[str]:
"""Get the list of test names by patterns."""
tests: set[str] = set()
for pattern in patterns:
tests.update(list(fnmatch.filter(self.all_tests, pattern)))
return sorted(tests)
@staticmethod
def parse_out_patterns(names: list[str]) -> tuple[list[str], list[str]]:
"""Parse out test patterns from a list of test names."""
patterns: list[str] = []
tests: list[str] = []
for name in names:
if name.startswith("pattern:") and "*" in name:
patterns.append(name[len("pattern:") :])
else:
tests.append(name)
return patterns, tests
@staticmethod
def format_tests(tests: list[str]) -> list[str]:
"""Format unit test names into expected format for direct calling."""
raw = [t.rsplit(".", maxsplit=2) for t in tests]
formatted: list[str] = []
for test in raw:
path, clazz, method = test
path = f"{path.replace('.', os.path.sep)}.py"
formatted.append(f"{path}::{clazz}::{method}")
return formatted
def get_test_names(self, formatted: bool = False) -> tuple[list[str], list[str]]:
"""Get the list of test names to run."""
if not self.unit_tests:
raise ValueError("No unit tests defined")
patterns_t, tests_t = self.parse_out_patterns(self.unit_tests.test_only or [])
patterns_b, tests_b = self.parse_out_patterns(self.unit_tests.bypass or [])
defined_tests = tests_t + tests_b
patterns = patterns_t + patterns_b
unknowns = sorted(set(defined_tests) - set(self.all_tests))
if unknowns:
raise ValueError(f"Unrecognized test names in config ({self.test_file}): {unknowns}")
combined_tests = sorted(set(defined_tests + self.tests_by_patterns(*patterns)))
if self.unit_tests.test_only is not None:
tests = combined_tests
skipped = [t for t in self.all_tests if t not in tests]
elif self.unit_tests.bypass:
tests: list[str] = []
skipped: list[str] = []
for test in self.all_tests:
if test not in combined_tests:
tests.append(test)
else:
skipped.append(test)
else:
tests = self.all_tests
skipped = []
if formatted:
return self.format_tests(tests), self.format_tests(skipped)
return tests, skipped
def check_skip_by_rule_id(self, rule_id: str) -> bool:
"""Check if a rule_id should be skipped."""
if not self.rule_validation:
raise ValueError("No rule validation specified")
bypass = self.rule_validation.bypass
test_only = self.rule_validation.test_only
# neither bypass nor test_only are defined, so no rules are skipped
if not (bypass or test_only):
return False
# if defined in bypass or not defined in test_only, then skip
return bool((bypass and rule_id in bypass) or (test_only and rule_id not in test_only))
@dataclass
class RulesConfig:
"""Detection rules config file."""
deprecated_rules_file: Path
deprecated_rules: dict[str, dict[str, Any]]
packages_file: Path
packages: dict[str, dict[str, Any]]
rule_dirs: list[Path]
stack_schema_map_file: Path
stack_schema_map: dict[str, dict[str, Any]]
test_config: TestConfig
version_lock_file: Path
version_lock: dict[str, dict[str, Any]]
action_dir: Path | None = None
action_connector_dir: Path | None = None
auto_gen_schema_file: Path | None = None
bbr_rules_dirs: list[Path] = field(default_factory=list) # type: ignore[reportUnknownVariableType]
bypass_version_lock: bool = False
exception_dir: Path | None = None
normalize_kql_keywords: bool = True
bypass_optional_elastic_validation: bool = False
bypass_note_validation_and_parse: bool = False
bypass_bbr_lookback_validation: bool = False
bypass_tags_validation: bool = False
bypass_timeline_template_validation: bool = False
bypass_esql_keep_validation: bool = False
bypass_esql_metadata_validation: bool = False
no_tactic_filename: bool = False
def __post_init__(self) -> None:
"""Perform post validation on packages.yaml file."""
if "package" not in self.packages:
raise ValueError("Missing the `package` field defined in packages.yaml.")
if "name" not in self.packages["package"]:
raise ValueError("Missing the `name` field defined in packages.yaml.")
@cached
def parse_rules_config(path: Path | None = None) -> RulesConfig: # noqa: PLR0912, PLR0915
"""Parse the _config.yaml file for default or custom rules."""
if path:
if not path.exists():
raise ValueError(f"rules config file does not exist: {path}")
loaded = yaml.safe_load(path.read_text())
elif CUSTOM_RULES_DIR:
path = Path(CUSTOM_RULES_DIR).expanduser() / "_config.yaml"
if not path.exists():
raise FileNotFoundError(
"""
Configuration file not found.
Please create a configuration file. You can use the 'custom-rules setup-config' command
and update the 'CUSTOM_RULES_DIR' environment variable as needed.
"""
)
loaded = yaml.safe_load(path.read_text())
else:
path = Path(get_etc_path(["_config.yaml"]))
loaded = load_etc_dump(["_config.yaml"])
try:
_ = ConfigFile.from_dict(loaded)
except KeyError as e:
raise SystemExit(f"Missing key `{e!s}` in _config.yaml file.") from e
except (AttributeError, TypeError) as e:
raise SystemExit(f"No data properly loaded from {path}") from e
except ValueError as e:
raise SystemExit(e) from e
base_dir = path.resolve().parent
# testing
# precedence to the environment variable
# environment variable is absolute path and config file is relative to the _config.yaml file
test_config_ev = os.getenv("DETECTION_RULES_TEST_CONFIG", None)
if test_config_ev:
test_config_path = Path(test_config_ev)
else:
test_config_file = loaded.get("testing", {}).get("config")
test_config_path = base_dir.joinpath(test_config_file) if test_config_file else None
if test_config_path:
test_config_data = yaml.safe_load(test_config_path.read_text())
# overwrite None with empty list to allow implicit exemption of all tests with `test_only` defined to None in
# test config
if "unit_tests" in test_config_data and test_config_data["unit_tests"] is not None:
test_config_data["unit_tests"] = {k: v or [] for k, v in test_config_data["unit_tests"].items()}
test_config = TestConfig.from_dict(test_file=test_config_path, **test_config_data)
else:
test_config = TestConfig.from_dict()
# files
# paths are relative
files = {f"{k}_file": base_dir.joinpath(v) for k, v in loaded["files"].items()}
contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded["files"].items()}
contents.update(**files)
# directories
# paths are relative
if loaded.get("directories"):
contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded["directories"].items()})
# rule_dirs
# paths are relative
contents["rule_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("rule_dirs")]
# directories
# paths are relative
if loaded.get("directories"):
directories = loaded.get("directories")
if directories.get("exception_dir"):
contents["exception_dir"] = base_dir.joinpath(directories.get("exception_dir")).resolve()
if directories.get("action_dir"):
contents["action_dir"] = base_dir.joinpath(directories.get("action_dir")).resolve()
if directories.get("action_connector_dir"):
contents["action_connector_dir"] = base_dir.joinpath(directories.get("action_connector_dir")).resolve()
# version strategy
contents["bypass_version_lock"] = loaded.get("bypass_version_lock", False)
# bbr_rules_dirs
# paths are relative
if loaded.get("bbr_rules_dirs"):
contents["bbr_rules_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("bbr_rules_dirs", [])]
# kql keyword normalization
contents["normalize_kql_keywords"] = loaded.get("normalize_kql_keywords", True)
if loaded.get("auto_gen_schema_file"):
contents["auto_gen_schema_file"] = base_dir.joinpath(loaded["auto_gen_schema_file"])
# Check if the file exists
if not contents["auto_gen_schema_file"].exists():
# If the file doesn't exist, create the necessary directories and file
contents["auto_gen_schema_file"].parent.mkdir(parents=True, exist_ok=True)
_ = contents["auto_gen_schema_file"].write_text("{}")
# bypass_optional_elastic_validation
contents["bypass_optional_elastic_validation"] = loaded.get("bypass_optional_elastic_validation", False)
if contents["bypass_optional_elastic_validation"]:
set_all_validation_bypass(True)
for yaml_key in OPTIONAL_ELASTIC_VALIDATION_BYPASS_ENV:
contents[yaml_key] = True
else:
for yaml_key, env_var in OPTIONAL_ELASTIC_VALIDATION_BYPASS_ENV.items():
if yaml_key in loaded:
val = loaded[yaml_key]
if not isinstance(val, bool):
raise SystemExit(
f"`{yaml_key}` in _config.yaml must be a boolean (true/false), not {type(val).__name__}"
)
else:
val = False
contents[yaml_key] = val
if val:
os.environ[env_var] = str(True)
# no_tactic_filename
contents["no_tactic_filename"] = loaded.get("no_tactic_filename", False)
# return the config
try:
rules_config = RulesConfig(test_config=test_config, **contents) # type: ignore[reportArgumentType]
except (ValueError, TypeError) as e:
raise SystemExit(f"Error parsing packages.yaml: {e!s}") from e
return rules_config
@cached
def load_current_package_version() -> str:
"""Load the current package version from config file."""
return parse_rules_config().packages["package"]["name"]