diff --git a/detection_rules/rule.py b/detection_rules/rule.py index f24668686ff..6bc2f5b7098 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -18,7 +18,10 @@ from marshmallow import ValidationError, validates_schema import kql +from . import beats +from . import ecs from . import utils +from .misc import load_current_package_version from .mixins import MarshmallowDataclassMixin, StackCompatMixin from .rule_formatter import toml_write, nested_normalize from .schemas import SCHEMA_DIR, definitions, downgrade, get_stack_schemas, get_min_supported_stack_version @@ -26,6 +29,7 @@ from .semver import Version from .utils import cached +BUILD_FIELD_VERSIONS = {"required_fields": (Version('8.3'), None)} _META_SCHEMA_REQ_DEFAULTS = {} MIN_FLEET_PACKAGE_VERSION = '7.13.0' @@ -149,6 +153,12 @@ class FlatThreatMapping(MarshmallowDataclassMixin): @dataclass(frozen=True) class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): + @dataclass + class RequiredFields: + name: definitions.NonEmptyStr + type: definitions.NonEmptyStr + ecs: bool + actions: Optional[list] author: List[str] building_block_type: Optional[str] @@ -171,7 +181,7 @@ class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): # output_index: Optional[str] references: Optional[List[str]] related_integrations: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - required_fields: Optional[List[str]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) + required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) risk_score: definitions.RiskScore risk_score_mapping: Optional[List[RiskScoreMapping]] rule_id: definitions.UUIDString @@ -220,9 +230,45 @@ class QueryValidator: def ast(self) -> Any: raise NotImplementedError + @property + def unique_fields(self) -> Any: + raise NotImplementedError + def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: raise NotImplementedError() + @cached + def get_required_fields(self, index: str) -> List[dict]: + """Retrieves fields needed for the query along with type information from the schema.""" + current_version = Version(Version(load_current_package_version()) + (0,)) + ecs_version = get_stack_schemas()[str(current_version)]['ecs'] + beats_version = get_stack_schemas()[str(current_version)]['beats'] + ecs_schema = ecs.get_schema(ecs_version) + + beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version) + + required = [] + unique_fields = self.unique_fields or [] + + for fld in unique_fields: + field_type = ecs_schema.get(fld, {}).get('type') + is_ecs = field_type is not None + + if beat_schema and not is_ecs: + field_type = beat_schema.get(fld, {}).get('type') + + required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs)) + + return sorted(required, key=lambda f: f['name']) + + @cached + def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict): + """Get an assembled beats schema.""" + beat_types = beats.parse_beats_from_index(index) + beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None + schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema) + return beat_types, beat_schema, schema + @dataclass(frozen=True) class QueryRuleData(BaseRuleData): @@ -251,6 +297,18 @@ def ast(self): if validator is not None: return validator.ast + @cached_property + def unique_fields(self): + validator = self.validator + if validator is not None: + return validator.unique_fields + + @cached + def get_required_fields(self, index: str) -> List[dict]: + validator = self.validator + if validator is not None: + return validator.get_required_fields(index or []) + @dataclass(frozen=True) class MachineLearningRuleData(BaseRuleData): @@ -438,8 +496,7 @@ def autobumped_version(self) -> Optional[int]: return version + 1 if self.is_dirty else version - @staticmethod - def _post_dict_transform(obj: dict) -> dict: + def _post_dict_transform(self, obj: dict) -> dict: """Transform the converted API in place before sending to Kibana.""" # cleanup the whitespace in the rule @@ -515,6 +572,59 @@ def name(self) -> str: def type(self) -> str: return self.data.type + def _post_dict_transform(self, obj: dict) -> dict: + """Transform the converted API in place before sending to Kibana.""" + super()._post_dict_transform(obj) + + self.add_related_integrations(obj) + self.add_required_fields(obj) + self.add_setup(obj) + + # validate new fields against the schema + rule_type = obj['type'] + subclass = self.get_data_subclass(rule_type) + subclass.from_dict(obj) + + return obj + + def add_related_integrations(self, obj: dict) -> None: + """Add restricted field related_integrations to the obj.""" + # field_name = "related_integrations" + ... + + def add_required_fields(self, obj: dict) -> None: + """Add restricted field required_fields to the obj, derived from the query AST.""" + if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene': + index = obj.get('index') or [] + required_fields = self.data.get_required_fields(index) + else: + required_fields = [] + + field_name = "required_fields" + if self.check_restricted_field_version(field_name=field_name): + obj.setdefault(field_name, required_fields) + + def add_setup(self, obj: dict) -> None: + """Add restricted field setup to the obj.""" + # field_name = "setup" + ... + + def check_explicit_restricted_field_version(self, field_name: str) -> bool: + """Explicitly check restricted fields against global min and max versions.""" + min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name] + return self.compare_field_versions(min_stack, max_stack) + + def check_restricted_field_version(self, field_name: str) -> bool: + """Check restricted fields against schema min and max versions.""" + min_stack, max_stack = self.data.get_restricted_fields.get(field_name) + return self.compare_field_versions(min_stack, max_stack) + + def compare_field_versions(self, min_stack: Version, max_stack: Version) -> bool: + """Check current rule version is witihin min and max stack versions.""" + current_version = Version(load_current_package_version()) + max_stack = max_stack or current_version + return Version(min_stack) <= current_version >= Version(max_stack) + @validates_schema def validate_query(self, value: dict, **kwargs): """Validate queries by calling into the validator for the relevant method.""" @@ -540,11 +650,11 @@ def flattened_dict(self) -> dict: def to_api_format(self, include_version=True) -> dict: """Convert the TOML rule to the API format.""" converted = self.data.to_dict() + converted = self._post_dict_transform(converted) + if include_version: converted["version"] = self.autobumped_version - converted = self._post_dict_transform(converted) - return converted def check_restricted_fields_compatibility(self) -> Dict[str, dict]: diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index aaed3de5efb..dca70b6df2e 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -10,7 +10,7 @@ import eql import kql -from . import ecs, beats +from . import ecs from .rule import QueryValidator, QueryRuleData, RuleMeta @@ -21,7 +21,7 @@ class KQLValidator(QueryValidator): def ast(self) -> kql.ast.Expression: return kql.parse(self.query) - @property + @cached_property def unique_fields(self) -> List[str]: return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field))) @@ -29,9 +29,7 @@ def to_eql(self) -> eql.ast.Expression: return kql.to_eql(self.query) def validate(self, data: QueryRuleData, meta: RuleMeta) -> None: - """Static method to validate the query, called from the parent which contains [metadata] information.""" - ast = self.ast - + """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return @@ -41,9 +39,7 @@ def validate(self, data: QueryRuleData, meta: RuleMeta) -> None: ecs_version = mapping['ecs'] err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}' - beat_types = beats.parse_beats_from_index(data.index) - beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None - schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema) + beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version) try: kql.parse(self.query, schema=schema) @@ -73,14 +69,12 @@ def text_fields(self, eql_schema: ecs.KqlSchema2Eql) -> List[str]: return [f for f in self.unique_fields if elasticsearch_type_family(eql_schema.kql_schema.get(f)) == 'text'] - @property + @cached_property def unique_fields(self) -> List[str]: return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field))) def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: """Validate an EQL query while checking TOMLRule.""" - ast = self.ast - if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return @@ -90,9 +84,7 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: ecs_version = mapping['ecs'] err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}' - beat_types = beats.parse_beats_from_index(data.index) - beat_schema = beats.get_schema_from_kql(ast, beat_types, version=beats_version) if beat_types else None - schema = ecs.get_kql_schema(version=ecs_version, indexes=data.index or [], beat_schema=beat_schema) + beat_types, beat_schema, schema = self.get_beats_schema(data.index or [], beats_version, ecs_version) eql_schema = ecs.KqlSchema2Eql(schema) try: