diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7332f920f..989f09e70 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,15 @@ repos: hooks: - id: isort + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.6.1 + hooks: + - id: mypy + args: [ + "--config", + "pyproject.toml" + ] + - repo: local hooks: - id: pylint diff --git a/benchmarks/track/bulk_params_test.py b/benchmarks/track/bulk_params_test.py index 6aab90729..8eb6b80dc 100644 --- a/benchmarks/track/bulk_params_test.py +++ b/benchmarks/track/bulk_params_test.py @@ -36,7 +36,7 @@ def seek(self, offset): pass def read(self): - return "\n".join(self.contents) + return "\n".join(self.contents) # type: ignore[arg-type] # TODO remove this ignore when introducing type hints def readline(self): return self.contents diff --git a/esrally/config.py b/esrally/config.py index b8dd24437..3c4513a5b 100644 --- a/esrally/config.py +++ b/esrally/config.py @@ -22,7 +22,7 @@ from enum import Enum from string import Template -from esrally import PROGRAM_NAME, exceptions, paths +from esrally import PROGRAM_NAME, exceptions, paths, types from esrally.utils import io @@ -50,7 +50,7 @@ def present(self): """ return os.path.isfile(self.location) - def load(self): + def load(self) -> configparser.ConfigParser: config = configparser.ConfigParser() config.read(self.location, encoding="utf-8") return config @@ -66,7 +66,7 @@ def store_default_config(self, template_path=None): contents = src.read() target.write(Template(contents).substitute(CONFIG_DIR=self.config_dir)) - def store(self, config): + def store(self, config: configparser.ConfigParser): io.ensure_dir(self.config_dir) with open(self.location, "w", encoding="utf-8") as configfile: config.write(configfile) @@ -89,7 +89,7 @@ def location(self): return os.path.join(self.config_dir, f"rally{config_name_suffix}.ini") -def auto_load_local_config(base_config, additional_sections=None, config_file_class=ConfigFile, **kwargs): +def auto_load_local_config(base_config, additional_sections=None, config_file_class=ConfigFile, **kwargs) -> types.Config: """ Loads a node-local configuration based on a ``base_config``. If an appropriate node-local configuration file is present, it will be used (and potentially upgraded to the newest config version). Otherwise, a new one will be created and as many settings as possible @@ -138,7 +138,7 @@ def __init__(self, config_name=None, config_file_class=ConfigFile, **kwargs): self._opts = {} self._clear_config() - def add(self, scope, section, key, value): + def add(self, scope, section: types.Section, key: types.Key, value): """ Adds or overrides a new configuration property. @@ -149,7 +149,7 @@ def add(self, scope, section, key, value): """ self._opts[self._k(scope, section, key)] = value - def add_all(self, source, section): + def add_all(self, source, section: types.Section): """ Adds all config items within the given `section` from the `source` config object. @@ -162,7 +162,7 @@ def add_all(self, source, section): if source_section == section: self.add(scope, source_section, key, v) - def opts(self, section, key, default_value=None, mandatory=True): + def opts(self, section: types.Section, key: types.Key, default_value=None, mandatory=True): """ Resolves a configuration property. @@ -182,7 +182,7 @@ def opts(self, section, key, default_value=None, mandatory=True): else: raise exceptions.ConfigError(f"No value for mandatory configuration: section='{section}', key='{key}'") - def all_opts(self, section): + def all_opts(self, section: types.Section): """ Finds all options in a section and returns them in a dict. @@ -200,7 +200,7 @@ def all_opts(self, section): scopes_per_key[key] = scope return opts_in_section - def exists(self, section, key): + def exists(self, section: types.Section, key: types.Key): """ :param section: The configuration section. :param key: The configuration key. @@ -261,7 +261,7 @@ def _stored_config_version(self): return int(self.opts("meta", "config.version", default_value=0, mandatory=False)) # recursively find the most narrow scope for a key - def _resolve_scope(self, section, key, start_from=Scope.invocation): + def _resolve_scope(self, section: types.Section, key: types.Key, start_from=Scope.invocation): if self._k(start_from, section, key) in self._opts: return start_from elif start_from == Scope.application: @@ -270,7 +270,7 @@ def _resolve_scope(self, section, key, start_from=Scope.invocation): # continue search in the enclosing scope return self._resolve_scope(section, key, Scope(start_from.value - 1)) - def _k(self, scope, section, key): + def _k(self, scope, section: types.Section, key: types.Key): if scope is None or scope == Scope.application: return Scope.application, section, key else: diff --git a/esrally/driver/driver.py b/esrally/driver/driver.py index 4c6f6d562..661ba5835 100644 --- a/esrally/driver/driver.py +++ b/esrally/driver/driver.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Callable +from typing import Callable, Optional import thespian.actors @@ -44,6 +44,7 @@ paths, telemetry, track, + types, ) from esrally.client import delete_api_keys from esrally.driver import runner, scheduler @@ -61,7 +62,7 @@ class PrepareBenchmark: Initiates preparation steps for a benchmark. The benchmark should only be started after StartBenchmark is sent. """ - def __init__(self, config, track): + def __init__(self, config: types.Config, track): """ :param config: Rally internal configuration object. :param track: The track to use. @@ -79,7 +80,7 @@ class Bootstrap: Prompts loading of track code on new actors """ - def __init__(self, cfg, worker_id=None): + def __init__(self, cfg: types.Config, worker_id=None): self.config = cfg self.worker_id = worker_id @@ -102,13 +103,13 @@ class TrackPrepared: class StartTaskLoop: - def __init__(self, track_name, cfg): + def __init__(self, track_name, cfg: types.Config): self.track_name = track_name self.cfg = cfg class DoTask: - def __init__(self, task, cfg): + def __init__(self, task, cfg: types.Config): self.task = task self.cfg = cfg @@ -143,7 +144,7 @@ class StartWorker: Starts a worker. """ - def __init__(self, worker_id, config, track, client_allocations, client_contexts): + def __init__(self, worker_id, config: types.Config, track, client_allocations, client_contexts): """ :param worker_id: Unique (numeric) id of the worker. :param config: Rally internal configuration object. @@ -306,12 +307,12 @@ def receiveMsg_WakeupMessage(self, msg, sender): self.driver.update_progress_message() self.wakeupAfter(datetime.timedelta(seconds=DriverActor.WAKEUP_INTERVAL_SECONDS)) - def create_client(self, host, cfg, worker_id): + def create_client(self, host, cfg: types.Config, worker_id): worker = self.createActor(Worker, targetActorRequirements=self._requirements(host)) self.send(worker, Bootstrap(cfg, worker_id)) return worker - def start_worker(self, driver, worker_id, cfg, track, allocations, client_contexts=None): + def start_worker(self, driver, worker_id, cfg: types.Config, track, allocations, client_contexts=None): self.send(driver, StartWorker(worker_id, cfg, track, allocations, client_contexts)) def drive_at(self, driver, client_start_timestamp): @@ -333,7 +334,7 @@ def _requirements(self, host): else: return {"ip": host} - def prepare_track(self, hosts, cfg, track): + def prepare_track(self, hosts, cfg: types.Config, track): self.track = track self.logger.info("Starting prepare track process on hosts [%s]", hosts) self.children = [self._create_track_preparator(h) for h in hosts] @@ -373,7 +374,7 @@ def on_benchmark_complete(self, metrics): self.send(self.benchmark_actor, BenchmarkComplete(metrics)) -def load_local_config(coordinator_config): +def load_local_config(coordinator_config) -> types.Config: cfg = config.auto_load_local_config( coordinator_config, additional_sections=[ @@ -404,7 +405,7 @@ def __init__(self): self.task_preparation_actor = None self.logger = logging.getLogger(__name__) self.track_name = None - self.cfg = None + self.cfg: Optional[types.Config] = None @actor.no_retry("task executor") # pylint: disable=no-value-for-parameter def receiveMsg_StartTaskLoop(self, msg, sender): @@ -471,7 +472,7 @@ def __init__(self): self.status = self.Status.INITIALIZING self.children = [] self.tasks = [] - self.cfg = None + self.cfg: Optional[types.Config] = None self.data_root_dir = None self.track = None @@ -501,6 +502,7 @@ def receiveMsg_BenchmarkFailure(self, msg, sender): @actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter def receiveMsg_PrepareTrack(self, msg, sender): + assert self.cfg is not None self.data_root_dir = self.cfg.opts("benchmarks", "local.dataset.cache") tpr = TrackProcessorRegistry(self.cfg) self.track = msg.track @@ -520,6 +522,7 @@ def receiveMsg_PrepareTrack(self, msg, sender): ) def resume(self): + assert self.cfg is not None if not self.processors.empty(): self._seed_tasks(self.processors.get()) self.send_to_children_and_transition( @@ -536,6 +539,7 @@ def _create_task_executor(self): @actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter def receiveMsg_ReadyForWork(self, msg, task_execution_actor): + assert self.cfg is not None if self.tasks: next_task = self.tasks.pop() else: @@ -549,7 +553,7 @@ def receiveMsg_WorkerIdle(self, msg, sender): self.transition_when_all_children_responded(sender, msg, self.Status.PROCESSOR_RUNNING, self.Status.PROCESSOR_COMPLETE, self.resume) -def num_cores(cfg): +def num_cores(cfg: types.Config): return int(cfg.opts("system", "available.cores", mandatory=False, default_value=multiprocessing.cpu_count())) @@ -560,11 +564,11 @@ def num_cores(cfg): class ClientContext: client_id: int parent_worker_id: int - api_key: ApiKey = None + api_key: Optional[ApiKey] = None class Driver: - def __init__(self, driver_actor, config, es_client_factory_class=client.EsClientFactory): + def __init__(self, driver_actor, config: types.Config, es_client_factory_class=client.EsClientFactory): """ Coordinates all workers. It is technology-agnostic, i.e. it does not know anything about actors. To allow us to hook in an actor, we provide a ``target`` parameter which will be called whenever some event has occurred. The ``target`` can use this to send @@ -772,7 +776,11 @@ def start_benchmark(self): self.number_of_steps = len(allocator.join_points) - 1 self.tasks_per_join_point = allocator.tasks_per_joinpoint - self.logger.info("Benchmark consists of [%d] steps executed by [%d] clients.", self.number_of_steps, len(self.allocations)) + self.logger.info( + "Benchmark consists of [%d] steps executed by [%d] clients.", + self.number_of_steps, + len(self.allocations), # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints + ) # avoid flooding the log if there are too many clients if allocator.clients < 128: self.logger.debug("Allocation matrix:\n%s", "\n".join([str(a) for a in self.allocations])) @@ -1209,7 +1217,7 @@ def __init__(self): super().__init__() self.driver_actor = None self.worker_id = None - self.config = None + self.config: Optional[types.Config] = None self.track = None self.client_allocations = None self.client_contexts = None @@ -1239,6 +1247,7 @@ def receiveMsg_Bootstrap(self, msg, sender): @actor.no_retry("worker") # pylint: disable=no-value-for-parameter def receiveMsg_StartWorker(self, msg, sender): + assert self.config is not None self.logger.info("Worker[%d] is about to start.", msg.worker_id) self.on_error = self.config.opts("driver", "on.error") self.sample_queue_size = int(self.config.opts("reporting", "sample.queue.size", mandatory=False, default_value=1 << 20)) @@ -1343,6 +1352,7 @@ def receiveUnrecognizedMessage(self, msg, sender): self.logger.debug("Worker[%d] received unknown message [%s] (ignoring).", self.worker_id, str(msg)) def drive(self): + assert self.config is not None task_allocations = self.current_tasks_and_advance() # skip non-tasks in the task list while len(task_allocations) == 0: @@ -1568,7 +1578,7 @@ def _merge(self, *args): return result -def select_challenge(config, t): +def select_challenge(config: types.Config, t): challenge_name = config.opts("track", "challenge.name") selected_challenge = t.find_challenge_or_default(challenge_name) @@ -1742,7 +1752,7 @@ def map_task_throughput(self, current_samples): class AsyncIoAdapter: - def __init__(self, cfg, track, task_allocations, sampler, cancel, complete, abort_on_error, client_contexts, worker_id): + def __init__(self, cfg: types.Config, track, task_allocations, sampler, cancel, complete, abort_on_error, client_contexts, worker_id): self.cfg = cfg self.track = track self.task_allocations = task_allocations diff --git a/esrally/driver/runner.py b/esrally/driver/runner.py index 38791324f..b5c0bf843 100644 --- a/esrally/driver/runner.py +++ b/esrally/driver/runner.py @@ -23,18 +23,18 @@ import re import sys import time -import types from collections import Counter, OrderedDict from copy import deepcopy from enum import Enum from functools import total_ordering from io import BytesIO from os.path import commonprefix +from types import FunctionType from typing import List, Optional import ijson -from esrally import exceptions, track +from esrally import exceptions, track, types from esrally.utils import convert from esrally.utils.versions import Version @@ -43,7 +43,7 @@ __RUNNERS = {} -def register_default_runners(config=None): +def register_default_runners(config: Optional[types.Config] = None): register_runner(track.OperationType.Bulk, BulkIndex(), async_runner=True) register_runner(track.OperationType.ForceMerge, ForceMerge(), async_runner=True) register_runner(track.OperationType.IndexStats, Retry(IndicesStats()), async_runner=True) @@ -144,7 +144,7 @@ def register_runner(operation_type, runner, **kwargs): logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), str(operation_type)) cluster_aware_runner = _multi_cluster_runner(runner, str(runner)) # we'd rather use callable() but this will erroneously also classify a class as callable... - elif isinstance(runner, types.FunctionType): + elif isinstance(runner, FunctionType): if logger.isEnabledFor(logging.DEBUG): logger.debug("Registering runner function [%s] for [%s].", str(runner), str(operation_type)) cluster_aware_runner = _single_cluster_runner(runner, runner.__name__) @@ -926,7 +926,11 @@ async def _search_after_query(es, params): body["pit"] = {"id": pit_id, "keep_alive": "1m"} response = await self._raw_search(es, doc_type=None, index=index, body=body.copy(), params=request_params, headers=headers) - parsed, last_sort = self._search_after_extractor(response, bool(pit_op), results.get("hits")) + parsed, last_sort = self._search_after_extractor( + response, + bool(pit_op), + results.get("hits"), # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints + ) results["pages"] = page results["weight"] = page if results.get("hits") is None: @@ -983,7 +987,12 @@ async def _composite_agg(es, params): body_to_send = tree_copy_composite_agg(body, path_to_composite) response = await self._raw_search(es, doc_type=None, index=index, body=body_to_send, params=request_params, headers=headers) - parsed = self._composite_agg_extractor(response, bool(pit_op), path_to_composite, results.get("hits")) + parsed = self._composite_agg_extractor( + response, + bool(pit_op), + path_to_composite, + results.get("hits"), # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + ) results["pages"] = page results["weight"] = page if results.get("hits") is None: @@ -1252,7 +1261,8 @@ def __call__(self, response: BytesIO, get_point_in_time: bool, path_to_composite after_key = "aggregations." + (".".join(path_to_composite_agg)) + ".after_key" - parsed = parse(response, properties, None, [after_key]) + # TODO remove the below ignore when introducing type hints + parsed = parse(response, properties, None, [after_key]) # type: ignore[arg-type] if get_point_in_time and not parsed.get("pit_id"): raise exceptions.RallyAssertionError("Paginated query failure: pit_id was expected but not found in the response.") @@ -2553,7 +2563,7 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - CompositeContext.ctx.reset(self.token) + CompositeContext.ctx.reset(self.token) # type: ignore[arg-type] # TODO remove this ignore when introducing type hints return False @staticmethod diff --git a/esrally/mechanic/launcher.py b/esrally/mechanic/launcher.py index e905d61f7..6ef3ee7c8 100644 --- a/esrally/mechanic/launcher.py +++ b/esrally/mechanic/launcher.py @@ -21,7 +21,7 @@ import psutil -from esrally import exceptions, telemetry, time +from esrally import exceptions, telemetry, time, types from esrally.mechanic import cluster, java_resolver from esrally.utils import io, opts, process @@ -30,7 +30,7 @@ class DockerLauncher: # May download a Docker image and that can take some time PROCESS_WAIT_TIMEOUT_SECONDS = 10 * 60 - def __init__(self, cfg, clock=time.Clock): + def __init__(self, cfg: types.Config, clock=time.Clock): self.cfg = cfg self.clock = clock self.logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class ProcessLauncher: PROCESS_WAIT_TIMEOUT_SECONDS = 90.0 - def __init__(self, cfg, clock=time.Clock): + def __init__(self, cfg: types.Config, clock=time.Clock): self.cfg = cfg self._clock = clock self.logger = logging.getLogger(__name__) diff --git a/esrally/mechanic/mechanic.py b/esrally/mechanic/mechanic.py index 4ee35923b..597880c46 100644 --- a/esrally/mechanic/mechanic.py +++ b/esrally/mechanic/mechanic.py @@ -23,17 +23,18 @@ import sys import traceback from collections import defaultdict +from typing import Optional import thespian.actors -from esrally import PROGRAM_NAME, actor, config, exceptions, metrics, paths +from esrally import PROGRAM_NAME, actor, config, exceptions, metrics, paths, types from esrally.mechanic import launcher, provisioner, supplier, team from esrally.utils import console, net METRIC_FLUSH_INTERVAL_SECONDS = 30 -def build(cfg): +def build(cfg: types.Config): car, plugins = load_team(cfg, external=False) s = supplier.create(cfg, sources=True, distribution=False, car=car, plugins=plugins) @@ -41,7 +42,7 @@ def build(cfg): console.println(json.dumps(binaries, indent=2), force=True) -def download(cfg): +def download(cfg: types.Config): car, plugins = load_team(cfg, external=False) s = supplier.create(cfg, sources=False, distribution=True, car=car, plugins=plugins) @@ -49,7 +50,7 @@ def download(cfg): console.println(json.dumps(binaries, indent=2), force=True) -def install(cfg): +def install(cfg: types.Config): root_path = paths.install_root(cfg) car, plugins = load_team(cfg, external=False) @@ -92,7 +93,7 @@ def install(cfg): console.println(json.dumps({"installation-id": cfg.opts("system", "install.id")}, indent=2), force=True) -def start(cfg): +def start(cfg: types.Config): root_path = paths.install_root(cfg) race_id = cfg.opts("system", "race.id") # avoid double-launching - we expect that the node file is absent @@ -116,7 +117,7 @@ def start(cfg): _store_node_file(root_path, (nodes, race_id)) -def stop(cfg): +def stop(cfg: types.Config): root_path = paths.install_root(cfg) node_config = provisioner.load_node_configuration(root_path) if node_config.build_type == "tar": @@ -183,7 +184,7 @@ def _delete_node_file(root_path): class StartEngine: - def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, ip=None, port=None, node_id=None): + def __init__(self, cfg: types.Config, open_metrics_context, sources, distribution, external, docker, ip=None, port=None, node_id=None): self.cfg = cfg self.open_metrics_context = open_metrics_context self.sources = sources @@ -245,7 +246,20 @@ def __init__(self, reset_in_seconds): class StartNodes: - def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, all_node_ips, all_node_ids, ip, port, node_ids): + def __init__( + self, + cfg: types.Config, + open_metrics_context, + sources, + distribution, + external, + docker, + all_node_ips, + all_node_ids, + ip, + port, + node_ids, + ): self.cfg = cfg self.open_metrics_context = open_metrics_context self.sources = sources @@ -322,7 +336,7 @@ class MechanicActor(actor.RallyActor): def __init__(self): super().__init__() - self.cfg = None + self.cfg: Optional[types.Config] = None self.race_control = None self.cluster_launcher = None self.cluster = None @@ -356,6 +370,7 @@ def receiveMsg_StartEngine(self, msg, sender): self.logger.info("Received signal from race control to start engine.") self.race_control = sender self.cfg = msg.cfg + assert self.cfg is not None self.car, _ = load_team(self.cfg, msg.external) # TODO: This is implicitly set by #load_team() - can we gather this elsewhere? self.team_revision = self.cfg.opts("mechanic", "repository.revision") @@ -622,7 +637,7 @@ def receiveUnrecognizedMessage(self, msg, sender): ##################################################### -def load_team(cfg, external): +def load_team(cfg: types.Config, external): # externally provisioned clusters do not support cars / plugins if external: car = None @@ -637,7 +652,16 @@ def load_team(cfg, external): def create( - cfg, metrics_store, node_ip, node_http_port, all_node_ips, all_node_ids, sources=False, distribution=False, external=False, docker=False + cfg: types.Config, + metrics_store, + node_ip, + node_http_port, + all_node_ips, + all_node_ids, + sources=False, + distribution=False, + external=False, + docker=False, ): race_root_path = paths.race_root(cfg) node_ids = cfg.opts("provisioning", "node.ids", mandatory=False) @@ -681,7 +705,7 @@ class Mechanic: running the benchmark). """ - def __init__(self, cfg, metrics_store, supply, provisioners, launcher): + def __init__(self, cfg: types.Config, metrics_store, supply, provisioners, launcher): self.cfg = cfg self.preserve_install = cfg.opts("mechanic", "preserve.install") self.metrics_store = metrics_store diff --git a/esrally/mechanic/provisioner.py b/esrally/mechanic/provisioner.py index a805def9c..536102745 100644 --- a/esrally/mechanic/provisioner.py +++ b/esrally/mechanic/provisioner.py @@ -24,12 +24,12 @@ import jinja2 -from esrally import exceptions +from esrally import exceptions, types from esrally.mechanic import java_resolver, team from esrally.utils import console, convert, io, process -def local(cfg, car, plugins, ip, http_port, all_node_ips, all_node_names, target_root, node_name): +def local(cfg: types.Config, car, plugins, ip, http_port, all_node_ips, all_node_names, target_root, node_name): distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False) cluster_name = cfg.opts("mechanic", "cluster.name") @@ -47,7 +47,7 @@ def local(cfg, car, plugins, ip, http_port, all_node_ips, all_node_names, target return BareProvisioner(es_installer, plugin_installers, distribution_version=distribution_version) -def docker(cfg, car, ip, http_port, target_root, node_name): +def docker(cfg: types.Config, car, ip, http_port, target_root, node_name): distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False) cluster_name = cfg.opts("mechanic", "cluster.name") rally_root = cfg.opts("node", "rally.root") @@ -289,7 +289,8 @@ def install(self, binary): self.data_paths = self._data_paths() def delete_pre_bundled_configuration(self): - config_path = os.path.join(self.es_home_path, "config") + # TODO remove the below ignore when introducing type hints + config_path = os.path.join(self.es_home_path, "config") # type: ignore[arg-type] self.logger.info("Deleting pre-bundled Elasticsearch configuration at [%s]", config_path) shutil.rmtree(config_path) @@ -342,7 +343,8 @@ def _data_paths(self): else: raise exceptions.SystemSetupError("Expected [data_paths] to be either a string or a list but was [%s]." % type(data_paths)) else: - return [os.path.join(self.es_home_path, "data")] + # TODO remove the below ignore when introducing type hints + return [os.path.join(self.es_home_path, "data")] # type: ignore[arg-type] class PluginInstaller: diff --git a/esrally/mechanic/supplier.py b/esrally/mechanic/supplier.py index f3ce39e68..d40773cbc 100644 --- a/esrally/mechanic/supplier.py +++ b/esrally/mechanic/supplier.py @@ -25,7 +25,7 @@ import urllib.error import docker -from esrally import PROGRAM_NAME, exceptions, paths +from esrally import PROGRAM_NAME, exceptions, paths, types from esrally.exceptions import BuildError, SystemSetupError from esrally.utils import console, convert, git, io, jvm, net, process, sysstats @@ -33,7 +33,7 @@ DEFAULT_PLUGIN_BRANCH = "main" -def create(cfg, sources, distribution, car, plugins=None): +def create(cfg: types.Config, sources, distribution, car, plugins=None): logger = logging.getLogger(__name__) if plugins is None: plugins = [] @@ -116,7 +116,8 @@ def create(cfg, sources, distribution, car, plugins=None): repo = DistributionRepository( name=cfg.opts("mechanic", "distribution.repository"), distribution_config=dist_cfg, template_renderer=template_renderer ) - suppliers.append(ElasticsearchDistributionSupplier(repo, es_version, distributions_root)) + # TODO remove the below ignore when introducing type hints + suppliers.append(ElasticsearchDistributionSupplier(repo, es_version, distributions_root)) # type: ignore[arg-type] for plugin in plugins: if plugin.moved_to_module: @@ -154,11 +155,12 @@ def create(cfg, sources, distribution, car, plugins=None): if caching_enabled: plugin_file_resolver = PluginFileNameResolver(plugin.name, plugin_version) plugin_supplier = CachedSourceSupplier(source_distributions_root, plugin_supplier, plugin_file_resolver) - suppliers.append(plugin_supplier) + suppliers.append(plugin_supplier) # type: ignore[arg-type] # TODO remove this ignore when introducing type hints else: logger.info("Adding plugin distribution supplier for [%s].", plugin.name) assert repo is not None, "Cannot benchmark plugin %s from a distribution version but Elasticsearch from sources" % plugin.name - suppliers.append(PluginDistributionSupplier(repo, plugin)) + # TODO remove the below ignore when introducing type hints + suppliers.append(PluginDistributionSupplier(repo, plugin)) # type: ignore[arg-type] return CompositeSupplier(suppliers) @@ -221,7 +223,7 @@ def _supply_requirements(sources, distribution, plugins, revisions, distribution return supply_requirements -def _src_dir(cfg, mandatory=True): +def _src_dir(cfg: types.Config, mandatory=True): # Don't let this spread across the whole module try: return cfg.opts("node", "src.root.dir", mandatory=mandatory) @@ -457,7 +459,7 @@ def resolve_build_jdk_major(cls, src_dir: str) -> int: else: major_version = 17 logger.info("Unable to resolve build JDK major release version. Defaulting to version [%s].", major_version) - return int(major_version) + return int(major_version) # type: ignore[arg-type] # TODO remove this ignore when introducing type hints def resolve_binary(self): try: diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 0e9108c95..6ea7cabc1 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -22,7 +22,7 @@ import tabulate -from esrally import PROGRAM_NAME, config, exceptions +from esrally import PROGRAM_NAME, config, exceptions, types from esrally.utils import console, io, modules, repo TEAM_FORMAT_VERSION = 1 @@ -35,7 +35,7 @@ def _path_for(team_root_path, team_member_type): return root_path -def list_cars(cfg): +def list_cars(cfg: types.Config): loader = CarLoader(team_path(cfg)) cars = [] for name in loader.car_names(): @@ -85,7 +85,7 @@ def __init__(self, root_path, entry_point): return Car(name, root_path, all_config_paths, variables) -def list_plugins(cfg): +def list_plugins(cfg: types.Config): plugins = PluginLoader(team_path(cfg)).plugins() if plugins: console.println("Available Elasticsearch plugins:\n") @@ -94,8 +94,8 @@ def list_plugins(cfg): console.println("No Elasticsearch plugins are available.\n") -def load_plugin(repo, name, config, plugin_params=None): - return PluginLoader(repo).load_plugin(name, config, plugin_params) +def load_plugin(repo, name, config_names, plugin_params=None): + return PluginLoader(repo).load_plugin(name, config_names, plugin_params) def load_plugins(repo, plugin_names, plugin_params=None): @@ -116,7 +116,7 @@ def name_and_config(p): return plugins -def team_path(cfg): +def team_path(cfg: types.Config): root_path = cfg.opts("mechanic", "team.path", mandatory=False) if root_path: return root_path @@ -125,7 +125,8 @@ def team_path(cfg): repo_name = cfg.opts("mechanic", "repository.name") repo_revision = cfg.opts("mechanic", "repository.revision") offline = cfg.opts("system", "offline.mode") - remote_url = cfg.opts("teams", "%s.url" % repo_name, mandatory=False) + # TODO remove the below ignore when introducing LiteralString on Python 3.11+ + remote_url = cfg.opts("teams", "%s.url" % repo_name, mandatory=False) # type: ignore[arg-type] root = cfg.opts("node", "root.dir") team_repositories = cfg.opts("mechanic", "team.repository.dir") teams_dir = os.path.join(root, team_repositories) diff --git a/esrally/metrics.py b/esrally/metrics.py index 9884aaf9f..61e00d4dd 100644 --- a/esrally/metrics.py +++ b/esrally/metrics.py @@ -26,14 +26,13 @@ import random import statistics import sys -import time import uuid import zlib from enum import Enum, IntEnum import tabulate -from esrally import client, config, exceptions, paths, time, version +from esrally import client, config, exceptions, paths, time, types, version from esrally.utils import console, convert, io, versions @@ -224,7 +223,7 @@ class EsClientFactory: Abstracts how the Elasticsearch client is created. Intended for testing. """ - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self._config = cfg host = self._config.opts("reporting", "datastore.host") port = self._config.opts("reporting", "datastore.port") @@ -284,7 +283,7 @@ class IndexTemplateProvider: Abstracts how the Rally index template is retrieved. Intended for testing. """ - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self._config = cfg self._number_of_shards = self._config.opts("reporting", "datastore.number_of_shards", default_value=None, mandatory=False) self._number_of_replicas = self._config.opts("reporting", "datastore.number_of_replicas", default_value=None, mandatory=False) @@ -343,7 +342,7 @@ def calculate_system_results(store, node_name): return calc() -def metrics_store(cfg, read_only=True, track=None, challenge=None, car=None, meta_info=None): +def metrics_store(cfg: types.Config, read_only=True, track=None, challenge=None, car=None, meta_info=None): """ Creates a proper metrics store based on the current configuration. @@ -363,7 +362,7 @@ def metrics_store(cfg, read_only=True, track=None, challenge=None, car=None, met return store -def metrics_store_class(cfg): +def metrics_store_class(cfg: types.Config): if cfg.opts("reporting", "datastore.type") == "elasticsearch": return EsMetricsStore else: @@ -380,7 +379,7 @@ class MetricsStore: Abstract metrics store """ - def __init__(self, cfg, clock=time.Clock, meta_info=None): + def __init__(self, cfg: types.Config, clock=time.Clock, meta_info=None): """ Creates a new metrics store. @@ -872,7 +871,7 @@ class EsMetricsStore(MetricsStore): def __init__( self, - cfg, + cfg: types.Config, client_factory_class=EsClientFactory, index_template_provider_class=IndexTemplateProvider, clock=time.Clock, @@ -1129,7 +1128,7 @@ def __str__(self): class InMemoryMetricsStore(MetricsStore): - def __init__(self, cfg, clock=time.Clock, meta_info=None): + def __init__(self, cfg: types.Config, clock=time.Clock, meta_info=None): """ Creates a new metrics store. @@ -1261,7 +1260,7 @@ def __str__(self): return "in-memory metrics store" -def race_store(cfg): +def race_store(cfg: types.Config): """ Creates a proper race store based on the current configuration. :param cfg: Config object. Mandatory. @@ -1276,7 +1275,7 @@ def race_store(cfg): return FileRaceStore(cfg) -def results_store(cfg): +def results_store(cfg: types.Config): """ Creates a proper race store based on the current configuration. :param cfg: Config object. Mandatory. @@ -1291,23 +1290,23 @@ def results_store(cfg): return NoopResultsStore() -def delete_race(cfg): +def delete_race(cfg: types.Config): race_store(cfg).delete_race() -def delete_annotation(cfg): +def delete_annotation(cfg: types.Config): race_store(cfg).delete_annotation() -def list_annotations(cfg): +def list_annotations(cfg: types.Config): race_store(cfg).list_annotations() -def add_annotation(cfg): +def add_annotation(cfg: types.Config): race_store(cfg).add_annotation() -def list_races(cfg): +def list_races(cfg: types.Config): def format_dict(d): if d: items = sorted(d.items()) @@ -1358,7 +1357,7 @@ def format_dict(d): console.println("No recent races found.") -def create_race(cfg, track, challenge, track_revision=None): +def create_race(cfg: types.Config, track, challenge, track_revision=None): car = cfg.opts("mechanic", "car.names") environment = cfg.opts("system", "env.name") race_id = cfg.opts("system", "race.id") @@ -1566,7 +1565,7 @@ def from_dict(cls, d): class RaceStore: - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.cfg = cfg self.environment_name = cfg.opts("system", "env.name") @@ -1728,7 +1727,7 @@ def _to_races(self, results): class EsRaceStore(RaceStore): INDEX_PREFIX = "rally-races-" - def __init__(self, cfg, client_factory_class=EsClientFactory, index_template_provider_class=IndexTemplateProvider): + def __init__(self, cfg: types.Config, client_factory_class=EsClientFactory, index_template_provider_class=IndexTemplateProvider): """ Creates a new metrics store. @@ -1957,7 +1956,7 @@ class EsResultsStore: INDEX_PREFIX = "rally-results-" - def __init__(self, cfg, client_factory_class=EsClientFactory, index_template_provider_class=IndexTemplateProvider): + def __init__(self, cfg: types.Config, client_factory_class=EsClientFactory, index_template_provider_class=IndexTemplateProvider): """ Creates a new results store. diff --git a/esrally/paths.py b/esrally/paths.py index f3c615946..a8b099d4c 100644 --- a/esrally/paths.py +++ b/esrally/paths.py @@ -16,6 +16,8 @@ # under the License. import os +from esrally import types + def rally_confdir(): default_home = os.path.expanduser("~") @@ -26,17 +28,17 @@ def rally_root(): return os.path.dirname(os.path.realpath(__file__)) -def races_root(cfg): +def races_root(cfg: types.Config): return os.path.join(cfg.opts("node", "root.dir"), "races") -def race_root(cfg, race_id=None): +def race_root(cfg: types.Config, race_id=None): if not race_id: race_id = cfg.opts("system", "race.id") return os.path.join(races_root(cfg), race_id) -def install_root(cfg=None): +def install_root(cfg: types.Config): install_id = cfg.opts("system", "install.id") return os.path.join(races_root(cfg), install_id) diff --git a/esrally/racecontrol.py b/esrally/racecontrol.py index d5fdcef8b..5abeacef7 100644 --- a/esrally/racecontrol.py +++ b/esrally/racecontrol.py @@ -19,6 +19,7 @@ import logging import os import sys +from typing import Optional import tabulate import thespian.actors @@ -35,6 +36,7 @@ metrics, reporter, track, + types, version, ) from esrally.utils import console, opts, versions @@ -68,12 +70,12 @@ def __init__(self, name, description, target, stable=True): self.stable = stable pipelines[name] = self - def __call__(self, cfg): + def __call__(self, cfg: types.Config): self.target(cfg) class Setup: - def __init__(self, cfg, sources=False, distribution=False, external=False, docker=False): + def __init__(self, cfg: types.Config, sources=False, distribution=False, external=False, docker=False): self.cfg = cfg self.sources = sources self.distribution = distribution @@ -88,7 +90,7 @@ class Success: class BenchmarkActor(actor.RallyActor): def __init__(self): super().__init__() - self.cfg = None + self.cfg: Optional[types.Config] = None self.start_sender = None self.mechanic = None self.main_driver = None @@ -107,6 +109,7 @@ def receiveUnrecognizedMessage(self, msg, sender): def receiveMsg_Setup(self, msg, sender): self.start_sender = sender self.cfg = msg.cfg + assert self.cfg is not None self.coordinator = BenchmarkCoordinator(msg.cfg) self.coordinator.setup(sources=msg.sources) self.logger.info("Asking mechanic to start the engine.") @@ -114,12 +117,18 @@ def receiveMsg_Setup(self, msg, sender): self.send( self.mechanic, mechanic.StartEngine( - self.cfg, self.coordinator.metrics_store.open_context, msg.sources, msg.distribution, msg.external, msg.docker + self.cfg, + self.coordinator.metrics_store.open_context, + msg.sources, + msg.distribution, + msg.external, + msg.docker, ), ) @actor.no_retry("race control") # pylint: disable=no-value-for-parameter def receiveMsg_EngineStarted(self, msg, sender): + assert self.cfg is not None self.logger.info("Mechanic has started engine successfully.") self.coordinator.race.team_revision = msg.team_revision self.main_driver = self.createActor(driver.DriverActor, targetActorRequirements={"coordinator": True}) @@ -167,7 +176,7 @@ def receiveMsg_EngineStopped(self, msg, sender): class BenchmarkCoordinator: - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.logger = logging.getLogger(__name__) self.cfg = cfg self.race = None @@ -276,7 +285,7 @@ def on_benchmark_complete(self, new_metrics): self.metrics_store.close() -def race(cfg, sources=False, distribution=False, external=False, docker=False): +def race(cfg: types.Config, sources=False, distribution=False, external=False, docker=False): logger = logging.getLogger(__name__) # at this point an actor system has to run and we should only join actor_system = actor.bootstrap_actor_system(try_join=True) @@ -304,7 +313,7 @@ def race(cfg, sources=False, distribution=False, external=False, docker=False): actor_system.tell(benchmark_actor, thespian.actors.ActorExitRequest()) -def set_default_hosts(cfg, host="127.0.0.1", port=9200): +def set_default_hosts(cfg: types.Config, host="127.0.0.1", port=9200): logger = logging.getLogger(__name__) configured_hosts = cfg.opts("client", "hosts") if len(configured_hosts.default) != 0: @@ -316,26 +325,26 @@ def set_default_hosts(cfg, host="127.0.0.1", port=9200): # Poor man's curry -def from_sources(cfg): +def from_sources(cfg: types.Config): port = cfg.opts("provisioning", "node.http.port") set_default_hosts(cfg, port=port) return race(cfg, sources=True) -def from_distribution(cfg): +def from_distribution(cfg: types.Config): port = cfg.opts("provisioning", "node.http.port") set_default_hosts(cfg, port=port) return race(cfg, distribution=True) -def benchmark_only(cfg): +def benchmark_only(cfg: types.Config): set_default_hosts(cfg) # We'll use a special car name for external benchmarks. cfg.add(config.Scope.benchmark, "mechanic", "car.names", ["external"]) return race(cfg, external=True) -def docker(cfg): +def docker(cfg: types.Config): set_default_hosts(cfg) return race(cfg, docker=True) @@ -361,7 +370,7 @@ def list_pipelines(): console.println(tabulate.tabulate(available_pipelines(), headers=["Name", "Description"])) -def run(cfg): +def run(cfg: types.Config): logger = logging.getLogger(__name__) name = cfg.opts("race", "pipeline") race_id = cfg.opts("system", "race.id") diff --git a/esrally/rally.py b/esrally/rally.py index c98c59db6..738826f78 100644 --- a/esrally/rally.py +++ b/esrally/rally.py @@ -45,6 +45,7 @@ reporter, telemetry, track, + types, version, ) from esrally.mechanic import mechanic, team @@ -850,7 +851,7 @@ def add_track_source(subparser): return parser -def dispatch_list(cfg): +def dispatch_list(cfg: types.Config): what = cfg.opts("system", "list.config.option") if what == "telemetry": telemetry.list_telemetry() @@ -870,7 +871,7 @@ def dispatch_list(cfg): raise exceptions.SystemSetupError("Cannot list unknown configuration option [%s]" % what) -def dispatch_add(cfg): +def dispatch_add(cfg: types.Config): what = cfg.opts("system", "add.config.option") if what == "annotation": metrics.add_annotation(cfg) @@ -878,7 +879,7 @@ def dispatch_add(cfg): raise exceptions.SystemSetupError("Cannot list unknown configuration option [%s]" % what) -def dispatch_delete(cfg): +def dispatch_delete(cfg: types.Config): what = cfg.opts("system", "delete.config.option") if what == "race": metrics.delete_race(cfg) @@ -901,7 +902,7 @@ def print_help_on_errors(): ) -def race(cfg, kill_running_processes=False): +def race(cfg: types.Config, kill_running_processes=False): logger = logging.getLogger(__name__) if kill_running_processes: @@ -931,7 +932,7 @@ def race(cfg, kill_running_processes=False): with_actor_system(racecontrol.run, cfg) -def with_actor_system(runnable, cfg): +def with_actor_system(runnable, cfg: types.Config): logger = logging.getLogger(__name__) already_running = actor.actor_system_already_running() logger.info("Actor system already running locally? [%s]", str(already_running)) @@ -1005,12 +1006,12 @@ def with_actor_system(runnable, cfg): ) -def configure_telemetry_params(args, cfg): +def configure_telemetry_params(args, cfg: types.Config): cfg.add(config.Scope.applicationOverride, "telemetry", "devices", opts.csv_to_list(args.telemetry)) cfg.add(config.Scope.applicationOverride, "telemetry", "params", opts.to_dict(args.telemetry_params)) -def configure_track_params(arg_parser, args, cfg, command_requires_track=True): +def configure_track_params(arg_parser, args, cfg: types.Config, command_requires_track=True): cfg.add(config.Scope.applicationOverride, "track", "repository.revision", args.track_revision) # We can assume here that if a track-path is given, the user did not specify a repository either (although argparse sets it to # its default value) @@ -1037,7 +1038,7 @@ def configure_track_params(arg_parser, args, cfg, command_requires_track=True): cfg.add(config.Scope.applicationOverride, "track", "exclude.tasks", opts.csv_to_list(args.exclude_tasks)) -def configure_mechanic_params(args, cfg, command_requires_car=True): +def configure_mechanic_params(args, cfg: types.Config, command_requires_car=True): if args.team_path: cfg.add(config.Scope.applicationOverride, "mechanic", "team.path", os.path.abspath(io.normalize_path(args.team_path))) cfg.add(config.Scope.applicationOverride, "mechanic", "repository.name", None) @@ -1057,7 +1058,7 @@ def configure_mechanic_params(args, cfg, command_requires_car=True): cfg.add(config.Scope.applicationOverride, "mechanic", "car.params", opts.to_dict(args.car_params)) -def configure_connection_params(arg_parser, args, cfg): +def configure_connection_params(arg_parser, args, cfg: types.Config): # Also needed by mechanic (-> telemetry) - duplicate by module? target_hosts = opts.TargetHosts(args.target_hosts) cfg.add(config.Scope.applicationOverride, "client", "hosts", target_hosts) @@ -1067,14 +1068,14 @@ def configure_connection_params(arg_parser, args, cfg): arg_parser.error("--target-hosts and --client-options must define the same keys for multi cluster setups.") -def configure_reporting_params(args, cfg): +def configure_reporting_params(args, cfg: types.Config): cfg.add(config.Scope.applicationOverride, "reporting", "format", args.report_format) cfg.add(config.Scope.applicationOverride, "reporting", "values", args.show_in_report) cfg.add(config.Scope.applicationOverride, "reporting", "output.path", args.report_file) cfg.add(config.Scope.applicationOverride, "reporting", "numbers.align", args.report_numbers_align) -def dispatch_sub_command(arg_parser, args, cfg): +def dispatch_sub_command(arg_parser, args, cfg: types.Config): sub_command = args.subcommand cfg.add(config.Scope.application, "system", "quiet.mode", args.quiet) diff --git a/esrally/reporter.py b/esrally/reporter.py index c1ee24664..73bd7bef9 100644 --- a/esrally/reporter.py +++ b/esrally/reporter.py @@ -23,7 +23,7 @@ import tabulate -from esrally import exceptions, metrics +from esrally import exceptions, metrics, types from esrally.utils import console, convert from esrally.utils import io as rio @@ -38,11 +38,11 @@ """ -def summarize(results, cfg): +def summarize(results, cfg: types.Config): SummaryReporter(results, cfg).report() -def compare(cfg, baseline_id, contender_id): +def compare(cfg: types.Config, baseline_id, contender_id): if not baseline_id or not contender_id: raise exceptions.SystemSetupError("compare needs baseline and a contender") race_store = metrics.race_store(cfg) @@ -115,7 +115,7 @@ def total_disk_usage_per_field(stats): class SummaryReporter: - def __init__(self, results, config): + def __init__(self, results, config: types.Config): self.results = results self.report_file = config.opts("reporting", "output.path") self.report_format = config.opts("reporting", "format") @@ -359,7 +359,7 @@ def _line(self, k, task, v, unit, converter=lambda x: x, force=False): class ComparisonReporter: - def __init__(self, config): + def __init__(self, config: types.Config): self.report_file = config.opts("reporting", "output.path") self.report_format = config.opts("reporting", "format") self.numbers_align = config.opts("reporting", "numbers.align", mandatory=False, default_value="decimal") diff --git a/esrally/track/loader.py b/esrally/track/loader.py index 22cf45a83..8da4ffb7b 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -24,7 +24,7 @@ import sys import tempfile import urllib.error -from typing import Callable, Generator, Tuple +from typing import Callable, Generator, Optional, Tuple import jinja2 import jinja2.exceptions @@ -32,7 +32,7 @@ import tabulate from jinja2 import meta -from esrally import PROGRAM_NAME, config, exceptions, paths, time, version +from esrally import PROGRAM_NAME, config, exceptions, paths, time, types, version from esrally.track import params, track from esrally.track.track import Parallel from esrally.utils import ( @@ -89,7 +89,7 @@ def on_prepare_track(self, track: track.Track, data_root_dir: str) -> Generator[ class TrackProcessorRegistry: - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.required_processors = [TaskFilterTrackProcessor(cfg), ServerlessFilterTrackProcessor(cfg), TestModeTrackProcessor(cfg)] self.track_processors = [] self.offline = cfg.opts("system", "offline.mode") @@ -119,7 +119,7 @@ def processors(self): return [*self.required_processors, *self.track_processors] -def tracks(cfg): +def tracks(cfg: types.Config): """ Lists all known tracks. Note that users can specify a distribution version so if different tracks are available for @@ -132,7 +132,7 @@ def tracks(cfg): return [_load_single_track(cfg, repo, track_name) for track_name in repo.track_names] -def list_tracks(cfg): +def list_tracks(cfg: types.Config): available_tracks = tracks(cfg) only_auto_generated_challenges = all(t.default_challenge.auto_generated for t in available_tracks) @@ -159,7 +159,7 @@ def list_tracks(cfg): console.println(tabulate.tabulate(tabular_data=data, headers=headers)) -def track_info(cfg): +def track_info(cfg: types.Config): def format_task(t, indent="", num="", suffix=""): msg = f"{indent}{num}{str(t)}" if t.clients > 1: @@ -203,7 +203,7 @@ def challenge_info(c): console.println("") -def load_track(cfg, install_dependencies=False): +def load_track(cfg: types.Config, install_dependencies=False): """ Loads a track @@ -230,7 +230,7 @@ def _install_dependencies(dependencies): raise exceptions.SystemSetupError(f"Installation of track dependencies failed. See [{install_log.name}] for more information.") -def _load_single_track(cfg, track_repository, track_name, install_dependencies=False): +def _load_single_track(cfg: types.Config, track_repository, track_name, install_dependencies=False): try: track_dir = track_repository.track_dir(track_name) reader = TrackFileReader(cfg) @@ -254,7 +254,7 @@ def _load_single_track(cfg, track_repository, track_name, install_dependencies=F def load_track_plugins( - cfg, + cfg: types.Config, track_name, register_runner=None, register_scheduler=None, @@ -285,7 +285,7 @@ def load_track_plugins( return False -def set_absolute_data_path(cfg, t): +def set_absolute_data_path(cfg: types.Config, t): """ Sets an absolute data path on all document files in this track. Internally we store only relative paths in the track as long as possible as the data root directory may be different on each host. In the end we need to have an absolute path though when we want to read the @@ -312,18 +312,18 @@ def first_existing(root_dirs, f): document_set.document_file = first_existing(data_root, document_set.document_file) -def is_simple_track_mode(cfg): +def is_simple_track_mode(cfg: types.Config): return cfg.exists("track", "track.path") -def track_path(cfg): +def track_path(cfg: types.Config): repo = track_repo(cfg) track_name = repo.track_name track_dir = repo.track_dir(track_name) return track_dir -def track_repo(cfg, fetch=True, update=True): +def track_repo(cfg: types.Config, fetch=True, update=True): if is_simple_track_mode(cfg): track_path = cfg.opts("track", "track.path") return SimpleTrackRepository(track_path) @@ -331,7 +331,7 @@ def track_repo(cfg, fetch=True, update=True): return GitTrackRepository(cfg, fetch, update) -def data_dir(cfg, track_name, corpus_name): +def data_dir(cfg: types.Config, track_name, corpus_name): """ Determines potential data directories for the provided track and corpus name. @@ -352,14 +352,15 @@ def data_dir(cfg, track_name, corpus_name): class GitTrackRepository: - def __init__(self, cfg, fetch, update, repo_class=repo.RallyRepository): + def __init__(self, cfg: types.Config, fetch, update, repo_class=repo.RallyRepository): # current track name (if any) self.track_name = cfg.opts("track", "track.name", mandatory=False) distribution_version = cfg.opts("mechanic", "distribution.version", mandatory=False) repo_name = cfg.opts("track", "repository.name") repo_revision = cfg.opts("track", "repository.revision", mandatory=False) offline = cfg.opts("system", "offline.mode") - remote_url = cfg.opts("tracks", "%s.url" % repo_name, mandatory=False) + # TODO remove the below ignore when introducing LiteralString on Python 3.11+ + remote_url = cfg.opts("tracks", "%s.url" % repo_name, mandatory=False) # type: ignore[arg-type] root = cfg.opts("node", "root.dir") track_repositories = cfg.opts("benchmarks", "track.repository.dir") tracks_dir = os.path.join(root, track_repositories) @@ -451,13 +452,13 @@ class DefaultTrackPreparator(TrackProcessor): def __init__(self): super().__init__() # just declare here, will be injected later - self.cfg = None + self.cfg: Optional[types.Config] = None self.downloader = None self.decompressor = None self.track = None @staticmethod - def prepare_docs(cfg, track, corpus, preparator): + def prepare_docs(cfg: types.Config, track, corpus, preparator): for document_set in corpus.documents: if document_set.is_bulk: data_root = data_dir(cfg, track.name, corpus.name) @@ -842,7 +843,7 @@ def relative_glob(start, f): class TaskFilterTrackProcessor(TrackProcessor): - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.logger = logging.getLogger(__name__) include_tasks = cfg.opts("track", "include.tasks", mandatory=False) exclude_tasks = cfg.opts("track", "exclude.tasks", mandatory=False) @@ -864,9 +865,11 @@ def _filters_from_filtered_tasks(self, filtered_tasks): filters.append(track.TaskNameFilter(spec[0])) elif len(spec) == 2: if spec[0] == "type": - filters.append(track.TaskOpTypeFilter(spec[1])) + # TODO remove the below ignore when introducing type hints + filters.append(track.TaskOpTypeFilter(spec[1])) # type: ignore[arg-type] elif spec[0] == "tag": - filters.append(track.TaskTagFilter(spec[1])) + # TODO remove the below ignore when introducing type hints + filters.append(track.TaskTagFilter(spec[1])) # type: ignore[arg-type] else: raise exceptions.SystemSetupError(f"Invalid format for filtered tasks: [{t}]. Expected [type] but got [{spec[0]}].") else: @@ -907,7 +910,7 @@ def on_after_load_track(self, track): class ServerlessFilterTrackProcessor(TrackProcessor): - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.logger = logging.getLogger(__name__) self.serverless_mode = convert.to_bool(cfg.opts("driver", "serverless.mode", mandatory=False, default_value=False)) self.serverless_operator = convert.to_bool(cfg.opts("driver", "serverless.operator", mandatory=False, default_value=False)) @@ -953,7 +956,7 @@ def on_after_load_track(self, track): class TestModeTrackProcessor(TrackProcessor): - def __init__(self, cfg): + def __init__(self, cfg: types.Config): self.test_mode_enabled = cfg.opts("track", "test.mode.enabled", mandatory=False, default_value=False) self.logger = logging.getLogger(__name__) @@ -1057,7 +1060,7 @@ class TrackFileReader: Creates a track from a track file. """ - def __init__(self, cfg): + def __init__(self, cfg: types.Config): track_schema_file = os.path.join(cfg.opts("node", "rally.root"), "resources", "track-schema.json") with open(track_schema_file, encoding="utf-8") as f: self.track_schema = json.loads(f.read()) diff --git a/esrally/tracker/tracker.py b/esrally/tracker/tracker.py index c1b87568f..b35bcdbaa 100644 --- a/esrally/tracker/tracker.py +++ b/esrally/tracker/tracker.py @@ -21,7 +21,7 @@ from elastic_transport import ApiError, TransportError from jinja2 import Environment, FileSystemLoader -from esrally import PROGRAM_NAME +from esrally import PROGRAM_NAME, types from esrally.client import factory from esrally.tracker import corpus, index from esrally.utils import console, io @@ -68,7 +68,7 @@ def extract_mappings_and_corpora(client, output_path, indices_to_extract): return indices, corpora -def create_track(cfg): +def create_track(cfg: types.Config): logger = logging.getLogger(__name__) track_name = cfg.opts("track", "track.name") diff --git a/esrally/types.py b/esrally/types.py new file mode 100644 index 000000000..fbc792746 --- /dev/null +++ b/esrally/types.py @@ -0,0 +1,177 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Literal, Protocol, TypeVar + +Section = Literal[ + "benchmarks", + "client", + "defaults", + "distributions", + "driver", + "generator", + "mechanic", + "meta", + "no_copy", + "node", + "provisioning", + "race", + "reporting", + "source", + "system", + "teams", + "telemetry", + "tests", + "track", + "tracks", + "unit-test", +] +Key = Literal[ + "add.chart_name", + "add.chart_type", + "add.config.option", + "add.message", + "add.race_timestamp", + "admin.dry_run", + "admin.track", + "assertions", + "async.debug", + "available.cores", + "build.type", + "cache", + "cache.days", + "car.names", + "car.params", + "car.plugins", + "challenge.name", + "challenge.root.dir", + "cluster.name", + "config.version", + "data_streams", + "datastore.host", + "datastore.number_of_replicas", + "datastore.number_of_shards", + "datastore.password", + "datastore.port", + "datastore.probe.cluster_version", + "datastore.secure", + "datastore.ssl.certificate_authorities", + "datastore.ssl.verification_mode", + "datastore.type", + "datastore.user", + "delete.config.option", + "delete.id", + "devices", + "distribution.dir", + "distribution.flavor", + "distribution.repository", + "distribution.version", + "elasticsearch.src.subdir", + "env.name", + "exclude.tasks", + "format", + "hosts", + "include.tasks", + "indices", + "install.id", + "list.config.option", + "list.from_date", + "list.max_results", + "list.races.benchmark_name", + "list.to_date", + "load_driver_hosts", + "local.dataset.cache", + "master.nodes", + "metrics.log.dir", + "metrics.request.downsample.factor", + "metrics.url", + "network.host", + "network.http.port", + "node.http.port", + "node.ids", + "node.name", + "node.name.prefix", + "numbers.align", + "offline.mode", + "on.error", + "options", + "other.key", + "output.path", + "output.processingtime", + "params", + "passenv", + "pipeline", + "plugin.community-plugin.src.dir", + "plugin.community-plugin.src.subdir", + "plugin.params", + "preserve.install", + "preserve_benchmark_candidate", + "private.url", + "profiling", + "quiet.mode", + "race.id", + "rally.cwd", + "rally.root", + "release.cache", + "release.url", + "remote.benchmarking.supported", + "remote.repo.url", + "repository.name", + "repository.revision", + "root.dir", + "runtime.jdk", + "sample.key", + "sample.property", + "sample.queue.size", + "seed.hosts", + "serverless.mode", + "serverless.operator", + "skip.rest.api.check", + "snapshot.cache", + "source.build.method", + "source.revision", + "src.root.dir", + "target.arch", + "target.os", + "team.path", + "team.repository.dir", + "test.mode.enabled", + "time.start", + "track.name", + "track.path", + "track.repository.dir", + "user.tags", + "values", +] +_Config = TypeVar("_Config", bound="Config") + + +class Config(Protocol): + def add(self, scope, section: Section, key: Key, value: Any) -> None: + ... + + def add_all(self, source: _Config, section: Section) -> None: + ... + + def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: + ... + + def all_opts(self, section: Section) -> dict: + ... + + def exists(self, section: Section, key: Key) -> bool: + ... diff --git a/esrally/utils/net.py b/esrally/utils/net.py index c5cea1194..3db34c893 100644 --- a/esrally/utils/net.py +++ b/esrally/utils/net.py @@ -194,7 +194,13 @@ def _download_http(url, local_path, expected_size_in_bytes=None, progress_indica "GET", url, preload_content=False, enforce_content_length=True, retries=10, timeout=urllib3.Timeout(connect=45, read=240) ) as r, open(local_path, "wb") as out_file: if r.status > 299: - raise urllib.error.HTTPError(url, r.status, "", None, None) + raise urllib.error.HTTPError( + url, + r.status, + "", + None, # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints + None, + ) try: size_from_content_header = int(r.getheader("Content-Length", "")) if expected_size_in_bytes is None: diff --git a/pyproject.toml b/pyproject.toml index ffefdc779..080ef9d79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,3 +139,29 @@ junit_family = "xunit2" junit_logging = "all" asyncio_mode = "strict" xfail_strict = true + +# With rare exceptions, Rally does not use type hints. The intention of the +# following largely reduced mypy configuration scope is verification of argument +# types in config.Config methods while introducing configuration properties +# (props). The error we are after here is "arg-type". +[tool.mypy] +python_version = 3.8 +check_untyped_defs = true +disable_error_code = [ + "assignment", + "attr-defined", + "call-arg", + "call-overload", + "dict-item", + "import-not-found", + "import-untyped", + "index", + "list-item", + "misc", + "name-defined", + "operator", + "str-bytes-safe", + "syntax", + "union-attr", + "var-annotated", +] diff --git a/tests/config_test.py b/tests/config_test.py index 4002530df..f4afcbe7d 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -151,8 +151,8 @@ def test_add_all_in_section(self): assert target_cfg.opts("tests", "sample.key") == "value" assert target_cfg.opts("no_copy", "other.key", mandatory=False) is None - # nonexisting key will not throw an error - target_cfg.add_all(source=source_cfg, section="this section does not exist") + # nonexisting key will not throw an error; intentional use of nonexistent section + target_cfg.add_all(source=source_cfg, section="this section does not exist") # type: ignore[arg-type] class TestAutoLoadConfig: diff --git a/tests/telemetry_test.py b/tests/telemetry_test.py index 101877ffe..5dfe96a5d 100644 --- a/tests/telemetry_test.py +++ b/tests/telemetry_test.py @@ -2279,7 +2279,8 @@ def test_stores_default_nodes_stats(self, metrics_store_put_doc): expected_doc = collections.OrderedDict() expected_doc["name"] = "node-stats" - expected_doc.update(self.default_stats_response_flattened) + # TODO remove the below ignore when introducing type hints + expected_doc.update(self.default_stats_response_flattened) # type: ignore[arg-type] metrics_store_put_doc.assert_called_once_with( expected_doc, level=MetaInfoScope.node, node_name="rally0", meta_data=metrics_store_meta_data diff --git a/tests/track/loader_test.py b/tests/track/loader_test.py index 2f1414a54..fe01de437 100644 --- a/tests/track/loader_test.py +++ b/tests/track/loader_test.py @@ -511,7 +511,11 @@ def test_raise_download_error_no_test_mode_file(self, is_file, ensure_dir, downl is_file.return_value = False download.side_effect = urllib.error.HTTPError( - "http://benchmarks.elasticsearch.org.s3.amazonaws.com/corpora/unit-test/docs-1k.json", 404, "", None, None + "http://benchmarks.elasticsearch.org.s3.amazonaws.com/corpora/unit-test/docs-1k.json", + 404, + "", + None, # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + None, ) p = loader.DocumentSetPreparator( @@ -545,7 +549,11 @@ def test_raise_download_error_on_connection_problems(self, is_file, ensure_dir, is_file.return_value = False download.side_effect = urllib.error.HTTPError( - "http://benchmarks.elasticsearch.org/corpora/unit-test/docs.json", 500, "Internal Server Error", None, None + "http://benchmarks.elasticsearch.org/corpora/unit-test/docs.json", + 500, + "Internal Server Error", + None, # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + None, ) p = loader.DocumentSetPreparator( diff --git a/tests/types_test.py b/tests/types_test.py new file mode 100644 index 000000000..9bae404e7 --- /dev/null +++ b/tests/types_test.py @@ -0,0 +1,130 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import builtins +from configparser import ConfigParser +from importlib import import_module +from inspect import getsourcelines, isclass, signature +from os.path import sep +from pathlib import Path as _Path +from types import FunctionType +from typing import Optional, get_args + +from esrally import types + + +class Path(_Path): + # needs to populate _flavour manually because Path.__new__() doesn't for subclasses + _flavour = _Path()._flavour # pylint: disable=W0212 + + def glob_modules(self, pattern, *args, **kwargs): + for file in self.glob(pattern, *args, **kwargs): + if not file.match("*.py"): + continue + pyfile = file.relative_to(self) + modpath = pyfile.parent if pyfile.name == "__init__.py" else pyfile.with_suffix("") + yield import_module(str(modpath).replace(sep, ".")) + + +project_root = Path(__file__).parent / ".." + + +class TestLiteralArgs: + def test_order_of_literal_args(self): + for literal in (types.Section, types.Key): + args = get_args(literal) + assert tuple(args) == tuple(sorted(args)), "Literal args are not sorted" + + def test_uniqueness_of_literal_args(self): + def _excerpt(lines, start, stop): + """Yields lines between start and stop markers not including both ends""" + started = False + for line in lines: + if not started and start in line: + started = True + elif started and stop in line: + break + elif started: + yield line + + sourcelines, _ = getsourcelines(types) + for name in ("Section", "Key"): + args = tuple(sorted(_excerpt(sourcelines, f"{name} = Literal[", "]"))) + assert args == tuple(sorted(set(args))), "Literal args are duplicate" + + def test_appearance_of_literal_args(self): + args = {f'"{arg}"' for arg in get_args(types.Section) + get_args(types.Key)} + + for pyfile in project_root.glob("[!.]*/**/*.py"): + if pyfile == project_root / "esrally/types.py": + continue # Should skip esrally.types module + + source = pyfile.read_text(encoding="utf-8", errors="replace") # No need to be so strict + for arg in args.copy(): + if arg in source: + args.remove(arg) # Keep only args that have not been found in any .py files + + if not args: + break # No need to look at more .py files because all args are already found + + assert not args, "literal args are not found in any .py files" + + +def assert_fn_param_annotations(fn, ident, *expects): + for param in signature(fn).parameters.values(): + if param.name == ident: + assert param.annotation in expects, f"'{ident}' of {fn.__name__}() is not annotated expectedly" + + +def assert_fn_return_annotation(fn, ident, *expects): + sourcelines, _ = getsourcelines(fn) + for line in sourcelines: + if line.endswith(f" return {ident}"): + assert signature(fn).return_annotation in expects, f"return of {fn.__name__}() is not annotated expectedly" + + +def assert_annotations(obj, ident, *expects): + """Asserts annotations recursively in the object""" + for name in dir(obj): + if name.startswith("_"): + continue + + attr = getattr(obj, name) + if attr in vars(builtins).values() or type(attr) in vars(builtins).values(): + continue # skip builtins + + obj_path = getattr(obj, "__module__", getattr(obj, "__qualname__", obj.__name__)) + try: + attr_path = getattr(attr, "__module__", getattr(attr, "__qualname__", attr.__name__)) + except AttributeError: + pass + else: + if attr_path and not attr_path.startswith(obj_path): + continue # the attribute is brought from outside of the object + + if isclass(attr): + assert_annotations(attr, ident, *expects) + elif isinstance(attr, FunctionType): + assert_fn_param_annotations(attr, ident, *expects) + assert_fn_return_annotation(attr, ident, *expects) + + +class TestConfigTypeHint: + def test_esrally_module_annotations(self): + for module in project_root.glob_modules("esrally/**/*.py"): + assert_annotations(module, "cfg", types.Config) + assert_annotations(module, "config", types.Config, Optional[types.Config], ConfigParser)