Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,3 @@ git_describe_command = ["sh", "-c", "tag=$(git tag | grep -v '-' | sort | tail -

[tool.setuptools.packages.find]
include = ["ssg*"]

[[tool.mypy.overrides]]
module = "pkg_resources"
ignore_missing_imports = true
78 changes: 62 additions & 16 deletions ssg/requirement_specs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,64 @@
"""
Common functions for processing Requirements Specs in SSG
"""

import pkg_resources
import re
from typing import Tuple, List

from ssg import utils

# Monkey-patch pkg_resources.safe_name function to keep underscores intact
# Setuptools recognize the issue: https://github.com/pypa/setuptools/issues/2522
pkg_resources.safe_name = lambda name: re.sub('[^A-Za-z0-9_.]+', '-', name)
# Monkey-patch pkg_resources.safe_extras function to keep dashes intact
# Setuptools recognize the issue: https://github.com/pypa/setuptools/pull/732
pkg_resources.safe_extra = lambda extra: re.sub('[^A-Za-z0-9.-]+', '_', extra).lower()

class RequirementParser:
"""
A simple parser for package requirements with version specifiers.
Handles formats like: package[extra]>=1.0,<2.0
"""

def __init__(self, target_v: str):
self.target = target_v
# First, extract package name and extras
base_match = re.match(
r'^(?P<name>[a-zA-Z0-9\-_.]+)(?:\[(?P<extra>[a-zA-Z0-9\-_]+)])?(?P<specs>.*)$',
target_v
)

if not base_match:
raise ValueError(f"Invalid requirement format: {target_v}")

self.name = base_match.group('name')
self.extra = base_match.group('extra')
specs_str = base_match.group('specs')

# Parse comma-separated version specifiers
self.specs_list = []
if specs_str and specs_str.strip():
for spec in specs_str.split(','):
spec = spec.strip()
if spec:
spec_match = re.match(r'^(?P<op>[><!~=]+)\s*(?P<ver>.+)$', spec)
if spec_match:
self.specs_list.append((spec_match.group('op'), spec_match.group('ver')))
else:
raise ValueError(f"Invalid version specifier: {spec}")

def __str__(self):
return self.target

@property
def specs(self) -> List[Tuple[str, str]]:
return self.specs_list

@property
def project_name(self) -> str:
return self.name

@property
def extras(self) -> List[str]:
if self.extra:
return [self.extra.lower()]
return []


def _parse_version_into_evr(version):
def parse_version_into_evr(version):
"""
Parses a version string into its epoch, version, and release components.

Expand Down Expand Up @@ -53,7 +96,7 @@ def _spec_to_version_specifier(spec):
VersionSpecifier: An object representing the version specifier.
"""
op, ver = spec
evr = _parse_version_into_evr(ver)
evr = parse_version_into_evr(ver)
return utils.VersionSpecifier(op, evr)


Expand All @@ -62,17 +105,20 @@ class Requirement:
A class to represent a package requirement with version specifications.

Attributes:
_req (pkg_resources.Requirement): The parsed requirement object.
_req (RequirementParser): The parsed requirement object.
_specs (utils.VersionSpecifierSet): The set of version specifiers for the requirement.
"""
def __init__(self, obj):
self._req = pkg_resources.Requirement.parse(obj)
def __init__(self, obj: str):
self._req = RequirementParser(obj)
self._specs = utils.VersionSpecifierSet(
[_spec_to_version_specifier(spec) for spec in self._req.specs]
)

def __contains__(self, item):
return item in self._req
"""Check if a version string satisfies the requirement specs."""
if not self.has_version_specs():
return False
return item in self._specs

def __str__(self):
return str(self._req)
Expand Down Expand Up @@ -131,7 +177,7 @@ def is_parametrized(name):
bool: True if the package requirement is parametrized (includes extras),
False otherwise.
"""
return bool(pkg_resources.Requirement.parse(name).extras)
return bool(RequirementParser(name).extras)

@staticmethod
def get_base_for_parametrized(name):
Expand All @@ -144,4 +190,4 @@ def get_base_for_parametrized(name):
Returns:
str: The base project name of the package.
"""
return pkg_resources.Requirement.parse(name).project_name
return RequirementParser(name).project_name
2 changes: 1 addition & 1 deletion ssg/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_module(module_name: str, module_path: str):
ValueError: If the module cannot be loaded due to an invalid spec or loader.
"""
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
import importlib
import importlib.util
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
if not spec:
raise ValueError("Error loading '%s' module" % module_path)
Expand Down
67 changes: 66 additions & 1 deletion ssg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from __future__ import print_function

import multiprocessing
import errno
import os
import re
from collections import namedtuple
import hashlib
from typing import Dict

from .constants import (FULL_NAME_TO_PRODUCT_MAPPING,
MAKEFILE_ID_TO_PRODUCT_MAP,
Expand Down Expand Up @@ -48,6 +48,25 @@ def __init__(self, s=()):
' invalid object: {0}'.format(repr(el)))
super(VersionSpecifierSet, self).__init__(s)

def __contains__(self, version_str):
"""Check if a version string satisfies all version specifiers in the set."""
# If checking for VersionSpecifier object membership, use parent implementation
if isinstance(version_str, VersionSpecifier):
return super(VersionSpecifierSet, self).__contains__(version_str)

# Otherwise, check if the version string satisfies all specs
from ssg import requirement_specs
try:
evr = requirement_specs.parse_version_into_evr(version_str)
except ValueError:
return False

# All specs must be satisfied
for spec in self:
if not spec.matches(evr):
return False
return True

@property
def title(self):
return ' and '.join([ver_spec.title for ver_spec in sorted(self)])
Expand Down Expand Up @@ -117,6 +136,52 @@ def cpe_id(self):
def oval_id(self):
return '{0}_{1}'.format(escape_comparison(self.op), escape_id(self.ver))

def matches(self, evr: Dict):
"""
Check if a given EVR dictionary satisfies this version specifier.

Args:
evr (dict): A dictionary containing 'epoch', 'version', and 'release' keys.

Returns:
bool: True if the EVR satisfies this version specifier, False otherwise.
"""
# Compare EVR components for proper version comparison
def evr_to_tuple(e):
"""Convert EVR dict to comparable tuple (epoch, version_parts, release)."""
epoch = int(e['epoch']) if e['epoch'] is not None else 0
# Split version into numeric parts for comparison
version_parts = [int(p) if p.isdigit() else p for p in e['version'].split('.')]
release = int(e['release']) if e['release'] is not None else 0
return (epoch, version_parts, release)

spec_epoch, spec_ver, spec_rel = evr_to_tuple(self._evr_ver_dict)
input_epoch, input_ver, input_rel = evr_to_tuple(evr)

# Normalize version lists to same length for comparison
max_len = max(len(spec_ver), len(input_ver))
spec_ver_norm = spec_ver + [0] * (max_len - len(spec_ver))
input_ver_norm = input_ver + [0] * (max_len - len(input_ver))

spec_tuple = (spec_epoch, spec_ver_norm, spec_rel)
input_tuple = (input_epoch, input_ver_norm, input_rel)

# Perform comparison based on operator
if self.op == '==':
return input_tuple == spec_tuple
elif self.op == '!=':
return input_tuple != spec_tuple
elif self.op == '>':
return input_tuple > spec_tuple
elif self.op == '<':
return input_tuple < spec_tuple
elif self.op == '>=':
return input_tuple >= spec_tuple
elif self.op == '<=':
return input_tuple <= spec_tuple
else:
return False

@staticmethod
def evr_dict_to_str(evr, fully_formed_evr_string=False):
"""
Expand Down
35 changes: 25 additions & 10 deletions tests/unit/ssg-module/test_requirement_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,47 @@


def test_parse_version_into_evr():
v = requirement_specs._parse_version_into_evr('1.22.333-4444')
v = requirement_specs.parse_version_into_evr('1.22.333-4444')
assert v == {'epoch': None, 'version': '1.22.333', 'release': '4444'}

v = requirement_specs._parse_version_into_evr('0')
v = requirement_specs.parse_version_into_evr('0')
assert v == {'epoch': None, 'version': '0', 'release': None}

v = requirement_specs._parse_version_into_evr('0-1')
v = requirement_specs.parse_version_into_evr('0-1')
assert v == {'epoch': None, 'version': '0', 'release': '1'}

# Empty version is not the same as version '0'.
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('')
v = requirement_specs.parse_version_into_evr('')

# we do not support letters anywhere for now

with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('1.0.0-r2')
v = requirement_specs.parse_version_into_evr('1.0.0-r2')
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('b1')
v = requirement_specs.parse_version_into_evr('b1')

# some more tests to ensure that the regex is correct
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('0:')
v = requirement_specs.parse_version_into_evr('0:')
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('-1')
v = requirement_specs.parse_version_into_evr('-1')
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr(':')
v = requirement_specs.parse_version_into_evr(':')
with pytest.raises(ValueError):
v = requirement_specs._parse_version_into_evr('-')
v = requirement_specs.parse_version_into_evr('-')

def test_requirement_parse():
req = requirement_specs.RequirementParser("package[NetworkManager]>=8.7")
assert req.project_name == 'package'
assert req.extras == ['networkmanager']
assert req.specs == [('>=', '8.7')]

req = requirement_specs.RequirementParser('linux_os')
assert req.project_name == 'linux_os'
assert req.specs == []

# Test comma-separated version specs
req = requirement_specs.RequirementParser('oranges>=1.0,<3.0')
assert req.project_name == 'oranges'
assert req.specs == [('>=', '1.0'), ('<', '3.0')]
Loading